diff --git a/octodns/provider/ns1.py b/octodns/provider/ns1.py index cf78241..0383dbf 100644 --- a/octodns/provider/ns1.py +++ b/octodns/provider/ns1.py @@ -19,6 +19,50 @@ from ..record import Record from .base import BaseProvider +class Ns1Client(object): + log = getLogger('NS1Client') + + def __init__(self, api_key, retry_count=4): + self.retry_count = retry_count + + client = NS1(apiKey=api_key) + self._records = client.records() + self._zones = client.zones() + + def _try(self, method, *args, **kwargs): + tries = self.retry_count + while True: # We'll raise to break after our tries expire + try: + return method(*args, **kwargs) + except RateLimitException as e: + if tries <= 1: + raise + period = float(e.period) + self.log.warn('rate limit encountered, pausing ' + 'for %ds and trying again, %d remaining', + period, tries) + sleep(period) + tries -= 1 + + def zones_retrieve(self, name): + return self._try(self._zones.retrieve, name) + + def zones_create(self, name): + return self._try(self._zones.create, name) + + def records_retrieve(self, zone, domain, _type): + return self._try(self._records.retrieve, zone, domain, _type) + + def records_create(self, zone, domain, _type, **params): + return self._try(self._records.create, zone, domain, _type, **params) + + def records_update(self, zone, domain, _type, **params): + return self._try(self._records.update, zone, domain, _type, **params) + + def records_delete(self, zone, domain, _type): + return self._try(self._records.delete, zone, domain, _type) + + class Ns1Provider(BaseProvider): ''' Ns1 provider @@ -34,13 +78,12 @@ class Ns1Provider(BaseProvider): ZONE_NOT_FOUND_MESSAGE = 'server error: zone not found' - def __init__(self, id, api_key, *args, **kwargs): + def __init__(self, id, api_key, retry_count=4, *args, **kwargs): self.log = getLogger('Ns1Provider[{}]'.format(id)) - self.log.debug('__init__: id=%s, api_key=***', id) + self.log.debug('__init__: id=%s, api_key=***, retry_count=%d', id, + retry_count) super(Ns1Provider, self).__init__(id, *args, **kwargs) - client = NS1(apiKey=api_key) - self._records = client.records() - self._zones = client.zones() + self._client = Ns1Client(api_key, retry_count) def _data_for_A(self, _type, record): # record meta (which would include geo information is only @@ -192,7 +235,7 @@ class Ns1Provider(BaseProvider): try: ns1_zone_name = zone.name[:-1] - ns1_zone = self._zones.retrieve(ns1_zone_name) + ns1_zone = self._client.zones_retrieve(ns1_zone_name) records = [] geo_records = [] @@ -207,9 +250,9 @@ class Ns1Provider(BaseProvider): if record.get('tier', 1) > 1: # Need to get the full record data for geo records - record = self._records.retrieve(ns1_zone_name, - record['domain'], - record['type']) + record = self._client.records_retrieve(ns1_zone_name, + record['domain'], + record['type']) geo_records.append(record) else: records.append(record) @@ -318,14 +361,7 @@ class Ns1Provider(BaseProvider): domain = new.fqdn[:-1] _type = new._type params = getattr(self, '_params_for_{}'.format(_type))(new) - try: - self._records.create(zone, domain, _type, **params) - except RateLimitException as e: - period = float(e.period) - self.log.warn('_apply_Create: rate limit encountered, pausing ' - 'for %ds and trying again', period) - sleep(period) - self._records.create(zone, domain, _type, **params) + self._client.records_create(zone, domain, _type, **params) def _apply_Update(self, ns1_zone, change): new = change.new @@ -333,28 +369,14 @@ class Ns1Provider(BaseProvider): domain = new.fqdn[:-1] _type = new._type params = getattr(self, '_params_for_{}'.format(_type))(new) - try: - self._records.update(zone, domain, _type, **params) - except RateLimitException as e: - period = float(e.period) - self.log.warn('_apply_Update: rate limit encountered, pausing ' - 'for %ds and trying again', period) - sleep(period) - self._records.update(zone, domain, _type, **params) + self._client.records_update(zone, domain, _type, **params) def _apply_Delete(self, ns1_zone, change): existing = change.existing zone = existing.zone.name[:-1] domain = existing.fqdn[:-1] _type = existing._type - try: - self._records.delete(zone, domain, _type) - except RateLimitException as e: - period = float(e.period) - self.log.warn('_apply_Delete: rate limit encountered, pausing ' - 'for %ds and trying again', period) - sleep(period) - self._records.delete(zone, domain, _type) + self._client.records_delete(zone, domain, _type) def _apply(self, plan): desired = plan.desired @@ -364,12 +386,12 @@ class Ns1Provider(BaseProvider): domain_name = desired.name[:-1] try: - ns1_zone = self._zones.retrieve(domain_name) + ns1_zone = self._client.zones_retrieve(domain_name) except ResourceException as e: if e.message != self.ZONE_NOT_FOUND_MESSAGE: raise self.log.debug('_apply: no matching zone, creating') - ns1_zone = self._zones.create(domain_name) + ns1_zone = self._client.zones_create(domain_name) for change in changes: class_name = change.__class__.__name__ diff --git a/tests/test_octodns_provider_ns1.py b/tests/test_octodns_provider_ns1.py index 0f23222..0743943 100644 --- a/tests/test_octodns_provider_ns1.py +++ b/tests/test_octodns_provider_ns1.py @@ -9,10 +9,11 @@ from collections import defaultdict from mock import call, patch from ns1.rest.errors import AuthException, RateLimitException, \ ResourceException +from six import text_type from unittest import TestCase from octodns.record import Delete, Record, Update -from octodns.provider.ns1 import Ns1Provider +from octodns.provider.ns1 import Ns1Client, Ns1Provider from octodns.zone import Zone @@ -497,3 +498,46 @@ class TestNs1Provider(TestCase): } self.assertEqual(b_expected, provider._data_for_CNAME(b_record['type'], b_record)) + + +class TestNs1Client(TestCase): + + @patch('ns1.rest.zones.Zones.retrieve') + def test_retry_behavior(self, zone_retrieve_mock): + client = Ns1Client('dummy-key') + + # No retry required, just calls and is returned + zone_retrieve_mock.reset_mock() + zone_retrieve_mock.side_effect = ['foo'] + self.assertEquals('foo', client.zones_retrieve('unit.tests')) + zone_retrieve_mock.assert_has_calls([call('unit.tests')]) + + # One retry required + zone_retrieve_mock.reset_mock() + zone_retrieve_mock.side_effect = [ + RateLimitException('boo', period=0), + 'foo' + ] + self.assertEquals('foo', client.zones_retrieve('unit.tests')) + zone_retrieve_mock.assert_has_calls([call('unit.tests')]) + + # Two retries required + zone_retrieve_mock.reset_mock() + zone_retrieve_mock.side_effect = [ + RateLimitException('boo', period=0), + 'foo' + ] + self.assertEquals('foo', client.zones_retrieve('unit.tests')) + zone_retrieve_mock.assert_has_calls([call('unit.tests')]) + + # Exhaust our retries + zone_retrieve_mock.reset_mock() + zone_retrieve_mock.side_effect = [ + RateLimitException('first', period=0), + RateLimitException('boo', period=0), + RateLimitException('boo', period=0), + RateLimitException('last', period=0), + ] + with self.assertRaises(RateLimitException) as ctx: + client.zones_retrieve('unit.tests') + self.assertEquals('last', text_type(ctx.exception))