Browse Source

EqualityTupleMixin impl, use everywhere we were doing tuple compares

pull/384/head
Ross McFarland 6 years ago
parent
commit
2b33f95c17
No known key found for this signature in database GPG Key ID: 61C10C4FC8FE4A89
4 changed files with 126 additions and 202 deletions
  1. +30
    -0
      octodns/equality.py
  2. +4
    -22
      octodns/provider/route53.py
  3. +24
    -180
      octodns/record/__init__.py
  4. +68
    -0
      tests/test_octodns_equality.py

+ 30
- 0
octodns/equality.py View File

@ -0,0 +1,30 @@
#
#
#
from __future__ import absolute_import, division, print_function, \
unicode_literals
class EqualityTupleMixin:
def _equality_tuple(self):
raise NotImplementedError('_equality_tuple method not implemented')
def __eq__(self, other):
return self._equality_tuple() == other._equality_tuple()
def __ne__(self, other):
return self._equality_tuple() != other._equality_tuple()
def __lt__(self, other):
return self._equality_tuple() < other._equality_tuple()
def __le__(self, other):
return self._equality_tuple() <= other._equality_tuple()
def __gt__(self, other):
return self._equality_tuple() > other._equality_tuple()
def __ge__(self, other):
return self._equality_tuple() >= other._equality_tuple()

+ 4
- 22
octodns/provider/route53.py View File

@ -16,6 +16,7 @@ import re
from six import text_type from six import text_type
from ..equality import EqualityTupleMixin
from ..record import Record, Update from ..record import Record, Update
from ..record.geo import GeoCodes from ..record.geo import GeoCodes
from .base import BaseProvider from .base import BaseProvider
@ -29,7 +30,7 @@ def _octal_replace(s):
return octal_re.sub(lambda m: chr(int(m.group(1), 8)), s) return octal_re.sub(lambda m: chr(int(m.group(1), 8)), s)
class _Route53Record(object):
class _Route53Record(EqualityTupleMixin):
@classmethod @classmethod
def _new_dynamic(cls, provider, record, hosted_zone_id, creating): def _new_dynamic(cls, provider, record, hosted_zone_id, creating):
@ -157,29 +158,10 @@ class _Route53Record(object):
return '{}:{}'.format(self.fqdn, self._type).__hash__() return '{}:{}'.format(self.fqdn, self._type).__hash__()
def _equality_tuple(self): def _equality_tuple(self):
'''Sub-classes should call up to this and return its value and add
any additional fields they need to hav considered.'''
return (self.__class__.__name__, self.fqdn, self._type) return (self.__class__.__name__, self.fqdn, self._type)
def __eq__(self, other):
'''Sub-classes should call up to this and return its value if true.
When it's false they should compute their own __eq__, same for other
ordering methods.'''
return self._equality_tuple() == other._equality_tuple()
def __ne__(self, other):
return self._equality_tuple() != other._equality_tuple()
def __lt__(self, other):
return self._equality_tuple() < other._equality_tuple()
def __le__(self, other):
return self._equality_tuple() <= other._equality_tuple()
def __gt__(self, other):
return self._equality_tuple() > other._equality_tuple()
def __ge__(self, other):
return self._equality_tuple() >= other._equality_tuple()
def __repr__(self): def __repr__(self):
return '_Route53Record<{} {} {} {}>'.format(self.fqdn, self._type, return '_Route53Record<{} {} {} {}>'.format(self.fqdn, self._type,
self.ttl, self.values) self.ttl, self.values)


+ 24
- 180
octodns/record/__init__.py View File

