diff --git a/.gitignore b/.gitignore index 842a688..c45a684 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,5 @@ nosetests.xml octodns.egg-info/ output/ tmp/ +build/ +config/ diff --git a/README.md b/README.md index 0e63e51..05cf979 100644 --- a/README.md +++ b/README.md @@ -149,6 +149,7 @@ The above command pulled the existing data out of Route53 and placed the results | Provider | Record Support | GeoDNS Support | Notes | |--|--|--|--| +| [AzureProvider](/octodns/provider/azuredns.py) | A, AAAA, CNAME, MX, NS, PTR, SRV, TXT | No | | | [CloudflareProvider](/octodns/provider/cloudflare.py) | A, AAAA, CNAME, MX, NS, SPF, TXT | No | | | [DnsimpleProvider](/octodns/provider/dnsimple.py) | All | No | | | [DynProvider](/octodns/provider/dyn.py) | All | Yes | | diff --git a/octodns/manager.py b/octodns/manager.py index e6fe253..8439eb6 100644 --- a/octodns/manager.py +++ b/octodns/manager.py @@ -6,7 +6,7 @@ from __future__ import absolute_import, division, print_function, \ unicode_literals from StringIO import StringIO -from concurrent.futures import Future, ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor from importlib import import_module from os import environ import logging @@ -38,6 +38,17 @@ class _AggregateTarget(object): return True +class MakeThreadFuture(object): + + def __init__(self, func, args, kwargs): + self.func = func + self.args = args + self.kwargs = kwargs + + def result(self): + return self.func(*self.args, **self.kwargs) + + class MainThreadExecutor(object): ''' Dummy executor that runs things on the main thread during the involcation @@ -48,13 +59,7 @@ class MainThreadExecutor(object): ''' def submit(self, func, *args, **kwargs): - future = Future() - try: - future.set_result(func(*args, **kwargs)) - except Exception as e: - # TODO: get right stacktrace here - future.set_exception(e) - return future + return MakeThreadFuture(func, args, kwargs) class Manager(object): diff --git a/octodns/provider/azuredns.py b/octodns/provider/azuredns.py new file mode 100644 index 0000000..4433b1e --- /dev/null +++ b/octodns/provider/azuredns.py @@ -0,0 +1,436 @@ +# +# +# + +from __future__ import absolute_import, division, print_function, \ + unicode_literals + +from azure.common.credentials import ServicePrincipalCredentials +from azure.mgmt.dns import DnsManagementClient +from msrestazure.azure_exceptions import CloudError + +from azure.mgmt.dns.models import ARecord, AaaaRecord, CnameRecord, MxRecord, \ + SrvRecord, NsRecord, PtrRecord, TxtRecord, Zone + +import logging +from functools import reduce +from ..record import Record +from .base import BaseProvider + + +class _AzureRecord(object): + '''Wrapper for OctoDNS record for AzureProvider to make dns_client calls. + + azuredns.py: + class: octodns.provider.azuredns._AzureRecord + An _AzureRecord is easily accessible to Azure DNS Management library + functions and is used to wrap all relevant data to create a record in + Azure. + ''' + TYPE_MAP = { + 'A': ARecord, + 'AAAA': AaaaRecord, + 'CNAME': CnameRecord, + 'MX': MxRecord, + 'SRV': SrvRecord, + 'NS': NsRecord, + 'PTR': PtrRecord, + 'TXT': TxtRecord + } + + def __init__(self, resource_group, record, delete=False): + '''Contructor for _AzureRecord. + + Notes on Azure records: An Azure record set has the form + RecordSet(name=<...>, type=<...>, arecords=[...], aaaa_records, ..) + When constructing an azure record as done in self._apply_Create, + the argument parameters for an A record would be + parameters={'ttl': , 'arecords': [ARecord(),]}. + As another example for CNAME record: + parameters={'ttl': , 'cname_record': CnameRecord()}. + + Below, key_name and class_name are the dictionary key and Azure + Record class respectively. + + :param resource_group: The name of resource group in Azure + :type resource_group: str + :param record: An OctoDNS record + :type record: ..record.Record + :param delete: If true, omit data parsing; not needed to delete + :type delete: bool + + :type return: _AzureRecord + ''' + self.resource_group = resource_group + self.zone_name = record.zone.name[:len(record.zone.name) - 1] + self.relative_record_set_name = record.name or '@' + self.record_type = record._type + + if delete: + return + + # Refer to function docstring for key_name and class_name. + format_u_s = '' if record._type == 'A' else '_' + key_name = '{}{}records'.format(self.record_type, format_u_s).lower() + if record._type == 'CNAME': + key_name = key_name[:len(key_name) - 1] + azure_class = self.TYPE_MAP[self.record_type] + + self.params = getattr(self, '_params_for_{}'.format(record._type)) + self.params = self.params(record.data, key_name, azure_class) + self.params['ttl'] = record.ttl + + def _params(self, data, key_name, azure_class): + if 'values' in data: + return {key_name: [azure_class(v) for v in data['values']]} + else: # Else there is a singular data point keyed by 'value'. + return {key_name: [azure_class(data['value'])]} + + _params_for_A = _params + _params_for_AAAA = _params + _params_for_NS = _params + _params_for_PTR = _params + _params_for_TXT = _params + + def _params_for_CNAME(self, data, key_name, azure_class): + return {key_name: azure_class(data['value'])} + + def _params_for_MX(self, data, key_name, azure_class): + params = [] + if 'values' in data: + for vals in data['values']: + params.append(azure_class(vals['preference'], + vals['exchange'])) + else: # Else there is a singular data point keyed by 'value'. + params.append(azure_class(data['value']['preference'], + data['value']['exchange'])) + return {key_name: params} + + def _params_for_SRV(self, data, key_name, azure_class): + params = [] + if 'values' in data: + for vals in data['values']: + params.append(azure_class(vals['priority'], + vals['weight'], + vals['port'], + vals['target'])) + else: # Else there is a singular data point keyed by 'value'. + params.append(azure_class(data['value']['priority'], + data['value']['weight'], + data['value']['port'], + data['value']['target'])) + return {key_name: params} + + def _equals(self, b): + '''Checks whether two records are equal by comparing all fields. + :param b: Another _AzureRecord object + :type b: _AzureRecord + + :type return: bool + ''' + def parse_dict(params): + vals = [] + for char in params: + if char != 'ttl': + list_records = params[char] + try: + for record in list_records: + vals.append(record.__dict__) + except: + vals.append(list_records.__dict__) + vals.sort() + return vals + + return (self.resource_group == b.resource_group) & \ + (self.zone_name == b.zone_name) & \ + (self.record_type == b.record_type) & \ + (self.params['ttl'] == b.params['ttl']) & \ + (parse_dict(self.params) == parse_dict(b.params)) & \ + (self.relative_record_set_name == b.relative_record_set_name) + + def __str__(self): + '''String representation of an _AzureRecord. + :type return: str + ''' + string = 'Zone: {}; '.format(self.zone_name) + string += 'Name: {}; '.format(self.relative_record_set_name) + string += 'Type: {}; '.format(self.record_type) + if not hasattr(self, 'params'): + return string + string += 'Ttl: {}; '.format(self.params['ttl']) + for char in self.params: + if char != 'ttl': + try: + for rec in self.params[char]: + string += 'Record: {}; '.format(rec.__dict__) + except: + string += 'Record: {}; '.format(self.params[char].__dict__) + return string + + +def _check_endswith_dot(string): + return string if string.endswith('.') else string + '.' + + +def _parse_azure_type(string): + '''Converts string representing an Azure RecordSet type to usual type. + + :param string: the Azure type. eg: + :type string: str + + :type return: str + ''' + return string.split('/')[len(string.split('/')) - 1] + + +class AzureProvider(BaseProvider): + ''' + Azure DNS Provider + + azuredns.py: + class: octodns.provider.azuredns.AzureProvider + # Current support of authentication of access to Azure services only + # includes using a Service Principal: + # https://docs.microsoft.com/en-us/azure/azure-resource-manager/ + # resource-group-create-service-principal-portal + # The Azure Active Directory Application ID (aka client ID): + client_id: + # Authentication Key Value: (note this should be secret) + key: + # Directory ID (aka tenant ID): + directory_id: + # Subscription ID: + sub_id: + # Resource Group name: + resource_group: + # All are required to authenticate. + + Example config file with variables: + " + --- + providers: + config: + class: octodns.provider.yaml.YamlProvider + directory: ./config (example path to directory of zone files) + azuredns: + class: octodns.provider.azuredns.AzureProvider + client_id: env/AZURE_APPLICATION_ID + key: env/AZURE_AUTHENICATION_KEY + directory_id: env/AZURE_DIRECTORY_ID + sub_id: env/AZURE_SUBSCRIPTION_ID + resource_group: 'TestResource1' + + zones: + example.com.: + sources: + - config + targets: + - azuredns + " + The first four variables above can be hidden in environment variables + and octoDNS will automatically search for them in the shell. It is + possible to also hard-code into the config file: eg, resource_group. + ''' + SUPPORTS_GEO = False + SUPPORTS = set(('A', 'AAAA', 'CNAME', 'MX', 'NS', 'PTR', 'SRV', 'TXT')) + + def __init__(self, id, client_id, key, directory_id, sub_id, + resource_group, *args, **kwargs): + self.log = logging.getLogger('AzureProvider[{}]'.format(id)) + self.log.debug('__init__: id=%s, client_id=%s, ' + 'key=***, directory_id:%s', id, client_id, directory_id) + super(AzureProvider, self).__init__(id, *args, **kwargs) + + credentials = ServicePrincipalCredentials( + client_id, secret=key, tenant=directory_id + ) + self._dns_client = DnsManagementClient(credentials, sub_id) + self._resource_group = resource_group + self._azure_zones = set() + + def _populate_zones(self): + self.log.debug('azure_zones: loading') + list_zones = self._dns_client.zones.list_by_resource_group + for zone in list_zones(self._resource_group): + self._azure_zones.add(zone.name) + + def _check_zone(self, name, create=False): + '''Checks whether a zone specified in a source exist in Azure server. + + Note that Azure zones omit end '.' eg: contoso.com vs contoso.com. + Returns the name if it exists. + + :param name: Name of a zone to checks + :type name: str + :param create: If True, creates the zone of that name. + :type create: bool + + :type return: str or None + ''' + self.log.debug('_check_zone: name=%s', name) + try: + if name in self._azure_zones: + return name + self._dns_client.zones.get(self._resource_group, name) + self._azure_zones.add(name) + return name + except CloudError as err: + msg = 'The Resource \'Microsoft.Network/dnszones/{}\''.format(name) + msg += ' under resource group \'{}\''.format(self._resource_group) + msg += ' was not found.' + if msg == err.message: + # Then the only error is that the zone doesn't currently exist + if create: + self.log.debug('_check_zone:no matching zone; creating %s', + name) + create_zone = self._dns_client.zones.create_or_update + create_zone(self._resource_group, name, Zone('global')) + return name + else: + return + raise + + def populate(self, zone, target=False, lenient=False): + '''Required function of manager.py to collect records from zone. + + Special notes for Azure. + Azure zone names omit final '.' + Azure root records names are represented by '@'. OctoDNS uses '' + Azure records created through online interface may have null values + (eg, no IP address for A record). + Azure online interface allows constructing records with null values + which are destroyed by _apply. + + Specific quirks such as these are responsible for any non-obvious + parsing in this function and the functions '_params_for_*'. + + :param zone: A dns zone + :type zone: octodns.zone.Zone + :param target: Checks if Azure is source or target of config. + Currently only supports as a target. Unused. + :type target: bool + :param lenient: Unused. Check octodns.manager for usage. + :type lenient: bool + + :type return: void + ''' + self.log.debug('populate: name=%s', zone.name) + before = len(zone.records) + + zone_name = zone.name[:len(zone.name) - 1] + self._populate_zones() + self._check_zone(zone_name) + + _records = set() + records = self._dns_client.record_sets.list_by_dns_zone + if self._check_zone(zone_name): + for azrecord in records(self._resource_group, zone_name): + if _parse_azure_type(azrecord.type) in self.SUPPORTS: + _records.add(azrecord) + for azrecord in _records: + record_name = azrecord.name if azrecord.name != '@' else '' + typ = _parse_azure_type(azrecord.type) + data = getattr(self, '_data_for_{}'.format(typ)) + data = data(azrecord) + data['type'] = typ + data['ttl'] = azrecord.ttl + record = Record.new(zone, record_name, data, source=self) + zone.add_record(record) + + self.log.info('populate: found %s records', len(zone.records) - before) + + def _data_for_A(self, azrecord): + return {'values': [ar.ipv4_address for ar in azrecord.arecords]} + + def _data_for_AAAA(self, azrecord): + return {'values': [ar.ipv6_address for ar in azrecord.aaaa_records]} + + def _data_for_CNAME(self, azrecord): + '''Parsing data from Azure DNS Client record call + :param azrecord: a return of a call to list azure records + :type azrecord: azure.mgmt.dns.models.RecordSet + + :type return: dict + + CNAME and PTR both use the catch block to catch possible empty + records. Refer to population comment. + ''' + try: + return {'value': _check_endswith_dot(azrecord.cname_record.cname)} + except: + return {'value': '.'} + + def _data_for_MX(self, azrecord): + return {'values': [{'preference': ar.preference, + 'exchange': ar.exchange} + for ar in azrecord.mx_records]} + + def _data_for_NS(self, azrecord): + vals = [ar.nsdname for ar in azrecord.ns_records] + return {'values': [_check_endswith_dot(val) for val in vals]} + + def _data_for_PTR(self, azrecord): + try: + ptrdname = azrecord.ptr_records[0].ptrdname + return {'value': _check_endswith_dot(ptrdname)} + except: + return {'value': '.'} + + def _data_for_SRV(self, azrecord): + return {'values': [{'priority': ar.priority, 'weight': ar.weight, + 'port': ar.port, 'target': ar.target} + for ar in azrecord.srv_records]} + + def _data_for_TXT(self, azrecord): + return {'values': [reduce((lambda a, b: a + b), ar.value) + for ar in azrecord.txt_records]} + + def _apply_Create(self, change): + '''A record from change must be created. + + :param change: a change object + :type change: octodns.record.Change + + :type return: void + ''' + ar = _AzureRecord(self._resource_group, change.new) + create = self._dns_client.record_sets.create_or_update + + create(resource_group_name=ar.resource_group, + zone_name=ar.zone_name, + relative_record_set_name=ar.relative_record_set_name, + record_type=ar.record_type, + parameters=ar.params) + + self.log.debug('* Success Create/Update: {}'.format(ar)) + + _apply_Update = _apply_Create + + def _apply_Delete(self, change): + ar = _AzureRecord(self._resource_group, change.existing, delete=True) + delete = self._dns_client.record_sets.delete + + delete(self._resource_group, ar.zone_name, ar.relative_record_set_name, + ar.record_type) + + self.log.debug('* Success Delete: {}'.format(ar)) + + def _apply(self, plan): + '''Required function of manager.py to actually apply a record change. + + :param plan: Contains the zones and changes to be made + :type plan: octodns.provider.base.Plan + + :type return: void + ''' + desired = plan.desired + changes = plan.changes + self.log.debug('_apply: zone=%s, len(changes)=%d', desired.name, + len(changes)) + + azure_zone_name = desired.name[:len(desired.name) - 1] + self._check_zone(azure_zone_name, create=True) + + for change in changes: + class_name = change.__class__.__name__ + getattr(self, '_apply_{}'.format(class_name))(change) diff --git a/octodns/provider/base.py b/octodns/provider/base.py index 2fd4349..d2561ba 100644 --- a/octodns/provider/base.py +++ b/octodns/provider/base.py @@ -56,19 +56,19 @@ class Plan(object): delete_pcent = self.change_counts['Delete'] / existing_record_count if update_pcent > self.MAX_SAFE_UPDATE_PCENT: - raise UnsafePlan('Too many updates, %s is over %s percent' - '(%s/%s)', - update_pcent, - self.MAX_SAFE_UPDATE_PCENT * 100, - self.change_counts['Update'], - existing_record_count) + raise UnsafePlan('Too many updates, {} is over {} percent' + '({}/{})'.format( + update_pcent, + self.MAX_SAFE_UPDATE_PCENT * 100, + self.change_counts['Update'], + existing_record_count)) if delete_pcent > self.MAX_SAFE_DELETE_PCENT: - raise UnsafePlan('Too many deletes, %s is over %s percent' - '(%s/%s)', - delete_pcent, - self.MAX_SAFE_DELETE_PCENT * 100, - self.change_counts['Delete'], - existing_record_count) + raise UnsafePlan('Too many deletes, {} is over {} percent' + '({}/{})'.format( + delete_pcent, + self.MAX_SAFE_DELETE_PCENT * 100, + self.change_counts['Delete'], + existing_record_count)) def __repr__(self): return 'Creates={}, Updates={}, Deletes={}, Existing Records={}' \ diff --git a/octodns/provider/ns1.py b/octodns/provider/ns1.py index 65db64c..7757812 100644 --- a/octodns/provider/ns1.py +++ b/octodns/provider/ns1.py @@ -42,8 +42,16 @@ class Ns1Provider(BaseProvider): } _data_for_AAAA = _data_for_A - _data_for_SPF = _data_for_A - _data_for_TXT = _data_for_A + + def _data_for_SPF(self, _type, record): + values = [v.replace(';', '\;') for v in record['short_answers']] + return { + 'ttl': record['ttl'], + 'type': _type, + 'values': values + } + + _data_for_TXT = _data_for_SPF def _data_for_CNAME(self, _type, record): return { @@ -141,8 +149,15 @@ class Ns1Provider(BaseProvider): _params_for_AAAA = _params_for_A _params_for_NS = _params_for_A - _params_for_SPF = _params_for_A - _params_for_TXT = _params_for_A + + def _params_for_SPF(self, record): + # NS1 seems to be the only provider that doesn't want things escaped in + # values so we have to strip them here and add them when going the + # other way + values = [v.replace('\;', ';') for v in record.values] + return {'answers': values, 'ttl': record.ttl} + + _params_for_TXT = _params_for_SPF def _params_for_CNAME(self, record): return {'answers': [record.value], 'ttl': record.ttl} diff --git a/octodns/source/base.py b/octodns/source/base.py index 2e2c5c2..4ace09f 100644 --- a/octodns/source/base.py +++ b/octodns/source/base.py @@ -20,7 +20,7 @@ class BaseSource(object): raise NotImplementedError('Abstract base class, SUPPORTS ' 'property missing') - def populate(self, zone, target=False): + def populate(self, zone, target=False, lenient=False): ''' Loads all zones the provider knows about diff --git a/requirements.txt b/requirements.txt index 93a8521..5d8089a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,8 @@ # These are known good versions. You're free to use others and things will # likely work, but no promises are made, especilly if you go older. PyYaml==3.12 +azure-mgmt-dns==1.0.1 +azure-common==1.1.6 boto3==1.4.4 botocore==1.5.4 dnspython==1.15.0 @@ -10,9 +12,10 @@ futures==3.0.5 incf.countryutils==1.0 ipaddress==1.0.18 jmespath==0.9.0 +msrestazure==0.4.10 natsort==5.0.3 nsone==0.9.14 python-dateutil==2.6.0 requests==2.13.0 s3transfer==0.1.10 -six==1.10.0 +six==1.10.0 \ No newline at end of file diff --git a/tests/test_octodns_provider_azuredns.py b/tests/test_octodns_provider_azuredns.py new file mode 100644 index 0000000..59cf551 --- /dev/null +++ b/tests/test_octodns_provider_azuredns.py @@ -0,0 +1,356 @@ +# +# +# + +from __future__ import absolute_import, division, print_function, \ + unicode_literals + +from octodns.record import Create, Delete, Record +from octodns.provider.azuredns import _AzureRecord, AzureProvider, \ + _check_endswith_dot, _parse_azure_type +from octodns.zone import Zone +from octodns.provider.base import Plan + +from azure.mgmt.dns.models import ARecord, AaaaRecord, CnameRecord, MxRecord, \ + SrvRecord, NsRecord, PtrRecord, TxtRecord, RecordSet, SoaRecord, \ + Zone as AzureZone +from msrestazure.azure_exceptions import CloudError + +from unittest import TestCase +from mock import Mock, patch + + +zone = Zone(name='unit.tests.', sub_zones=[]) +octo_records = [] +octo_records.append(Record.new(zone, '', { + 'ttl': 0, + 'type': 'A', + 'values': ['1.2.3.4', '10.10.10.10']})) +octo_records.append(Record.new(zone, 'a', { + 'ttl': 1, + 'type': 'A', + 'values': ['1.2.3.4', '1.1.1.1']})) +octo_records.append(Record.new(zone, 'aa', { + 'ttl': 9001, + 'type': 'A', + 'values': ['1.2.4.3']})) +octo_records.append(Record.new(zone, 'aaa', { + 'ttl': 2, + 'type': 'A', + 'values': ['1.1.1.3']})) +octo_records.append(Record.new(zone, 'cname', { + 'ttl': 3, + 'type': 'CNAME', + 'value': 'a.unit.tests.'})) +octo_records.append(Record.new(zone, 'mx1', { + 'ttl': 3, + 'type': 'MX', + 'values': [{ + 'priority': 10, + 'value': 'mx1.unit.tests.', + }, { + 'priority': 20, + 'value': 'mx2.unit.tests.', + }]})) +octo_records.append(Record.new(zone, 'mx2', { + 'ttl': 3, + 'type': 'MX', + 'values': [{ + 'priority': 10, + 'value': 'mx1.unit.tests.', + }]})) +octo_records.append(Record.new(zone, '', { + 'ttl': 4, + 'type': 'NS', + 'values': ['ns1.unit.tests.', 'ns2.unit.tests.']})) +octo_records.append(Record.new(zone, 'foo', { + 'ttl': 5, + 'type': 'NS', + 'value': 'ns1.unit.tests.'})) +octo_records.append(Record.new(zone, '_srv._tcp', { + 'ttl': 6, + 'type': 'SRV', + 'values': [{ + 'priority': 10, + 'weight': 20, + 'port': 30, + 'target': 'foo-1.unit.tests.', + }, { + 'priority': 12, + 'weight': 30, + 'port': 30, + 'target': 'foo-2.unit.tests.', + }]})) +octo_records.append(Record.new(zone, '_srv2._tcp', { + 'ttl': 7, + 'type': 'SRV', + 'values': [{ + 'priority': 12, + 'weight': 17, + 'port': 1, + 'target': 'srvfoo.unit.tests.', + }]})) + +azure_records = [] +_base0 = _AzureRecord('TestAzure', octo_records[0]) +_base0.zone_name = 'unit.tests' +_base0.relative_record_set_name = '@' +_base0.record_type = 'A' +_base0.params['ttl'] = 0 +_base0.params['arecords'] = [ARecord('1.2.3.4'), ARecord('10.10.10.10')] +azure_records.append(_base0) + +_base1 = _AzureRecord('TestAzure', octo_records[1]) +_base1.zone_name = 'unit.tests' +_base1.relative_record_set_name = 'a' +_base1.record_type = 'A' +_base1.params['ttl'] = 1 +_base1.params['arecords'] = [ARecord('1.2.3.4'), ARecord('1.1.1.1')] +azure_records.append(_base1) + +_base2 = _AzureRecord('TestAzure', octo_records[2]) +_base2.zone_name = 'unit.tests' +_base2.relative_record_set_name = 'aa' +_base2.record_type = 'A' +_base2.params['ttl'] = 9001 +_base2.params['arecords'] = ARecord('1.2.4.3') +azure_records.append(_base2) + +_base3 = _AzureRecord('TestAzure', octo_records[3]) +_base3.zone_name = 'unit.tests' +_base3.relative_record_set_name = 'aaa' +_base3.record_type = 'A' +_base3.params['ttl'] = 2 +_base3.params['arecords'] = ARecord('1.1.1.3') +azure_records.append(_base3) + +_base4 = _AzureRecord('TestAzure', octo_records[4]) +_base4.zone_name = 'unit.tests' +_base4.relative_record_set_name = 'cname' +_base4.record_type = 'CNAME' +_base4.params['ttl'] = 3 +_base4.params['cname_record'] = CnameRecord('a.unit.tests.') +azure_records.append(_base4) + +_base5 = _AzureRecord('TestAzure', octo_records[5]) +_base5.zone_name = 'unit.tests' +_base5.relative_record_set_name = 'mx1' +_base5.record_type = 'MX' +_base5.params['ttl'] = 3 +_base5.params['mx_records'] = [MxRecord(10, 'mx1.unit.tests.'), + MxRecord(20, 'mx2.unit.tests.')] +azure_records.append(_base5) + +_base6 = _AzureRecord('TestAzure', octo_records[6]) +_base6.zone_name = 'unit.tests' +_base6.relative_record_set_name = 'mx2' +_base6.record_type = 'MX' +_base6.params['ttl'] = 3 +_base6.params['mx_records'] = [MxRecord(10, 'mx1.unit.tests.')] +azure_records.append(_base6) + +_base7 = _AzureRecord('TestAzure', octo_records[7]) +_base7.zone_name = 'unit.tests' +_base7.relative_record_set_name = '@' +_base7.record_type = 'NS' +_base7.params['ttl'] = 4 +_base7.params['ns_records'] = [NsRecord('ns1.unit.tests.'), + NsRecord('ns2.unit.tests.')] +azure_records.append(_base7) + +_base8 = _AzureRecord('TestAzure', octo_records[8]) +_base8.zone_name = 'unit.tests' +_base8.relative_record_set_name = 'foo' +_base8.record_type = 'NS' +_base8.params['ttl'] = 5 +_base8.params['ns_records'] = [NsRecord('ns1.unit.tests.')] +azure_records.append(_base8) + +_base9 = _AzureRecord('TestAzure', octo_records[9]) +_base9.zone_name = 'unit.tests' +_base9.relative_record_set_name = '_srv._tcp' +_base9.record_type = 'SRV' +_base9.params['ttl'] = 6 +_base9.params['srv_records'] = [SrvRecord(10, 20, 30, 'foo-1.unit.tests.'), + SrvRecord(12, 30, 30, 'foo-2.unit.tests.')] +azure_records.append(_base9) + +_base10 = _AzureRecord('TestAzure', octo_records[10]) +_base10.zone_name = 'unit.tests' +_base10.relative_record_set_name = '_srv2._tcp' +_base10.record_type = 'SRV' +_base10.params['ttl'] = 7 +_base10.params['srv_records'] = [SrvRecord(12, 17, 1, 'srvfoo.unit.tests.')] +azure_records.append(_base10) + + +class Test_AzureRecord(TestCase): + def test_azure_record(self): + assert(len(azure_records) == len(octo_records)) + for i in range(len(azure_records)): + octo = _AzureRecord('TestAzure', octo_records[i]) + assert(azure_records[i]._equals(octo)) + string = str(azure_records[i]) + assert(('Ttl: ' in string)) + + +class Test_ParseAzureType(TestCase): + def test_parse_azure_type(self): + for expected, test in [['A', 'Microsoft.Network/dnszones/A'], + ['AAAA', 'Microsoft.Network/dnszones/AAAA'], + ['NS', 'Microsoft.Network/dnszones/NS'], + ['MX', 'Microsoft.Network/dnszones/MX']]: + self.assertEquals(expected, _parse_azure_type(test)) + + +class Test_CheckEndswithDot(TestCase): + def test_check_endswith_dot(self): + for expected, test in [['a.', 'a'], + ['a.', 'a.'], + ['foo.bar.', 'foo.bar.'], + ['foo.bar.', 'foo.bar']]: + self.assertEquals(expected, _check_endswith_dot(test)) + + +class TestAzureDnsProvider(TestCase): + def _provider(self): + return self._get_provider('mock_spc', 'mock_dns_client') + + @patch('octodns.provider.azuredns.DnsManagementClient') + @patch('octodns.provider.azuredns.ServicePrincipalCredentials') + def _get_provider(self, mock_spc, mock_dns_client): + '''Returns a mock AzureProvider object to use in testing. + + :param mock_spc: placeholder + :type mock_spc: str + :param mock_dns_client: placeholder + :type mock_dns_client: str + + :type return: AzureProvider + ''' + return AzureProvider('mock_id', 'mock_client', 'mock_key', + 'mock_directory', 'mock_sub', 'mock_rg') + + def test_populate_records(self): + provider = self._get_provider() + + rs = [] + rs.append(RecordSet(name='a1', ttl=0, type='A', + arecords=[ARecord('1.1.1.1')])) + rs.append(RecordSet(name='a2', ttl=1, type='A', + arecords=[ARecord('1.1.1.1'), + ARecord('2.2.2.2')])) + rs.append(RecordSet(name='aaaa1', ttl=2, type='AAAA', + aaaa_records=[AaaaRecord('1:1ec:1::1')])) + rs.append(RecordSet(name='aaaa2', ttl=3, type='AAAA', + aaaa_records=[AaaaRecord('1:1ec:1::1'), + AaaaRecord('1:1ec:1::2')])) + rs.append(RecordSet(name='cname1', ttl=4, type='CNAME', + cname_record=CnameRecord('cname.unit.test.'))) + rs.append(RecordSet(name='cname2', ttl=5, type='CNAME', + cname_record=None)) + rs.append(RecordSet(name='mx1', ttl=6, type='MX', + mx_records=[MxRecord(10, 'mx1.unit.test.')])) + rs.append(RecordSet(name='mx2', ttl=7, type='MX', + mx_records=[MxRecord(10, 'mx1.unit.test.'), + MxRecord(11, 'mx2.unit.test.')])) + rs.append(RecordSet(name='ns1', ttl=8, type='NS', + ns_records=[NsRecord('ns1.unit.test.')])) + rs.append(RecordSet(name='ns2', ttl=9, type='NS', + ns_records=[NsRecord('ns1.unit.test.'), + NsRecord('ns2.unit.test.')])) + rs.append(RecordSet(name='ptr1', ttl=10, type='PTR', + ptr_records=[PtrRecord('ptr1.unit.test.')])) + rs.append(RecordSet(name='ptr2', ttl=11, type='PTR', + ptr_records=[PtrRecord(None)])) + rs.append(RecordSet(name='_srv1._tcp', ttl=12, type='SRV', + srv_records=[SrvRecord(1, 2, 3, '1unit.tests.')])) + rs.append(RecordSet(name='_srv2._tcp', ttl=13, type='SRV', + srv_records=[SrvRecord(1, 2, 3, '1unit.tests.'), + SrvRecord(4, 5, 6, '2unit.tests.')])) + rs.append(RecordSet(name='txt1', ttl=14, type='TXT', + txt_records=[TxtRecord('sample text1')])) + rs.append(RecordSet(name='txt2', ttl=15, type='TXT', + txt_records=[TxtRecord('sample text1'), + TxtRecord('sample text2')])) + rs.append(RecordSet(name='', ttl=16, type='SOA', + soa_record=[SoaRecord()])) + + record_list = provider._dns_client.record_sets.list_by_dns_zone + record_list.return_value = rs + + provider.populate(zone) + + self.assertEquals(len(zone.records), 16) + + def test_populate_zone(self): + provider = self._get_provider() + + zone_list = provider._dns_client.zones.list_by_resource_group + zone_list.return_value = [AzureZone(location='global'), + AzureZone(location='global')] + + provider._populate_zones() + + self.assertEquals(len(provider._azure_zones), 1) + + def test_bad_zone_response(self): + provider = self._get_provider() + + _get = provider._dns_client.zones.get + _get.side_effect = CloudError(Mock(status=404), 'Azure Error') + trip = False + try: + provider._check_zone('unit.test', create=False) + except CloudError: + trip = True + self.assertEquals(trip, True) + + def test_apply(self): + provider = self._get_provider() + + changes = [] + deletes = [] + for i in octo_records: + changes.append(Create(i)) + deletes.append(Delete(i)) + + self.assertEquals(11, provider.apply(Plan(None, zone, changes))) + self.assertEquals(11, provider.apply(Plan(zone, zone, deletes))) + + def test_create_zone(self): + provider = self._get_provider() + + changes = [] + for i in octo_records: + changes.append(Create(i)) + desired = Zone('unit2.test.', []) + + err_msg = 'The Resource \'Microsoft.Network/dnszones/unit2.test\' ' + err_msg += 'under resource group \'mock_rg\' was not found.' + _get = provider._dns_client.zones.get + _get.side_effect = CloudError(Mock(status=404), err_msg) + + self.assertEquals(11, provider.apply(Plan(None, desired, changes))) + + def test_check_zone_no_create(self): + provider = self._get_provider() + + rs = [] + rs.append(RecordSet(name='a1', ttl=0, type='A', + arecords=[ARecord('1.1.1.1')])) + rs.append(RecordSet(name='a2', ttl=1, type='A', + arecords=[ARecord('1.1.1.1'), + ARecord('2.2.2.2')])) + + record_list = provider._dns_client.record_sets.list_by_dns_zone + record_list.return_value = rs + + err_msg = 'The Resource \'Microsoft.Network/dnszones/unit3.test\' ' + err_msg += 'under resource group \'mock_rg\' was not found.' + _get = provider._dns_client.zones.get + _get.side_effect = CloudError(Mock(status=404), err_msg) + + provider.populate(Zone('unit3.test.', [])) + + self.assertEquals(len(zone.records), 0) diff --git a/tests/test_octodns_provider_base.py b/tests/test_octodns_provider_base.py index bd134bc..e44adc0 100644 --- a/tests/test_octodns_provider_base.py +++ b/tests/test_octodns_provider_base.py @@ -214,9 +214,11 @@ class TestBaseProvider(TestCase): for i in range(int(Plan.MIN_EXISTING_RECORDS * Plan.MAX_SAFE_UPDATE_PCENT) + 1)] - with self.assertRaises(UnsafePlan): + with self.assertRaises(UnsafePlan) as ctx: Plan(zone, zone, changes).raise_if_unsafe() + self.assertTrue('Too many updates' in ctx.exception.message) + def test_safe_updates_min_existing_pcent(self): # MAX_SAFE_UPDATE_PCENT is safe when more # than MIN_EXISTING_RECORDS exist @@ -260,9 +262,11 @@ class TestBaseProvider(TestCase): for i in range(int(Plan.MIN_EXISTING_RECORDS * Plan.MAX_SAFE_DELETE_PCENT) + 1)] - with self.assertRaises(UnsafePlan): + with self.assertRaises(UnsafePlan) as ctx: Plan(zone, zone, changes).raise_if_unsafe() + self.assertTrue('Too many deletes' in ctx.exception.message) + def test_safe_deletes_min_existing_pcent(self): # MAX_SAFE_DELETE_PCENT is safe when more # than MIN_EXISTING_RECORDS exist diff --git a/tests/test_octodns_provider_ns1.py b/tests/test_octodns_provider_ns1.py index ce1353b..4df56b3 100644 --- a/tests/test_octodns_provider_ns1.py +++ b/tests/test_octodns_provider_ns1.py @@ -278,3 +278,37 @@ class TestNs1Provider(TestCase): call.update(answers=[u'1.2.3.4'], ttl=32), call.delete() ]) + + def test_escaping(self): + provider = Ns1Provider('test', 'api-key') + + record = { + 'ttl': 31, + 'short_answers': ['foo; bar baz; blip'] + } + self.assertEquals(['foo\; bar baz\; blip'], + provider._data_for_SPF('SPF', record)['values']) + + record = { + 'ttl': 31, + 'short_answers': ['no', 'foo; bar baz; blip', 'yes'] + } + self.assertEquals(['no', 'foo\; bar baz\; blip', 'yes'], + provider._data_for_TXT('TXT', record)['values']) + + zone = Zone('unit.tests.', []) + record = Record.new(zone, 'spf', { + 'ttl': 34, + 'type': 'SPF', + 'value': 'foo\; bar baz\; blip' + }) + self.assertEquals(['foo; bar baz; blip'], + provider._params_for_SPF(record)['answers']) + + record = Record.new(zone, 'txt', { + 'ttl': 35, + 'type': 'TXT', + 'value': 'foo\; bar baz\; blip' + }) + self.assertEquals(['foo; bar baz; blip'], + provider._params_for_TXT(record)['answers'])