Browse Source

Merge branch 'master' into rreichel3/azure-dns-zones-improvement

pull/667/head
Ross McFarland 5 years ago
committed by GitHub
parent
commit
c0000ee627
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 13 deletions
  1. +8
    -6
      octodns/source/axfr.py
  2. +3
    -3
      tests/test_octodns_source_axfr.py
  3. +4
    -4
      tests/zones/ext.unit.tests.extension

+ 8
- 6
octodns/source/axfr.py View File

@ -229,14 +229,16 @@ class ZoneFileSource(AxfrBaseSource):
self._zone_records = {} self._zone_records = {}
def _load_zone_file(self, zone_name): def _load_zone_file(self, zone_name):
zone_filename = zone_name
if self.file_extension:
zone_filename = '{}{}'.format(zone_name,
self.file_extension.lstrip('.'))
zonefiles = listdir(self.directory) zonefiles = listdir(self.directory)
if zone_name in zonefiles:
if zone_filename in zonefiles:
try: try:
filename = zone_name
if self.file_extension:
filename = '{}{}'.format(zone_name,
self.file_extension.lstrip('.'))
z = dns.zone.from_file(join(self.directory, filename),
z = dns.zone.from_file(join(self.directory, zone_filename),
zone_name, relativize=False, zone_name, relativize=False,
check_origin=self.check_origin) check_origin=self.check_origin)
except DNSException as error: except DNSException as error:


+ 3
- 3
tests/test_octodns_source_axfr.py View File

@ -45,12 +45,12 @@ class TestAxfrSource(TestCase):
class TestZoneFileSource(TestCase): class TestZoneFileSource(TestCase):
source = ZoneFileSource('test', './tests/zones') source = ZoneFileSource('test', './tests/zones')
source_extension = ZoneFileSource('test', './tests/zones', 'extension')
def test_zonefiles_with_extension(self): def test_zonefiles_with_extension(self):
source = ZoneFileSource('test', './tests/zones', 'extension')
# Load zonefiles with a specified file extension # Load zonefiles with a specified file extension
valid = Zone('unit.tests.', [])
self.source_extension.populate(valid)
valid = Zone('ext.unit.tests.', [])
source.populate(valid)
self.assertEquals(1, len(valid.records)) self.assertEquals(1, len(valid.records))
def test_populate(self): def test_populate(self):


tests/zones/unit.tests.extension → tests/zones/ext.unit.tests.extension View File


Loading…
Cancel
Save