diff --git a/octodns/manager.py b/octodns/manager.py index 2d96c96..c631104 100644 --- a/octodns/manager.py +++ b/octodns/manager.py @@ -6,7 +6,7 @@ from __future__ import absolute_import, division, print_function, \ unicode_literals from StringIO import StringIO -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor from importlib import import_module from os import environ import logging @@ -37,10 +37,21 @@ class _AggregateTarget(object): return True +class MainThreadExecutor(object): + + def submit(self, func, *args, **kwargs): + future = Future() + try: + future.set_result(func(*args, **kwargs)) + except Exception as e: + future.set_exception(e) + return future + + class Manager(object): log = logging.getLogger('Manager') - def __init__(self, config_file): + def __init__(self, config_file, max_workers=None): self.log.info('__init__: config_file=%s', config_file) # Read our config file @@ -48,8 +59,12 @@ class Manager(object): self.config = safe_load(fh, enforce_order=False) manager_config = self.config.get('manager', {}) - max_workers = manager_config.get('max_workers', 4) - self._executor = ThreadPoolExecutor(max_workers=max_workers) + max_workers = manager_config.get('max_workers', 1) \ + if max_workers is None else max_workers + if max_workers > 1: + self._executor = ThreadPoolExecutor(max_workers=max_workers) + else: + self._executor = MainThreadExecutor() self.log.debug('__init__: configuring providers') self.providers = {} diff --git a/tests/test_octodns_manager.py b/tests/test_octodns_manager.py index a6c4bce..f81d5bf 100644 --- a/tests/test_octodns_manager.py +++ b/tests/test_octodns_manager.py @@ -115,6 +115,11 @@ class TestManager(TestCase): .sync(dry_run=False, force=True) self.assertEquals(19, tc) + # Again with max_workers = 1 + tc = Manager(get_config_filename('simple.yaml'), max_workers=1) \ + .sync(dry_run=False, force=True) + self.assertEquals(19, tc) + def test_eligible_targets(self): with TemporaryDirectory() as tmpdir: environ['YAML_TMP_DIR'] = tmpdir.dirname