From 8ba533533337f402ec722969aa77d2da447bf74d Mon Sep 17 00:00:00 2001 From: Ross McFarland Date: Sun, 1 Jan 2023 21:15:11 -0500 Subject: [PATCH] Add support for DsRecord type --- CHANGELOG.md | 1 + octodns/record/__init__.py | 138 +++++++++++++++++++++++- tests/test_octodns_record.py | 197 +++++++++++++++++++++++++++++++++++ 3 files changed, 333 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 901ca14..0348d50 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ #### Stuff +* Added new DsRecord type (provider support will be added over time) * Added simple IgnoreRootNsFilter * Minor refactor on YamlProvider to add get_filenames making it a bit easier to create specialized providers inheriting from it diff --git a/octodns/record/__init__.py b/octodns/record/__init__.py index 6245812..2c18dba 100644 --- a/octodns/record/__init__.py +++ b/octodns/record/__init__.py @@ -902,6 +902,135 @@ class DnameValue(_TargetValue): pass +class DsValue(EqualityTupleMixin, dict): + # https://www.rfc-editor.org/rfc/rfc4034.html#section-2.1 + + @classmethod + def parse_rdata_text(cls, value): + try: + flags, protocol, algorithm, public_key = value.split(' ') + except ValueError: + raise RrParseError() + try: + flags = int(flags) + except ValueError: + pass + try: + protocol = int(protocol) + except ValueError: + pass + try: + algorithm = int(algorithm) + except ValueError: + pass + return { + 'flags': flags, + 'protocol': protocol, + 'algorithm': algorithm, + 'public_key': public_key, + } + + @classmethod + def validate(cls, data, _type): + if not isinstance(data, (list, tuple)): + data = (data,) + reasons = [] + for value in data: + try: + int(value['flags']) + except KeyError: + reasons.append('missing flags') + except ValueError: + reasons.append(f'invalid flags "{value["flags"]}"') + try: + int(value['protocol']) + except KeyError: + reasons.append('missing protocol') + except ValueError: + reasons.append(f'invalid protocol "{value["protocol"]}"') + try: + int(value['algorithm']) + except KeyError: + reasons.append('missing algorithm') + except ValueError: + reasons.append(f'invalid algorithm "{value["algorithm"]}"') + if 'public_key' not in value: + reasons.append('missing public_key') + return reasons + + @classmethod + def process(cls, values): + return [cls(v) for v in values] + + def __init__(self, value): + super().__init__( + { + 'flags': int(value['flags']), + 'protocol': int(value['protocol']), + 'algorithm': int(value['algorithm']), + 'public_key': value['public_key'], + } + ) + + @property + def flags(self): + return self['flags'] + + @flags.setter + def flags(self, value): + self['flags'] = value + + @property + def protocol(self): + return self['protocol'] + + @protocol.setter + def protocol(self, value): + self['protocol'] = value + + @property + def algorithm(self): + return self['algorithm'] + + @algorithm.setter + def algorithm(self, value): + self['algorithm'] = value + + @property + def public_key(self): + return self['public_key'] + + @public_key.setter + def public_key(self, value): + self['public_key'] = value + + @property + def data(self): + return self + + @property + def rdata_text(self): + return ( + f'{self.flags} {self.protocol} {self.algorithm} {self.public_key}' + ) + + def _equality_tuple(self): + return (self.flags, self.protocol, self.algorithm, self.public_key) + + def __repr__(self): + return ( + f'{self.flags} {self.protocol} {self.algorithm} {self.public_key}' + ) + + +class DsRecord(ValuesMixin, Record): + _type = 'DS' + _value_type = DsValue + + +Record.register_type(DsRecord) + + class _IpAddress(str): @classmethod def parse_rdata_text(cls, value): @@ -2215,9 +2344,12 @@ class TlsaValue(EqualityTupleMixin, dict): 'certificate_usage': int(value.get('certificate_usage', 0)), 'selector': int(value.get('selector', 0)), 'matching_type': int(value.get('matching_type', 0)), - 'certificate_association_data': value[ - 'certificate_association_data' - ], + # force it to a string, in case the hex has only numerical + # values and it was converted to an int at some point + # TODO: this needed on any others? + 'certificate_association_data': str( + value['certificate_association_data'] + ), } ) diff --git a/tests/test_octodns_record.py b/tests/test_octodns_record.py index c59547d..a7d8445 100644 --- a/tests/test_octodns_record.py +++ b/tests/test_octodns_record.py @@ -13,6 +13,8 @@ from octodns.record import ( CaaValue, CnameRecord, DnameRecord, + DsValue, + DsRecord, Create, Delete, GeoValue, @@ -654,6 +656,201 @@ class TestRecord(TestCase): def test_dname(self): self.assertSingleValue(DnameRecord, 'target.foo.com.', 'other.foo.com.') + def test_ds(self): + for a, b in ( + # diff flags + ( + { + 'flags': 0, + 'protocol': 1, + 'algorithm': 2, + 'public_key': 'abcdef0123456', + }, + { + 'flags': 1, + 'protocol': 1, + 'algorithm': 2, + 'public_key': 'abcdef0123456', + }, + ), + # diff protocol + ( + { + 'flags': 0, + 'protocol': 1, + 'algorithm': 2, + 'public_key': 'abcdef0123456', + }, + { + 'flags': 0, + 'protocol': 2, + 'algorithm': 2, + 'public_key': 'abcdef0123456', + }, + ), + # diff algorithm + ( + { + 'flags': 0, + 'protocol': 1, + 'algorithm': 2, + 'public_key': 'abcdef0123456', + }, + { + 'flags': 0, + 'protocol': 1, + 'algorithm': 3, + 'public_key': 'abcdef0123456', + }, + ), + # diff public_key + ( + { + 'flags': 0, + 'protocol': 1, + 'algorithm': 2, + 'public_key': 'abcdef0123456', + }, + { + 'flags': 0, + 'protocol': 1, + 'algorithm': 2, + 'public_key': 'bcdef0123456a', + }, + ), + ): + a = DsValue(a) + self.assertEqual(a, a) + b = DsValue(b) + self.assertEqual(b, b) + self.assertNotEqual(a, b) + self.assertNotEqual(b, a) + self.assertTrue(a < b) + + # empty string won't parse + with self.assertRaises(RrParseError): + DsValue.parse_rdata_text('') + + # single word won't parse + with self.assertRaises(RrParseError): + DsValue.parse_rdata_text('nope') + + # 2nd word won't parse + with self.assertRaises(RrParseError): + DsValue.parse_rdata_text('0 1') + + # 3rd word won't parse + with self.assertRaises(RrParseError): + DsValue.parse_rdata_text('0 1 2') + + # 5th word won't parse + with self.assertRaises(RrParseError): + DsValue.parse_rdata_text('0 1 2 key blah') + + # things ints, will parse + self.assertEqual( + { + 'flags': 'one', + 'protocol': 'two', + 'algorithm': 'three', + 'public_key': 'key', + }, + DsValue.parse_rdata_text('one two three key'), + ) + + # valid + data = { + 'flags': 0, + 'protocol': 1, + 'algorithm': 2, + 'public_key': '99148c81', + } + self.assertEqual(data, DsValue.parse_rdata_text('0 1 2 99148c81')) + self.assertEqual([], DsValue.validate(data, 'DS')) + + # missing flags + data = {'protocol': 1, 'algorithm': 2, 'public_key': '99148c81'} + self.assertEqual(['missing flags'], DsValue.validate(data, 'DS')) + # invalid flags + data = { + 'flags': 'a', + 'protocol': 1, + 'algorithm': 2, + 'public_key': '99148c81', + } + self.assertEqual(['invalid flags "a"'], DsValue.validate(data, 'DS')) + + # missing protocol + data = {'flags': 1, 'algorithm': 2, 'public_key': '99148c81'} + self.assertEqual(['missing protocol'], DsValue.validate(data, 'DS')) + # invalid protocol + data = { + 'flags': 1, + 'protocol': 'a', + 'algorithm': 2, + 'public_key': '99148c81', + } + self.assertEqual(['invalid protocol "a"'], DsValue.validate(data, 'DS')) + + # missing algorithm + data = {'flags': 1, 'protocol': 2, 'public_key': '99148c81'} + self.assertEqual(['missing algorithm'], DsValue.validate(data, 'DS')) + # invalid algorithm + data = { + 'flags': 1, + 'protocol': 2, + 'algorithm': 'a', + 'public_key': '99148c81', + } + self.assertEqual( + ['invalid algorithm "a"'], DsValue.validate(data, 'DS') + ) + + # missing algorithm (list) + data = {'flags': 1, 'protocol': 2, 'algorithm': 3} + self.assertEqual(['missing public_key'], DsValue.validate([data], 'DS')) + + zone = Zone('unit.tests.', []) + values = [ + { + 'flags': 0, + 'protocol': 1, + 'algorithm': 2, + 'public_key': '99148c81', + }, + { + 'flags': 1, + 'protocol': 2, + 'algorithm': 3, + 'public_key': '99148c44', + }, + ] + a = DsRecord(zone, 'ds', {'ttl': 32, 'values': values}) + self.assertEqual(0, a.values[0].flags) + a.values[0].flags += 1 + self.assertEqual(1, a.values[0].flags) + + self.assertEqual(1, a.values[0].protocol) + a.values[0].protocol += 1 + self.assertEqual(2, a.values[0].protocol) + + self.assertEqual(2, a.values[0].algorithm) + a.values[0].algorithm += 1 + self.assertEqual(3, a.values[0].algorithm) + + self.assertEqual('99148c81', a.values[0].public_key) + a.values[0].public_key = '99148c42' + self.assertEqual('99148c42', a.values[0].public_key) + + self.assertEqual(1, a.values[1].flags) + self.assertEqual(2, a.values[1].protocol) + self.assertEqual(3, a.values[1].algorithm) + self.assertEqual('99148c44', a.values[1].public_key) + + self.assertEqual(DsValue(values[1]), a.values[1].data) + self.assertEqual('1 2 3 99148c44', a.values[1].rdata_text) + self.assertEqual('1 2 3 99148c44', a.values[1].__repr__()) + def test_loc(self): a_values = [ LocValue(