diff --git a/.changelog/bb94cc6d9dde44b38d875cf17be4bbce.md b/.changelog/bb94cc6d9dde44b38d875cf17be4bbce.md new file mode 100644 index 0000000..da44967 --- /dev/null +++ b/.changelog/bb94cc6d9dde44b38d875cf17be4bbce.md @@ -0,0 +1,4 @@ +--- +type: minor +--- +Full rewrite of octodns-report: support for IPv6 resolvers, async names resolution and JSON output \ No newline at end of file diff --git a/octodns/cmds/report.py b/octodns/cmds/report.py index 80de6b0..97d1673 100755 --- a/octodns/cmds/report.py +++ b/octodns/cmds/report.py @@ -3,25 +3,41 @@ Octo-DNS Reporter ''' -import re -from concurrent.futures import ThreadPoolExecutor +import asyncio +import csv +import io +import ipaddress +import json +import sys +from collections import defaultdict from logging import getLogger -from sys import stdout -from dns.exception import Timeout -from dns.resolver import NXDOMAIN, NoAnswer, NoNameservers, Resolver, query +import dns.asyncresolver +import dns.resolver from octodns.cmds.args import ArgumentParser from octodns.manager import Manager -class AsyncResolver(Resolver): - def __init__(self, num_workers, *args, **kwargs): - super().__init__(*args, **kwargs) - self.executor = ThreadPoolExecutor(max_workers=num_workers) +async def async_resolve(record, resolver, timeout, limit): + async with limit: + r = dns.asyncresolver.Resolver(configure=False) + r.lifetime = timeout + r.nameservers = [resolver] - def query(self, *args, **kwargs): - return self.executor.submit(super().query, *args, **kwargs) + try: + query = await r.resolve(qname=record.fqdn, rdtype=record._type) + answer = [str(a) for a in query] + except (dns.resolver.NoAnswer, dns.resolver.NoNameservers): + answer = ['*no answer*'] + except dns.resolver.NXDOMAIN: + answer = ['*does not exist*'] + except dns.resolver.YXDOMAIN: + answer = ['*should not exist*'] + except dns.resolver.LifetimeTimeout: + answer = ['*timeout*'] + + return [record, resolver, sorted(answer)] def main(): @@ -41,10 +57,22 @@ def main(): help='Source(s) to pull data from', ) parser.add_argument( - '--num-workers', default=4, help='Number of background workers' + '--concurrency', + type=int, + default=10, + help='Maximum number of concurrent DNS queries', + ) + parser.add_argument( + '--timeout', + type=float, + default=1, + help='Number seconds to wait for an answer', ) parser.add_argument( - '--timeout', default=1, help='Number seconds to wait for an answer' + '--output-format', + choices=['csv', 'json'], + default='csv', + help='Output format', ) parser.add_argument( '--lenient', @@ -52,13 +80,17 @@ def main(): default=False, help='Ignore record validations and do a best effort dump', ) - parser.add_argument('server', nargs='+', help='Servers to query') + parser.add_argument('server', nargs='+', help='DNS resolver to query') args = parser.parse_args() + concurrency = args.concurrency + timeout = args.timeout + output_format = args.output_format manager = Manager(args.config_file) log = getLogger('report') + log.info(f'concurrency={concurrency} timeout={timeout}') try: sources = [manager.providers[source] for source in args.source] @@ -69,49 +101,102 @@ def main(): for source in sources: source.populate(zone, lenient=args.lenient) - servers = ','.join(args.server) - print(f'name,type,ttl,{servers},consistent') + servers = args.server resolvers = [] - ip_addr_re = re.compile(r'^[\d\.]+$') - for server in args.server: - resolver = AsyncResolver( - configure=False, num_workers=int(args.num_workers) - ) - if not ip_addr_re.match(server): - server = str(query(server, 'A')[0]) - log.info('server=%s', server) - resolver.nameservers = [server] - resolver.lifetime = int(args.timeout) - resolvers.append(resolver) - - queries = {} + for server in servers: + resolver = None + is_hostname = False + + try: + ip = ipaddress.ip_address(server) + # "2001:4860:4860:0:0:0:0:8888" => "2001:4860:4860::8888" + resolver = ip.compressed + + # The specified server isn't a valid IP address, maybe it's a valid + # hostname? So we try to resolve it. + except ValueError: + # IPv4 first, then IPv6. + for rrtype in ['A', 'AAAA']: + try: + query = dns.resolver.resolve(server, rrtype) + resolver = str(query.rrset[0]) + is_hostname = True + # Exit on first IP address found. + break + + # NXDOMAIN, NoAnswer, NoNameservers... + except: + continue + + if resolver and not resolver in resolvers: + if not is_hostname: + log.info(f'server={resolver}') + else: + log.info(f'server={resolver} ({server})') + + resolvers.append(resolver) + + if not resolvers: + print(f'Error: No valid resolver specified ({', '.join(servers)})') + sys.exit(1) + + loop = asyncio.new_event_loop() + limit = asyncio.Semaphore(concurrency) + tasks = [] for record in sorted(zone.records): - queries[record] = [ - r.query(record.fqdn, record._type) for r in resolvers - ] - - for record, futures in sorted(queries.items(), key=lambda d: d[0]): - stdout.write(record.decoded_fqdn) - stdout.write(',') - stdout.write(record._type) - stdout.write(',') - stdout.write(str(record.ttl)) - compare = {} - for future in futures: - stdout.write(',') - try: - answers = [str(r) for r in future.result()] - except (NoAnswer, NoNameservers): - answers = ['*no answer*'] - except NXDOMAIN: - answers = ['*does not exist*'] - except Timeout: - answers = ['*timeout*'] - stdout.write(' '.join(answers)) - # sorting to ignore order - answers = '*:*'.join(sorted(answers)).lower() - compare[answers] = True - stdout.write(',True\n' if len(compare) == 1 else ',False\n') + for resolver in resolvers: + tasks.append( + loop.create_task( + async_resolve(record, resolver, timeout, limit) + ) + ) + + queries = defaultdict(dict) + done, _ = loop.run_until_complete(asyncio.wait(tasks)) + for task in done: + _record, _resolver, _answer = task.result() + queries[_record][_resolver] = _answer + + loop.close() + + output = io.StringIO() + if output_format == 'csv': + csvout = csv.writer(output, quoting=csv.QUOTE_MINIMAL) + csvheader = ['Name', 'Type', 'TTL'] + csvheader = [*csvheader, *resolvers] + csvheader.append('Consistent') + csvout.writerow(csvheader) + + for record, answers in sorted(queries.items()): + csvrow = [record.decoded_fqdn, record._type, record.ttl] + values_check = {} + + for resolver in resolvers: + answer = f'{' '.join(answers.get(resolver))}' + values_check[answer.lower()] = True + csvrow.append(answer) + + csvrow.append(bool(len(values_check) == 1)) + csvout.writerow(csvrow) + + elif output_format == 'json': + jsonout = defaultdict(lambda: defaultdict(dict)) + for record, answers in sorted(queries.items()): + values_check = {} + + for resolver in resolvers: + answer = answers.get(resolver) + jsonout[record.fqdn][record._type][resolver] = answer + values_check[f'{' '.join(answer)}'.lower()] = True + + jsonout[record.fqdn][record._type]['consistent'] = bool( + len(values_check) == 1 + ) + + json.dump(jsonout, output) + + print(output.getvalue()) + output.close() if __name__ == '__main__':