|
|
|
@ -5,7 +5,7 @@ |
|
|
|
from os.path import dirname, join |
|
|
|
|
|
|
|
from natsort import natsort_keygen |
|
|
|
from yaml import SafeDumper, SafeLoader, dump, load |
|
|
|
from yaml import SafeDumper, SafeLoader, compose, dump, load |
|
|
|
from yaml.constructor import ConstructorError |
|
|
|
from yaml.representer import SafeRepresenter |
|
|
|
|
|
|
|
@ -18,91 +18,123 @@ _natsort_key = staticmethod(natsort_keygen()) |
|
|
|
|
|
|
|
class ContextLoader(SafeLoader): |
|
|
|
|
|
|
|
def _context(self, node): |
|
|
|
start_mark = node.start_mark |
|
|
|
return f'{start_mark.name}, line {start_mark.line+1}, column {start_mark.column+1}' |
|
|
|
def construct_include(self, node): |
|
|
|
mark = self.get_mark() |
|
|
|
directory = dirname(mark.name) |
|
|
|
|
|
|
|
def _pairs(self, node): |
|
|
|
self.flatten_mapping(node) |
|
|
|
pairs = self.construct_pairs(node) |
|
|
|
context = self._context(node) |
|
|
|
return ContextDict(pairs, context=context), pairs, context |
|
|
|
filename = join(directory, self.construct_scalar(node)) |
|
|
|
|
|
|
|
def _construct(self, node): |
|
|
|
return self._pairs(node)[0] |
|
|
|
with open(filename, 'r') as fh: |
|
|
|
return load(fh, self.__class__) |
|
|
|
|
|
|
|
def include(self, node): |
|
|
|
def flatten_include(self, node): |
|
|
|
mark = self.get_mark() |
|
|
|
directory = dirname(mark.name) |
|
|
|
|
|
|
|
def load_file(filename): |
|
|
|
filename = join(directory, filename) |
|
|
|
with open(filename, 'r') as fh: |
|
|
|
return load(fh, self.__class__) |
|
|
|
|
|
|
|
if not isinstance(node.value, list): |
|
|
|
# single filename, just load and return whatever is in it |
|
|
|
scalar = node.value |
|
|
|
return load_file(scalar) |
|
|
|
|
|
|
|
scalars = node.value |
|
|
|
data = [load_file(s.value) for s in scalars] |
|
|
|
|
|
|
|
if not data: |
|
|
|
return None |
|
|
|
elif isinstance(data[0], list): |
|
|
|
# we're working with lists |
|
|
|
ret = data[0] |
|
|
|
for i, d in enumerate(data[1:]): |
|
|
|
if not isinstance(d, list): |
|
|
|
context = self._context(node) |
|
|
|
raise ConstructorError( |
|
|
|
None, |
|
|
|
None, |
|
|
|
f'!include first element contained a list, element {i+1} contained a {d.__class__.__name__} at {context}', |
|
|
|
) |
|
|
|
ret.extend(d) |
|
|
|
return ret |
|
|
|
elif isinstance(data[0], dict): |
|
|
|
# assume we're working with dict |
|
|
|
ret = data[0] |
|
|
|
for i, d in enumerate(data[1:]): |
|
|
|
if not isinstance(d, dict): |
|
|
|
context = self._context(node) |
|
|
|
filename = join(directory, self.construct_scalar(node)) |
|
|
|
|
|
|
|
with open(filename, 'r') as fh: |
|
|
|
yield compose(fh, self.__class__).value |
|
|
|
|
|
|
|
def construct_mapping(self, node, deep=False): |
|
|
|
''' |
|
|
|
Calls our parent and wraps the resulting dict with a ContextDict |
|
|
|
''' |
|
|
|
start_mark = node.start_mark |
|
|
|
context = f'{start_mark.name}, line {start_mark.line+1}, column {start_mark.column+1}' |
|
|
|
return ContextDict( |
|
|
|
super().construct_mapping(node, deep), context=context |
|
|
|
) |
|
|
|
|
|
|
|
# the following 4 methods are ported out of |
|
|
|
# https://github.com/yaml/pyyaml/pull/894 an intended to be used until we |
|
|
|
# can (hopefully) require a version of pyyaml with that PR merged. |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def add_flattener(cls, tag, flattener): |
|
|
|
if not 'yaml_flatteners' in cls.__dict__: |
|
|
|
cls.yaml_flatteners = {} |
|
|
|
cls.yaml_flatteners[tag] = flattener |
|
|
|
|
|
|
|
# this overwrites/ignores the built-in version of the method |
|
|
|
def flatten_mapping(self, node): # pragma: no cover |
|
|
|
merge = [] |
|
|
|
for key_node, value_node in node.value: |
|
|
|
if key_node.tag == 'tag:yaml.org,2002:merge': |
|
|
|
flattener = self.yaml_flatteners.get(value_node.tag) |
|
|
|
if flattener: |
|
|
|
for value in flattener(self, value_node): |
|
|
|
merge.extend(value) |
|
|
|
else: |
|
|
|
raise ConstructorError( |
|
|
|
None, |
|
|
|
None, |
|
|
|
f'!include first element contained a dict, element {i+1} contained a {d.__class__.__name__} at {context}', |
|
|
|
"while constructing a mapping", |
|
|
|
node.start_mark, |
|
|
|
"expected a mapping or list of mappings for merging, but found %s" |
|
|
|
% value_node.id, |
|
|
|
value_node.start_mark, |
|
|
|
) |
|
|
|
ret.update(d) |
|
|
|
return ret |
|
|
|
|
|
|
|
context = self._context(node) |
|
|
|
raise ConstructorError( |
|
|
|
None, |
|
|
|
None, |
|
|
|
f'!include first element contained an unsupported type, {data[0].__class__.__name__} at {context}', |
|
|
|
) |
|
|
|
elif key_node.tag == 'tag:yaml.org,2002:value': |
|
|
|
key_node.tag = 'tag:yaml.org,2002:str' |
|
|
|
merge.append((key_node, value_node)) |
|
|
|
else: |
|
|
|
merge.append((key_node, value_node)) |
|
|
|
|
|
|
|
node.value = merge |
|
|
|
|
|
|
|
def flatten_yaml_seq(self, node): # pragma: no cover |
|
|
|
submerge = [] |
|
|
|
for subnode in node.value: |
|
|
|
# we need to flatten each item in the seq, most likely they'll be mappings, |
|
|
|
# but we need to allow for custom flatteners as well. |
|
|
|
flattener = self.yaml_flatteners.get(subnode.tag) |
|
|
|
if flattener: |
|
|
|
for value in flattener(self, subnode): |
|
|
|
submerge.append(value) |
|
|
|
else: |
|
|
|
raise ConstructorError( |
|
|
|
"while constructing a mapping", |
|
|
|
node.start_mark, |
|
|
|
"expected a mapping for merging, but found %s" % subnode.id, |
|
|
|
subnode.start_mark, |
|
|
|
) |
|
|
|
submerge.reverse() |
|
|
|
for value in submerge: |
|
|
|
yield value |
|
|
|
|
|
|
|
def flatten_yaml_map(self, node): # pragma: no cover |
|
|
|
self.flatten_mapping(node) |
|
|
|
yield node.value |
|
|
|
|
|
|
|
|
|
|
|
# These 2 add's are also ported out of the PR |
|
|
|
ContextLoader.add_flattener( |
|
|
|
'tag:yaml.org,2002:seq', ContextLoader.flatten_yaml_seq |
|
|
|
) |
|
|
|
ContextLoader.add_flattener( |
|
|
|
'tag:yaml.org,2002:map', ContextLoader.flatten_yaml_map |
|
|
|
) |
|
|
|
|
|
|
|
ContextLoader.add_constructor('!include', ContextLoader.include) |
|
|
|
ContextLoader.add_constructor( |
|
|
|
ContextLoader.DEFAULT_MAPPING_TAG, ContextLoader._construct |
|
|
|
ContextLoader.DEFAULT_MAPPING_TAG, ContextLoader.construct_mapping |
|
|
|
) |
|
|
|
ContextLoader.add_constructor('!include', ContextLoader.construct_include) |
|
|
|
ContextLoader.add_flattener('!include', ContextLoader.flatten_include) |
|
|
|
|
|
|
|
|
|
|
|
# Found http://stackoverflow.com/a/21912744 which guided me on how to hook in |
|
|
|
# here |
|
|
|
class SortEnforcingLoader(ContextLoader): |
|
|
|
|
|
|
|
def _construct(self, node): |
|
|
|
ret, pairs, context = self._pairs(node) |
|
|
|
def construct_mapping(self, node, deep=False): |
|
|
|
ret = super().construct_mapping(node, deep) |
|
|
|
|
|
|
|
keys = [d[0] for d in pairs] |
|
|
|
keys = list(ret.keys()) |
|
|
|
keys_sorted = sorted(keys, key=self.KEYGEN) |
|
|
|
for key in keys: |
|
|
|
expected = keys_sorted.pop(0) |
|
|
|
if key != expected: |
|
|
|
start_mark = node.start_mark |
|
|
|
context = f'{start_mark.name}, line {start_mark.line+1}, column {start_mark.column+1}' |
|
|
|
raise ConstructorError( |
|
|
|
None, |
|
|
|
None, |
|
|
|
@ -119,7 +151,7 @@ class NaturalSortEnforcingLoader(SortEnforcingLoader): |
|
|
|
|
|
|
|
NaturalSortEnforcingLoader.add_constructor( |
|
|
|
SortEnforcingLoader.DEFAULT_MAPPING_TAG, |
|
|
|
NaturalSortEnforcingLoader._construct, |
|
|
|
NaturalSortEnforcingLoader.construct_mapping, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@ -129,7 +161,7 @@ class SimpleSortEnforcingLoader(SortEnforcingLoader): |
|
|
|
|
|
|
|
SimpleSortEnforcingLoader.add_constructor( |
|
|
|
SortEnforcingLoader.DEFAULT_MAPPING_TAG, |
|
|
|
SimpleSortEnforcingLoader._construct, |
|
|
|
SimpleSortEnforcingLoader.construct_mapping, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|