Browse Source

EqualityTupleMixin impl, use everywhere we were doing tuple compares

pull/384/head
Ross McFarland 6 years ago
parent
commit
2b33f95c17
No known key found for this signature in database GPG Key ID: 61C10C4FC8FE4A89
4 changed files with 126 additions and 202 deletions
  1. +30
    -0
      octodns/equality.py
  2. +4
    -22
      octodns/provider/route53.py
  3. +24
    -180
      octodns/record/__init__.py
  4. +68
    -0
      tests/test_octodns_equality.py

+ 30
- 0
octodns/equality.py View File

@ -0,0 +1,30 @@
#
#
#
from __future__ import absolute_import, division, print_function, \
unicode_literals
class EqualityTupleMixin:
def _equality_tuple(self):
raise NotImplementedError('_equality_tuple method not implemented')
def __eq__(self, other):
return self._equality_tuple() == other._equality_tuple()
def __ne__(self, other):
return self._equality_tuple() != other._equality_tuple()
def __lt__(self, other):
return self._equality_tuple() < other._equality_tuple()
def __le__(self, other):
return self._equality_tuple() <= other._equality_tuple()
def __gt__(self, other):
return self._equality_tuple() > other._equality_tuple()
def __ge__(self, other):
return self._equality_tuple() >= other._equality_tuple()

+ 4
- 22
octodns/provider/route53.py View File

@ -16,6 +16,7 @@ import re
from six import text_type
from ..equality import EqualityTupleMixin
from ..record import Record, Update
from ..record.geo import GeoCodes
from .base import BaseProvider
@ -29,7 +30,7 @@ def _octal_replace(s):
return octal_re.sub(lambda m: chr(int(m.group(1), 8)), s)
class _Route53Record(object):
class _Route53Record(EqualityTupleMixin):
@classmethod
def _new_dynamic(cls, provider, record, hosted_zone_id, creating):
@ -157,29 +158,10 @@ class _Route53Record(object):
return '{}:{}'.format(self.fqdn, self._type).__hash__()
def _equality_tuple(self):
'''Sub-classes should call up to this and return its value and add
any additional fields they need to hav considered.'''
return (self.__class__.__name__, self.fqdn, self._type)
def __eq__(self, other):
'''Sub-classes should call up to this and return its value if true.
When it's false they should compute their own __eq__, same for other
ordering methods.'''
return self._equality_tuple() == other._equality_tuple()
def __ne__(self, other):
return self._equality_tuple() != other._equality_tuple()
def __lt__(self, other):
return self._equality_tuple() < other._equality_tuple()
def __le__(self, other):
return self._equality_tuple() <= other._equality_tuple()
def __gt__(self, other):
return self._equality_tuple() > other._equality_tuple()
def __ge__(self, other):
return self._equality_tuple() >= other._equality_tuple()
def __repr__(self):
return '_Route53Record<{} {} {} {}>'.format(self.fqdn, self._type,
self.ttl, self.values)


+ 24
- 180
octodns/record/__init__.py View File

