diff --git a/octodns/provider/__init__.py b/octodns/provider/__init__.py index 14ccf18..dbbaaa8 100644 --- a/octodns/provider/__init__.py +++ b/octodns/provider/__init__.py @@ -4,3 +4,7 @@ from __future__ import absolute_import, division, print_function, \ unicode_literals + + +class ProviderException(Exception): + pass diff --git a/octodns/provider/base.py b/octodns/provider/base.py index 238a68a..433eab8 100644 --- a/octodns/provider/base.py +++ b/octodns/provider/base.py @@ -10,17 +10,15 @@ from six import text_type from ..source.base import BaseSource from ..zone import Zone from .plan import Plan - - -class ProviderException(Exception): - pass +from . import ProviderException class BaseProvider(BaseSource): def __init__(self, id, apply_disabled=False, update_pcent_threshold=Plan.MAX_SAFE_UPDATE_PCENT, - delete_pcent_threshold=Plan.MAX_SAFE_DELETE_PCENT): + delete_pcent_threshold=Plan.MAX_SAFE_DELETE_PCENT, + strict_supports=False): super(BaseProvider, self).__init__(id) self.log.debug('__init__: id=%s, apply_disabled=%s, ' 'update_pcent_threshold=%.2f, ' @@ -32,6 +30,21 @@ class BaseProvider(BaseSource): self.apply_disabled = apply_disabled self.update_pcent_threshold = update_pcent_threshold self.delete_pcent_threshold = delete_pcent_threshold + self.strict_supports = strict_supports + + def _process_desired_zone(self, desired): + ''' + An opportunity for providers to modify that desired zone records before + planning. + + - Must do their work and then call super with the results of that work + - Must not modify the `desired` parameter or its records and should + make a copy of anything it's modifying + - Must call supports_warn_or_except with information about any changes + that are made to have them logged or throw errors depending on the + configuration + ''' + return desired def _include_change(self, change): ''' @@ -49,18 +62,14 @@ class BaseProvider(BaseSource): return [] def supports_warn_or_except(self, msg): - # TODO: base class param to control warn vs except - if False: - raise ProviderException(msg) + if self.strict_supports: + raise ProviderException('{}: {}'.format(self.id, msg)) self.log.warning(msg) - def process_desired_zone(self, desired): - return desired - def plan(self, desired, processors=[]): self.log.info('plan: desired=%s', desired.name) - desired = self.process_desired_zone(desired) + desired = self._process_desired_zone(desired) existing = Zone(desired.name, desired.sub_zones) exists = self.populate(existing, target=True, lenient=True) diff --git a/octodns/provider/route53.py b/octodns/provider/route53.py index 8841df4..8552a8c 100644 --- a/octodns/provider/route53.py +++ b/octodns/provider/route53.py @@ -925,7 +925,7 @@ class Route53Provider(BaseProvider): return data - def process_desired_zone(self, desired): + def _process_desired_zone(self, desired): ret = Zone(desired.name, desired.sub_zones) for record in desired.records: if getattr(record, 'dynamic', False): @@ -957,7 +957,7 @@ class Route53Provider(BaseProvider): ret.add_record(record) - return super(Route53Provider, self).process_desired_zone(ret) + return super(Route53Provider, self)._process_desired_zone(ret) def populate(self, zone, target=False, lenient=False): self.log.debug('populate: name=%s, target=%s, lenient=%s', zone.name, diff --git a/tests/test_octodns_provider_base.py b/tests/test_octodns_provider_base.py index 4dfce48..b748762 100644 --- a/tests/test_octodns_provider_base.py +++ b/tests/test_octodns_provider_base.py @@ -6,11 +6,12 @@ from __future__ import absolute_import, division, print_function, \ unicode_literals from logging import getLogger +from mock import MagicMock, call from six import text_type from unittest import TestCase from octodns.processor.base import BaseProcessor -from octodns.provider.base import BaseProvider +from octodns.provider.base import BaseProvider, ProviderException from octodns.provider.plan import Plan, UnsafePlan from octodns.record import Create, Delete, Record, Update from octodns.zone import Zone @@ -429,3 +430,27 @@ class TestBaseProvider(TestCase): delete_pcent_threshold=safe_pcent).raise_if_unsafe() self.assertTrue('Too many deletes' in text_type(ctx.exception)) + + def test_supports_warn_or_except(self): + class MinimalProvider(BaseProvider): + SUPPORTS = set() + SUPPORTS_GEO = False + + def __init__(self, **kwargs): + self.log = MagicMock() + super(MinimalProvider, self).__init__('minimal', **kwargs) + + normal = MinimalProvider(strict_supports=False) + # Should log and not expect + normal.supports_warn_or_except('Hello World!') + normal.log.warning.assert_called_once() + normal.log.warning.assert_has_calls([ + call('Hello World!') + ]) + + strict = MinimalProvider(strict_supports=True) + # Should log and not expect + with self.assertRaises(ProviderException) as ctx: + strict.supports_warn_or_except('Hello World!') + self.assertEquals('minimal: Hello World!', text_type(ctx.exception)) + strict.log.warning.assert_not_called() diff --git a/tests/test_octodns_provider_route53.py b/tests/test_octodns_provider_route53.py index 34124ee..b3e5ba4 100644 --- a/tests/test_octodns_provider_route53.py +++ b/tests/test_octodns_provider_route53.py @@ -399,7 +399,7 @@ class TestRoute53Provider(TestCase): # No records, essentially a no-op desired = Zone('unit.tests.', []) - got = provider.process_desired_zone(desired) + got = provider._process_desired_zone(desired) self.assertEquals(desired.records, got.records) # Record without any geos @@ -422,7 +422,7 @@ class TestRoute53Provider(TestCase): }, }) desired.add_record(record) - got = provider.process_desired_zone(desired) + got = provider._process_desired_zone(desired) self.assertEquals(desired.records, got.records) self.assertEquals(1, len(list(got.records)[0].dynamic.rules)) self.assertFalse('geos' in list(got.records)[0].dynamic.rules[0].data) @@ -455,7 +455,7 @@ class TestRoute53Provider(TestCase): }, }) desired.add_record(record) - got = provider.process_desired_zone(desired) + got = provider._process_desired_zone(desired) self.assertEquals(2, len(list(got.records)[0].dynamic.rules)) self.assertEquals(['EU', 'NA-US-OR'], list(got.records)[0].dynamic.rules[0].data['geos']) @@ -489,7 +489,7 @@ class TestRoute53Provider(TestCase): }, }) desired.add_record(record) - got = provider.process_desired_zone(desired) + got = provider._process_desired_zone(desired) self.assertEquals(1, len(list(got.records)[0].dynamic.rules)) self.assertFalse('geos' in list(got.records)[0].dynamic.rules[0].data) @@ -521,7 +521,7 @@ class TestRoute53Provider(TestCase): }, }) desired.add_record(record) - got = provider.process_desired_zone(desired) + got = provider._process_desired_zone(desired) self.assertEquals(2, len(list(got.records)[0].dynamic.rules)) self.assertEquals(['EU', 'NA-US-OR'], list(got.records)[0].dynamic.rules[0].data['geos'])