Browse Source

ZoneFileSource: allow users to specify file extension

pull/660/head
Adam Smith 5 years ago
parent
commit
b2eab63d54
3 changed files with 34 additions and 4 deletions
  1. +15
    -4
      octodns/source/axfr.py
  2. +7
    -0
      tests/test_octodns_source_axfr.py
  3. +12
    -0
      tests/zones/unit.tests.extension

+ 15
- 4
octodns/source/axfr.py View File

@ -206,17 +206,24 @@ class ZoneFileSource(AxfrBaseSource):
class: octodns.source.axfr.ZoneFileSource
# The directory holding the zone files
# Filenames should match zone name (eg. example.com.)
# with optional extension specified with file_extension
directory: ./zonefiles
# File extension on zone files
# Appended to zone name to locate file
# (optional, default None)
file_extension: zone
# Should sanity checks of the origin node be done
# (optional, default true)
check_origin: false
'''
def __init__(self, id, directory, check_origin=True):
def __init__(self, id, directory, file_extension=None, check_origin=True):
self.log = logging.getLogger('ZoneFileSource[{}]'.format(id))
self.log.debug('__init__: id=%s, directory=%s, check_origin=%s', id,
directory, check_origin)
self.log.debug('__init__: id=%s, directory=%s, file_extension=%s, '
'check_origin=%s', id,
directory, file_extension, check_origin)
super(ZoneFileSource, self).__init__(id)
self.directory = directory
self.file_extension = file_extension
self.check_origin = check_origin
self._zone_records = {}
@ -225,7 +232,11 @@ class ZoneFileSource(AxfrBaseSource):
zonefiles = listdir(self.directory)
if zone_name in zonefiles:
try:
z = dns.zone.from_file(join(self.directory, zone_name),
filename = zone_name
if self.file_extension:
filename = '{}{}'.format(zone_name,
self.file_extension.lstrip('.'))
z = dns.zone.from_file(join(self.directory, filename),
zone_name, relativize=False,
check_origin=self.check_origin)
except DNSException as error:


+ 7
- 0
tests/test_octodns_source_axfr.py View File

@ -45,6 +45,13 @@ class TestAxfrSource(TestCase):
class TestZoneFileSource(TestCase):
source = ZoneFileSource('test', './tests/zones')
source_extension = ZoneFileSource('test', './tests/zones', 'extension')
def test_zonefiles_with_extension(self):
# Load zonefiles with a specified file extension
valid = Zone('unit.tests.', [])
self.source_extension.populate(valid)
self.assertEquals(1, len(valid.records))
def test_populate(self):
# Valid zone file in directory


+ 12
- 0
tests/zones/unit.tests.extension View File

@ -0,0 +1,12 @@
$ORIGIN unit.tests.
@ 3600 IN SOA ns1.unit.tests. root.unit.tests. (
2018071501 ; Serial
3600 ; Refresh (1 hour)
600 ; Retry (10 minutes)
604800 ; Expire (1 week)
3600 ; NXDOMAIN ttl (1 hour)
)
; NS Records
@ 3600 IN NS ns1.unit.tests.
@ 3600 IN NS ns2.unit.tests.

Loading…
Cancel
Save