@ -11,6 +11,7 @@ import re
from six import string_types, text_type from six import string_types, text_type
from ..equality import EqualityTupleMixin
from .geo import GeoCodes from .geo import GeoCodes
@ -76,7 +77,7 @@ class ValidationError(Exception):
self.reasons = reasons self.reasons = reasons
class Record(object):
class Record(EqualityTupleMixin):
log = getLogger('Record') log = getLogger('Record')
@classmethod @classmethod
@ -209,30 +210,15 @@ class Record(object):
def __hash__(self): def __hash__(self):
return '{}:{}'.format(self.name, self._type).__hash__() return '{}:{}'.format(self.name, self._type).__hash__()
def __eq__(self, other):
return ((self.name, self._type) == (other.name, other._type))
def __ne__(self, other):
return ((self.name, self._type) != (other.name, other._type))
def __lt__(self, other):
return ((self.name, self._type) < (other.name, other._type))
def __le__(self, other):
return ((self.name, self._type) <= (other.name, other._type))
def __gt__(self, other):
return ((self.name, self._type) > (other.name, other._type))
def __ge__(self, other):
return ((self.name, self._type) >= (other.name, other._type))
def _equality_tuple(self):
return (self.name, self._type)
def __repr__(self): def __repr__(self):
# Make sure this is always overridden # Make sure this is always overridden
raise NotImplementedError('Abstract base class, __repr__ required') raise NotImplementedError('Abstract base class, __repr__ required')
class GeoValue(object):
class GeoValue(EqualityTupleMixin):
geo_re = re.compile(r'^(?P<continent_code>\w\w)(-(?P<country_code>\w\w)' geo_re = re.compile(r'^(?P<continent_code>\w\w)(-(?P<country_code>\w\w)'
r'(-(?P<subdivision_code>\w\w))?)?$') r'(-(?P<subdivision_code>\w\w))?)?$')
@ -259,35 +245,9 @@ class GeoValue(object):
yield '-'.join(bits) yield '-'.join(bits)
bits.pop() bits.pop()
def __eq__(self, other):
return ((self.continent_code, self.country_code, self.subdivision_code,
self.values) == (other.continent_code, other.country_code,
other.subdivision_code, other.values))
def __ne__(self, other):
return ((self.continent_code, self.country_code, self.subdivision_code,
self.values) != (other.continent_code, other.country_code,
other.subdivision_code, other.values))
def __lt__(self, other):
return ((self.continent_code, self.country_code, self.subdivision_code,
self.values) < (other.continent_code, other.country_code,
other.subdivision_code, other.values))
def __le__(self, other):
return ((self.continent_code, self.country_code, self.subdivision_code,
self.values) <= (other.continent_code, other.country_code,
other.subdivision_code, other.values))
def __gt__(self, other):
return ((self.continent_code, self.country_code, self.subdivision_code,
self.values) > (other.continent_code, other.country_code,
other.subdivision_code, other.values))
def __ge__(self, other):
return ((self.continent_code, self.country_code, self.subdivision_code,
self.values) >= (other.continent_code, other.country_code,
other.subdivision_code, other.values))
def _equality_tuple(self):
return (self.continent_code, self.country_code, self.subdivision_code,
self.values)
def __repr__(self): def __repr__(self):
return "'Geo {} {} {} {}'".format(self.continent_code, return "'Geo {} {} {} {}'".format(self.continent_code,
@ -787,7 +747,7 @@ class AliasRecord(_ValueMixin, Record):
_value_type = AliasValue _value_type = AliasValue
class CaaValue(object):
class CaaValue(EqualityTupleMixin):
# https://tools.ietf.org/html/rfc6844#page-5 # https://tools.ietf.org/html/rfc6844#page-5
@classmethod @classmethod
@ -826,29 +786,8 @@ class CaaValue(object):
'value': self.value, 'value': self.value,
} }
def __eq__(self, other):
return ((self.flags, self.tag, self.value) ==
(other.flags, other.tag, other.value))
def __ne__(self, other):
return ((self.flags, self.tag, self.value) !=
(other.flags, other.tag, other.value))
def __lt__(self, other):
return ((self.flags, self.tag, self.value) <
(other.flags, other.tag, other.value))
def __le__(self, other):
return ((self.flags, self.tag, self.value) <=
(other.flags, other.tag, other.value))
def __gt__(self, other):
return ((self.flags, self.tag, self.value) >
(other.flags, other.tag, other.value))
def __ge__(self, other):
return ((self.flags, self.tag, self.value) >=
(other.flags, other.tag, other.value))
def _equality_tuple(self):
return (self.flags, self.tag, self.value)
def __repr__(self): def __repr__(self):
return '{} {} "{}"'.format(self.flags, self.tag, self.value) return '{} {} "{}"'.format(self.flags, self.tag, self.value)
@ -872,7 +811,7 @@ class CnameRecord(_DynamicMixin, _ValueMixin, Record):
return reasons return reasons
class MxValue(object):
class MxValue(EqualityTupleMixin):
@classmethod @classmethod
def validate(cls, data, _type): def validate(cls, data, _type):
@ -928,29 +867,8 @@ class MxValue(object):
def __hash__(self): def __hash__(self):
return hash((self.preference, self.exchange)) return hash((self.preference, self.exchange))
def __eq__(self, other):
return ((self.preference, self.exchange) ==
(other.preference, other.exchange))
def __ne__(self, other):
return ((self.preference, self.exchange) !=
(other.preference, other.exchange))
def __lt__(self, other):
return ((self.preference, self.exchange) <
(other.preference, other.exchange))
def __le__(self, other):
return ((self.preference, self.exchange) <=
(other.preference, other.exchange))
def __gt__(self, other):
return ((self.preference, self.exchange) >
(other.preference, other.exchange))
def __ge__(self, other):
return ((self.preference, self.exchange) >=
(other.preference, other.exchange))
def _equality_tuple(self):
return (self.preference, self.exchange)
def __repr__(self): def __repr__(self):
return "'{} {}'".format(self.preference, self.exchange) return "'{} {}'".format(self.preference, self.exchange)
@ -961,7 +879,7 @@ class MxRecord(_ValuesMixin, Record):
_value_type = MxValue _value_type = MxValue
class NaptrValue(object):
class NaptrValue(EqualityTupleMixin):
VALID_FLAGS = ('S', 'A', 'U', 'P') VALID_FLAGS = ('S', 'A', 'U', 'P')
@classmethod @classmethod
@ -1023,41 +941,9 @@ class NaptrValue(object):
def __hash__(self): def __hash__(self):
return hash(self.__repr__()) return hash(self.__repr__())
def __eq__(self, other):
return ((self.order, self.preference, self.flags, self.service,
self.regexp, self.replacement) ==
(other.order, other.preference, other.flags, other.service,
other.regexp, other.replacement))
def __ne__(self, other):
return ((self.order, self.preference, self.flags, self.service,
self.regexp, self.replacement) !=
(other.order, other.preference, other.flags, other.service,
other.regexp, other.replacement))
def __lt__(self, other):
return ((self.order, self.preference, self.flags, self.service,
self.regexp, self.replacement) <
(other.order, other.preference, other.flags, other.service,
other.regexp, other.replacement))
def __le__(self, other):
return ((self.order, self.preference, self.flags, self.service,
self.regexp, self.replacement) <=
(other.order, other.preference, other.flags, other.service,
other.regexp, other.replacement))
def __gt__(self, other):
return ((self.order, self.preference, self.flags, self.service,
self.regexp, self.replacement) >
(other.order, other.preference, other.flags, other.service,
other.regexp, other.replacement))
def __ge__(self, other):
return ((self.order, self.preference, self.flags, self.service,
self.regexp, self.replacement) >=
(other.order, other.preference, other.flags, other.service,
other.regexp, other.replacement))
def _equality_tuple(self):
return (self.order, self.preference, self.flags, self.service,
self.regexp, self.replacement)
def __repr__(self): def __repr__(self):
flags = self.flags if self.flags is not None else '' flags = self.flags if self.flags is not None else ''
@ -1107,7 +993,7 @@ class PtrRecord(_ValueMixin, Record):
_value_type = PtrValue _value_type = PtrValue
class SshfpValue(object):
class SshfpValue(EqualityTupleMixin):
VALID_ALGORITHMS = (1, 2, 3, 4) VALID_ALGORITHMS = (1, 2, 3, 4)
VALID_FINGERPRINT_TYPES = (1, 2) VALID_FINGERPRINT_TYPES = (1, 2)
@ -1161,29 +1047,8 @@ class SshfpValue(object):
def __hash__(self): def __hash__(self):
return hash(self.__repr__()) return hash(self.__repr__())
def __eq__(self, other):
return ((self.algorithm, self.fingerprint_type, self.fingerprint) ==
(other.algorithm, other.fingerprint_type, other.fingerprint))
def __ne__(self, other):
return ((self.algorithm, self.fingerprint_type, self.fingerprint) !=
(other.algorithm, other.fingerprint_type, other.fingerprint))
def __lt__(self, other):
return ((self.algorithm, self.fingerprint_type, self.fingerprint) <
(other.algorithm, other.fingerprint_type, other.fingerprint))
def __le__(self, other):
return ((self.algorithm, self.fingerprint_type, self.fingerprint) <=
(other.algorithm, other.fingerprint_type, other.fingerprint))
def __gt__(self, other):
return ((self.algorithm, self.fingerprint_type, self.fingerprint) >
(other.algorithm, other.fingerprint_type, other.fingerprint))
def __ge__(self, other):
return ((self.algorithm, self.fingerprint_type, self.fingerprint) >=
(other.algorithm, other.fingerprint_type, other.fingerprint))
def _equality_tuple(self):
return (self.algorithm, self.fingerprint_type, self.fingerprint)
def __repr__(self): def __repr__(self):
return "'{} {} {}'".format(self.algorithm, self.fingerprint_type, return "'{} {} {}'".format(self.algorithm, self.fingerprint_type,
@ -1244,7 +1109,7 @@ class SpfRecord(_ChunkedValuesMixin, Record):
_value_type = _ChunkedValue _value_type = _ChunkedValue
class SrvValue(object):
class SrvValue(EqualityTupleMixin):
@classmethod @classmethod
def validate(cls, data, _type): def validate(cls, data, _type):
@ -1302,29 +1167,8 @@ class SrvValue(object):
def __hash__(self): def __hash__(self):
return hash(self.__repr__()) return hash(self.__repr__())
def __eq__(self, other):
return ((self.priority, self.weight, self.port, self.target) ==
(other.priority, other.weight, other.port, other.target))
def __ne__(self, other):
return ((self.priority, self.weight, self.port, self.target) !=
(other.priority, other.weight, other.port, other.target))
def __lt__(self, other):
return ((self.priority, self.weight, self.port, self.target) <
(other.priority, other.weight, other.port, other.target))
def __le__(self, other):
return ((self.priority, self.weight, self.port, self.target) <=
(other.priority, other.weight, other.port, other.target))
def __gt__(self, other):
return ((self.priority, self.weight, self.port, self.target) >
(other.priority, other.weight, other.port, other.target))
def __ge__(self, other):
return ((self.priority, self.weight, self.port, self.target) >=
(other.priority, other.weight, other.port, other.target))
def _equality_tuple(self):
return (self.priority, self.weight, self.port, self.target)
def __repr__(self): def __repr__(self):
return "'{} {} {} {}'".format(self.priority, self.weight, self.port, return "'{} {} {} {}'".format(self.priority, self.weight, self.port,


+ 68
- 0
tests/test_octodns_equality.py View File

@ -0,0 +1,68 @@
#
#
#
from __future__ import absolute_import, division, print_function, \
unicode_literals
from unittest import TestCase
from octodns.equality import EqualityTupleMixin
class TestEqualityTupleMixin(TestCase):
def test_basics(self):
class Simple(EqualityTupleMixin):
def __init__(self, a, b, c):
self.a = a
self.b = b
self.c = c
def _equality_tuple(self):
return (self.a, self.b)
one = Simple(1, 2, 3)
same = Simple(1, 2, 3)
matches = Simple(1, 2, 'ignored')
doesnt = Simple(2, 3, 4)
# equality
self.assertEquals(one, one)
self.assertEquals(one, same)
self.assertEquals(same, one)
# only a & c are considered
self.assertEquals(one, matches)
self.assertEquals(matches, one)
self.assertNotEquals(one, doesnt)
self.assertNotEquals(doesnt, one)
# lt
self.assertTrue(one < doesnt)
self.assertFalse(doesnt < one)
self.assertFalse(one < same)
# le
self.assertTrue(one <= doesnt)
self.assertFalse(doesnt <= one)
self.assertTrue(one <= same)
# gt
self.assertFalse(one > doesnt)
self.assertTrue(doesnt > one)
self.assertFalse(one > same)
# ge
self.assertFalse(one >= doesnt)
self.assertTrue(doesnt >= one)
self.assertTrue(one >= same)
def test_not_implemented(self):
class MissingMethod(EqualityTupleMixin):
pass
with self.assertRaises(NotImplementedError):
MissingMethod() == MissingMethod()

Loading…
Cancel
Save