From 5b8498a550681857477aaaddd6a83f1792857759 Mon Sep 17 00:00:00 2001 From: Ross McFarland Date: Sat, 12 Aug 2023 19:31:38 -0700 Subject: [PATCH] Refactory yaml source logic out to make it easily testable --- octodns/provider/yaml.py | 68 +++++++++++----------- tests/test_octodns_provider_yaml.py | 90 ++++++++++++++++++++++------- 2 files changed, 103 insertions(+), 55 deletions(-) diff --git a/octodns/provider/yaml.py b/octodns/provider/yaml.py index e0a727a..2579414 100644 --- a/octodns/provider/yaml.py +++ b/octodns/provider/yaml.py @@ -205,11 +205,9 @@ class YamlProvider(BaseProvider): self.log.debug('list_zones:') zones = set() - # TODO: don't allow both utf8 and idna versions of the same zone extension = self.split_extension if extension: self.log.debug('list_zones: looking for split zones') - # look for split # we want to leave the . trim = len(extension) - 1 for dirname in listdir(self.directory): @@ -228,10 +226,41 @@ class YamlProvider(BaseProvider): or not isfile(join(self.directory, filename)) ): continue + # trim off the yaml, leave the . zones.add(filename[:-4]) return sorted(zones) + def _split_sources(self, zone): + ext = self.split_extension + utf8 = join(self.directory, f'{zone.decoded_name[:-1]}{ext}') + idna = join(self.directory, f'{zone.name[:-1]}{ext}') + directory = None + if isdir(utf8): + if utf8 != idna and isdir(idna): + raise ProviderException( + f'Both UTF-8 "{utf8}" and IDNA "{idna}" exist for {zone.decoded_name}' + ) + directory = utf8 + else: + directory = idna + + for filename in listdir(directory): + if filename.endswith('.yaml'): + yield join(directory, filename) + + def _zone_sources(self, zone): + utf8 = join(self.directory, f'{zone.decoded_name}yaml') + idna = join(self.directory, f'{zone.name}yaml') + if isfile(utf8): + if utf8 != idna and isfile(idna): + raise ProviderException( + f'Both UTF-8 "{utf8}" and IDNA "{idna}" exist for {zone.decoded_name}' + ) + return utf8 + + return idna + def _populate_from_file(self, filename, zone, lenient): with open(filename, 'r') as fh: yaml_data = safe_load(fh, enforce_order=self.enforce_order) @@ -271,43 +300,12 @@ class YamlProvider(BaseProvider): sources = [] - zone_name_utf8 = zone.name[:-1] - zone_name_idna = zone.decoded_name[:-1] - - directory = None split_extension = self.split_extension if split_extension: - utf8 = join(self.directory, f'{zone_name_utf8}{split_extension}') - idna = join(self.directory, f'{zone_name_idna}{split_extension}') - directory = None - if isdir(utf8): - if utf8 != idna and isdir(idna): - raise ProviderException( - f'Both UTF-8 "{utf8}" and IDNA "{idna}" exist for {zone.decoded_name}' - ) - directory = utf8 - else: - directory = idna - - for filename in listdir(directory): - if filename.endswith('.yaml'): - sources.append(join(directory, filename)) + sources.extend(self._split_sources(zone)) if not self.split_only: - utf8 = join(self.directory, f'{zone_name_utf8}.yaml') - idna = join(self.directory, f'{zone_name_idna}.yaml') - if isfile(utf8): - if utf8 != idna and isfile(idna): - raise ProviderException( - f'Both UTF-8 "{utf8}" and IDNA "{idna}" exist for {zone.decoded_name}' - ) - sources.append(utf8) - else: - sources.append(idna) - - if len(sources) == 0: - # TODO: what if we don't have any files - pass + sources.append(self._zone_sources(zone)) # determinstically order our sources sources.sort() diff --git a/tests/test_octodns_provider_yaml.py b/tests/test_octodns_provider_yaml.py index 8f52901..89b4bb1 100644 --- a/tests/test_octodns_provider_yaml.py +++ b/tests/test_octodns_provider_yaml.py @@ -2,8 +2,9 @@ # # -from os import makedirs +from os import makedirs, remove from os.path import dirname, isdir, isfile, join +from shutil import rmtree from unittest import TestCase from helpers import TemporaryDirectory @@ -12,12 +13,16 @@ from yaml.constructor import ConstructorError from octodns.idna import idna_encode from octodns.provider import ProviderException -from octodns.provider.base import Plan from octodns.provider.yaml import SplitYamlProvider, YamlProvider from octodns.record import Create, NsValue, Record, ValuesMixin from octodns.zone import SubzoneRecordException, Zone +def touch(filename): + with open(filename, 'w'): + pass + + class TestYamlProvider(TestCase): def test_provider(self): source = YamlProvider('test', join(dirname(__file__), 'config')) @@ -326,29 +331,74 @@ class TestSplitYamlProvider(TestCase): d = [join(directory, f) for f in yaml_files] self.assertEqual(len(yaml_files), len(d)) - def test_zone_directory(self): - source = SplitYamlProvider( - 'test', join(dirname(__file__), 'config/split'), extension='.tst' - ) + def test_split_sources(self): + with TemporaryDirectory() as td: + directory = join(td.dirname) - zone = Zone('unit.tests.', []) + provider = YamlProvider('test', directory, split_extension='.') - self.assertEqual( - join(dirname(__file__), 'config/split', 'unit.tests.tst'), - source._zone_directory(zone), - ) + zone = Zone('déjà.vu.', []) + zone_utf8 = join(directory, f'{zone.decoded_name}') + zone_idna = join(directory, f'{zone.name}') - def test_apply_handles_existing_zone_directory(self): - with TemporaryDirectory() as td: - provider = SplitYamlProvider( - 'test', join(td.dirname, 'config'), extension='.tst' + filenames = ( + '*.yaml', + '.yaml', + 'www.yaml', + f'${zone.decoded_name}yaml', ) - makedirs(join(td.dirname, 'config', 'does.exist.tst')) - zone = Zone('does.exist.', []) - self.assertTrue(isdir(provider._zone_directory(zone))) - provider.apply(Plan(None, zone, [], True)) - self.assertTrue(isdir(provider._zone_directory(zone))) + # create the utf8 zone dir + makedirs(zone_utf8) + # nothing in it so we should get nothing back + self.assertEqual([], list(provider._split_sources(zone))) + # create some record files + for filename in filenames: + touch(join(zone_utf8, filename)) + # make sure we see them + expected = [join(zone_utf8, f) for f in sorted(filenames)] + self.assertEqual(expected, sorted(provider._split_sources(zone))) + + # add a idna zone directory + makedirs(zone_idna) + for filename in filenames: + touch(join(zone_idna, filename)) + with self.assertRaises(ProviderException) as ctx: + list(provider._split_sources(zone)) + msg = str(ctx.exception) + self.assertTrue('Both UTF-8' in msg) + + # delete the utf8 version + rmtree(zone_utf8) + expected = [join(zone_idna, f) for f in sorted(filenames)] + self.assertEqual(expected, sorted(provider._split_sources(zone))) + + def test_zone_sources(self): + with TemporaryDirectory() as td: + directory = join(td.dirname) + + provider = YamlProvider('test', directory) + + zone = Zone('déjà.vu.', []) + utf8 = join(directory, f'{zone.decoded_name}yaml') + idna = join(directory, f'{zone.name}yaml') + + # create the utf8 version + touch(utf8) + # make sure that's what we get back + self.assertEqual(utf8, provider._zone_sources(zone)) + + # create idna version, both exists + touch(idna) + with self.assertRaises(ProviderException) as ctx: + provider._zone_sources(zone) + msg = str(ctx.exception) + self.assertTrue('Both UTF-8' in msg) + + # delete the utf8 version + remove(utf8) + # make sure that we get the idna one back + self.assertEqual(idna, provider._zone_sources(zone)) def test_provider(self): source = SplitYamlProvider(