@ -11,6 +11,7 @@ import re
from six import string_types, text_type
from ..equality import EqualityTupleMixin
from .geo import GeoCodes
@ -76,7 +77,7 @@ class ValidationError(Exception):
self.reasons = reasons
class Record(object):
class Record(EqualityTupleMixin):
log = getLogger('Record')
@classmethod
@ -209,30 +210,15 @@ class Record(object):
def __hash__(self):
return '{}:{}'.format(self.name, self._type).__hash__()
def __eq__(self, other):
return ((self.name, self._type) == (other.name, other._type))
def __ne__(self, other):
return ((self.name, self._type) != (other.name, other._type))
def __lt__(self, other):
return ((self.name, self._type) < (other.name, other._type))
def __le__(self, other):
return ((self.name, self._type) <= (other.name, other._type))
def __gt__(self, other):
return ((self.name, self._type) > (other.name, other._type))
def __ge__(self, other):
return ((self.name, self._type) >= (other.name, other._type))
def _equality_tuple(self):
return (self.name, self._type)
def __repr__(self):
# Make sure this is always overridden
raise NotImplementedError('Abstract base class, __repr__ required')
class GeoValue(object):
class GeoValue(EqualityTupleMixin):
geo_re = re.compile(r'^(?P<continent_code>\w\w)(-(?P<country_code>\w\w)'
r'(-(?P<subdivision_code>\w\w))?)?$')
@ -259,35 +245,9 @@ class GeoValue(object):
yield '-'.join(bits)
bits.pop()
def __eq__(self, other):
return ((self.continent_code, self.country_code, self.subdivision_code,
self.values) == (other.continent_code, other.country_code,
other.subdivision_code, other.values))
def __ne__(self, other):
return ((self.continent_code, self.country_code, self.subdivision_code,
self.values) != (other.continent_code, other.country_code,
other.subdivision_code, other.values))
def __lt__(self, other):
return ((self.continent_code, self.country_code, self.subdivision_code,
self.values) < (other.continent_code, other.country_code,
other.subdivision_code, other.values))
def __le__(self, other):
return ((self.continent_code, self.country_code, self.subdivision_code,
self.values) <= (other.continent_code, other.country_code,
other.subdivision_code, other.values))
def __gt__(self, other):
return ((self.continent_code, self.country_code, self.subdivision_code,
self.values) > (other.continent_code, other.country_code,
other.subdivision_code, other.values))
def __ge__(self, other):
return ((self.continent_code, self.country_code, self.subdivision_code,
self.values) >= (other.continent_code, other.country_code,
other.subdivision_code, other.values))
def _equality_tuple(self):
return (self.continent_code, self.country_code, self.subdivision_code,
self.values)
def __repr__(self):
return "'Geo {} {} {} {}'".format(self.continent_code,
@ -787,7 +747,7 @@ class AliasRecord(_ValueMixin, Record):
_value_type = AliasValue
class CaaValue(object):
class CaaValue(EqualityTupleMixin):
# https://tools.ietf.org/html/rfc6844#page-5
@classmethod
@ -826,29 +786,8 @@ class CaaValue(object):
'value': self.value,
}
def __eq__(self, other):
return ((self.flags, self.tag, self.value) ==
(other.flags, other.tag, other.value))
def __ne__(self, other):
return ((self.flags, self.tag, self.value) !=
(other.flags, other.tag, other.value))
def __lt__(self, other):
return ((self.flags, self.tag, self.value) <
(other.flags, other.tag, other.value))
def __le__(self, other):
return ((self.flags, self.tag, self.value) <=
(other.flags, other.tag, other.value))
def __gt__(self, other):
return ((self.flags, self.tag, self.value) >
(other.flags, other.tag, other.value))
def __ge__(self, other):
return ((self.flags, self.tag, self.value) >=
(other.flags, other.tag, other.value))
def _equality_tuple(self):
return (self.flags, self.tag, self.value)
def __repr__(self):
return '{} {} "{}"'.format(self.flags, self.tag, self.value)
@ -872,7 +811,7 @@ class CnameRecord(_DynamicMixin, _ValueMixin, Record):
return reasons
class MxValue(object):
class MxValue(EqualityTupleMixin):
@classmethod
def validate(cls, data, _type):
@ -928,29 +867,8 @@ class MxValue(object):
def __hash__(self):
return hash((self.preference, self.exchange))
def __eq__(self, other):
return ((self.preference, self.exchange) ==
(other.preference, other.exchange))
def __ne__(self, other):
return ((self.preference, self.exchange) !=
(other.preference, other.exchange))
def __lt__(self, other):
return ((self.preference, self.exchange) <
(other.preference, other.exchange))
def __le__(self, other):
return ((self.preference, self.exchange) <=
(other.preference, other.exchange))
def __gt__(self, other):
return ((self.preference, self.exchange) >
(other.preference, other.exchange))
def __ge__(self, other):
return ((self.preference, self.exchange) >=
(other.preference, other.exchange))
def _equality_tuple(self):
return (self.preference, self.exchange)
def __repr__(self):
return "'{} {}'".format(self.preference, self.exchange)
@ -961,7 +879,7 @@ class MxRecord(_ValuesMixin, Record):
_value_type = MxValue
class NaptrValue(object):
class NaptrValue(EqualityTupleMixin):
VALID_FLAGS = ('S', 'A', 'U', 'P')
@classmethod
@ -1023,41 +941,9 @@ class NaptrValue(object):
def __hash__(self):
return hash(self.__repr__())
def __eq__(self, other):
return ((self.order, self.preference, self.flags, self.service,
self.regexp, self.replacement) ==
(other.order, other.preference, other.flags, other.service,
other.regexp, other.replacement))
def __ne__(self, other):
return ((self.order, self.preference, self.flags, self.service,
self.regexp, self.replacement) !=
(other.order, other.preference, other.flags, other.service,
other.regexp, other.replacement))
def __lt__(self, other):
return ((self.order, self.preference, self.flags, self.service,
self.regexp, self.replacement) <
(other.order, other.preference, other.flags, other.service,
other.regexp, other.replacement))
def __le__(self, other):
return ((self.order, self.preference, self.flags, self.service,
self.regexp, self.replacement) <=
(other.order, other.preference, other.flags, other.service,
other.regexp, other.replacement))
def __gt__(self, other):
return ((self.order, self.preference, self.flags, self.service,
self.regexp, self.replacement) >
(other.order, other.preference, other.flags, other.service,
other.regexp, other.replacement))
def __ge__(self, other):
return ((self.order, self.preference, self.flags, self.service,
self.regexp, self.replacement) >=
(other.order, other.preference, other.flags, other.service,
other.regexp, other.replacement))
def _equality_tuple(self):
return (self.order, self.preference, self.flags, self.service,
self.regexp, self.replacement)
def __repr__(self):
flags = self.flags if self.flags is not None else ''
@ -1107,7 +993,7 @@ class PtrRecord(_ValueMixin, Record):
_value_type = PtrValue
class SshfpValue(object):
class SshfpValue(EqualityTupleMixin):
VALID_ALGORITHMS = (1, 2, 3, 4)
VALID_FINGERPRINT_TYPES = (1, 2)
@ -1161,29 +1047,8 @@ class SshfpValue(object):
def __hash__(self):
return hash(self.__repr__())
def __eq__(self, other):
return ((self.algorithm, self.fingerprint_type, self.fingerprint) ==
(other.algorithm, other.fingerprint_type, other.fingerprint))
def __ne__(self, other):
return ((self.algorithm, self.fingerprint_type, self.fingerprint) !=
(other.algorithm, other.fingerprint_type, other.fingerprint))
def __lt__(self, other):
return ((self.algorithm, self.fingerprint_type, self.fingerprint) <
(other.algorithm, other.fingerprint_type, other.fingerprint))
def __le__(self, other):
return ((self.algorithm, self.fingerprint_type, self.fingerprint) <=
(other.algorithm, other.fingerprint_type, other.fingerprint))
def __gt__(self, other):
return ((self.algorithm, self.fingerprint_type, self.fingerprint) >
(other.algorithm, other.fingerprint_type, other.fingerprint))
def __ge__(self, other):
return ((self.algorithm, self.fingerprint_type, self.fingerprint) >=
(other.algorithm, other.fingerprint_type, other.fingerprint))
def _equality_tuple(self):
return (self.algorithm, self.fingerprint_type, self.fingerprint)
def __repr__(self):
return "'{} {} {}'".format(self.algorithm, self.fingerprint_type,
@ -1244,7 +1109,7 @@ class SpfRecord(_ChunkedValuesMixin, Record):
_value_type = _ChunkedValue
class SrvValue(object):
class SrvValue(EqualityTupleMixin):
@classmethod
def validate(cls, data, _type):
@ -1302,29 +1167,8 @@ class SrvValue(object):
def __hash__(self):
return hash(self.__repr__())
def __eq__(self, other):
return ((self.priority, self.weight, self.port, self.target) ==
(other.priority, other.weight, other.port, other.target))
def __ne__(self, other):
return ((self.priority, self.weight, self.port, self.target) !=
(other.priority, other.weight, other.port, other.target))
def __lt__(self, other):
return ((self.priority, self.weight, self.port, self.target) <
(other.priority, other.weight, other.port, other.target))
def __le__(self, other):
return ((self.priority, self.weight, self.port, self.target) <=
(other.priority, other.weight, other.port, other.target))
def __gt__(self, other):
return ((self.priority, self.weight, self.port, self.target) >
(other.priority, other.weight, other.port, other.target))
def __ge__(self, other):
return ((self.priority, self.weight, self.port, self.target) >=
(other.priority, other.weight, other.port, other.target))
def _equality_tuple(self):
return (self.priority, self.weight, self.port, self.target)
def __repr__(self):
return "'{} {} {} {}'".format(self.priority, self.weight, self.port,


+ 68
- 0
tests/test_octodns_equality.py View File

@ -0,0 +1,68 @@
#
#
#
from __future__ import absolute_import, division, print_function, \
unicode_literals
from unittest import TestCase
from octodns.equality import EqualityTupleMixin
class TestEqualityTupleMixin(TestCase):
def test_basics(self):
class Simple(EqualityTupleMixin):
def __init__(self, a, b, c):
self.a = a
self.b = b
self.c = c
def _equality_tuple(self):
return (self.a, self.b)
one = Simple(1, 2, 3)
same = Simple(1, 2, 3)
matches = Simple(1, 2, 'ignored')
doesnt = Simple(2, 3, 4)
# equality
self.assertEquals(one, one)
self.assertEquals(one, same)
self.assertEquals(same, one)
# only a & c are considered
self.assertEquals(one, matches)
self.assertEquals(matches, one)
self.assertNotEquals(one, doesnt)
self.assertNotEquals(doesnt, one)
# lt
self.assertTrue(one < doesnt)
self.assertFalse(doesnt < one)
self.assertFalse(one < same)
# le
self.assertTrue(one <= doesnt)
self.assertFalse(doesnt <= one)
self.assertTrue(one <= same)
# gt
self.assertFalse(one > doesnt)
self.assertTrue(doesnt > one)
self.assertFalse(one > same)
# ge
self.assertFalse(one >= doesnt)
self.assertTrue(doesnt >= one)
self.assertTrue(one >= same)
def test_not_implemented(self):
class MissingMethod(EqualityTupleMixin):
pass
with self.assertRaises(NotImplementedError):
MissingMethod() == MissingMethod()

Loading…
Cancel
Save