diff --git a/octodns/idna.py b/octodns/idna.py index bc91d46..bf30343 100644 --- a/octodns/idna.py +++ b/octodns/idna.py @@ -2,6 +2,8 @@ # # +from collections.abc import MutableMapping + from idna import decode as _decode, encode as _encode # Providers will need to to make calls to these at the appropriate points, @@ -12,6 +14,7 @@ from idna import decode as _decode, encode as _encode def idna_encode(name): # Based on https://github.com/psf/requests/pull/3695/files # #diff-0debbb2447ce5debf2872cb0e17b18babe3566e9d9900739e8581b355bd513f7R39 + name = name.lower() try: name.encode('ascii') # No utf8 chars, just use as-is @@ -34,3 +37,35 @@ def idna_decode(name): return _decode(name) # not idna, just return as-is return name + + +class IdnaDict(MutableMapping): + '''A dict type that is insensitive to case and utf-8/idna encoded strings''' + + def __init__(self, data=None): + self._data = dict() + if data is not None: + self.update(data) + + def __setitem__(self, k, v): + self._data[idna_encode(k)] = v + + def __getitem__(self, k): + return self._data[idna_encode(k)] + + def __delitem__(self, k): + del self._data[idna_encode(k)] + + def __iter__(self): + return iter(self._data) + + def __len__(self): + return len(self._data) + + def decoded_keys(self): + for key in self.keys(): + yield idna_decode(key) + + def decoded_items(self): + for key, value in self.items(): + yield (idna_decode(key), value) diff --git a/tests/test_octodns_idna.py b/tests/test_octodns_idna.py index 0c6b125..2e09401 100644 --- a/tests/test_octodns_idna.py +++ b/tests/test_octodns_idna.py @@ -11,7 +11,7 @@ from __future__ import ( from unittest import TestCase -from octodns.idna import idna_decode, idna_encode +from octodns.idna import IdnaDict, idna_decode, idna_encode class TestIdna(TestCase): @@ -56,5 +56,87 @@ class TestIdna(TestCase): self.assertIdna('bleep_bloop.foo_bar.pl.', 'bleep_bloop.foo_bar.pl.') def test_case_insensitivity(self): - # Shouldn't be hit by octoDNS use cases, but checked anyway self.assertEqual('zajęzyk.pl.', idna_decode('XN--ZAJZYK-Y4A.PL.')) + self.assertEqual('xn--zajzyk-y4a.pl.', idna_encode('ZajęzyK.Pl.')) + + +class TestIdnaDict(TestCase): + plain = 'testing.tests.' + almost = 'tésting.tests.' + utf8 = 'déjà.vu.' + + normal = {plain: 42, almost: 43, utf8: 44} + + def test_basics(self): + d = IdnaDict() + + # plain ascii + d[self.plain] = 42 + self.assertEqual(42, d[self.plain]) + + # almost the same, single utf-8 char + d[self.almost] = 43 + # fetch as utf-8 + self.assertEqual(43, d[self.almost]) + # fetch as idna + self.assertEqual(43, d[idna_encode(self.almost)]) + # plain is stil there, unchanged + self.assertEqual(42, d[self.plain]) + + # lots of utf8 + d[self.utf8] = 44 + self.assertEqual(44, d[self.utf8]) + self.assertEqual(44, d[idna_encode(self.utf8)]) + + # setting with idna version replaces something set previously with utf8 + d[idna_encode(self.almost)] = 45 + self.assertEqual(45, d[self.almost]) + self.assertEqual(45, d[idna_encode(self.almost)]) + + # contains + self.assertTrue(self.plain in d) + self.assertTrue(self.almost in d) + self.assertTrue(idna_encode(self.almost) in d) + self.assertTrue(self.utf8 in d) + self.assertTrue(idna_encode(self.utf8) in d) + + # we can delete with either form + del d[self.almost] + self.assertFalse(self.almost in d) + self.assertFalse(idna_encode(self.almost) in d) + del d[idna_encode(self.utf8)] + self.assertFalse(self.utf8 in d) + self.assertFalse(idna_encode(self.utf8) in d) + + def test_keys(self): + d = IdnaDict(self.normal) + + # keys are idna versions by default + self.assertEqual( + (self.plain, idna_encode(self.almost), idna_encode(self.utf8)), + tuple(d.keys()), + ) + + # decoded keys gives the utf8 version + self.assertEqual( + (self.plain, self.almost, self.utf8), tuple(d.decoded_keys()) + ) + + def test_items(self): + d = IdnaDict(self.normal) + + # idna keys in items + self.assertEqual( + ( + (self.plain, 42), + (idna_encode(self.almost), 43), + (idna_encode(self.utf8), 44), + ), + tuple(d.items()), + ) + + # utf8 keys in decoded_items + self.assertEqual( + ((self.plain, 42), (self.almost, 43), (self.utf8, 44)), + tuple(d.decoded_items()), + )