Browse Source

Merge pull request #12 from github/multi-threaded-planning

Multi threaded planning
pull/13/head
Ross McFarland 9 years ago
committed by GitHub
parent
commit
757bfad87e
6 changed files with 122 additions and 39 deletions
  1. +1
    -3
      octodns/cmds/args.py
  2. +52
    -12
      octodns/manager.py
  3. +17
    -13
      octodns/provider/dyn.py
  4. +9
    -10
      octodns/yaml.py
  5. +2
    -0
      tests/config/simple.yaml
  6. +41
    -1
      tests/test_octodns_manager.py

+ 1
- 3
octodns/cmds/args.py View File

@ -43,9 +43,7 @@ class ArgumentParser(_Base):
return args return args
def _setup_logging(self, args, default_log_level): def _setup_logging(self, args, default_log_level):
# TODO: if/when things are multi-threaded add [%(thread)d] in to the
# format
fmt = '%(asctime)s %(levelname)-5s %(name)s %(message)s'
fmt = '%(asctime)s [%(thread)d] %(levelname)-5s %(name)s %(message)s'
formatter = Formatter(fmt=fmt, datefmt='%Y-%m-%dT%H:%M:%S ') formatter = Formatter(fmt=fmt, datefmt='%Y-%m-%dT%H:%M:%S ')
stream = stdout if args.log_stream_stdout else stderr stream = stdout if args.log_stream_stdout else stderr
handler = StreamHandler(stream=stream) handler = StreamHandler(stream=stream)


+ 52
- 12
octodns/manager.py View File

