#!/usr/bin/python3
"""AptCli::Hooks::Install pre-prompt hook for getpagespeed-extras-release.

Mirrors the RPM-side dnf-plugin.py UX on Debian/Ubuntu. Fires AT THE
RESOLUTION STAGE, before any download — so an unsubscribed customer
sees the friendly ANSI subscribe-link nag before they hit a 403 at the
server's gate (R3). The earlier Pre-Install-Pkgs hook ran post-download
which defeats that promise once the gate ships.

Protocol (apt 1.6+, JSON-RPC 2.0 on $APT_HOOK_SOCKET, messages
separated by `\\n\\n`):
  apt -> hook: org.debian.apt.hooks.hello (request, has id)
  hook -> apt: result with {"version": "0.1"}
  apt -> hook: org.debian.apt.hooks.install.pre-prompt (notification)
  apt -> hook: org.debian.apt.hooks.install.package-list (notification)
  apt -> hook: org.debian.apt.hooks.install.statistics (notification)
  apt -> hook: org.debian.apt.hooks.bye (request, has id)
  hook -> apt: result with null

The hook decides on pre-prompt. If any package to be installed/upgraded
has an origin pointing at extras.getpagespeed.com (or the GetPageSpeed
Origin string), we pre-flight https://www.getpagespeed.com/ip2.php; if
the response body contains 'link:' the body is written verbatim to
stderr (carries the ANSI subscribe-link block from ip.php) and the hook
exits 1 — apt's RunJsonHook detects the abort and bails the transaction
with `E: Failure running hook /usr/lib/getpagespeed-extras/subscribe-check`
BEFORE any download is attempted. Network failure -> silent pass (apt
proceeds; same posture as dnf-plugin.py).
"""

import json
import os
import socket
import sys
from urllib.parse import quote_plus
from urllib.request import Request, urlopen
from urllib.error import HTTPError, URLError

GPS_SITE = "extras.getpagespeed.com"
GPS_ORIGIN = "GetPageSpeed"
IP2_URL = "https://www.getpagespeed.com/ip2.php"
USER_AGENT = "getpagespeed-apt-plugin"
TIMEOUT = 3

# Self-skip: the release package itself must always be installable
# (chicken-and-egg). Mirror of dnf-plugin.py's skip-list pattern.
SELF_PKG = "getpagespeed-extras-release"

INSTALL_MODES = ("install", "upgrade", "downgrade", "reinstall")
MSG_SEP = b"\n\n"


class Wire:
    def __init__(self, sock):
        self.sock = sock
        self.buf = b""

    def recv(self):
        while MSG_SEP not in self.buf:
            chunk = self.sock.recv(8192)
            if not chunk:
                return None
            self.buf += chunk
        raw, self.buf = self.buf.split(MSG_SEP, 1)
        return json.loads(raw)

    def send(self, msg):
        self.sock.sendall((json.dumps(msg) + "\n\n").encode("utf-8"))


def gps_origin(versions):
    if not isinstance(versions, dict):
        return False
    for key in ("install", "candidate"):
        ver = versions.get(key)
        if not isinstance(ver, dict):
            continue
        for o in ver.get("origins") or []:
            if not isinstance(o, dict):
                continue
            if o.get("site") == GPS_SITE:
                return True
            if o.get("origin") == GPS_ORIGIN:
                return True
    return False


def paid_packages(params):
    pkgs = params.get("packages") or []
    paid = []
    for pkg in pkgs:
        if not isinstance(pkg, dict):
            continue
        if pkg.get("name") == SELF_PKG:
            continue
        if pkg.get("mode") not in INSTALL_MODES:
            continue
        if not gps_origin(pkg.get("versions")):
            continue
        paid.append(pkg.get("name"))
    return paid


def ip2_nag(paid):
    query = "&".join("packages[]=" + quote_plus(p) for p in paid)
    url = "{}?{}".format(IP2_URL, query)
    try:
        req = Request(url, headers={"User-Agent": USER_AGENT})
        body = urlopen(req, timeout=TIMEOUT).read()
    except (HTTPError, URLError, OSError):
        return None
    text = body.decode("utf-8", errors="replace")
    if "link:" not in text:
        return None
    return text


def emit_nag(paid, text):
    sep = "=" * 71
    sys.stderr.write(sep + "\n")
    sys.stderr.write("Premium packages in your transaction require a subscription:\n")
    for p in paid:
        sys.stderr.write(" - {}\n".format(p))
    sys.stderr.write(text)
    if not text.endswith("\n"):
        sys.stderr.write("\n")
    sys.stderr.write(sep + "\n")
    sys.stderr.flush()


def main():
    fd_env = os.environ.get("APT_HOOK_SOCKET")
    if fd_env is None:
        return 0
    try:
        fd = int(fd_env)
    except ValueError:
        return 0
    wire = Wire(socket.socket(fileno=fd))

    hello = wire.recv()
    if not hello or hello.get("method") != "org.debian.apt.hooks.hello":
        return 0
    wire.send({
        "jsonrpc": "2.0",
        "id": hello.get("id"),
        "result": {"version": "0.1"},
    })

    while True:
        msg = wire.recv()
        if msg is None:
            break
        method = msg.get("method", "")
        if method == "org.debian.apt.hooks.install.pre-prompt":
            paid = paid_packages(msg.get("params") or {})
            if paid:
                text = ip2_nag(paid)
                if text is not None:
                    emit_nag(paid, text)
                    return 1
        elif method == "org.debian.apt.hooks.bye":
            wire.send({
                "jsonrpc": "2.0",
                "id": msg.get("id"),
                "result": None,
            })
            break

    return 0


if __name__ == "__main__":
    try:
        sys.exit(main())
    except SystemExit:
        raise
    except Exception:
        sys.exit(0)
