#!/usr/bin/env python3
import argparse
import base64
import hashlib
import os
import socket
import struct
import subprocess
import sys
import time
from pathlib import Path


TYPE_A = 1
TYPE_DNSKEY = 48
TYPE_RRSIG = 46
TYPE_NSEC = 47
TYPE_OPT = 41
CLASS_IN = 1
ALG_RSASHA256 = 8
DNSKEY_FLAGS_KSK = 257
DNSKEY_PROTOCOL = 3
TTL = 300


def canon_name(name):
    if name in ("", "."):
        return "."
    return name.rstrip(".").lower() + "."


def label_count(name):
    name = canon_name(name)
    if name == ".":
        return 0
    return len(name.rstrip(".").split("."))


def encode_name(name):
    name = canon_name(name)
    if name == ".":
        return b"\x00"
    out = bytearray()
    for label in name.rstrip(".").split("."):
        raw = label.encode("ascii")
        if len(raw) > 63:
            raise ValueError(f"label too long: {label!r}")
        out.append(len(raw))
        out.extend(raw)
    out.append(0)
    return bytes(out)


def decode_name(packet, offset):
    labels = []
    jumped = False
    end_offset = offset
    seen = set()

    while True:
        if offset >= len(packet):
            raise ValueError("name exceeds packet")
        length = packet[offset]
        if length & 0xC0 == 0xC0:
            if offset + 1 >= len(packet):
                raise ValueError("truncated compression pointer")
            ptr = ((length & 0x3F) << 8) | packet[offset + 1]
            if ptr in seen:
                raise ValueError("compression pointer loop")
            seen.add(ptr)
            if not jumped:
                end_offset = offset + 2
            offset = ptr
            jumped = True
            continue
        if length == 0:
            if not jumped:
                end_offset = offset + 1
            break
        offset += 1
        labels.append(packet[offset:offset + length].decode("ascii").lower())
        offset += length

    return (".".join(labels) + "." if labels else "."), end_offset


def parse_question(packet):
    qname, offset = decode_name(packet, 12)
    qtype, qclass = struct.unpack("!HH", packet[offset:offset + 4])
    return qname, qtype, qclass, packet[12:offset + 4]


def rr(owner, rrtype, rrclass, ttl, rdata):
    return encode_name(owner) + struct.pack("!HHIH", rrtype, rrclass, ttl, len(rdata)) + rdata


def canonical_rr(owner, rrtype, rrclass, ttl, rdata):
    return encode_name(owner) + struct.pack("!HHIH", rrtype, rrclass, ttl, len(rdata)) + rdata


def keytag(dnskey_rdata):
    acc = 0
    for i, byte in enumerate(dnskey_rdata):
        acc += byte << 8 if i % 2 == 0 else byte
    acc += (acc >> 16) & 0xFFFF
    return acc & 0xFFFF


def pkcs1_v1_5_sha256_sign(data, modulus, private_exponent):
    digest = hashlib.sha256(data).digest()
    digest_info = bytes.fromhex("3031300d060960864801650304020105000420") + digest
    k = (modulus.bit_length() + 7) // 8
    if len(digest_info) + 11 > k:
        raise ValueError("RSA key too small")
    encoded = b"\x00\x01" + (b"\xff" * (k - len(digest_info) - 3)) + b"\x00" + digest_info
    sig_int = pow(int.from_bytes(encoded, "big"), private_exponent, modulus)
    return sig_int.to_bytes(k, "big")


class RootKey:
    def __init__(self, state_dir):
        self.state_dir = Path(state_dir)
        self.private_file = self._find_or_create_key()
        fields = self._parse_private(self.private_file)
        self.modulus = int.from_bytes(fields["Modulus"], "big")
        self.private_exponent = int.from_bytes(fields["PrivateExponent"], "big")
        self.public_exponent = fields["PublicExponent"]
        self.public_key = self._dnskey_public_key(fields["PublicExponent"], fields["Modulus"])
        self.dnskey_rdata = struct.pack("!HBB", DNSKEY_FLAGS_KSK, DNSKEY_PROTOCOL, ALG_RSASHA256) + self.public_key
        self.keytag = keytag(self.dnskey_rdata)
        self.ds_digest = hashlib.sha256(encode_name(".") + self.dnskey_rdata).hexdigest().upper()

    def _find_or_create_key(self):
        self.state_dir.mkdir(parents=True, exist_ok=True)
        existing = sorted(self.state_dir.glob("K.+008+*.private"))
        if existing:
            return existing[0]
        subprocess.run(
            ["dnssec-keygen", "-K", str(self.state_dir), "-a", "RSASHA256", "-b", "1024", "-n", "ZONE", "-f", "KSK", "."],
            check=True,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
        )
        existing = sorted(self.state_dir.glob("K.+008+*.private"))
        if not existing:
            raise RuntimeError("dnssec-keygen did not create a root private key")
        return existing[0]

    @staticmethod
    def _parse_private(path):
        fields = {}
        for line in Path(path).read_text().splitlines():
            if ": " not in line:
                continue
            key, value = line.split(": ", 1)
            if key in {"Modulus", "PublicExponent", "PrivateExponent"}:
                fields[key] = base64.b64decode(value)
        missing = {"Modulus", "PublicExponent", "PrivateExponent"} - set(fields)
        if missing:
            raise RuntimeError(f"missing private key fields: {sorted(missing)}")
        return fields

    @staticmethod
    def _dnskey_public_key(exponent, modulus):
        if len(exponent) < 256:
            exp_len = bytes([len(exponent)])
        else:
            exp_len = b"\x00" + struct.pack("!H", len(exponent))
        return exp_len + exponent + modulus

    def trust_anchor(self):
        return f"--trust-anchor=.,{self.keytag},{ALG_RSASHA256},2,{self.ds_digest}"

    def sign_rrset(self, owner, rrtype, rdata_list, ttl=TTL, signer="."):
        now = int(time.time())
        inception = now - 3600
        expiration = now + 86400
        owner = canon_name(owner)
        signer = canon_name(signer)
        rrset = b"".join(
            sorted(canonical_rr(owner, rrtype, CLASS_IN, ttl, rdata) for rdata in rdata_list)
        )
        signed_prefix = (
            struct.pack(
                "!HBBIIIH",
                rrtype,
                ALG_RSASHA256,
                label_count(owner),
                ttl,
                expiration,
                inception,
                self.keytag,
            )
            + encode_name(signer)
        )
        return signed_prefix + pkcs1_v1_5_sha256_sign(signed_prefix + rrset, self.modulus, self.private_exponent)

    def malformed_rrsig_prefix_only(self, owner, rrtype, ttl=TTL):
        now = int(time.time())
        return struct.pack(
            "!HBBIIIH",
            rrtype,
            ALG_RSASHA256,
            label_count(owner),
            ttl,
            now + 86400,
            now - 3600,
            self.keytag,
        )


