Browse Source

Merge pull request #12 from github/multi-threaded-planning

Multi threaded planning
pull/13/head
Ross McFarland 9 years ago
committed by GitHub
parent
commit
757bfad87e
6 changed files with 122 additions and 39 deletions
  1. +1
    -3
      octodns/cmds/args.py
  2. +52
    -12
      octodns/manager.py
  3. +17
    -13
      octodns/provider/dyn.py
  4. +9
    -10
      octodns/yaml.py
  5. +2
    -0
      tests/config/simple.yaml
  6. +41
    -1
      tests/test_octodns_manager.py

+ 1
- 3
octodns/cmds/args.py View File

@ -43,9 +43,7 @@ class ArgumentParser(_Base):
return args
def _setup_logging(self, args, default_log_level):
# TODO: if/when things are multi-threaded add [%(thread)d] in to the
# format
fmt = '%(asctime)s %(levelname)-5s %(name)s %(message)s'
fmt = '%(asctime)s [%(thread)d] %(levelname)-5s %(name)s %(message)s'
formatter = Formatter(fmt=fmt, datefmt='%Y-%m-%dT%H:%M:%S ')
stream = stdout if args.log_stream_stdout else stderr
handler = StreamHandler(stream=stream)


+ 52
- 12
octodns/manager.py View File