@ -6,6 +6,7 @@ from __future__ import absolute_import, division, print_function, \
unicode_literals unicode_literals
from StringIO import StringIO from StringIO import StringIO
from concurrent.futures import Future, ThreadPoolExecutor
from importlib import import_module from importlib import import_module
from os import environ from os import environ
import logging import logging
@ -36,16 +37,42 @@ class _AggregateTarget(object):
return True return True
class MainThreadExecutor(object):
'''
Dummy executor that runs things on the main thread during the involcation
of submit, but still returns a future object with the result. This allows
code to be written to handle async, even in the case where we don't want to
use multiple threads/workers and would prefer that things flow as if
traditionally written.
'''
def submit(self, func, *args, **kwargs):
future = Future()
try:
future.set_result(func(*args, **kwargs))
except Exception as e:
future.set_exception(e)
return future
class Manager(object): class Manager(object):
log = logging.getLogger('Manager') log = logging.getLogger('Manager')
def __init__(self, config_file):
def __init__(self, config_file, max_workers=None):
self.log.info('__init__: config_file=%s', config_file) self.log.info('__init__: config_file=%s', config_file)
# Read our config file # Read our config file
with open(config_file, 'r') as fh: with open(config_file, 'r') as fh:
self.config = safe_load(fh, enforce_order=False) self.config = safe_load(fh, enforce_order=False)
manager_config = self.config.get('manager', {})
max_workers = manager_config.get('max_workers', 1) \
if max_workers is None else max_workers
if max_workers > 1:
self._executor = ThreadPoolExecutor(max_workers=max_workers)
else:
self._executor = MainThreadExecutor()
self.log.debug('__init__: configuring providers') self.log.debug('__init__: configuring providers')
self.providers = {} self.providers = {}
for provider_name, provider_config in self.config['providers'].items(): for provider_name, provider_config in self.config['providers'].items():
@ -135,6 +162,24 @@ class Manager(object):
self.log.debug('configured_sub_zones: subs=%s', sub_zone_names) self.log.debug('configured_sub_zones: subs=%s', sub_zone_names)
return set(sub_zone_names) return set(sub_zone_names)
def _populate_and_plan(self, zone_name, sources, targets):
self.log.debug('sync: populating, zone=%s', zone_name)
zone = Zone(zone_name,
sub_zones=self.configured_sub_zones(zone_name))
for source in sources:
source.populate(zone)
self.log.debug('sync: planning, zone=%s', zone_name)
plans = []
for target in targets:
plan = target.plan(zone)
if plan:
plans.append((target, plan))
return plans
def sync(self, eligible_zones=[], eligible_targets=[], dry_run=True, def sync(self, eligible_zones=[], eligible_targets=[], dry_run=True,
force=False): force=False):
self.log.info('sync: eligible_zones=%s, eligible_targets=%s, ' self.log.info('sync: eligible_zones=%s, eligible_targets=%s, '
@ -145,7 +190,7 @@ class Manager(object):
if eligible_zones: if eligible_zones:
zones = filter(lambda d: d[0] in eligible_zones, zones) zones = filter(lambda d: d[0] in eligible_zones, zones)
plans = []
futures = []
for zone_name, config in zones: for zone_name, config in zones:
self.log.info('sync: zone=%s', zone_name) self.log.info('sync: zone=%s', zone_name)
try: try:
@ -181,17 +226,12 @@ class Manager(object):
raise Exception('Zone {}, unknown target: {}'.format(zone_name, raise Exception('Zone {}, unknown target: {}'.format(zone_name,
target)) target))
self.log.debug('sync: populating')
zone = Zone(zone_name,
sub_zones=self.configured_sub_zones(zone_name))
for source in sources:
source.populate(zone)
futures.append(self._executor.submit(self._populate_and_plan,
zone_name, sources, targets))
self.log.debug('sync: planning')
for target in targets:
plan = target.plan(zone)
if plan:
plans.append((target, plan))
# Wait on all results and unpack/flatten them in to a list of target &
# plan pairs.
plans = [p for f in futures for p in f.result()]
hr = '*************************************************************' \ hr = '*************************************************************' \
'*******************\n' '*******************\n'


+ 17
- 13
octodns/provider/dyn.py View File

@ -14,6 +14,7 @@ from dyn.tm.services.dsf import DSFARecord, DSFAAAARecord, DSFFailoverChain, \
from dyn.tm.session import DynectSession from dyn.tm.session import DynectSession
from dyn.tm.zones import Zone as DynZone from dyn.tm.zones import Zone as DynZone
from logging import getLogger from logging import getLogger
from threading import local
from uuid import uuid4 from uuid import uuid4
from ..record import Record from ..record import Record
@ -134,8 +135,7 @@ class DynProvider(BaseProvider):
'AN': 17, # Continental Antartica 'AN': 17, # Continental Antartica
} }
# Going to be lazy loaded b/c it makes a (slow) request, global
_dyn_sess = None
_thread_local = local()
def __init__(self, id, customer, username, password, def __init__(self, id, customer, username, password,
traffic_directors_enabled=False, *args, **kwargs): traffic_directors_enabled=False, *args, **kwargs):
@ -159,17 +159,21 @@ class DynProvider(BaseProvider):
return self.traffic_directors_enabled return self.traffic_directors_enabled
def _check_dyn_sess(self): def _check_dyn_sess(self):
if self._dyn_sess:
self.log.debug('_check_dyn_sess: exists')
return
self.log.debug('_check_dyn_sess: creating')
# Dynect's client is ugly, you create a session object, but then don't
# use it for anything. It just makes the other objects work behind the
# scences. :-( That probably means we can only support a single set of
# dynect creds
self._dyn_sess = DynectSession(self.customer, self.username,
self.password)
try:
DynProvider._thread_local.dyn_sess
except AttributeError:
self.log.debug('_check_dyn_sess: creating')
# Dynect's client is odd, you create a session object, but don't
# use it for anything. It just makes the other objects work behind
# the scences. :-( That probably means we can only support a single
# set of dynect creds, so no split accounts. They're also per
# thread so we need to create one per thread. I originally tried
# calling DynectSession.get_session to see if there was one and
# creating if not, but that was always returning None, so now I'm
# manually creating them once per-thread. I'd imagine this could be
# figured out, but ...
DynectSession(self.customer, self.username, self.password)
DynProvider._thread_local.dyn_sess = True
def _data_for_A(self, _type, records): def _data_for_A(self, _type, records):
return { return {


+ 9
- 10
octodns/yaml.py View File

@ -30,11 +30,7 @@ def _zero_padded_numbers(s):
# here # here
class SortEnforcingLoader(SafeLoader): class SortEnforcingLoader(SafeLoader):
def __init__(self, *args, **kwargs):
super(SortEnforcingLoader, self).__init__(*args, **kwargs)
self.add_constructor(self.DEFAULT_MAPPING_TAG, self._construct)
def _construct(self, _, node):
def _construct(self, node):
self.flatten_mapping(node) self.flatten_mapping(node)
ret = self.construct_pairs(node) ret = self.construct_pairs(node)
keys = [d[0] for d in ret] keys = [d[0] for d in ret]
@ -44,6 +40,10 @@ class SortEnforcingLoader(SafeLoader):
return dict(ret) return dict(ret)
SortEnforcingLoader.add_constructor(SortEnforcingLoader.DEFAULT_MAPPING_TAG,
SortEnforcingLoader._construct)
def safe_load(stream, enforce_order=True): def safe_load(stream, enforce_order=True):
return load(stream, SortEnforcingLoader if enforce_order else SafeLoader) return load(stream, SortEnforcingLoader if enforce_order else SafeLoader)
@ -57,16 +57,15 @@ class SortingDumper(SafeDumper):
more info more info
''' '''
def __init__(self, *args, **kwargs):
super(SortingDumper, self).__init__(*args, **kwargs)
self.add_representer(dict, self._representer)
def _representer(self, _, data):
def _representer(self, data):
data = data.items() data = data.items()
data.sort(key=lambda d: _zero_padded_numbers(d[0])) data.sort(key=lambda d: _zero_padded_numbers(d[0]))
return self.represent_mapping(self.DEFAULT_MAPPING_TAG, data) return self.represent_mapping(self.DEFAULT_MAPPING_TAG, data)
SortingDumper.add_representer(dict, SortingDumper._representer)
def safe_dump(data, fh, **options): def safe_dump(data, fh, **options):
kwargs = { kwargs = {
'canonical': False, 'canonical': False,


+ 2
- 0
tests/config/simple.yaml View File

@ -1,3 +1,5 @@
manager:
max_workers: 2
providers: providers:
in: in:
class: octodns.provider.yaml.YamlProvider class: octodns.provider.yaml.YamlProvider


+ 41
- 1
tests/test_octodns_manager.py View File

@ -10,7 +10,7 @@ from os.path import dirname, join
from unittest import TestCase from unittest import TestCase
from octodns.record import Record from octodns.record import Record
from octodns.manager import _AggregateTarget, Manager
from octodns.manager import _AggregateTarget, MainThreadExecutor, Manager
from octodns.zone import Zone from octodns.zone import Zone
from helpers import GeoProvider, NoSshFpProvider, SimpleProvider, \ from helpers import GeoProvider, NoSshFpProvider, SimpleProvider, \
@ -115,6 +115,11 @@ class TestManager(TestCase):
.sync(dry_run=False, force=True) .sync(dry_run=False, force=True)
self.assertEquals(19, tc) self.assertEquals(19, tc)
# Again with max_workers = 1
tc = Manager(get_config_filename('simple.yaml'), max_workers=1) \
.sync(dry_run=False, force=True)
self.assertEquals(19, tc)
def test_eligible_targets(self): def test_eligible_targets(self):
with TemporaryDirectory() as tmpdir: with TemporaryDirectory() as tmpdir:
environ['YAML_TMP_DIR'] = tmpdir.dirname environ['YAML_TMP_DIR'] = tmpdir.dirname
@ -128,6 +133,9 @@ class TestManager(TestCase):
environ['YAML_TMP_DIR'] = tmpdir.dirname environ['YAML_TMP_DIR'] = tmpdir.dirname
manager = Manager(get_config_filename('simple.yaml')) manager = Manager(get_config_filename('simple.yaml'))
# make sure this was pulled in from the config
self.assertEquals(2, manager._executor._max_workers)
changes = manager.compare(['in'], ['in'], 'unit.tests.') changes = manager.compare(['in'], ['in'], 'unit.tests.')
self.assertEquals([], changes) self.assertEquals([], changes)
@ -201,3 +209,35 @@ class TestManager(TestCase):
Manager(get_config_filename('unknown-provider.yaml')) \ Manager(get_config_filename('unknown-provider.yaml')) \
.validate_configs() .validate_configs()
self.assertTrue('unknown source' in ctx.exception.message) self.assertTrue('unknown source' in ctx.exception.message)
class TestMainThreadExecutor(TestCase):
def test_success(self):
mte = MainThreadExecutor()
future = mte.submit(self.success, 42)
self.assertEquals(42, future.result())
future = mte.submit(self.success, ret=43)
self.assertEquals(43, future.result())
def test_exception(self):
mte = MainThreadExecutor()
e = Exception('boom')
future = mte.submit(self.exception, e)
with self.assertRaises(Exception) as ctx:
future.result()
self.assertEquals(e, ctx.exception)
future = mte.submit(self.exception, e=e)
with self.assertRaises(Exception) as ctx:
future.result()
self.assertEquals(e, ctx.exception)
def success(self, ret):
return ret
def exception(self, e):
raise e

Loading…
Cancel
Save