Browse Source

Merge pull request #1321 from jleroy/feature/report-refactoring

octodns-report refactoring
pull/1325/head
Ross McFarland 2 months ago
committed by GitHub
parent
commit
09dc68b2c7
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
2 changed files with 154 additions and 56 deletions
  1. +4
    -0
      .changelog/bb94cc6d9dde44b38d875cf17be4bbce.md
  2. +150
    -56
      octodns/cmds/report.py

+ 4
- 0
.changelog/bb94cc6d9dde44b38d875cf17be4bbce.md View File

@ -0,0 +1,4 @@
---
type: minor
---
Full rewrite of octodns-report: support for IPv6 resolvers, async names resolution and JSON output

+ 150
- 56
octodns/cmds/report.py View File

@ -3,25 +3,48 @@
Octo-DNS Reporter
'''
import re
from concurrent.futures import ThreadPoolExecutor
from asyncio import Semaphore, new_event_loop, wait
from collections import defaultdict
from csv import QUOTE_NONE, writer
from io import StringIO
from ipaddress import ip_address
from json import dump
from logging import getLogger
from sys import stdout
from dns.exception import Timeout
from dns.resolver import NXDOMAIN, NoAnswer, NoNameservers, Resolver, query
from sys import exit
from dns.asyncresolver import Resolver as AsyncResolver
from dns.resolver import (
NXDOMAIN,
YXDOMAIN,
LifetimeTimeout,
NoAnswer,
NoNameservers,
resolve,
)
from octodns.cmds.args import ArgumentParser
from octodns.manager import Manager
class AsyncResolver(Resolver):
def __init__(self, num_workers, *args, **kwargs):
super().__init__(*args, **kwargs)
self.executor = ThreadPoolExecutor(max_workers=num_workers)
async def async_resolve(record, resolver, timeout, limit):
async with limit:
r = AsyncResolver(configure=False)
r.lifetime = timeout
r.nameservers = [resolver]
try:
query = await r.resolve(qname=record.fqdn, rdtype=record._type)
answer = sorted([str(a) for a in query])
except (NoAnswer, NoNameservers):
answer = ['*no answer*']
except NXDOMAIN:
answer = ['*does not exist*']
except YXDOMAIN:
answer = ['*should not exist*']
except LifetimeTimeout:
answer = ['*timeout*']
def query(self, *args, **kwargs):
return self.executor.submit(super().query, *args, **kwargs)
return [record, resolver, answer]
def main():
@ -41,10 +64,22 @@ def main():
help='Source(s) to pull data from',
)
parser.add_argument(
'--num-workers', default=4, help='Number of background workers'
'--concurrency',
type=int,
default=4,
help='Maximum number of concurrent DNS queries',
)
parser.add_argument(
'--timeout',
type=float,
default=1,
help='Number seconds to wait for an answer',
)
parser.add_argument(
'--timeout', default=1, help='Number seconds to wait for an answer'
'--output-format',
choices=['csv', 'json'],
default='csv',
help='Output format',
)
parser.add_argument(
'--lenient',
@ -52,13 +87,17 @@ def main():
default=False,
help='Ignore record validations and do a best effort dump',
)
parser.add_argument('server', nargs='+', help='Servers to query')
parser.add_argument('server', nargs='+', help='DNS resolver to query')
args = parser.parse_args()
concurrency = args.concurrency
timeout = args.timeout
output_format = args.output_format
manager = Manager(args.config_file)
log = getLogger('report')
log.info(f'concurrency={concurrency} timeout={timeout}')
try:
sources = [manager.providers[source] for source in args.source]
@ -69,49 +108,104 @@ def main():
for source in sources:
source.populate(zone, lenient=args.lenient)
servers = ','.join(args.server)
print(f'name,type,ttl,{servers},consistent')
servers = args.server
resolvers = []
ip_addr_re = re.compile(r'^[\d\.]+$')
for server in args.server:
resolver = AsyncResolver(
configure=False, num_workers=int(args.num_workers)
)
if not ip_addr_re.match(server):
server = str(query(server, 'A')[0])
log.info('server=%s', server)
resolver.nameservers = [server]
resolver.lifetime = int(args.timeout)
resolvers.append(resolver)
queries = {}
for server in servers:
resolver = None
is_hostname = False
try:
ip = ip_address(server)
# "2001:4860:4860:0:0:0:0:8888" => "2001:4860:4860::8888"
resolver = ip.compressed
# The specified server isn't a valid IP address, maybe it's a valid
# hostname? So we try to resolve it.
except ValueError:
# IPv4 first, then IPv6.
for rrtype in ['A', 'AAAA']:
try:
query = resolve(server, rrtype)
resolver = str(query.rrset[0])
is_hostname = True
# Exit on first IP address found.
break
# NXDOMAIN, NoAnswer, NoNameservers...
except Exception:
continue
if resolver and resolver not in resolvers:
if not is_hostname:
log.info(f'server={resolver}')
else:
log.info(f'server={resolver} ({server})')
resolvers.append(resolver)
if not resolvers:
print(f'Error: No valid resolver specified ({", ".join(servers)})')
exit(1)
loop = new_event_loop()
limit = Semaphore(concurrency)
tasks = []
for record in sorted(zone.records):
queries[record] = [
r.query(record.fqdn, record._type) for r in resolvers
]
for record, futures in sorted(queries.items(), key=lambda d: d[0]):
stdout.write(record.decoded_fqdn)
stdout.write(',')
stdout.write(record._type)
stdout.write(',')
stdout.write(str(record.ttl))
compare = {}
for future in futures:
stdout.write(',')
try:
answers = [str(r) for r in future.result()]
except (NoAnswer, NoNameservers):
answers = ['*no answer*']
except NXDOMAIN:
answers = ['*does not exist*']
except Timeout:
answers = ['*timeout*']
stdout.write(' '.join(answers))
# sorting to ignore order
answers = '*:*'.join(sorted(answers)).lower()
compare[answers] = True
stdout.write(',True\n' if len(compare) == 1 else ',False\n')
for resolver in resolvers:
tasks.append(
loop.create_task(
async_resolve(record, resolver, timeout, limit)
)
)
queries = defaultdict(dict)
done, _ = loop.run_until_complete(wait(tasks))
for task in done:
_record, _resolver, _answer = task.result()
queries[_record][_resolver] = _answer
loop.close()
output = StringIO()
if output_format == 'csv':
csvout = writer(output, quoting=QUOTE_NONE, quotechar=None)
csvheader = ['Name', 'Type', 'TTL']
csvheader = [*csvheader, *resolvers]
csvheader.append('Consistent')
csvout.writerow(csvheader)
for record, answers in sorted(queries.items()):
csvrow = [record.decoded_fqdn, record._type, record.ttl]
values_check = {}
for resolver in resolvers:
answer = ' '.join(answers.get(resolver, []))
values_check[answer.lower()] = True
csvrow.append(answer)
csvrow.append(bool(len(values_check) == 1))
csvout.writerow(csvrow)
elif output_format == 'json':
jsonout = defaultdict(lambda: defaultdict(dict))
for record, answers in sorted(queries.items()):
values_check = {}
for resolver in resolvers:
# Stripping the surrounding quotes of TXT records values to
# avoid them being unnecessarily escaped by JSON module.
answer = [a.strip('"') for a in answers.get(resolver, [])]
jsonout[record.decoded_fqdn][record._type][resolver] = answer
values_check[' '.join(answer).lower()] = True
jsonout[record.fqdn][record._type]['consistent'] = bool(
len(values_check) == 1
)
dump(jsonout, output)
print(output.getvalue())
output.close()
if __name__ == '__main__':


Loading…
Cancel
Save