diff --git a/octodns/cmds/args.py b/octodns/cmds/args.py index daec5c9..c84dd92 100644 --- a/octodns/cmds/args.py +++ b/octodns/cmds/args.py @@ -43,9 +43,7 @@ class ArgumentParser(_Base): return args def _setup_logging(self, args, default_log_level): - # TODO: if/when things are multi-threaded add [%(thread)d] in to the - # format - fmt = '%(asctime)s %(levelname)-5s %(name)s %(message)s' + fmt = '%(asctime)s [%(thread)d] %(levelname)-5s %(name)s %(message)s' formatter = Formatter(fmt=fmt, datefmt='%Y-%m-%dT%H:%M:%S ') stream = stdout if args.log_stream_stdout else stderr handler = StreamHandler(stream=stream) diff --git a/octodns/manager.py b/octodns/manager.py index 86b7f24..8d8569a 100644 --- a/octodns/manager.py +++ b/octodns/manager.py @@ -6,6 +6,7 @@ from __future__ import absolute_import, division, print_function, \ unicode_literals from StringIO import StringIO +from concurrent.futures import Future, ThreadPoolExecutor from importlib import import_module from os import environ import logging @@ -36,16 +37,42 @@ class _AggregateTarget(object): return True +class MainThreadExecutor(object): + ''' + Dummy executor that runs things on the main thread during the involcation + of submit, but still returns a future object with the result. This allows + code to be written to handle async, even in the case where we don't want to + use multiple threads/workers and would prefer that things flow as if + traditionally written. + ''' + + def submit(self, func, *args, **kwargs): + future = Future() + try: + future.set_result(func(*args, **kwargs)) + except Exception as e: + future.set_exception(e) + return future + + class Manager(object): log = logging.getLogger('Manager') - def __init__(self, config_file): + def __init__(self, config_file, max_workers=None): self.log.info('__init__: config_file=%s', config_file) # Read our config file with open(config_file, 'r') as fh: self.config = safe_load(fh, enforce_order=False) + manager_config = self.config.get('manager', {}) + max_workers = manager_config.get('max_workers', 1) \ + if max_workers is None else max_workers + if max_workers > 1: + self._executor = ThreadPoolExecutor(max_workers=max_workers) + else: + self._executor = MainThreadExecutor() + self.log.debug('__init__: configuring providers') self.providers = {} for provider_name, provider_config in self.config['providers'].items(): @@ -135,6 +162,24 @@ class Manager(object): self.log.debug('configured_sub_zones: subs=%s', sub_zone_names) return set(sub_zone_names) + def _populate_and_plan(self, zone_name, sources, targets): + + self.log.debug('sync: populating, zone=%s', zone_name) + zone = Zone(zone_name, + sub_zones=self.configured_sub_zones(zone_name)) + for source in sources: + source.populate(zone) + + self.log.debug('sync: planning, zone=%s', zone_name) + plans = [] + + for target in targets: + plan = target.plan(zone) + if plan: + plans.append((target, plan)) + + return plans + def sync(self, eligible_zones=[], eligible_targets=[], dry_run=True, force=False): self.log.info('sync: eligible_zones=%s, eligible_targets=%s, ' @@ -145,7 +190,7 @@ class Manager(object): if eligible_zones: zones = filter(lambda d: d[0] in eligible_zones, zones) - plans = [] + futures = [] for zone_name, config in zones: self.log.info('sync: zone=%s', zone_name) try: @@ -181,17 +226,12 @@ class Manager(object): raise Exception('Zone {}, unknown target: {}'.format(zone_name, target)) - self.log.debug('sync: populating') - zone = Zone(zone_name, - sub_zones=self.configured_sub_zones(zone_name)) - for source in sources: - source.populate(zone) + futures.append(self._executor.submit(self._populate_and_plan, + zone_name, sources, targets)) - self.log.debug('sync: planning') - for target in targets: - plan = target.plan(zone) - if plan: - plans.append((target, plan)) + # Wait on all results and unpack/flatten them in to a list of target & + # plan pairs. + plans = [p for f in futures for p in f.result()] hr = '*************************************************************' \ '*******************\n' diff --git a/octodns/provider/dyn.py b/octodns/provider/dyn.py index 321bad5..90b2a51 100644 --- a/octodns/provider/dyn.py +++ b/octodns/provider/dyn.py @@ -14,6 +14,7 @@ from dyn.tm.services.dsf import DSFARecord, DSFAAAARecord, DSFFailoverChain, \ from dyn.tm.session import DynectSession from dyn.tm.zones import Zone as DynZone from logging import getLogger +from threading import local from uuid import uuid4 from ..record import Record @@ -134,8 +135,7 @@ class DynProvider(BaseProvider): 'AN': 17, # Continental Antartica } - # Going to be lazy loaded b/c it makes a (slow) request, global - _dyn_sess = None + _thread_local = local() def __init__(self, id, customer, username, password, traffic_directors_enabled=False, *args, **kwargs): @@ -159,17 +159,21 @@ class DynProvider(BaseProvider): return self.traffic_directors_enabled def _check_dyn_sess(self): - if self._dyn_sess: - self.log.debug('_check_dyn_sess: exists') - return - - self.log.debug('_check_dyn_sess: creating') - # Dynect's client is ugly, you create a session object, but then don't - # use it for anything. It just makes the other objects work behind the - # scences. :-( That probably means we can only support a single set of - # dynect creds - self._dyn_sess = DynectSession(self.customer, self.username, - self.password) + try: + DynProvider._thread_local.dyn_sess + except AttributeError: + self.log.debug('_check_dyn_sess: creating') + # Dynect's client is odd, you create a session object, but don't + # use it for anything. It just makes the other objects work behind + # the scences. :-( That probably means we can only support a single + # set of dynect creds, so no split accounts. They're also per + # thread so we need to create one per thread. I originally tried + # calling DynectSession.get_session to see if there was one and + # creating if not, but that was always returning None, so now I'm + # manually creating them once per-thread. I'd imagine this could be + # figured out, but ... + DynectSession(self.customer, self.username, self.password) + DynProvider._thread_local.dyn_sess = True def _data_for_A(self, _type, records): return { diff --git a/octodns/yaml.py b/octodns/yaml.py index b6c6379..2cab58c 100644 --- a/octodns/yaml.py +++ b/octodns/yaml.py @@ -30,11 +30,7 @@ def _zero_padded_numbers(s): # here class SortEnforcingLoader(SafeLoader): - def __init__(self, *args, **kwargs): - super(SortEnforcingLoader, self).__init__(*args, **kwargs) - self.add_constructor(self.DEFAULT_MAPPING_TAG, self._construct) - - def _construct(self, _, node): + def _construct(self, node): self.flatten_mapping(node) ret = self.construct_pairs(node) keys = [d[0] for d in ret] @@ -44,6 +40,10 @@ class SortEnforcingLoader(SafeLoader): return dict(ret) +SortEnforcingLoader.add_constructor(SortEnforcingLoader.DEFAULT_MAPPING_TAG, + SortEnforcingLoader._construct) + + def safe_load(stream, enforce_order=True): return load(stream, SortEnforcingLoader if enforce_order else SafeLoader) @@ -57,16 +57,15 @@ class SortingDumper(SafeDumper): more info ''' - def __init__(self, *args, **kwargs): - super(SortingDumper, self).__init__(*args, **kwargs) - self.add_representer(dict, self._representer) - - def _representer(self, _, data): + def _representer(self, data): data = data.items() data.sort(key=lambda d: _zero_padded_numbers(d[0])) return self.represent_mapping(self.DEFAULT_MAPPING_TAG, data) +SortingDumper.add_representer(dict, SortingDumper._representer) + + def safe_dump(data, fh, **options): kwargs = { 'canonical': False, diff --git a/tests/config/simple.yaml b/tests/config/simple.yaml index 604b772..cf970a9 100644 --- a/tests/config/simple.yaml +++ b/tests/config/simple.yaml @@ -1,3 +1,5 @@ +manager: + max_workers: 2 providers: in: class: octodns.provider.yaml.YamlProvider diff --git a/tests/test_octodns_manager.py b/tests/test_octodns_manager.py index 7ee4fbd..811503a 100644 --- a/tests/test_octodns_manager.py +++ b/tests/test_octodns_manager.py @@ -10,7 +10,7 @@ from os.path import dirname, join from unittest import TestCase from octodns.record import Record -from octodns.manager import _AggregateTarget, Manager +from octodns.manager import _AggregateTarget, MainThreadExecutor, Manager from octodns.zone import Zone from helpers import GeoProvider, NoSshFpProvider, SimpleProvider, \ @@ -115,6 +115,11 @@ class TestManager(TestCase): .sync(dry_run=False, force=True) self.assertEquals(19, tc) + # Again with max_workers = 1 + tc = Manager(get_config_filename('simple.yaml'), max_workers=1) \ + .sync(dry_run=False, force=True) + self.assertEquals(19, tc) + def test_eligible_targets(self): with TemporaryDirectory() as tmpdir: environ['YAML_TMP_DIR'] = tmpdir.dirname @@ -128,6 +133,9 @@ class TestManager(TestCase): environ['YAML_TMP_DIR'] = tmpdir.dirname manager = Manager(get_config_filename('simple.yaml')) + # make sure this was pulled in from the config + self.assertEquals(2, manager._executor._max_workers) + changes = manager.compare(['in'], ['in'], 'unit.tests.') self.assertEquals([], changes) @@ -201,3 +209,35 @@ class TestManager(TestCase): Manager(get_config_filename('unknown-provider.yaml')) \ .validate_configs() self.assertTrue('unknown source' in ctx.exception.message) + + +class TestMainThreadExecutor(TestCase): + + def test_success(self): + mte = MainThreadExecutor() + + future = mte.submit(self.success, 42) + self.assertEquals(42, future.result()) + + future = mte.submit(self.success, ret=43) + self.assertEquals(43, future.result()) + + def test_exception(self): + mte = MainThreadExecutor() + + e = Exception('boom') + future = mte.submit(self.exception, e) + with self.assertRaises(Exception) as ctx: + future.result() + self.assertEquals(e, ctx.exception) + + future = mte.submit(self.exception, e=e) + with self.assertRaises(Exception) as ctx: + future.result() + self.assertEquals(e, ctx.exception) + + def success(self, ret): + return ret + + def exception(self, e): + raise e