@ -6,6 +6,7 @@ from __future__ import absolute_import, division, print_function, \
unicode_literals
from StringIO import StringIO
from concurrent.futures import Future, ThreadPoolExecutor
from importlib import import_module
from os import environ
import logging
@ -36,16 +37,42 @@ class _AggregateTarget(object):
return True
class MainThreadExecutor(object):
'''
Dummy executor that runs things on the main thread during the involcation
of submit, but still returns a future object with the result. This allows
code to be written to handle async, even in the case where we don't want to
use multiple threads/workers and would prefer that things flow as if
traditionally written.
'''
def submit(self, func, *args, **kwargs):
future = Future()
try:
future.set_result(func(*args, **kwargs))
except Exception as e:
future.set_exception(e)
return future
class Manager(object):
log = logging.getLogger('Manager')
def __init__(self, config_file):
def __init__(self, config_file, max_workers=None):
self.log.info('__init__: config_file=%s', config_file)
# Read our config file
with open(config_file, 'r') as fh:
self.config = safe_load(fh, enforce_order=False)
manager_config = self.config.get('manager', {})
max_workers = manager_config.get('max_workers', 1) \
if max_workers is None else max_workers
if max_workers > 1:
self._executor = ThreadPoolExecutor(max_workers=max_workers)
else:
self._executor = MainThreadExecutor()
self.log.debug('__init__: configuring providers')
self.providers = {}
for provider_name, provider_config in self.config['providers'].items():
@ -135,6 +162,24 @@ class Manager(object):
self.log.debug('configured_sub_zones: subs=%s', sub_zone_names)
return set(sub_zone_names)
def _populate_and_plan(self, zone_name, sources, targets):
self.log.debug('sync: populating, zone=%s', zone_name)
zone = Zone(zone_name,
sub_zones=self.configured_sub_zones(zone_name))
for source in sources:
source.populate(zone)
self.log.debug('sync: planning, zone=%s', zone_name)
plans = []
for target in targets:
plan = target.plan(zone)
if plan:
plans.append((target, plan))
return plans
def sync(self, eligible_zones=[], eligible_targets=[], dry_run=True,
force=False):
self.log.info('sync: eligible_zones=%s, eligible_targets=%s, '
@ -145,7 +190,7 @@ class Manager(object):
if eligible_zones:
zones = filter(lambda d: d[0] in eligible_zones, zones)
plans = []
futures = []
for zone_name, config in zones:
self.log.info('sync: zone=%s', zone_name)
try:
@ -181,17 +226,12 @@ class Manager(object):
raise Exception('Zone {}, unknown target: {}'.format(zone_name,
target))
self.log.debug('sync: populating')
zone = Zone(zone_name,
sub_zones=self.configured_sub_zones(zone_name))
for source in sources:
source.populate(zone)
futures.append(self._executor.submit(self._populate_and_plan,
zone_name, sources, targets))
self.log.debug('sync: planning')
for target in targets:
plan = target.plan(zone)
if plan:
plans.append((target, plan))
# Wait on all results and unpack/flatten them in to a list of target &
# plan pairs.
plans = [p for f in futures for p in f.result()]
hr = '*************************************************************' \
'*******************\n'


+ 17
- 13
octodns/provider/dyn.py View File

@ -14,6 +14,7 @@ from dyn.tm.services.dsf import DSFARecord, DSFAAAARecord, DSFFailoverChain, \
from dyn.tm.session import DynectSession
from dyn.tm.zones import Zone as DynZone
from logging import getLogger
from threading import local
from uuid import uuid4
from ..record import Record
@ -134,8 +135,7 @@ class DynProvider(BaseProvider):
'AN': 17, # Continental Antartica
}
# Going to be lazy loaded b/c it makes a (slow) request, global
_dyn_sess = None
_thread_local = local()
def __init__(self, id, customer, username, password,
traffic_directors_enabled=False, *args, **kwargs):
@ -159,17 +159,21 @@ class DynProvider(BaseProvider):
return self.traffic_directors_enabled
def _check_dyn_sess(self):
if self._dyn_sess:
self.log.debug('_check_dyn_sess: exists')
return
self.log.debug('_check_dyn_sess: creating')
# Dynect's client is ugly, you create a session object, but then don't
# use it for anything. It just makes the other objects work behind the
# scences. :-( That probably means we can only support a single set of
# dynect creds
self._dyn_sess = DynectSession(self.customer, self.username,
self.password)
try:
DynProvider._thread_local.dyn_sess
except AttributeError:
self.log.debug('_check_dyn_sess: creating')
# Dynect's client is odd, you create a session object, but don't
# use it for anything. It just makes the other objects work behind
# the scences. :-( That probably means we can only support a single
# set of dynect creds, so no split accounts. They're also per
# thread so we need to create one per thread. I originally tried
# calling DynectSession.get_session to see if there was one and
# creating if not, but that was always returning None, so now I'm
# manually creating them once per-thread. I'd imagine this could be
# figured out, but ...
DynectSession(self.customer, self.username, self.password)
DynProvider._thread_local.dyn_sess = True
def _data_for_A(self, _type, records):
return {


+ 9
- 10
octodns/yaml.py View File

@ -30,11 +30,7 @@ def _zero_padded_numbers(s):
# here
class SortEnforcingLoader(SafeLoader):
def __init__(self, *args, **kwargs):
super(SortEnforcingLoader, self).__init__(*args, **kwargs)
self.add_constructor(self.DEFAULT_MAPPING_TAG, self._construct)
def _construct(self, _, node):
def _construct(self, node):
self.flatten_mapping(node)
ret = self.construct_pairs(node)
keys = [d[0] for d in ret]
@ -44,6 +40,10 @@ class SortEnforcingLoader(SafeLoader):
return dict(ret)
SortEnforcingLoader.add_constructor(SortEnforcingLoader.DEFAULT_MAPPING_TAG,
SortEnforcingLoader._construct)
def safe_load(stream, enforce_order=True):
return load(stream, SortEnforcingLoader if enforce_order else SafeLoader)
@ -57,16 +57,15 @@ class SortingDumper(SafeDumper):
more info
'''
def __init__(self, *args, **kwargs):
super(SortingDumper, self).__init__(*args, **kwargs)
self.add_representer(dict, self._representer)
def _representer(self, _, data):
def _representer(self, data):
data = data.items()
data.sort(key=lambda d: _zero_padded_numbers(d[0]))
return self.represent_mapping(self.DEFAULT_MAPPING_TAG, data)
SortingDumper.add_representer(dict, SortingDumper._representer)
def safe_dump(data, fh, **options):
kwargs = {
'canonical': False,


+ 2
- 0
tests/config/simple.yaml View File

@ -1,3 +1,5 @@
manager:
max_workers: 2
providers:
in:
class: octodns.provider.yaml.YamlProvider


+ 41
- 1
tests/test_octodns_manager.py View File

@ -10,7 +10,7 @@ from os.path import dirname, join
from unittest import TestCase
from octodns.record import Record
from octodns.manager import _AggregateTarget, Manager
from octodns.manager import _AggregateTarget, MainThreadExecutor, Manager
from octodns.zone import Zone
from helpers import GeoProvider, NoSshFpProvider, SimpleProvider, \
@ -115,6 +115,11 @@ class TestManager(TestCase):
.sync(dry_run=False, force=True)
self.assertEquals(19, tc)
# Again with max_workers = 1
tc = Manager(get_config_filename('simple.yaml'), max_workers=1) \
.sync(dry_run=False, force=True)
self.assertEquals(19, tc)
def test_eligible_targets(self):
with TemporaryDirectory() as tmpdir:
environ['YAML_TMP_DIR'] = tmpdir.dirname
@ -128,6 +133,9 @@ class TestManager(TestCase):
environ['YAML_TMP_DIR'] = tmpdir.dirname
manager = Manager(get_config_filename('simple.yaml'))
# make sure this was pulled in from the config
self.assertEquals(2, manager._executor._max_workers)
changes = manager.compare(['in'], ['in'], 'unit.tests.')
self.assertEquals([], changes)
@ -201,3 +209,35 @@ class TestManager(TestCase):
Manager(get_config_filename('unknown-provider.yaml')) \
.validate_configs()
self.assertTrue('unknown source' in ctx.exception.message)
class TestMainThreadExecutor(TestCase):
def test_success(self):
mte = MainThreadExecutor()
future = mte.submit(self.success, 42)
self.assertEquals(42, future.result())
future = mte.submit(self.success, ret=43)
self.assertEquals(43, future.result())
def test_exception(self):
mte = MainThreadExecutor()
e = Exception('boom')
future = mte.submit(self.exception, e)
with self.assertRaises(Exception) as ctx:
future.result()
self.assertEquals(e, ctx.exception)
future = mte.submit(self.exception, e=e)
with self.assertRaises(Exception) as ctx:
future.result()
self.assertEquals(e, ctx.exception)
def success(self, ret):
return ret
def exception(self, e):
raise e

Loading…
Cancel
Save