diff --git a/README.md b/README.md index 0a04e27..8d59cd0 100644 --- a/README.md +++ b/README.md @@ -192,7 +192,7 @@ The above command pulled the existing data out of Route53 and placed the results | Provider | Requirements | Record Support | Dynamic | Notes | |--|--|--|--|--| -| [AzureProvider](/octodns/provider/azuredns.py) | azure-identity, azure-mgmt-dns | A, AAAA, CAA, CNAME, MX, NS, PTR, SRV, TXT | No | | +| [AzureProvider](/octodns/provider/azuredns.py) | azure-identity, azure-mgmt-dns, azure-mgmt-trafficmanager | A, AAAA, CAA, CNAME, MX, NS, PTR, SRV, TXT | Yes (CNAMEs only) | | | [Akamai](/octodns/provider/edgedns.py) | edgegrid-python | A, AAAA, CNAME, MX, NAPTR, NS, PTR, SPF, SRV, SSHFP, TXT | No | | | [CloudflareProvider](/octodns/provider/cloudflare.py) | | A, AAAA, ALIAS, CAA, CNAME, LOC, MX, NS, PTR, SPF, SRV, TXT | No | CAA tags restricted | | [ConstellixProvider](/octodns/provider/constellix.py) | | A, AAAA, ALIAS (ANAME), CAA, CNAME, MX, NS, PTR, SPF, SRV, TXT | No | CAA tags restricted | diff --git a/octodns/provider/azuredns.py b/octodns/provider/azuredns.py index 83cfeb0..4d3dcc5 100644 --- a/octodns/provider/azuredns.py +++ b/octodns/provider/azuredns.py @@ -5,18 +5,28 @@ from __future__ import absolute_import, division, print_function, \ unicode_literals +from collections import defaultdict + from azure.identity import ClientSecretCredential +from azure.common.credentials import ServicePrincipalCredentials from azure.mgmt.dns import DnsManagementClient +from azure.mgmt.trafficmanager import TrafficManagerManagementClient from azure.mgmt.dns.models import ARecord, AaaaRecord, CaaRecord, \ CnameRecord, MxRecord, SrvRecord, NsRecord, PtrRecord, TxtRecord, Zone +from azure.mgmt.trafficmanager.models import Profile, DnsConfig, \ + MonitorConfig, Endpoint, MonitorConfigCustomHeadersItem import logging from functools import reduce -from ..record import Record +from ..record import Record, Update, GeoCodes from .base import BaseProvider +class AzureException(Exception): + pass + + def escape_semicolon(s): assert s return s.replace(';', '\\;') @@ -67,7 +77,8 @@ class _AzureRecord(object): 'TXT': TxtRecord } - def __init__(self, resource_group, record, delete=False): + def __init__(self, resource_group, record, delete=False, + traffic_manager=None): '''Constructor for _AzureRecord. Notes on Azure records: An Azure record set has the form @@ -94,9 +105,11 @@ class _AzureRecord(object): self.log = logging.getLogger('AzureRecord') self.resource_group = resource_group - self.zone_name = record.zone.name[:len(record.zone.name) - 1] + self.zone_name = record.zone.name[:-1] self.relative_record_set_name = record.name or '@' self.record_type = record._type + self._record = record + self.traffic_manager = traffic_manager if delete: return @@ -104,11 +117,11 @@ class _AzureRecord(object): # Refer to function docstring for key_name and class_name. key_name = '{}_records'.format(self.record_type).lower() if record._type == 'CNAME': - key_name = key_name[:len(key_name) - 1] + key_name = key_name[:-1] azure_class = self.TYPE_MAP[self.record_type] - self.params = getattr(self, '_params_for_{}'.format(record._type)) - self.params = self.params(record.data, key_name, azure_class) + params_for = getattr(self, '_params_for_{}'.format(record._type)) + self.params = params_for(record.data, key_name, azure_class) self.params['ttl'] = record.ttl def _params_for_A(self, data, key_name, azure_class): @@ -139,6 +152,9 @@ class _AzureRecord(object): return {key_name: params} def _params_for_CNAME(self, data, key_name, azure_class): + if self._record.dynamic and self.traffic_manager: + return {'target_resource': self.traffic_manager} + return {key_name: azure_class(cname=data['value'])} def _params_for_MX(self, data, key_name, azure_class): @@ -227,25 +243,6 @@ class _AzureRecord(object): (parse_dict(self.params) == parse_dict(b.params)) & \ (self.relative_record_set_name == b.relative_record_set_name) - def __str__(self): - '''String representation of an _AzureRecord. - :type return: str - ''' - string = 'Zone: {}; '.format(self.zone_name) - string += 'Name: {}; '.format(self.relative_record_set_name) - string += 'Type: {}; '.format(self.record_type) - if not hasattr(self, 'params'): - return string - string += 'Ttl: {}; '.format(self.params['ttl']) - for char in self.params: - if char != 'ttl': - try: - for rec in self.params[char]: - string += 'Record: {}; '.format(rec.__dict__) - except: - string += 'Record: {}; '.format(self.params[char].__dict__) - return string - def _check_endswith_dot(string): return string if string.endswith('.') else string + '.' @@ -259,14 +256,88 @@ def _parse_azure_type(string): :type return: str ''' - return string.split('/')[len(string.split('/')) - 1] - + return string.split('/')[-1] + + +def _traffic_manager_suffix(record): + return record.fqdn[:-1].replace('.', '-') + + +def _get_monitor(record): + monitor = MonitorConfig( + protocol=record.healthcheck_protocol, + port=record.healthcheck_port, + path=record.healthcheck_path, + ) + host = record.healthcheck_host + if host: + monitor.custom_headers = [MonitorConfigCustomHeadersItem( + name='Host', value=host + )] + return monitor + + +def _profile_is_match(have, desired): + if have is None or desired is None: + return False + + # compare basic attributes + if have.name != desired.name or \ + have.traffic_routing_method != desired.traffic_routing_method or \ + have.dns_config.ttl != desired.dns_config.ttl or \ + len(have.endpoints) != len(desired.endpoints): + return False + + # compare monitoring configuration + monitor_have = have.monitor_config + monitor_desired = desired.monitor_config + if monitor_have.protocol != monitor_desired.protocol or \ + monitor_have.port != monitor_desired.port or \ + monitor_have.path != monitor_desired.path or \ + monitor_have.custom_headers != monitor_desired.custom_headers: + return False + + # compare endpoints + method = have.traffic_routing_method + if method == 'Priority': + have_endpoints = sorted(have.endpoints, key=lambda e: e.priority) + desired_endpoints = sorted(desired.endpoints, + key=lambda e: e.priority) + elif method == 'Weighted': + have_endpoints = sorted(have.endpoints, key=lambda e: e.target) + desired_endpoints = sorted(desired.endpoints, key=lambda e: e.target) + else: + have_endpoints = have.endpoints + desired_endpoints = desired.endpoints + endpoints = zip(have_endpoints, desired_endpoints) + for have_endpoint, desired_endpoint in endpoints: + if have_endpoint.name != desired_endpoint.name or \ + have_endpoint.type != desired_endpoint.type: + return False + target_type = have_endpoint.type.split('/')[-1] + if target_type == 'externalEndpoints': + # compare value, weight, priority + if have_endpoint.target != desired_endpoint.target: + return False + if method == 'Weighted' and \ + have_endpoint.weight != desired_endpoint.weight: + return False + elif target_type == 'nestedEndpoints': + # compare targets + if have_endpoint.target_resource_id != \ + desired_endpoint.target_resource_id: + return False + # compare geos + if method == 'Geographic': + have_geos = sorted(have_endpoint.geo_mapping) + desired_geos = sorted(desired_endpoint.geo_mapping) + if have_geos != desired_geos: + return False + else: + # unexpected, give up + return False -def _check_for_alias(azrecord): - if (azrecord.target_resource.id and not azrecord.a_records and not - azrecord.cname_record): - return True - return False + return True class AzureProvider(BaseProvider): @@ -318,7 +389,7 @@ class AzureProvider(BaseProvider): possible to also hard-code into the config file: eg, resource_group. ''' SUPPORTS_GEO = False - SUPPORTS_DYNAMIC = False + SUPPORTS_DYNAMIC = True SUPPORTS = set(('A', 'AAAA', 'CAA', 'CNAME', 'MX', 'NS', 'PTR', 'SRV', 'TXT')) @@ -336,24 +407,45 @@ class AzureProvider(BaseProvider): self._dns_client_directory_id = directory_id self._dns_client_subscription_id = sub_id self.__dns_client = None + self.__tm_client = None self._resource_group = resource_group self._azure_zones = set() + self._traffic_managers = dict() @property def _dns_client(self): if self.__dns_client is None: - credential = ClientSecretCredential( - client_id=self._dns_client_client_id, - client_secret=self._dns_client_key, - tenant_id=self._dns_client_directory_id - ) + # Azure's logger spits out a lot of debug messages at 'INFO' + # level, override it by re-assigning `info` method to `debug` + # (ugly hack until I find a better way) + logger_name = 'azure.core.pipeline.policies.http_logging_policy' + logger = logging.getLogger(logger_name) + logger.info = logger.debug self.__dns_client = DnsManagementClient( - credential=credential, - subscription_id=self._dns_client_subscription_id + credential=ClientSecretCredential( + client_id=self._dns_client_client_id, + client_secret=self._dns_client_key, + tenant_id=self._dns_client_directory_id, + logger=logger, + ), + subscription_id=self._dns_client_subscription_id, ) return self.__dns_client + @property + def _tm_client(self): + if self.__tm_client is None: + self.__tm_client = TrafficManagerManagementClient( + ServicePrincipalCredentials( + self._dns_client_client_id, + secret=self._dns_client_key, + tenant=self._dns_client_directory_id, + ), + self._dns_client_subscription_id, + ) + return self.__tm_client + def _populate_zones(self): self.log.debug('azure_zones: loading') list_zones = self._dns_client.zones.list_by_resource_group @@ -388,6 +480,42 @@ class AzureProvider(BaseProvider): # Else return nothing (aka false) return + def _populate_traffic_managers(self): + self.log.debug('traffic managers: loading') + list_profiles = self._tm_client.profiles.list_by_resource_group + for profile in list_profiles(self._resource_group): + self._traffic_managers[profile.id] = profile + # link nested profiles in advance for convenience + for _, profile in self._traffic_managers.items(): + self._populate_nested_profiles(profile) + + def _populate_nested_profiles(self, profile): + for ep in profile.endpoints: + target_id = ep.target_resource_id + if target_id and target_id in self._traffic_managers: + target = self._traffic_managers[target_id] + ep.target_resource = self._populate_nested_profiles(target) + return profile + + def _get_tm_profile_by_id(self, resource_id): + if not self._traffic_managers: + self._populate_traffic_managers() + return self._traffic_managers.get(resource_id) + + def _profile_name_to_id(self, name): + return '/subscriptions/' + self._dns_client_subscription_id + \ + '/resourceGroups/' + self._resource_group + \ + '/providers/Microsoft.Network/trafficManagerProfiles/' + \ + name + + def _get_tm_profile_by_name(self, name): + profile_id = self._profile_name_to_id(name) + return self._get_tm_profile_by_id(profile_id) + + def _get_tm_for_dynamic_record(self, record): + name = _traffic_manager_suffix(record) + return self._get_tm_profile_by_name(name) + def populate(self, zone, target=False, lenient=False): '''Required function of manager.py to collect records from zone. @@ -417,40 +545,35 @@ class AzureProvider(BaseProvider): exists = False before = len(zone.records) - zone_name = zone.name[:len(zone.name) - 1] + zone_name = zone.name[:-1] self._populate_zones() - self._check_zone(zone_name) - _records = [] records = self._dns_client.record_sets.list_by_dns_zone if self._check_zone(zone_name): exists = True for azrecord in records(self._resource_group, zone_name): - if _parse_azure_type(azrecord.type) in self.SUPPORTS: - _records.append(azrecord) - for azrecord in _records: - record_name = azrecord.name if azrecord.name != '@' else '' typ = _parse_azure_type(azrecord.type) + if typ not in self.SUPPORTS: + continue - if typ in ['A', 'CNAME']: - if _check_for_alias(azrecord): - self.log.debug( - 'Skipping - ALIAS. zone=%s record=%s, type=%s', - zone_name, record_name, typ) # pragma: no cover - continue # pragma: no cover - - data = getattr(self, '_data_for_{}'.format(typ)) - data = data(azrecord) - data['type'] = typ - data['ttl'] = azrecord.ttl - record = Record.new(zone, record_name, data, source=self) - + record = self._populate_record(zone, azrecord, lenient) zone.add_record(record, lenient=lenient) self.log.info('populate: found %s records, exists=%s', len(zone.records) - before, exists) return exists + def _populate_record(self, zone, azrecord, lenient=False): + record_name = azrecord.name if azrecord.name != '@' else '' + typ = _parse_azure_type(azrecord.type) + + data_for = getattr(self, '_data_for_{}'.format(typ)) + data = data_for(azrecord) + data['type'] = typ + data['ttl'] = azrecord.ttl + return Record.new(zone, record_name, data, source=self, + lenient=lenient) + def _data_for_A(self, azrecord): return {'values': [ar.ipv4_address for ar in azrecord.a_records]} @@ -470,6 +593,9 @@ class AzureProvider(BaseProvider): :type return: dict ''' + if azrecord.cname_record is None and azrecord.target_resource.id: + return self._data_for_dynamic(azrecord) + return {'value': _check_endswith_dot(azrecord.cname_record.cname)} def _data_for_MX(self, azrecord): @@ -495,6 +621,322 @@ class AzureProvider(BaseProvider): ar.value)) for ar in azrecord.txt_records]} + def _data_for_dynamic(self, azrecord): + default = set() + pools = defaultdict(lambda: {'fallback': None, 'values': []}) + rules = [] + + # top level geo profile + geo_profile = self._get_tm_profile_by_id(azrecord.target_resource.id) + for geo_ep in geo_profile.endpoints: + rule = {} + + # resolve list of regions + geo_map = list(geo_ep.geo_mapping) + if geo_map != ['WORLD']: + if 'GEO-ME' in geo_map: + # Azure treats Middle East as a separate group, but + # its part of Asia in octoDNS, so we need to remove GEO-ME + # if GEO-AS is also in the list + # Throw exception otherwise, it should not happen if the + # profile was generated by octoDNS + if 'GEO-AS' not in geo_map: + msg = '_data_for_dynamic: Profile={}: '.format( + geo_profile.name) + msg += 'Middle East (GEO-ME) is not supported by ' + \ + 'octoDNS. It needs to be either paired ' + \ + 'with Asia (GEO-AS) or expanded into ' + \ + 'individual list of countries.' + raise AzureException(msg) + geo_map.remove('GEO-ME') + geos = rule.setdefault('geos', []) + for code in geo_map: + if code.startswith('GEO-'): + geos.append(code[len('GEO-'):]) + elif '-' in code: + country, province = code.split('-', 1) + country = GeoCodes.country_to_code(country) + geos.append('{}-{}'.format(country, province)) + else: + geos.append(GeoCodes.country_to_code(code)) + + # second level priority profile + pool = None + rule_endpoints = geo_ep.target_resource.endpoints + rule_endpoints = sorted(rule_endpoints, key=lambda e: e.priority) + for rule_ep in rule_endpoints: + pool_name = rule_ep.name + + # third (and last) level weighted RR profile + # these should be leaf node profiles with no further nesting + pool_profile = rule_ep.target_resource + + # last/default pool + if pool_name == '--default--': + for pool_ep in pool_profile.endpoints: + default.add(pool_ep.target) + # this should be the last one, so let's break here + break + + # set first priority endpoint as the rule's primary pool + if 'pool' not in rule: + rule['pool'] = pool_name + + if pool: + # set current pool as fallback of the previous pool + pool['fallback'] = pool_name + + pool = pools[pool_name] + for pool_ep in pool_profile.endpoints: + val = pool_ep.target + value_dict = { + 'value': _check_endswith_dot(val), + 'weight': pool_ep.weight, + } + if value_dict not in pool['values']: + pool['values'].append(value_dict) + + if 'pool' not in rule or not default: + # this will happen if the priority profile does not have + # enough endpoints + msg = 'Expected at least 2 endpoints in {}, got {}'.format( + geo_ep.target_resource.name, len(rule_endpoints) + ) + raise AzureException(msg) + rules.append(rule) + + # Order and convert to a list + default = sorted(default) + + data = { + 'dynamic': { + 'pools': pools, + 'rules': rules, + }, + 'value': _check_endswith_dot(default[0]), + } + + return data + + def _extra_changes(self, existing, desired, changes): + changed = set() + + # Abort if there are non-CNAME dynamic records + for change in changes: + record = change.record + changed.add(record) + typ = record._type + dynamic = getattr(record, 'dynamic', False) + if dynamic and typ != 'CNAME': + msg = '{}: Dynamic records in Azure must be of type CNAME' + msg = msg.format(record.fqdn) + raise AzureException(msg) + + log = self.log.info + extra = [] + for record in desired.records: + if not getattr(record, 'dynamic', False): + # Already changed, or not dynamic, no need to check it + continue + + # let's walk through and show what will be changed even if + # the record is already be in list of changes + added = (record in changed) + + active = set() + profiles = self._generate_traffic_managers(record) + for profile in profiles: + name = profile.name + active.add(name) + existing_profile = self._get_tm_profile_by_name(name) + if not _profile_is_match(existing_profile, profile): + log('_extra_changes: Profile name=%s will be synced', + name) + if not added: + extra.append(Update(record, record)) + added = True + + existing_profiles = self._find_traffic_managers(record) + for name in existing_profiles - active: + log('_extra_changes: Profile name=%s will be destroyed', name) + if not added: + extra.append(Update(record, record)) + added = True + + return extra + + def _generate_tm_profile(self, name, routing, endpoints, record): + # set appropriate endpoint types + endpoint_type_prefix = 'Microsoft.Network/trafficManagerProfiles/' + for ep in endpoints: + if ep.target_resource_id: + ep.type = endpoint_type_prefix + 'nestedEndpoints' + elif ep.target: + ep.type = endpoint_type_prefix + 'externalEndpoints' + else: + msg = ('_generate_tm_profile: Invalid endpoint {} ' + + 'in profile {}, needs to have either target or ' + + 'target_resource_id').format(ep.name, name) + raise AzureException(msg) + + # build and return + return Profile( + id=self._profile_name_to_id(name), + name=name, + traffic_routing_method=routing, + dns_config=DnsConfig( + relative_name=name, + ttl=record.ttl, + ), + monitor_config=_get_monitor(record), + endpoints=endpoints, + location='global', + ) + + def _generate_traffic_managers(self, record): + traffic_managers = [] + pools = record.dynamic.pools + + tm_suffix = _traffic_manager_suffix(record) + profile = self._generate_tm_profile + + # construct the default pool that will be used at the end of + # all rules + target = record.value[:-1] + default_endpoints = [Endpoint( + name=target, + target=target, + weight=1, + )] + default_profile_name = 'default--{}'.format(tm_suffix) + default_profile = profile(default_profile_name, 'Weighted', + default_endpoints, record) + traffic_managers.append(default_profile) + + geo_endpoints = [] + + for rule in record.dynamic.rules: + pool_name = rule.data['pool'] + rule_endpoints = [] + priority = 1 + + while pool_name: + # iterate until we reach end of fallback chain + pool = pools[pool_name].data + profile_name = 'pool-{}--{}'.format(pool_name, tm_suffix) + endpoints = [] + for val in pool['values']: + target = val['value'] + # strip trailing dot from CNAME value + target = target[:-1] + endpoints.append(Endpoint( + name=target, + target=target, + weight=val.get('weight', 1), + )) + pool_profile = profile(profile_name, 'Weighted', endpoints, + record) + traffic_managers.append(pool_profile) + + # append pool to endpoint list of fallback rule profile + rule_endpoints.append(Endpoint( + name=pool_name, + target_resource_id=pool_profile.id, + priority=priority, + )) + + priority += 1 + pool_name = pool.get('fallback') + + # append default profile to the end + rule_endpoints.append(Endpoint( + name='--default--', + target_resource_id=default_profile.id, + priority=priority, + )) + # create rule profile with fallback chain + rule_profile_name = 'rule-{}--{}'.format(rule.data['pool'], + tm_suffix) + rule_profile = profile(rule_profile_name, 'Priority', + rule_endpoints, record) + traffic_managers.append(rule_profile) + + # append rule profile to top-level geo profile + rule_geos = rule.data.get('geos', []) + geos = [] + if len(rule_geos) > 0: + for geo in rule_geos: + if '-' in geo: + geos.append(geo.split('-', 1)[-1]) + else: + geos.append('GEO-{}'.format(geo)) + if geo == 'AS': + # Middle East is part of Asia in octoDNS, but + # Azure treats it as a separate "group", so let's + # add it in the list of geo mappings. We will drop + # it when we later parse the list of regions. + geos.append('GEO-ME') + else: + geos.append('WORLD') + geo_endpoints.append(Endpoint( + name='rule-{}'.format(rule.data['pool']), + target_resource_id=rule_profile.id, + geo_mapping=geos, + )) + + geo_profile = profile(tm_suffix, 'Geographic', geo_endpoints, record) + traffic_managers.append(geo_profile) + + return traffic_managers + + def _sync_traffic_managers(self, record): + desired_profiles = self._generate_traffic_managers(record) + seen = set() + + tm_sync = self._tm_client.profiles.create_or_update + populate = self._populate_nested_profiles + + for desired in desired_profiles: + name = desired.name + if name in seen: + continue + + existing = self._get_tm_profile_by_name(name) + if not _profile_is_match(existing, desired): + self.log.info( + '_sync_traffic_managers: Syncing profile=%s', name) + profile = tm_sync(self._resource_group, name, desired) + self._traffic_managers[profile.id] = populate(profile) + else: + self.log.debug( + '_sync_traffic_managers: Skipping profile=%s: up to date', + name) + seen.add(name) + + return seen + + def _find_traffic_managers(self, record): + tm_suffix = _traffic_manager_suffix(record) + + profiles = set() + for profile_id in self._traffic_managers: + # match existing profiles with record's suffix + name = profile_id.split('/')[-1] + if name == tm_suffix or \ + name.endswith('--{}'.format(tm_suffix)): + profiles.add(name) + + return profiles + + def _traffic_managers_gc(self, record, active_profiles): + existing_profiles = self._find_traffic_managers(record) + + # delete unused profiles + for profile_name in existing_profiles - active_profiles: + self.log.info('_traffic_managers_gc: Deleting profile=%s', + profile_name) + self._tm_client.profiles.delete(self._resource_group, profile_name) + def _apply_Create(self, change): '''A record from change must be created. @@ -503,7 +945,15 @@ class AzureProvider(BaseProvider): :type return: void ''' - ar = _AzureRecord(self._resource_group, change.new) + record = change.new + + dynamic = getattr(record, 'dynamic', False) + if dynamic: + self._sync_traffic_managers(record) + + profile = self._get_tm_for_dynamic_record(record) + ar = _AzureRecord(self._resource_group, record, + traffic_manager=profile) create = self._dns_client.record_sets.create_or_update create(resource_group_name=ar.resource_group, @@ -512,17 +962,71 @@ class AzureProvider(BaseProvider): record_type=ar.record_type, parameters=ar.params) - self.log.debug('* Success Create/Update: {}'.format(ar)) + self.log.debug('* Success Create: {}'.format(record)) - _apply_Update = _apply_Create + def _apply_Update(self, change): + '''A record from change must be created. + + :param change: a change object + :type change: octodns.record.Change + + :type return: void + ''' + existing = change.existing + new = change.new + existing_is_dynamic = getattr(existing, 'dynamic', False) + new_is_dynamic = getattr(new, 'dynamic', False) + + update_record = True + + if new_is_dynamic: + active = self._sync_traffic_managers(new) + # only TTL is configured in record, everything else goes inside + # traffic managers, so no need to update if TTL is unchanged + # and existing record is already aliased to its traffic manager + if existing.ttl == new.ttl and existing_is_dynamic: + update_record = False + + if update_record: + profile = self._get_tm_for_dynamic_record(new) + ar = _AzureRecord(self._resource_group, new, + traffic_manager=profile) + update = self._dns_client.record_sets.create_or_update + + update(resource_group_name=ar.resource_group, + zone_name=ar.zone_name, + relative_record_set_name=ar.relative_record_set_name, + record_type=ar.record_type, + parameters=ar.params) + + if new_is_dynamic: + # let's cleanup unused traffic managers + self._traffic_managers_gc(new, active) + elif existing_is_dynamic: + # cleanup traffic managers when a dynamic record gets + # changed to a simple record + self._traffic_managers_gc(existing, set()) + + self.log.debug('* Success Update: {}'.format(new)) def _apply_Delete(self, change): - ar = _AzureRecord(self._resource_group, change.existing, delete=True) + '''A record from change must be deleted. + + :param change: a change object + :type change: octodns.record.Change + + :type return: void + ''' + record = change.record + ar = _AzureRecord(self._resource_group, record, delete=True) delete = self._dns_client.record_sets.delete delete(self._resource_group, ar.zone_name, ar.relative_record_set_name, ar.record_type) + if getattr(record, 'dynamic', False): + self._traffic_managers_gc(record, set()) + self.log.debug('* Success Delete: {}'.format(ar)) def _apply(self, plan): diff --git a/requirements.txt b/requirements.txt index 54bd2a3..143ba67 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ PyYaml==5.4 azure-common==1.1.27 azure-identity==1.5.0 azure-mgmt-dns==8.0.0 +azure-mgmt-trafficmanager==0.51.0 boto3==1.15.9 botocore==1.18.9 dnspython==1.16.0 diff --git a/tests/test_octodns_provider_azuredns.py b/tests/test_octodns_provider_azuredns.py index d9b2b5a..5655719 100644 --- a/tests/test_octodns_provider_azuredns.py +++ b/tests/test_octodns_provider_azuredns.py @@ -5,19 +5,23 @@ from __future__ import absolute_import, division, print_function, \ unicode_literals -from octodns.record import Create, Delete, Record +from octodns.record import Create, Update, Delete, Record from octodns.provider.azuredns import _AzureRecord, AzureProvider, \ - _check_endswith_dot, _parse_azure_type, _check_for_alias + _check_endswith_dot, _parse_azure_type, _traffic_manager_suffix, \ + _get_monitor, _profile_is_match, AzureException from octodns.zone import Zone from octodns.provider.base import Plan from azure.mgmt.dns.models import ARecord, AaaaRecord, CaaRecord, \ CnameRecord, MxRecord, SrvRecord, NsRecord, PtrRecord, TxtRecord, \ RecordSet, SoaRecord, SubResource, Zone as AzureZone +from azure.mgmt.trafficmanager.models import Profile, DnsConfig, \ + MonitorConfig, Endpoint, MonitorConfigCustomHeadersItem from msrestazure.azure_exceptions import CloudError +from six import text_type from unittest import TestCase -from mock import Mock, patch +from mock import Mock, patch, call zone = Zone(name='unit.tests.', sub_zones=[]) @@ -343,6 +347,43 @@ class Test_AzureRecord(TestCase): assert(azure_records[i]._equals(octo)) +class Test_DynamicAzureRecord(TestCase): + def test_azure_record(self): + tm_profile = Profile() + data = { + 'ttl': 60, + 'type': 'CNAME', + 'value': 'default.unit.tests.', + 'dynamic': { + 'pools': { + 'one': { + 'values': [ + {'value': 'one.unit.tests.', 'weight': 1} + ], + 'fallback': 'two', + }, + 'two': { + 'values': [ + {'value': 'two.unit.tests.', 'weight': 1} + ], + }, + }, + 'rules': [ + {'geos': ['AF'], 'pool': 'one'}, + {'pool': 'two'}, + ], + } + } + octo_record = Record.new(zone, 'foo', data) + azure_record = _AzureRecord('TestAzure', octo_record, + traffic_manager=tm_profile) + self.assertEqual(azure_record.zone_name, zone.name[:-1]) + self.assertEqual(azure_record.relative_record_set_name, 'foo') + self.assertEqual(azure_record.record_type, 'CNAME') + self.assertEqual(azure_record.params['ttl'], 60) + self.assertEqual(azure_record.params['target_resource'], tm_profile) + + class Test_ParseAzureType(TestCase): def test_parse_azure_type(self): for expected, test in [['A', 'Microsoft.Network/dnszones/A'], @@ -361,40 +402,369 @@ class Test_CheckEndswithDot(TestCase): self.assertEquals(expected, _check_endswith_dot(test)) -class Test_CheckAzureAlias(TestCase): - def test_check_for_alias(self): - alias_record = type('C', (object,), {}) - alias_record.target_resource = type('C', (object,), {}) - alias_record.target_resource.id = "/subscriptions/x/resourceGroups/y/z" - alias_record.a_records = None - alias_record.cname_record = None +class Test_TrafficManagerSuffix(TestCase): + def test_traffic_manager_suffix(self): + test = Record.new(zone, 'foo', data={ + 'ttl': 60, 'type': 'CNAME', 'value': 'default.unit.tests.', + }) + self.assertEqual(_traffic_manager_suffix(test), 'foo-unit-tests') + + +class Test_GetMonitor(TestCase): + def test_get_monitor(self): + record = Record.new(zone, 'foo', data={ + 'type': 'CNAME', 'ttl': 60, 'value': 'default.unit.tests.', + 'octodns': { + 'healthcheck': { + 'path': '/_ping', + 'port': 4443, + 'protocol': 'HTTPS', + } + }, + }) + + monitor = _get_monitor(record) + self.assertEqual(monitor.protocol, 'HTTPS') + self.assertEqual(monitor.port, 4443) + self.assertEqual(monitor.path, '/_ping') + headers = monitor.custom_headers + self.assertIsInstance(headers, list) + self.assertEquals(len(headers), 1) + headers = headers[0] + self.assertEqual(headers.name, 'Host') + self.assertEqual(headers.value, record.healthcheck_host) + + # test TCP monitor + record._octodns['healthcheck']['protocol'] = 'TCP' + monitor = _get_monitor(record) + self.assertEqual(monitor.protocol, 'TCP') + self.assertIsNone(monitor.custom_headers) + + +class Test_ProfileIsMatch(TestCase): + def test_profile_is_match(self): + is_match = _profile_is_match + + self.assertFalse(is_match(None, Profile())) + + # Profile object builder with default property values that can be + # overridden for testing below + def profile( + name = 'foo-unit-tests', + ttl = 60, + method = 'Geographic', + monitor_proto = 'HTTPS', + monitor_port = 4443, + monitor_path = '/_ping', + endpoints = 1, + endpoint_name = 'name', + endpoint_type = 'profile/nestedEndpoints', + target = 'target.unit.tests', + target_id = 'resource/id', + geos = ['GEO-AF'], + weight = 1, + priority = 1, + ): + dns = DnsConfig(ttl=ttl) + return Profile( + name=name, traffic_routing_method=method, dns_config=dns, + monitor_config=MonitorConfig( + protocol=monitor_proto, + port=monitor_port, + path=monitor_path, + ), + endpoints=[Endpoint( + name=endpoint_name, + type=endpoint_type, + target=target, + target_resource_id=target_id, + geo_mapping=geos, + weight=weight, + priority=priority, + )] + [Endpoint()] * (endpoints - 1), + ) + + self.assertTrue(is_match(profile(), profile())) + + self.assertFalse(is_match(profile(), profile(name='two'))) + self.assertFalse(is_match(profile(), profile(endpoints=2))) + self.assertFalse(is_match(profile(), profile(monitor_proto='HTTP'))) + self.assertFalse(is_match(profile(), profile(endpoint_name='a'))) + self.assertFalse(is_match(profile(), profile(endpoint_type='b'))) + self.assertFalse( + is_match(profile(endpoint_type='b'), profile(endpoint_type='b')) + ) + self.assertFalse(is_match(profile(), profile(target_id='rsrc/id2'))) + self.assertFalse(is_match(profile(), profile(geos=['IN']))) + + def wprofile(**kwargs): + kwargs['method'] = 'Weighted' + kwargs['endpoint_type'] = 'profile/externalEndpoints' + return profile(**kwargs) - self.assertEquals(_check_for_alias(alias_record), True) + self.assertFalse(is_match(wprofile(), wprofile(target='bar.unit'))) + self.assertFalse(is_match(wprofile(), wprofile(weight=3))) class TestAzureDnsProvider(TestCase): def _provider(self): return self._get_provider('mock_spc', 'mock_dns_client') + @patch('octodns.provider.azuredns.TrafficManagerManagementClient') @patch('octodns.provider.azuredns.DnsManagementClient') @patch('octodns.provider.azuredns.ClientSecretCredential') - def _get_provider(self, mock_spc, mock_dns_client): + @patch('octodns.provider.azuredns.ServicePrincipalCredentials') + def _get_provider(self, mock_spc, mock_css, mock_dns_client, + mock_tm_client): '''Returns a mock AzureProvider object to use in testing. :param mock_spc: placeholder :type mock_spc: str :param mock_dns_client: placeholder :type mock_dns_client: str + :param mock_tm_client: placeholder + :type mock_tm_client: str :type return: AzureProvider ''' provider = AzureProvider('mock_id', 'mock_client', 'mock_key', 'mock_directory', 'mock_sub', 'mock_rg' ) + # Fetch the client to force it to load the creds provider._dns_client + + # set critical functions to return properly + tm_list = provider._tm_client.profiles.list_by_resource_group + tm_list.return_value = [] + tm_sync = provider._tm_client.profiles.create_or_update + + def side_effect(rg, name, profile): + return profile + + tm_sync.side_effect = side_effect + return provider + def _get_dynamic_record(self, zone): + return Record.new(zone, 'foo', data={ + 'type': 'CNAME', + 'ttl': 60, + 'value': 'default.unit.tests.', + 'dynamic': { + 'pools': { + 'one': { + 'values': [ + {'value': 'one.unit.tests.', 'weight': 11}, + ], + 'fallback': 'two', + }, + 'two': { + 'values': [ + {'value': 'two1.unit.tests.', 'weight': 3}, + {'value': 'two2.unit.tests.', 'weight': 4}, + ], + 'fallback': 'three', + }, + 'three': { + 'values': [ + {'value': 'three.unit.tests.', 'weight': 13}, + ], + }, + }, + 'rules': [ + {'geos': ['AF', 'EU-DE', 'NA-US-CA'], 'pool': 'one'}, + {'pool': 'three'}, + ], + }, + 'octodns': { + 'healthcheck': { + 'path': '/_ping', + 'port': 4443, + 'protocol': 'HTTPS', + } + }, + }) + + def _get_tm_profiles(self, provider): + sub = provider._dns_client_subscription_id + rg = provider._resource_group + base_id = '/subscriptions/' + sub + \ + '/resourceGroups/' + rg + \ + '/providers/Microsoft.Network/trafficManagerProfiles/' + suffix = 'foo-unit-tests' + id_format = base_id + '{}--' + suffix + name_format = '{}--' + suffix + + dns = DnsConfig(ttl=60) + header = MonitorConfigCustomHeadersItem(name='Host', + value='foo.unit.tests') + monitor = MonitorConfig(protocol='HTTPS', port=4443, path='/_ping', + custom_headers=[header]) + external = 'Microsoft.Network/trafficManagerProfiles/externalEndpoints' + nested = 'Microsoft.Network/trafficManagerProfiles/nestedEndpoints' + + return [ + Profile( + id=id_format.format('default'), + name=name_format.format('default'), + traffic_routing_method='Weighted', + dns_config=dns, + monitor_config=monitor, + endpoints=[ + Endpoint( + name='default.unit.tests', + type=external, + target='default.unit.tests', + weight=1, + ), + ], + ), + Profile( + id=id_format.format('pool-one'), + name=name_format.format('pool-one'), + traffic_routing_method='Weighted', + dns_config=dns, + monitor_config=monitor, + endpoints=[ + Endpoint( + name='one.unit.tests', + type=external, + target='one.unit.tests', + weight=11, + ), + ], + ), + Profile( + id=id_format.format('pool-two'), + name=name_format.format('pool-two'), + traffic_routing_method='Weighted', + dns_config=dns, + monitor_config=monitor, + endpoints=[ + Endpoint( + name='two1.unit.tests', + type=external, + target='two1.unit.tests', + weight=3, + ), + Endpoint( + name='two2.unit.tests', + type=external, + target='two2.unit.tests', + weight=4, + ), + ], + ), + Profile( + id=id_format.format('pool-three'), + name=name_format.format('pool-three'), + traffic_routing_method='Weighted', + dns_config=dns, + monitor_config=monitor, + endpoints=[ + Endpoint( + name='three.unit.tests', + type=external, + target='three.unit.tests', + weight=13, + ), + ], + ), + Profile( + id=id_format.format('rule-one'), + name=name_format.format('rule-one'), + traffic_routing_method='Priority', + dns_config=dns, + monitor_config=monitor, + endpoints=[ + Endpoint( + name='one', + type=nested, + target_resource_id=id_format.format('pool-one'), + priority=1, + ), + Endpoint( + name='two', + type=nested, + target_resource_id=id_format.format('pool-two'), + priority=2, + ), + Endpoint( + name='three', + type=nested, + target_resource_id=id_format.format('pool-three'), + priority=3, + ), + Endpoint( + name='--default--', + type=nested, + target_resource_id=id_format.format('default'), + priority=4, + ), + ], + ), + Profile( + id=id_format.format('rule-three'), + name=name_format.format('rule-three'), + traffic_routing_method='Priority', + dns_config=dns, + monitor_config=monitor, + endpoints=[ + Endpoint( + name='three', + type=nested, + target_resource_id=id_format.format('pool-three'), + priority=1, + ), + Endpoint( + name='--default--', + type=nested, + target_resource_id=id_format.format('default'), + priority=2, + ), + ], + ), + Profile( + id=base_id + suffix, + name=suffix, + traffic_routing_method='Geographic', + dns_config=dns, + monitor_config=monitor, + endpoints=[ + Endpoint( + geo_mapping=['GEO-AF', 'DE', 'US-CA'], + name='rule-one', + type=nested, + target_resource_id=id_format.format('rule-one'), + ), + Endpoint( + geo_mapping=['WORLD'], + name='rule-three', + type=nested, + target_resource_id=id_format.format('rule-three'), + ), + ], + ), + ] + + def _get_dynamic_package(self): + '''Convenience function to setup a sample dynamic record. + ''' + provider = self._get_provider() + + # setup traffic manager profiles + tm_list = provider._tm_client.profiles.list_by_resource_group + tm_list.return_value = self._get_tm_profiles(provider) + + # setup zone with dynamic record + zone = Zone(name='unit.tests.', sub_zones=[]) + record = self._get_dynamic_record(zone) + zone.add_record(record) + + # return everything + return provider, zone, record + def test_populate_records(self): provider = self._get_provider() @@ -510,6 +880,121 @@ class TestAzureDnsProvider(TestCase): self.assertEquals(len(zone.records), 17) self.assertTrue(exists) + def test_populate_dynamic(self): + # Middle east without Asia raises exception + provider, zone, record = self._get_dynamic_package() + tm_suffix = _traffic_manager_suffix(record) + tm_id = provider._profile_name_to_id + tm_list = provider._tm_client.profiles.list_by_resource_group + rule_name = 'rule-one--{}'.format(tm_suffix) + nested = 'Microsoft.Network/trafficManagerProfiles/nestedEndpoints' + tm_list.return_value = [ + Profile( + id=tm_id(tm_suffix), + name=tm_suffix, + traffic_routing_method='Geographic', + endpoints=[ + Endpoint( + geo_mapping=['GEO-ME'], + ), + ], + ), + ] + azrecord = RecordSet( + ttl=60, + target_resource=SubResource(id=tm_id(tm_suffix)), + ) + azrecord.name = record.name or '@' + azrecord.type = 'Microsoft.Network/dnszones/{}'.format(record._type) + with self.assertRaises(AzureException) as ctx: + provider._populate_record(zone, azrecord) + self.assertTrue(text_type(ctx).startswith( + 'Middle East (GEO-ME) is not supported' + )) + + # empty priority profile raises exception + provider, zone, record = self._get_dynamic_package() + tm_list = provider._tm_client.profiles.list_by_resource_group + rule_name = 'rule-one--{}'.format(tm_suffix) + nested = 'Microsoft.Network/trafficManagerProfiles/nestedEndpoints' + tm_list.return_value = [ + Profile( + id=tm_id(rule_name), + name=rule_name, + traffic_routing_method='Priority', + endpoints=[], + ), + Profile( + id=tm_id(tm_suffix), + name=tm_suffix, + traffic_routing_method='Geographic', + endpoints=[ + Endpoint( + geo_mapping=['WORLD'], + name='rule-one', + type=nested, + target_resource_id=tm_id(rule_name), + ), + ], + ), + ] + with self.assertRaises(AzureException) as ctx: + provider._populate_record(zone, azrecord) + self.assertTrue(text_type(ctx).startswith( + 'Expected at least 2 endpoints' + )) + + # valid set of profiles produce expected dynamic record + provider, zone, record = self._get_dynamic_package() + root_profile_id = provider._profile_name_to_id( + _traffic_manager_suffix(record) + ) + azrecord = RecordSet( + ttl=60, + target_resource=SubResource(id=root_profile_id), + ) + azrecord.name = record.name or '@' + azrecord.type = 'Microsoft.Network/dnszones/{}'.format(record._type) + + record = provider._populate_record(zone, azrecord) + self.assertEqual(record.name, 'foo') + self.assertEqual(record.ttl, 60) + self.assertEqual(record.value, 'default.unit.tests.') + self.assertEqual(record.dynamic._data(), { + 'pools': { + 'one': { + 'values': [ + {'value': 'one.unit.tests.', 'weight': 11}, + ], + 'fallback': 'two', + }, + 'two': { + 'values': [ + {'value': 'two1.unit.tests.', 'weight': 3}, + {'value': 'two2.unit.tests.', 'weight': 4}, + ], + 'fallback': 'three', + }, + 'three': { + 'values': [ + {'value': 'three.unit.tests.', 'weight': 13}, + ], + 'fallback': None, + }, + }, + 'rules': [ + {'geos': ['AF', 'EU-DE', 'NA-US-CA'], 'pool': 'one'}, + {'pool': 'three'}, + ], + }) + + # valid profiles with Middle East test case + geo_profile = provider._get_tm_for_dynamic_record(record) + geo_profile.endpoints[0].geo_mapping.extend(['GEO-ME', 'GEO-AS']) + record = provider._populate_record(zone, azrecord) + self.assertIn('AS', record.dynamic.rules[0].data['geos']) + self.assertNotIn('ME', record.dynamic.rules[0].data['geos']) + def test_populate_zone(self): provider = self._get_provider() @@ -541,20 +1026,388 @@ class TestAzureDnsProvider(TestCase): None ) + def test_extra_changes(self): + provider, existing, record = self._get_dynamic_package() + + # test simple records produce no extra changes + desired = Zone(name=existing.name, sub_zones=[]) + desired.add_record(Record.new(desired, 'simple', data={ + 'type': record._type, + 'ttl': record.ttl, + 'value': record.value, + })) + extra = provider._extra_changes(desired, desired, []) + self.assertEqual(len(extra), 0) + + # test an unchanged dynamic record produces no extra changes + desired.add_record(record) + extra = provider._extra_changes(existing, desired, []) + self.assertEqual(len(extra), 0) + + # test unused TM produces the extra change for clean up + sample_profile = self._get_tm_profiles(provider)[0] + tm_id = provider._profile_name_to_id + root_profile_name = _traffic_manager_suffix(record) + extra_profile = Profile( + id=tm_id('random--{}'.format(root_profile_name)), + name='random--{}'.format(root_profile_name), + traffic_routing_method='Weighted', + dns_config=sample_profile.dns_config, + monitor_config=sample_profile.monitor_config, + endpoints=sample_profile.endpoints, + ) + tm_list = provider._tm_client.profiles.list_by_resource_group + tm_list.return_value.append(extra_profile) + provider._populate_traffic_managers() + extra = provider._extra_changes(existing, desired, []) + self.assertEqual(len(extra), 1) + extra = extra[0] + self.assertIsInstance(extra, Update) + self.assertEqual(extra.new, record) + desired._remove_record(record) + tm_list.return_value.pop() + + # test new dynamic record does not produce an extra change for it + new_dynamic = Record.new(desired, record.name + '2', data={ + 'type': record._type, + 'ttl': record.ttl, + 'value': record.value, + 'dynamic': record.dynamic._data(), + 'octodns': record._octodns, + }) + # test change in healthcheck by using a different port number + update_dynamic = Record.new(desired, record.name, data={ + 'type': record._type, + 'ttl': record.ttl, + 'value': record.value, + 'dynamic': record.dynamic._data(), + 'octodns': { + 'healthcheck': { + 'path': '/_ping', + 'port': 443, + 'protocol': 'HTTPS', + }, + }, + }) + desired.add_record(new_dynamic) + desired.add_record(update_dynamic) + changes = [Create(new_dynamic)] + extra = provider._extra_changes(existing, desired, changes) + # implicitly asserts that new_dynamic was not added to extra changes + # as it was already in the `changes` list + self.assertEqual(len(extra), 1) + extra = extra[0] + self.assertIsInstance(extra, Update) + self.assertEqual(extra.new, update_dynamic) + + # test non-CNAME dynamic record throws exception + a_dynamic = Record.new(desired, record.name + '3', data={ + 'type': 'A', + 'ttl': record.ttl, + 'values': ['1.1.1.1'], + 'dynamic': { + 'pools': { + 'one': {'values': [{'value': '2.2.2.2'}]}, + }, + 'rules': [ + {'pool': 'one'}, + ], + }, + }) + desired.add_record(a_dynamic) + changes.append(Create(a_dynamic)) + with self.assertRaises(AzureException): + provider._extra_changes(existing, desired, changes) + + def test_generate_tm_profile(self): + provider, zone, record = self._get_dynamic_package() + profile_gen = provider._generate_tm_profile + + name = 'foobar' + routing = 'Priority' + endpoints = [ + Endpoint(target='one.unit.tests'), + Endpoint(target_resource_id='/s/1/rg/foo/tm/foobar2'), + Endpoint(name='invalid'), + ] + + # invalid endpoint raises exception + with self.assertRaises(AzureException): + profile_gen(name, routing, endpoints, record) + + # regular test + endpoints.pop() + profile = profile_gen(name, routing, endpoints, record) + + # implicitly tests _profile_name_to_id + sub = provider._dns_client_subscription_id + rg = provider._resource_group + expected_id = '/subscriptions/' + sub + \ + '/resourceGroups/' + rg + \ + '/providers/Microsoft.Network/trafficManagerProfiles/' + name + self.assertEqual(profile.id, expected_id) + self.assertEqual(profile.name, name) + self.assertEqual(profile.traffic_routing_method, routing) + self.assertEqual(profile.dns_config.ttl, record.ttl) + self.assertEqual(len(profile.endpoints), len(endpoints)) + + self.assertEqual( + profile.endpoints[0].type, + 'Microsoft.Network/trafficManagerProfiles/externalEndpoints' + ) + self.assertEqual( + profile.endpoints[1].type, + 'Microsoft.Network/trafficManagerProfiles/nestedEndpoints' + ) + + def test_generate_traffic_managers(self): + provider, zone, record = self._get_dynamic_package() + profiles = provider._generate_traffic_managers(record) + deduped = [] + seen = set() + for profile in profiles: + if profile.name not in seen: + deduped.append(profile) + seen.add(profile.name) + + # check that every profile is a match with what we expect + expected_profiles = self._get_tm_profiles(provider) + self.assertEqual(len(expected_profiles), len(deduped)) + for have, expected in zip(deduped, expected_profiles): + self.assertTrue(_profile_is_match(have, expected)) + + # check Asia/Middle East test case + record.dynamic._data()['rules'][0]['geos'].append('AS') + profiles = provider._generate_traffic_managers(record) + geo_profile_name = _traffic_manager_suffix(record) + geo_profile = next( + profile + for profile in profiles + if profile.name == geo_profile_name + ) + self.assertIn('GEO-ME', geo_profile.endpoints[0].geo_mapping) + self.assertIn('GEO-AS', geo_profile.endpoints[0].geo_mapping) + + def test_sync_traffic_managers(self): + provider, zone, record = self._get_dynamic_package() + provider._populate_traffic_managers() + + tm_sync = provider._tm_client.profiles.create_or_update + + suffix = 'foo-unit-tests' + expected_seen = { + suffix, 'default--{}'.format(suffix), + 'rule-one--{}'.format(suffix), 'rule-three--{}'.format(suffix), + 'pool-one--{}'.format(suffix), 'pool-two--{}'.format(suffix), + 'pool-three--{}'.format(suffix), + } + + # test no change + seen = provider._sync_traffic_managers(record) + self.assertEqual(seen, expected_seen) + tm_sync.assert_not_called() + + # test that changing weight causes update API call + dynamic = record.dynamic._data() + dynamic['pools']['one']['values'][0]['weight'] = 14 + data = { + 'type': 'CNAME', + 'ttl': record.ttl, + 'value': record.value, + 'dynamic': dynamic, + 'octodns': record._octodns, + } + new_record = Record.new(zone, record.name, data) + tm_sync.reset_mock() + seen2 = provider._sync_traffic_managers(new_record) + self.assertEqual(seen2, expected_seen) + tm_sync.assert_called_once() + + # test that new profile was successfully inserted in cache + new_profile = provider._get_tm_profile_by_name( + 'pool-one--{}'.format(suffix) + ) + self.assertEqual(new_profile.endpoints[0].weight, 14) + + def test_find_traffic_managers(self): + provider, zone, record = self._get_dynamic_package() + + # insert a non-matching profile + sample_profile = self._get_tm_profiles(provider)[0] + # dummy record for generating suffix + record2 = Record.new(zone, record.name + '2', data={ + 'type': record._type, + 'ttl': record.ttl, + 'value': record.value, + }) + suffix2 = _traffic_manager_suffix(record2) + tm_id = provider._profile_name_to_id + extra_profile = Profile( + id=tm_id('random--{}'.format(suffix2)), + name='random--{}'.format(suffix2), + traffic_routing_method='Weighted', + dns_config=sample_profile.dns_config, + monitor_config=sample_profile.monitor_config, + endpoints=sample_profile.endpoints, + ) + tm_list = provider._tm_client.profiles.list_by_resource_group + tm_list.return_value.append(extra_profile) + provider._populate_traffic_managers() + + # implicitly asserts that non-matching profile is not included + suffix = _traffic_manager_suffix(record) + self.assertEqual(provider._find_traffic_managers(record), { + suffix, 'default--{}'.format(suffix), + 'rule-one--{}'.format(suffix), 'rule-three--{}'.format(suffix), + 'pool-one--{}'.format(suffix), 'pool-two--{}'.format(suffix), + 'pool-three--{}'.format(suffix), + }) + + def test_traffic_manager_gc(self): + provider, zone, record = self._get_dynamic_package() + provider._populate_traffic_managers() + + profiles = provider._find_traffic_managers(record) + profile_delete_mock = provider._tm_client.profiles.delete + + provider._traffic_managers_gc(record, profiles) + profile_delete_mock.assert_not_called() + + profile_delete_mock.reset_mock() + remove = list(profiles)[3] + profiles.discard(remove) + + provider._traffic_managers_gc(record, profiles) + profile_delete_mock.assert_has_calls( + [call(provider._resource_group, remove)] + ) + def test_apply(self): provider = self._get_provider() - changes = [] - deletes = [] - for i in octo_records: - changes.append(Create(i)) - deletes.append(Delete(i)) + half = int(len(octo_records) / 2) + changes = [Create(r) for r in octo_records[:half]] + \ + [Update(r, r) for r in octo_records[half:]] + deletes = [Delete(r) for r in octo_records] self.assertEquals(19, provider.apply(Plan(None, zone, changes, True))) self.assertEquals(19, provider.apply(Plan(zone, zone, deletes, True))) + def test_apply_create_dynamic(self): + provider = self._get_provider() + + tm_list = provider._tm_client.profiles.list_by_resource_group + tm_list.return_value = [] + + tm_sync = provider._tm_client.profiles.create_or_update + + zone = Zone(name='unit.tests.', sub_zones=[]) + record = self._get_dynamic_record(zone) + + profiles = self._get_tm_profiles(provider) + + provider._apply_Create(Create(record)) + # create was called as many times as number of profiles required for + # the dynamic record + self.assertEqual(tm_sync.call_count, len(profiles)) + + create = provider._dns_client.record_sets.create_or_update + create.assert_called_once() + + def test_apply_update_dynamic(self): + # existing is simple, new is dynamic + provider = self._get_provider() + tm_list = provider._tm_client.profiles.list_by_resource_group + tm_list.return_value = [] + profiles = self._get_tm_profiles(provider) + dynamic_record = self._get_dynamic_record(zone) + simple_record = Record.new(zone, dynamic_record.name, data={ + 'type': 'CNAME', + 'ttl': 3600, + 'value': 'cname.unit.tests.', + }) + change = Update(simple_record, dynamic_record) + provider._apply_Update(change) + tm_sync, dns_update, tm_delete = ( + provider._tm_client.profiles.create_or_update, + provider._dns_client.record_sets.create_or_update, + provider._tm_client.profiles.delete + ) + self.assertEqual(tm_sync.call_count, len(profiles)) + dns_update.assert_called_once() + tm_delete.assert_not_called() + + # existing is dynamic, new is simple + provider, existing, dynamic_record = self._get_dynamic_package() + profiles = self._get_tm_profiles(provider) + change = Update(dynamic_record, simple_record) + provider._apply_Update(change) + tm_sync, dns_update, tm_delete = ( + provider._tm_client.profiles.create_or_update, + provider._dns_client.record_sets.create_or_update, + provider._tm_client.profiles.delete + ) + tm_sync.assert_not_called() + dns_update.assert_called_once() + self.assertEqual(tm_delete.call_count, len(profiles)) + + # both are dynamic, healthcheck port is changed + provider, existing, dynamic_record = self._get_dynamic_package() + profiles = self._get_tm_profiles(provider) + dynamic_record2 = self._get_dynamic_record(existing) + dynamic_record2._octodns['healthcheck']['port'] += 1 + change = Update(dynamic_record, dynamic_record2) + provider._apply_Update(change) + tm_sync, dns_update, tm_delete = ( + provider._tm_client.profiles.create_or_update, + provider._dns_client.record_sets.create_or_update, + provider._tm_client.profiles.delete + ) + self.assertEqual(tm_sync.call_count, len(profiles)) + dns_update.assert_not_called() + tm_delete.assert_not_called() + + # both are dynamic, extra profile should be deleted + provider, existing, dynamic_record = self._get_dynamic_package() + sample_profile = self._get_tm_profiles(provider)[0] + tm_id = provider._profile_name_to_id + root_profile_name = _traffic_manager_suffix(dynamic_record) + extra_profile = Profile( + id=tm_id('random--{}'.format(root_profile_name)), + name='random--{}'.format(root_profile_name), + traffic_routing_method='Weighted', + dns_config=sample_profile.dns_config, + monitor_config=sample_profile.monitor_config, + endpoints=sample_profile.endpoints, + ) + tm_list = provider._tm_client.profiles.list_by_resource_group + tm_list.return_value.append(extra_profile) + change = Update(dynamic_record, dynamic_record) + provider._apply_Update(change) + tm_sync, dns_update, tm_delete = ( + provider._tm_client.profiles.create_or_update, + provider._dns_client.record_sets.create_or_update, + provider._tm_client.profiles.delete + ) + tm_sync.assert_not_called() + dns_update.assert_not_called() + tm_delete.assert_called_once() + + def test_apply_delete_dynamic(self): + provider, existing, record = self._get_dynamic_package() + provider._populate_traffic_managers() + profiles = self._get_tm_profiles(provider) + change = Delete(record) + provider._apply_Delete(change) + dns_delete, tm_delete = ( + provider._dns_client.record_sets.delete, + provider._tm_client.profiles.delete + ) + dns_delete.assert_called_once() + self.assertEqual(tm_delete.call_count, len(profiles)) + def test_create_zone(self): provider = self._get_provider()