"""Reusable signing primitives for vote-mcp management HTTP requests.

This module is client-side helper code for constructing RFC 9421 signatures and
RFC 9530 Content-Digest headers in a deterministic way.
"""

from __future__ import annotations

import base64
import hashlib
import json
import re
import secrets
import time
from typing import Iterable

from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey

MANAGEMENT_SIGNATURE_SCHEME = "ed25519_statement_v1"
_ALLOWED_COMPONENTS = frozenset({"@method", "@path", "@query", "content-digest", "idempotency-key"})
_KEY_ID_RE = re.compile(r"^sha256:[0-9a-f]{64}$")
_NONCE_RE = re.compile(r"^[A-Za-z0-9._~-]{8,128}$")
_SEED_HEX_RE = re.compile(r"^[0-9a-f]{64}$")
_IDEMPOTENCY_KEY_RE = re.compile(r"^[A-Za-z0-9._~-]{20,200}$")


def canonical_json_bytes(payload: object) -> bytes:
    """Deterministically serialize JSON payload to UTF-8 bytes for signing/digest."""
    return json.dumps(
        payload,
        separators=(",", ":"),
        sort_keys=True,
        allow_nan=False,
        ensure_ascii=False,
    ).encode("utf-8")


def content_digest_sha256_header(body_bytes: bytes | bytearray | memoryview) -> str:
    """Build RFC 9530 `Content-Digest` header value for exact body bytes."""
    if not isinstance(body_bytes, (bytes, bytearray, memoryview)):
        raise ValueError("body_bytes must be bytes-like.")
    digest = hashlib.sha256(bytes(body_bytes)).digest()
    encoded = base64.b64encode(digest).decode("ascii")
    return f"sha-256=:{encoded}:"


def private_key_from_seed_hex(seed_hex: str) -> Ed25519PrivateKey:
    """Reconstruct Ed25519 private key from a 32-byte lowercase hex seed."""
    if not isinstance(seed_hex, str) or _SEED_HEX_RE.fullmatch(seed_hex) is None:
        raise ValueError("seed_hex must be 32-byte lowercase hex (64 chars).")
    return Ed25519PrivateKey.from_private_bytes(bytes.fromhex(seed_hex))


def is_valid_idempotency_key(idempotency_key: str) -> bool:
    """Return whether idempotency key matches the API contract pattern."""
    if not isinstance(idempotency_key, str):
        return False
    return _IDEMPOTENCY_KEY_RE.fullmatch(idempotency_key) is not None


def require_valid_idempotency_key(idempotency_key: str) -> str:
    """Validate idempotency key and return it for fluent usage."""
    if not is_valid_idempotency_key(idempotency_key):
        raise ValueError("idempotency_key must match `[A-Za-z0-9._~-]{20,200}`.")
    return idempotency_key


def _public_key_raw_bytes(private_key: Ed25519PrivateKey) -> bytes:
    return private_key.public_key().public_bytes(
        encoding=serialization.Encoding.Raw,
        format=serialization.PublicFormat.Raw,
    )


def public_key_base64url_from_private_key(private_key: Ed25519PrivateKey) -> str:
    """Encode Ed25519 public key as canonical unpadded base64url."""
    return base64.urlsafe_b64encode(_public_key_raw_bytes(private_key)).decode("ascii").rstrip("=")


def keyid_sha256_from_private_key(private_key: Ed25519PrivateKey) -> str:
    """Derive `keyid` as `sha256:<64 hex>` from raw public-key bytes."""
    return "sha256:" + hashlib.sha256(_public_key_raw_bytes(private_key)).hexdigest()


def owner_declaration_from_seed_hex(seed_hex: str) -> dict[str, str]:
    """Build create payload `owner` object for a private key seed."""
    private_key = private_key_from_seed_hex(seed_hex)
    return {
        "scheme": MANAGEMENT_SIGNATURE_SCHEME,
        "public_key": public_key_base64url_from_private_key(private_key),
    }


def build_signature_input_value(
    *,
    covered_components: Iterable[str],
    keyid: str,
    created: int,
    expires: int,
    nonce: str,
) -> str:
    """Build RFC 9421 `sig1` signature-input value (without `sig1=` prefix)."""
    if not isinstance(keyid, str) or _KEY_ID_RE.fullmatch(keyid) is None:
        raise ValueError("keyid must be `sha256:<64 lowercase hex>`.")
    if not isinstance(created, int) or not isinstance(expires, int):
        raise ValueError("created and expires must be integer Unix seconds.")
    if expires <= created:
        raise ValueError("expires must be greater than created.")
    if not isinstance(nonce, str) or _NONCE_RE.fullmatch(nonce) is None:
        raise ValueError("nonce must match `[A-Za-z0-9._~-]{8,128}`.")

    normalized_components: list[str] = []
    for component in covered_components:
        if component not in _ALLOWED_COMPONENTS:
            raise ValueError(f"unsupported covered component: {component}")
        if component in normalized_components:
            raise ValueError(f"duplicate covered component: {component}")
        normalized_components.append(component)

    if not normalized_components:
        raise ValueError("covered_components must not be empty.")

    covered = " ".join(f'"{component}"' for component in normalized_components)
    return (
        f"({covered});created={created};expires={expires};"
        f'keyid="{keyid}";nonce="{nonce}"'
    )