class Upstream:
    def __init__(self, state_dir):
        self.key = RootKey(state_dir)

    def response_header(self, query, ancount, nscount, arcount, rcode=0):
        qid, flags, qdcount = struct.unpack("!HHH", query[:6])
        rd = flags & 0x0100
        resp_flags = 0x8400 | rd | rcode
        return struct.pack("!HHHHHH", qid, resp_flags, qdcount, ancount, nscount, arcount)

    def dnskey_response(self, query, question):
        dnskey = rr(".", TYPE_DNSKEY, CLASS_IN, TTL, self.key.dnskey_rdata)
        sig = rr(".", TYPE_RRSIG, CLASS_IN, TTL, self.key.sign_rrset(".", TYPE_DNSKEY, [self.key.dnskey_rdata]))
        return self.response_header(query, 2, 0, 0) + question + dnskey + sig

    def crash_response(self, query, question):
        a_rdata = socket.inet_aton("192.0.2.123")
        a_rr = rr("crash.", TYPE_A, CLASS_IN, TTL, a_rdata)
        malformed_sig = rr("crash.", TYPE_RRSIG, CLASS_IN, TTL, self.key.malformed_rrsig_prefix_only("crash.", TYPE_A))
        opt_rr = encode_name(".") + struct.pack("!HHIH", TYPE_OPT, 1232, 0, 0)
        return self.response_header(query, 2, 0, 1) + question + a_rr + malformed_sig + opt_rr

    def hang_response(self, query, question):
        nsec_rdata = encode_name("zzz.") + b"\x01\x00"
        nsec_rr = rr("hang.", TYPE_NSEC, CLASS_IN, TTL, nsec_rdata)
        sig_rr = rr("hang.", TYPE_RRSIG, CLASS_IN, TTL, self.key.sign_rrset("hang.", TYPE_NSEC, [nsec_rdata]))
        return self.response_header(query, 0, 2, 0) + question + nsec_rr + sig_rr

    def servfail_response(self, query, question):
        return self.response_header(query, 0, 0, 0, rcode=2) + question

    def build_response(self, query):
        qname, qtype, qclass, question = parse_question(query)
        print(f"query {qname} type={qtype} class={qclass}", flush=True)
        if qclass == CLASS_IN and qname == "." and qtype == TYPE_DNSKEY:
            return self.dnskey_response(query, question)
        if qclass == CLASS_IN and qname == "crash." and qtype == TYPE_A:
            return self.crash_response(query, question)
        if qclass == CLASS_IN and qname == "hang." and qtype == TYPE_A:
            return self.hang_response(query, question)
        return self.servfail_response(query, question)

    def serve(self, host, port):
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sock.bind((host, port))
        print(f"fake upstream listening on {host}:{port}", flush=True)
        print(self.key.trust_anchor(), flush=True)
        while True:
            query, addr = sock.recvfrom(4096)
            try:
                response = self.build_response(query)
            except Exception as exc:
                print(f"error building response: {exc}", file=sys.stderr, flush=True)
                continue
            sock.sendto(response, addr)


def init_state(state_dir):
    key = RootKey(state_dir)
    trust_path = Path(state_dir) / "trust-anchor.txt"
    trust_path.write_text(key.trust_anchor() + "\n")
    print(f"state_dir={state_dir}")
    print(f"keytag={key.keytag}")
    print(f"trust_anchor={key.trust_anchor()}")
    print(f"wrote {trust_path}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--state-dir", default="/private/tmp/dnsmasq-live-state")
    parser.add_argument("--init", action="store_true")
    parser.add_argument("--serve", action="store_true")
    parser.add_argument("--host", default="127.0.0.1")
    parser.add_argument("--port", type=int, default=5300)
    args = parser.parse_args()

    if args.init:
        init_state(args.state_dir)
    if args.serve:
        Upstream(args.state_dir).serve(args.host, args.port)
    if not args.init and not args.serve:
        parser.error("use --init and/or --serve")


if __name__ == "__main__":
    main()

