From 11ddb20005676a95b5367946ffcea8eb8d3437ae Mon Sep 17 00:00:00 2001 From: Ross McFarland Date: Sat, 12 Aug 2023 11:18:33 -0700 Subject: [PATCH] Refactory YamlProvider and SplitYamlProvider into a unified class --- octodns/provider/yaml.py | 315 +++++++++++++++------------- tests/test_octodns_provider_yaml.py | 12 +- 2 files changed, 176 insertions(+), 151 deletions(-) diff --git a/octodns/provider/yaml.py b/octodns/provider/yaml.py index 1d495a9..e0a727a 100644 --- a/octodns/provider/yaml.py +++ b/octodns/provider/yaml.py @@ -28,8 +28,35 @@ class YamlProvider(BaseProvider): # (optional, default True) enforce_order: true # Whether duplicate records should replace rather than error - # (optiona, default False) + # (optional, default False) populate_should_replace: false + # The filename used to load split style zones, False means disabled. + # When enabled the provider will search for zone records split across + # multiple YAML files in a directory with the zone name. + # See "Split Details" below for more information + # (optional, default False, . is the recommended best practice when + # enabling) + split_extension: false + + Split Details + ------------- + + All files are stored in a subdirectory matching the name of the zone + (including the trailing .) of the directory config. It is a recommended + best practice that the files be named RECORD.yaml, but all files are + sourced and processed as if they were a single large file. + + A full directory structure for the zone github.com. managed under directory + "zones/" would be: + + zones/ + github.com./ + .yaml + www.yaml + ... + + Overriding Values + ----------------- Overriding values can be accomplished using multiple yaml providers in the `sources` list where subsequent providers have `populate_should_replace` @@ -98,7 +125,6 @@ class YamlProvider(BaseProvider): You can then sync our records eternally with `--config-file=external.yaml` and internally (with the custom overrides) with `--config-file=internal.yaml` - ''' SUPPORTS_GEO = True @@ -107,6 +133,10 @@ class YamlProvider(BaseProvider): SUPPORTS_DYNAMIC_SUBNETS = True SUPPORTS_MULTIVALUE_PTR = True + # Any record name added to this set will be included in the catch-all file, + # instead of a file matching the record name. + CATCHALL_RECORD_NAMES = ('*', '') + def __init__( self, id, @@ -115,19 +145,25 @@ class YamlProvider(BaseProvider): enforce_order=True, populate_should_replace=False, supports_root_ns=True, + split_extension=False, + split_only=False, + split_catchall=False, *args, **kwargs, ): klass = self.__class__.__name__ self.log = logging.getLogger(f'{klass}[{id}]') self.log.debug( - '__init__: id=%s, directory=%s, default_ttl=%d, ' - 'enforce_order=%d, populate_should_replace=%d', + '__init__: id=%s, directory=%s, default_ttl=%d, enforce_order=%d, populate_should_replace=%s, supports_root_ns=%s, split_extension=%s, split_only=%s, split_catchall=%s', id, directory, default_ttl, enforce_order, populate_should_replace, + supports_root_ns, + split_extension, + split_only, + split_catchall, ) super().__init__(id, *args, **kwargs) self.directory = directory @@ -135,12 +171,15 @@ class YamlProvider(BaseProvider): self.enforce_order = enforce_order self.populate_should_replace = populate_should_replace self.supports_root_ns = supports_root_ns + self.split_extension = split_extension + self.split_only = split_only + self.split_catchall = split_catchall def copy(self): - args = dict(self.__dict__) - args['id'] = f'{args["id"]}-copy' - del args['log'] - return self.__class__(**args) + kwargs = dict(self.__dict__) + kwargs['id'] = f'{kwargs["id"]}-copy' + del kwargs['log'] + return YamlProvider(**kwargs) @property def SUPPORTS(self): @@ -162,6 +201,37 @@ class YamlProvider(BaseProvider): def SUPPORTS_ROOT_NS(self): return self.supports_root_ns + def list_zones(self): + self.log.debug('list_zones:') + zones = set() + + # TODO: don't allow both utf8 and idna versions of the same zone + extension = self.split_extension + if extension: + self.log.debug('list_zones: looking for split zones') + # look for split + # we want to leave the . + trim = len(extension) - 1 + for dirname in listdir(self.directory): + if not dirname.endswith(extension) or not isdir( + join(self.directory, dirname) + ): + continue + zones.add(dirname[:-trim]) + + if not self.split_only: + self.log.debug('list_zones: looking for zone files') + for filename in listdir(self.directory): + if ( + not filename.endswith('.yaml') + or filename.count('.') < 2 + or not isfile(join(self.directory, filename)) + ): + continue + zones.add(filename[:-4]) + + return sorted(zones) + def _populate_from_file(self, filename, zone, lenient): with open(filename, 'r') as fh: yaml_data = safe_load(fh, enforce_order=self.enforce_order) @@ -184,18 +254,6 @@ class YamlProvider(BaseProvider): '_populate_from_file: successfully loaded "%s"', filename ) - def get_filenames(self, zone): - return ( - join(self.directory, f'{zone.decoded_name}yaml'), - join(self.directory, f'{zone.name}yaml'), - ) - - def list_zones(self): - for filename in listdir(self.directory): - if not filename.endswith('.yaml') or filename.count('.') < 2: - continue - yield filename[:-4] - def populate(self, zone, target=False, lenient=False): self.log.debug( 'populate: name=%s, target=%s, lenient=%s', @@ -210,23 +268,52 @@ class YamlProvider(BaseProvider): return False before = len(zone.records) - utf8_filename, idna_filename = self.get_filenames(zone) - # we prefer utf8 - if isfile(utf8_filename): - if utf8_filename != idna_filename and isfile(idna_filename): - raise ProviderException( - f'Both UTF-8 "{utf8_filename}" and IDNA "{idna_filename}" exist for {zone.decoded_name}' - ) - filename = utf8_filename - else: - self.log.warning( - 'populate: "%s" does not exist, falling back to try idna version "%s"', - utf8_filename, - idna_filename, - ) - filename = idna_filename - self._populate_from_file(filename, zone, lenient) + sources = [] + + zone_name_utf8 = zone.name[:-1] + zone_name_idna = zone.decoded_name[:-1] + + directory = None + split_extension = self.split_extension + if split_extension: + utf8 = join(self.directory, f'{zone_name_utf8}{split_extension}') + idna = join(self.directory, f'{zone_name_idna}{split_extension}') + directory = None + if isdir(utf8): + if utf8 != idna and isdir(idna): + raise ProviderException( + f'Both UTF-8 "{utf8}" and IDNA "{idna}" exist for {zone.decoded_name}' + ) + directory = utf8 + else: + directory = idna + + for filename in listdir(directory): + if filename.endswith('.yaml'): + sources.append(join(directory, filename)) + + if not self.split_only: + utf8 = join(self.directory, f'{zone_name_utf8}.yaml') + idna = join(self.directory, f'{zone_name_idna}.yaml') + if isfile(utf8): + if utf8 != idna and isfile(idna): + raise ProviderException( + f'Both UTF-8 "{utf8}" and IDNA "{idna}" exist for {zone.decoded_name}' + ) + sources.append(utf8) + else: + sources.append(idna) + + if len(sources) == 0: + # TODO: what if we don't have any files + pass + + # determinstically order our sources + sources.sort() + + for source in sources: + self._populate_from_file(source, zone, lenient) self.log.info( 'populate: found %s records, exists=False', @@ -264,123 +351,65 @@ class YamlProvider(BaseProvider): data[k] = data[k][0] if not isdir(self.directory): + self.log.debug('_apply: creating directory=%s', self.directory) makedirs(self.directory) - self._do_apply(desired, data) - - def _do_apply(self, desired, data): - filename = join(self.directory, f'{desired.decoded_name}yaml') - self.log.debug('_apply: writing filename=%s', filename) - with open(filename, 'w') as fh: - safe_dump(dict(data), fh, allow_unicode=True) + if self.split_extension: + # we're going to do split files + decoded_name = desired.decoded_name[:-1] + directory = join( + self.directory, f'{decoded_name}{self.split_extension}' + ) + if not isdir(directory): + self.log.debug('_apply: creating split directory=%s', directory) + makedirs(directory) + + catchall = {} + for record, config in data.items(): + if self.split_catchall and record in self.CATCHALL_RECORD_NAMES: + catchall[record] = config + continue + filename = join(directory, f'{record}.yaml') + self.log.debug('_apply: writing filename=%s', filename) + + with open(filename, 'w') as fh: + record_data = {record: config} + safe_dump(record_data, fh) + + if catchall: + # Scrub the trailing . to make filenames more sane. + filename = join(directory, f'${decoded_name}.yaml') + self.log.debug( + '_apply: writing catchall filename=%s', filename + ) + with open(filename, 'w') as fh: + safe_dump(catchall, fh) -def _list_all_yaml_files(directory): - yaml_files = set() - for f in listdir(directory): - filename = join(directory, f) - if f.endswith('.yaml') and isfile(filename): - yaml_files.add(filename) - return list(yaml_files) + else: + # single large file + filename = join(self.directory, f'{desired.decoded_name}yaml') + self.log.debug('_apply: writing filename=%s', filename) + with open(filename, 'w') as fh: + safe_dump(dict(data), fh, allow_unicode=True) class SplitYamlProvider(YamlProvider): ''' - Core provider for records configured in multiple YAML files on disk. - - Behaves mostly similarly to YamlConfig, but interacts with multiple YAML - files, instead of a single monolitic one. All files are stored in a - subdirectory matching the name of the zone (including the trailing .) of - the directory config. The files are named RECORD.yaml, except for any - record which cannot be represented easily as a file; these are stored in - the catchall file, which is a YAML file the zone name, prepended with '$'. - For example, a zone, 'github.com.' would have a catch-all file named - '$github.com.yaml'. + DEPRECATED: Use YamlProvider with the split_extension parameter instead. - A full directory structure for the zone github.com. managed under directory - "zones/" would be: - - zones/ - github.com./ - $github.com.yaml - www.yaml - ... - - config: - class: octodns.provider.yaml.SplitYamlProvider - # The location of yaml config files (required) - directory: ./config - # The ttl to use for records when not specified in the data - # (optional, default 3600) - default_ttl: 3600 - # Whether or not to enforce sorting order on the yaml config - # (optional, default True) - enforce_order: True + TO BE REMOVED: 2.0 ''' - # Any record name added to this set will be included in the catch-all file, - # instead of a file matching the record name. - CATCHALL_RECORD_NAMES = ('*', '') - - def __init__(self, id, directory, extension='.', *args, **kwargs): - super().__init__(id, directory, *args, **kwargs) - self.extension = extension - - def _zone_directory(self, zone): - filename = f'{zone.name[:-1]}{self.extension}' - return join(self.directory, filename) - - def list_zones(self): - n = len(self.extension) - 1 - for filename in listdir(self.directory): - if not filename.endswith(self.extension): - continue - yield filename[:-n] - - def populate(self, zone, target=False, lenient=False): - self.log.debug( - 'populate: name=%s, target=%s, lenient=%s', - zone.name, - target, - lenient, + def __init__(self, id, directory, *args, extension='.', **kwargs): + kwargs.update( + { + 'split_extension': extension, + 'split_only': True, + 'split_catchall': True, + } ) - - if target: - # When acting as a target we ignore any existing records so that we - # create a completely new copy - return False - - before = len(zone.records) - yaml_filenames = _list_all_yaml_files(self._zone_directory(zone)) - self.log.info('populate: found %s YAML files', len(yaml_filenames)) - for yaml_filename in yaml_filenames: - self._populate_from_file(yaml_filename, zone, lenient) - - self.log.info( - 'populate: found %s records, exists=False', - len(zone.records) - before, + super().__init__(id, directory, *args, **kwargs) + self.log.warning( + '__init__: DEPRECATED use YamlProvider with split_extension and optionally split_only instead, will go away in v2.0' ) - return False - - def _do_apply(self, desired, data): - zone_dir = self._zone_directory(desired) - if not isdir(zone_dir): - makedirs(zone_dir) - - catchall = dict() - for record, config in data.items(): - if record in self.CATCHALL_RECORD_NAMES: - catchall[record] = config - continue - filename = join(zone_dir, f'{record}.yaml') - self.log.debug('_apply: writing filename=%s', filename) - with open(filename, 'w') as fh: - record_data = {record: config} - safe_dump(record_data, fh) - if catchall: - # Scrub the trailing . to make filenames more sane. - dname = desired.name[:-1] - filename = join(zone_dir, f'${dname}.yaml') - self.log.debug('_apply: writing catchall filename=%s', filename) - with open(filename, 'w') as fh: - safe_dump(catchall, fh) diff --git a/tests/test_octodns_provider_yaml.py b/tests/test_octodns_provider_yaml.py index 1cf017a..8f52901 100644 --- a/tests/test_octodns_provider_yaml.py +++ b/tests/test_octodns_provider_yaml.py @@ -3,7 +3,7 @@ # from os import makedirs -from os.path import basename, dirname, isdir, isfile, join +from os.path import dirname, isdir, isfile, join from unittest import TestCase from helpers import TemporaryDirectory @@ -13,11 +13,7 @@ from yaml.constructor import ConstructorError from octodns.idna import idna_encode from octodns.provider import ProviderException from octodns.provider.base import Plan -from octodns.provider.yaml import ( - SplitYamlProvider, - YamlProvider, - _list_all_yaml_files, -) +from octodns.provider.yaml import SplitYamlProvider, YamlProvider from octodns.record import Create, NsValue, Record, ValuesMixin from octodns.zone import SubzoneRecordException, Zone @@ -327,7 +323,7 @@ class TestSplitYamlProvider(TestCase): # This isn't great, but given the variable nature of the temp dir # names, it's necessary. - d = list(basename(f) for f in _list_all_yaml_files(directory)) + d = [join(directory, f) for f in yaml_files] self.assertEqual(len(yaml_files), len(d)) def test_zone_directory(self): @@ -573,7 +569,7 @@ class TestSplitYamlProvider(TestCase): ) copy = source.copy() self.assertEqual(source.directory, copy.directory) - self.assertEqual(source.extension, copy.extension) + self.assertEqual(source.split_extension, copy.split_extension) self.assertEqual(source.default_ttl, copy.default_ttl) self.assertEqual(source.enforce_order, copy.enforce_order) self.assertEqual(