diff --git a/octodns/processor/spf.py b/octodns/processor/spf.py index fbecc7e..0f5b5d6 100644 --- a/octodns/processor/spf.py +++ b/octodns/processor/spf.py @@ -29,7 +29,7 @@ class SpfDnsLookupProcessor(BaseProcessor): super().__init__(name) def _get_spf_from_txt_values( - self, values: list[str], record: Record + self, record: Record, values: list[str] ) -> Optional[str]: self.log.debug( f"_get_spf_from_txt_values: record={record.fqdn} values={values}" @@ -43,7 +43,7 @@ class SpfDnsLookupProcessor(BaseProcessor): if len(spf) > 1: raise SpfValueException( - f"{record.fqdn} has more than one SPF value" + f"{record.fqdn} has more than one SPF value in the TXT record" ) match = re.search(r"(v=spf1\s.+(?:all|redirect=))", "".join(values)) @@ -60,7 +60,7 @@ class SpfDnsLookupProcessor(BaseProcessor): f"_check_dns_lookups: record={record.fqdn} values={values} lookups={lookups}" ) - spf = self._get_spf_from_txt_values(values, record) + spf = self._get_spf_from_txt_values(record, values) if spf is None: return lookups @@ -84,11 +84,11 @@ class SpfDnsLookupProcessor(BaseProcessor): # The include mechanism can result in further lookups after resolving the DNS record if term.startswith('include:'): - answer = dns.resolver.resolve( - term.removeprefix('include:'), 'TXT' - ) + domain = term.removeprefix('include:') + answer = dns.resolver.resolve(domain, 'TXT') + answer_values = [value.to_text()[1:-1] for value in answer] lookups = self._check_dns_lookups( - record, [value.to_text()[1:-1] for value in answer], lookups + record, answer_values, lookups ) return lookups @@ -101,6 +101,6 @@ class SpfDnsLookupProcessor(BaseProcessor): if record._octodns.get('lenient'): continue - self._check_dns_lookups(record, record.values) + self._check_dns_lookups(record, record.values, 0) return zone diff --git a/tests/test_octodns_processor_spf.py b/tests/test_octodns_processor_spf.py index a9b9828..6784b9d 100644 --- a/tests/test_octodns_processor_spf.py +++ b/tests/test_octodns_processor_spf.py @@ -1,5 +1,5 @@ from unittest import TestCase -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch from octodns.processor.spf import ( SpfDnsLookupException, @@ -22,36 +22,36 @@ class TestSpfDnsLookupProcessor(TestCase): self.assertEqual( 'v=spf1 include:example.com ~all', processor._get_spf_from_txt_values( - ['v=DMARC1\; p=reject\;', 'v=spf1 include:example.com ~all'], record, + ['v=DMARC1\; p=reject\;', 'v=spf1 include:example.com ~all'], ), ) with self.assertRaises(SpfValueException): processor._get_spf_from_txt_values( + record, [ 'v=spf1 include:example.com ~all', 'v=spf1 include:example.com ~all', ], - record, ) self.assertEqual( 'v=spf1 include:example.com ~all', processor._get_spf_from_txt_values( - ['v=DMARC1\; p=reject\;', 'v=spf1 include:example.com ~all'], record, + ['v=DMARC1\; p=reject\;', 'v=spf1 include:example.com ~all'], ), ) with self.assertRaises(SpfValueException): processor._get_spf_from_txt_values( - ['v=spf1 include:example.com'], record + record, ['v=spf1 include:example.com'] ) self.assertIsNone( processor._get_spf_from_txt_values( - ['v=DMARC1\; p=reject\;'], record + record, ['v=DMARC1\; p=reject\;'] ) ) @@ -59,20 +59,20 @@ class TestSpfDnsLookupProcessor(TestCase): self.assertEqual( 'v=spf1 include:example.com ip4:1.2.3.4 ~all', processor._get_spf_from_txt_values( + record, [ 'v=spf1 include:example.com', ' ip4:1.2.3.4 ~all', 'v=DMARC1\; p=reject\;', ], - record, ), ) self.assertEqual( 'v=spf1 +mx redirect=', processor._get_spf_from_txt_values( - ['v=spf1 +mx redirect=example.com', 'v=DMARC1\; p=reject\;'], record, + ['v=spf1 +mx redirect=example.com', 'v=DMARC1\; p=reject\;'], ), ) @@ -117,11 +117,13 @@ class TestSpfDnsLookupProcessor(TestCase): ) ) + resolver_mock.reset_mock() txt_value_mock = MagicMock() txt_value_mock.to_text.return_value = '"v=spf1 -all"' resolver_mock.return_value = [txt_value_mock] self.assertEqual(zone, processor.process_source_zone(zone)) + resolver_mock.assert_called_once_with('example.com', 'TXT') zone = Zone('unit.tests.', []) zone.add_record( @@ -176,6 +178,7 @@ class TestSpfDnsLookupProcessor(TestCase): ) ) + resolver_mock.reset_mock() txt_value_mock = MagicMock() txt_value_mock.to_text.return_value = ( '"v=spf1 a a a a a a a a a a a -all"' @@ -184,6 +187,40 @@ class TestSpfDnsLookupProcessor(TestCase): with self.assertRaises(SpfDnsLookupException): processor.process_source_zone(zone) + resolver_mock.assert_called_once_with('example.com', 'TXT') + + zone = Zone('unit.tests.', []) + zone.add_record( + Record.new( + zone, + '', + { + 'type': 'TXT', + 'ttl': 86400, + 'values': [ + 'v=spf1 include:example.com -all', + 'v=DMARC1\; p=reject\;', + ], + }, + ) + ) + + resolver_mock.reset_mock() + first_txt_value_mock = MagicMock() + first_txt_value_mock.to_text.return_value = ( + '"v=spf1 include:_spf.example.com -all"' + ) + second_txt_value_mock = MagicMock() + second_txt_value_mock.to_text.return_value = '"v=spf1 a -all"' + resolver_mock.side_effect = [ + [first_txt_value_mock], + [second_txt_value_mock], + ] + + self.assertEqual(zone, processor.process_source_zone(zone)) + resolver_mock.assert_has_calls( + [call('example.com', 'TXT'), call('_spf.example.com', 'TXT')] + ) def test_processor_with_long_txt_value(self): processor = SpfDnsLookupProcessor('test') @@ -311,3 +348,33 @@ class TestSpfDnsLookupProcessor(TestCase): 'unit.tests. uses the deprecated ptr mechanism', str(context.exception), ) + resolver_mock.assert_called_once_with('example.com', 'TXT') + + @patch('dns.resolver.resolve') + def test_processor_errors_on_recursive_include_mechanism( + self, resolver_mock + ): + processor = SpfDnsLookupProcessor('test') + zone = Zone('unit.tests.', []) + + zone.add_record( + Record.new( + zone, + '', + { + 'type': 'TXT', + 'ttl': 86400, + 'values': ['v=spf1 include:example.com ~all'], + }, + ) + ) + + txt_value_mock = MagicMock() + txt_value_mock.to_text.return_value = ( + '"v=spf1 include:example.com ~all"' + ) + resolver_mock.return_value = [txt_value_mock] + + with self.assertRaises(SpfDnsLookupException): + processor.process_source_zone(zone) + resolver_mock.assert_called_with('example.com', 'TXT')