def _format_component_value(
    *,
    component: str,
    method: str,
    path: str,
    query_string: str,
    idempotency_key: str | None,
    content_digest_header: str | None,
) -> str:
    if component == "@method":
        return method.lower()
    if component == "@path":
        return path
    if component == "@query":
        return f"?{query_string}" if query_string else "?"
    if component == "idempotency-key":
        if idempotency_key is None:
            raise ValueError("idempotency-key covered component requires idempotency_key value.")
        return idempotency_key
    if component == "content-digest":
        if content_digest_header is None:
            raise ValueError("content-digest covered component requires content_digest_header value.")
        return content_digest_header
    raise ValueError(f"unsupported covered component: {component}")


def build_signature_base_bytes(
    *,
    method: str,
    path: str,
    query_string: str,
    covered_components: Iterable[str],
    signature_input_value: str,
    idempotency_key: str | None = None,
    content_digest_header: str | None = None,
) -> bytes:
    """Build canonical signature-base bytes, including `@signature-params` line."""
    if not isinstance(method, str) or not method:
        raise ValueError("method must be a non-empty string.")
    if not isinstance(path, str) or not path.startswith("/"):
        raise ValueError("path must be an absolute route path.")
    if not isinstance(query_string, str):
        raise ValueError("query_string must be a string.")
    if "?" in query_string:
        raise ValueError("query_string must not include '?'.")
    if not isinstance(signature_input_value, str) or not signature_input_value:
        raise ValueError("signature_input_value must be non-empty.")

    lines: list[str] = []
    for component in covered_components:
        value = _format_component_value(
            component=component,
            method=method,
            path=path,
            query_string=query_string,
            idempotency_key=idempotency_key,
            content_digest_header=content_digest_header,
        )
        lines.append(f'"{component}": {value}')
    lines.append(f'"@signature-params": {signature_input_value}')
    return "\n".join(lines).encode("utf-8")


def build_management_headers(
    *,
    private_key_seed_hex: str,
    method: str,
    path: str,
    body: bytes | bytearray | memoryview,
    query_string: str = "",
    idempotency_key: str | None = None,
    include_content_digest: bool | None = None,
    created: int | None = None,
    expires: int | None = None,
    expires_in_seconds: int = 300,
    nonce: str | None = None,
) -> dict[str, str]:
    """Assemble management headers (`Signature-Input`, `Signature`, optional digest/idempotency)."""
    private_key = private_key_from_seed_hex(private_key_seed_hex)
    body_bytes = bytes(body)

    if not isinstance(method, str) or not method:
        raise ValueError("method must be a non-empty string.")
    method_upper = method.upper()

    use_content_digest = include_content_digest
    if use_content_digest is None:
        use_content_digest = method_upper != "GET"

    resolved_created = int(time.time()) if created is None else created
    if not isinstance(resolved_created, int):
        raise ValueError("created must be integer Unix seconds.")

    resolved_expires: int
    if expires is None:
        if not isinstance(expires_in_seconds, int) or expires_in_seconds <= 0:
            raise ValueError("expires_in_seconds must be a positive integer.")
        resolved_expires = resolved_created + expires_in_seconds
    else:
        resolved_expires = expires

    resolved_nonce = nonce or f"nonce_{secrets.token_hex(8)}"

    covered_components: list[str] = ["@method", "@path", "@query"]
    content_digest_header = None
    if use_content_digest:
        content_digest_header = content_digest_sha256_header(body_bytes)
        covered_components.append("content-digest")
    if idempotency_key is not None:
        covered_components.append("idempotency-key")
        idempotency_key = require_valid_idempotency_key(idempotency_key)

    keyid = keyid_sha256_from_private_key(private_key)
    signature_input_value = build_signature_input_value(
        covered_components=covered_components,
        keyid=keyid,
        created=resolved_created,
        expires=resolved_expires,
        nonce=resolved_nonce,
    )
    signature_base = build_signature_base_bytes(
        method=method_upper,
        path=path,
        query_string=query_string,
        covered_components=covered_components,
        signature_input_value=signature_input_value,
        idempotency_key=idempotency_key,
        content_digest_header=content_digest_header,
    )
    signature = base64.b64encode(private_key.sign(signature_base)).decode("ascii")

    headers: dict[str, str] = {
        "Signature-Input": f"sig1={signature_input_value}",
        "Signature": f"sig1=:{signature}:",
    }
    if content_digest_header is not None:
        headers["Content-Digest"] = content_digest_header
    if idempotency_key is not None:
        headers["Idempotency-Key"] = idempotency_key
    return headers
