diff --git a/.gitignore b/.gitignore index 9bb50ff..a404f07 100644 --- a/.gitignore +++ b/.gitignore @@ -20,4 +20,5 @@ src/certapi.egg-info !/.vscode/launch.json .venv .env -test_db_temp/ \ No newline at end of file +test_db_temp/ +*.key diff --git a/src/certapi/__init__.py b/src/certapi/__init__.py index 7909e06..426508d 100644 --- a/src/certapi/__init__.py +++ b/src/certapi/__init__.py @@ -22,3 +22,4 @@ ) from .issuers import CertIssuer, SelfCertIssuer, AcmeCertIssuer from .client import CertManagerClient +from .domain_batching import create_safe_domain_batches diff --git a/src/certapi/cli.py b/src/certapi/cli.py index d717f7a..431db6f 100644 --- a/src/certapi/cli.py +++ b/src/certapi/cli.py @@ -111,7 +111,7 @@ def _ensure_port_80_available() -> None: sys.exit(1) -def obtain_certificate(domains: List[str], api_key: Optional[str] = None): +def obtain_certificate(domains: List[str], api_key: Optional[str] = None, keystore_path: str = "/etc/ssl"): challenge_solver = None server = None @@ -124,7 +124,6 @@ def obtain_certificate(domains: List[str], api_key: Optional[str] = None): print("Starting HTTP challenge server on port 80...") server, _ = _start_http_challenge_server(challenge_solver, port=80) - keystore_path = "/etc/ssl" key_store = FileSystemKeyStore(keystore_path) cert_issuer = AcmeCertIssuer.with_keystore( key_store, @@ -140,7 +139,8 @@ def obtain_certificate(domains: List[str], api_key: Optional[str] = None): cert_chain = certs_from_pem(cert.encode("utf-8")) leaf_cert = cert_chain[0] if cert_chain else None - expiry = leaf_cert.not_valid_after.isoformat() if leaf_cert else "unknown" + expiry_dt = leaf_cert.not_valid_after_utc if leaf_cert else None + expiry = expiry_dt.isoformat() if expiry_dt else "unknown" key_path = os.path.join(key_store.keys_dir, f"{key_name}.key") cert_path = os.path.join(key_store.certs_dir, f"{key_name}.crt") @@ -219,14 +219,60 @@ def verify_environment(domains: List[str], api_key: Optional[str] = None) -> Non def main(): - parser = argparse.ArgumentParser(prog="certapi") - subparsers = parser.add_subparsers(dest="command") - - verify_parser = subparsers.add_parser("verify", help="Verify certapi installation and environment.") - verify_parser.add_argument("domains", nargs="*", help="Optional domain(s) to verify.") + class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter): + pass + + parser = argparse.ArgumentParser( + prog="certapi", + formatter_class=HelpFormatter, + description=( + "Issue and verify ACME certificates.\n\n" + "Challenge mode selection:\n" + "- If CLOUDFLARE_API_KEY or CLOUDFLARE_API_TOKEN is set, DNS-01 is used.\n" + "- Otherwise HTTP-01 is used (requires binding to port 80)." + ), + epilog=( + "Cloudflare credentials are read from environment variables: " + "CLOUDFLARE_API_KEY or CLOUDFLARE_API_TOKEN.\n\n" + "Examples:\n" + " certapi verify example.com\n" + " certapi obtain --path ./certs example.com www.example.com\n" + " CLOUDFLARE_API_TOKEN= certapi obtain --path . example.com" + ), + ) + subparsers = parser.add_subparsers( + dest="command", + title="commands", + metavar="{verify,obtain}", + ) - obtain_parser = subparsers.add_parser("obtain", help="Obtain certificate for domains.") - obtain_parser.add_argument("domains", nargs="+", help="Domain(s) to obtain certificate for.") + verify_parser = subparsers.add_parser( + "verify", + help="Verify certapi installation and challenge readiness.", + description=( + "Verify environment readiness.\n" + "- With Cloudflare credentials: checks DNS zone ownership for provided domains.\n" + "- Without credentials: checks local HTTP challenge serving/routing for domains." + ), + formatter_class=HelpFormatter, + ) + verify_parser.add_argument("domains", nargs="*", help="Optional domain(s) to validate.") + + obtain_parser = subparsers.add_parser( + "obtain", + help="Issue a certificate for one or more domains.", + description=( + "Obtain a certificate and write key/cert files to the selected path.\n" + "When Cloudflare credentials are set, DNS-01 challenge is used." + ), + formatter_class=HelpFormatter, + ) + obtain_parser.add_argument("domains", nargs="+", help="Domain(s) to include in the certificate.") + obtain_parser.add_argument( + "--path", + default="/etc/ssl", + help="Directory where key/cert files will be written.", + ) args = parser.parse_args() @@ -236,7 +282,7 @@ def main(): sys.exit(0) elif args.command == "obtain": api_key = _resolve_cloudflare_api_key() - obtain_certificate(args.domains, api_key=api_key) + obtain_certificate(args.domains, api_key=api_key, keystore_path=args.path) else: parser.print_help() sys.exit(1) diff --git a/src/certapi/client/cert_manager_client.py b/src/certapi/client/cert_manager_client.py index ff8d9ab..dd3d56c 100644 --- a/src/certapi/client/cert_manager_client.py +++ b/src/certapi/client/cert_manager_client.py @@ -49,6 +49,7 @@ def issue_certificate( user_id: Optional[str] = None, renew_threshold_days: Optional[int] = None, skip_failing: bool = True, + batch_domains: bool = False, ) -> CertificateResponse: params = { "hostname": hosts if isinstance(hosts, str) else hosts, @@ -68,6 +69,7 @@ def issue_certificate( if renew_threshold_days is not None: params["renew_threshold_days"] = renew_threshold_days params["skip_failing"] = skip_failing + params["batch_domains"] = batch_domains data = self._get("/api/obtain", params=params) res = CertificateResponse.from_json(data) diff --git a/src/certapi/domain_batching.py b/src/certapi/domain_batching.py new file mode 100644 index 0000000..f04bb1c --- /dev/null +++ b/src/certapi/domain_batching.py @@ -0,0 +1,82 @@ +def _normalize_domain(domain: str) -> str: + domain = domain.strip().lower().rstrip(".") + if domain.startswith("*."): + domain = domain[2:] + return domain + + +def _split_labels(domain: str) -> list[str]: + normalized = _normalize_domain(domain) + return [label for label in normalized.split(".") if label] + + +def would_trigger(labels: list[str], blocked: set[str]) -> bool: + """ + Boulder-style recursive on-demand trigger check used by tests. + """ + if len(labels) < 4: + return False + + blocked = {label.lower() for label in blocked} + + consecutive_blocked = 0 + prev_label: str | None = None + for label in labels: + label = label.lower() + is_blocked = label in blocked + if is_blocked: + consecutive_blocked += 1 + if prev_label is not None and prev_label == label: + return True + if consecutive_blocked >= 3: + return True + else: + consecutive_blocked = 0 + prev_label = label + return False + + +def split_domain_to_safe_groups(domain: str, blocked_labels: list[str] | None = None) -> list[list[str]]: + """ + Split labels into contiguous chunks of at most 3 labels. + This mirrors the "3 labels are always safe for this specific LE heuristic" + invariant, but these label groups are not used directly as issuance domains. + """ + labels = _split_labels(domain) + n = len(labels) + if n == 0: + return [] + return [labels[i : i + 3] for i in range(0, n, 3)] + + +def create_safe_domain_batches(domains: list[str], blocked_labels: list[str] | None = None) -> list[list[str]]: + """ + Create compact issuance batches without any blocked-label list assumptions. + + Rules: + - Domains with <= 3 labels are grouped together into one compact batch. + - Longer domains are split into safe label groups and emitted as singleton batches. + + This avoids guessed label lists and keeps behavior deterministic. + """ + compact_batch: list[str] = [] + singleton_batches: list[list[str]] = [] + seen: set[str] = set() + + for domain in domains: + normalized = _normalize_domain(domain) + if not normalized or normalized in seen: + continue + seen.add(normalized) + + labels = [label for label in normalized.split(".") if label] + if len(labels) <= 3: + compact_batch.append(normalized) + continue + + for i in range(0, len(labels), 3): + singleton_batches.append([".".join(labels[i : i + 3])]) + + if compact_batch: + return [compact_batch, *singleton_batches] + return singleton_batches diff --git a/src/certapi/manager/acme_cert_manager.py b/src/certapi/manager/acme_cert_manager.py index 7d227f8..b0980e6 100644 --- a/src/certapi/manager/acme_cert_manager.py +++ b/src/certapi/manager/acme_cert_manager.py @@ -1,5 +1,5 @@ import time -from typing import List, Literal, Optional, Tuple, Union, Dict +from typing import Callable, List, Literal, Optional, Tuple, Union, Dict from datetime import datetime, timezone, timedelta from certapi import crypto @@ -11,6 +11,7 @@ from ..keystore.KeyStore import KeyStore from cryptography.x509 import Certificate, CertificateSigningRequest from ..crypto import Key, certs_to_pem, cert_to_pem, get_csr_hostnames +from ..domain_batching import create_safe_domain_batches DEFAULT_RENEW_THRESHOLD_DAYS = 62 @@ -75,8 +76,59 @@ def issue_certificate( user_id: Optional[str] = None, renew_threshold_days: Optional[int] = None, ) -> CertificateResponse: + return self._issue_certificate_internal( + hosts=hosts, + key_type=key_type, + expiry_days=expiry_days, + country=country, + state=state, + locality=locality, + organization=organization, + user_id=user_id, + renew_threshold_days=renew_threshold_days, + batch_generator=None, + ) + + def issue_certificate_in_batches( + self, + hosts: Union[str, List[str]], + key_type: Literal["rsa", "ecdsa", "ed25519"] = "ecdsa", + expiry_days: int = 90, + country: Optional[str] = None, + state: Optional[str] = None, + locality: Optional[str] = None, + organization: Optional[str] = None, + user_id: Optional[str] = None, + renew_threshold_days: Optional[int] = None, + batch_generator: Callable[[List[str]], List[List[str]]] = create_safe_domain_batches, + ) -> CertificateResponse: + return self._issue_certificate_internal( + hosts=hosts, + key_type=key_type, + expiry_days=expiry_days, + country=country, + state=state, + locality=locality, + organization=organization, + user_id=user_id, + renew_threshold_days=renew_threshold_days, + batch_generator=batch_generator, + ) - if type(hosts) == str: + def _issue_certificate_internal( + self, + hosts: Union[str, List[str]], + key_type: Literal["rsa", "ecdsa", "ed25519"] = "ecdsa", + expiry_days: int = 90, + country: Optional[str] = None, + state: Optional[str] = None, + locality: Optional[str] = None, + organization: Optional[str] = None, + user_id: Optional[str] = None, + renew_threshold_days: Optional[int] = None, + batch_generator: Optional[Callable[[List[str]], List[List[str]]]] = None, + ) -> CertificateResponse: + if isinstance(hosts, str): hosts = [hosts] existing: Dict[str, Tuple[int | str, Key, List[Certificate] | str]] = {} @@ -109,25 +161,30 @@ def issue_certificate( print(f"Warning: No challenge solver found that supports domain: {host}. Skipping.") for store, domains_to_issue in domains_by_store.items(): - - private_key, fullchain_cert = self.cert_issuer.generate_key_and_cert_for_domains( - domains_to_issue, - key_type=key_type, - expiry_days=expiry_days, - country=country, - state=state, - locality=locality, - organization=organization, - user_id=user_id, - challenge_solver=store, + issuance_batches = ( + batch_generator(domains_to_issue) if batch_generator is not None else [domains_to_issue] ) - - if fullchain_cert: - key_id = self.key_store.save_key(private_key, domains_to_issue[0]) - self.key_store.save_cert(key_id, fullchain_cert, domains_to_issue) - issued_certs_list.append(IssuedCert(key=private_key, cert=fullchain_cert, domains=domains_to_issue)) - else: - print(f"Failed to issue certificate for domains: {domains_to_issue}") + for batch in issuance_batches: + if not batch: + continue + private_key, fullchain_cert = self.cert_issuer.generate_key_and_cert_for_domains( + batch, + key_type=key_type, + expiry_days=expiry_days, + country=country, + state=state, + locality=locality, + organization=organization, + user_id=user_id, + challenge_solver=store, + ) + + if fullchain_cert: + key_id = self.key_store.save_key(private_key, batch[0]) + self.key_store.save_cert(key_id, fullchain_cert, batch) + issued_certs_list.append(IssuedCert(key=private_key, cert=fullchain_cert, domains=batch)) + else: + print(f"Failed to issue certificate for domains: {batch}") # self.cert_issuer.challenge_solver = original_challenge_solver # Restore original return createExistingResponse(existing, issued_certs_list) diff --git a/src/certapi/server/api.py b/src/certapi/server/api.py index cc2e22e..bc48982 100644 --- a/src/certapi/server/api.py +++ b/src/certapi/server/api.py @@ -97,6 +97,12 @@ def create_api_resources(api_ns, cert_manager: AcmeCertManager, renew_queue_size default=False, help="Allow issuance to proceed if some domains fail verification", ) + obtain_parser.add_argument( + "batch_domains", + type=inputs.boolean, + default=False, + help="Issue certificates in separate safe batches instead of one combined order", + ) @api_ns.route("/obtain") class ObtainCert(Resource): @@ -128,7 +134,12 @@ def get(self): if not hostnames: return CertificateResponse().to_json() - data = cert_manager.issue_certificate( + issue_fn = ( + cert_manager.issue_certificate_in_batches + if args.get("batch_domains") + else cert_manager.issue_certificate + ) + data = issue_fn( hostnames, key_type=args["key_type"], expiry_days=args["expiry_days"], diff --git a/tests/test_acme_cert_manager_batching.py b/tests/test_acme_cert_manager_batching.py new file mode 100644 index 0000000..f1c0e83 --- /dev/null +++ b/tests/test_acme_cert_manager_batching.py @@ -0,0 +1,95 @@ +from certapi.domain_batching import create_safe_domain_batches, split_domain_to_safe_groups, would_trigger +from certapi.manager.acme_cert_manager import AcmeCertManager + + +class DummyKeyStore: + def __init__(self): + self.saved_keys = [] + self.saved_certs = [] + + def find_key_and_cert_by_domain(self, _domain): + return None + + def save_key(self, key, name): + self.saved_keys.append((key, name)) + return f"key-{len(self.saved_keys)}" + + def save_cert(self, key_id, cert, domains): + self.saved_certs.append((key_id, cert, tuple(domains))) + + +class DummySolver: + def supports_domain(self, _domain): + return True + + +class DummyIssuer: + def __init__(self): + self.calls = [] + + def generate_key_and_cert_for_domains(self, domains, **_kwargs): + self.calls.append(list(domains)) + return "dummy-key", "dummy-cert" + + +def test_would_trigger_exact_rule(): + blocked = {"abc", "def", "ghi"} + assert would_trigger(["x", "abc", "def", "ghi", "example", "com"], blocked) is True + assert would_trigger(["x", "abc", "def", "example", "com"], blocked) is False + assert would_trigger(["x", "abc", "abc", "example", "com"], blocked) is True + + +def test_split_domain_to_safe_groups_greedy(): + groups = split_domain_to_safe_groups( + "x.abc.def.ghi.example.com", + blocked_labels=["abc", "def", "ghi"], + ) + assert groups == [["x", "abc", "def"], ["ghi", "example", "com"]] + + +def test_create_safe_domain_batches_compacts_non_triggering_domains(): + domains = ["x.example.com", "y.example.com", "y.example.com"] + assert create_safe_domain_batches(domains) == [["x.example.com", "y.example.com"]] + + +def test_issue_certificate_default_does_not_batch(): + key_store = DummyKeyStore() + issuer = DummyIssuer() + solver = DummySolver() + manager = AcmeCertManager(key_store=key_store, cert_issuer=issuer, challenge_solvers=[solver]) + + domains = [ + "xyz.com", + "abcd.xyz.com", + "def.abcd.xyz.com", + ] + manager.issue_certificate(hosts=domains) + + assert issuer.calls == [domains] + + +def test_issue_certificate_in_batches_uses_custom_batch_generator(): + key_store = DummyKeyStore() + issuer = DummyIssuer() + solver = DummySolver() + manager = AcmeCertManager(key_store=key_store, cert_issuer=issuer, challenge_solvers=[solver]) + + domains = [ + "x.abc.def.ghi.example.com", + "root.example.com", + ] + response = manager.issue_certificate_in_batches( + hosts=domains, + batch_generator=lambda d: create_safe_domain_batches( + d, + blocked_labels=["abc", "def", "ghi"], + ), + ) + + assert issuer.calls == [ + ["root.example.com"], + ["x.abc.def"], + ["ghi.example.com"], + ] + assert len(response.issued) == 3 + assert [issued.domains for issued in response.issued] == issuer.calls diff --git a/tests/test_cli.py b/tests/test_cli.py index e914d3e..ed92ff6 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -54,8 +54,8 @@ def generate_key_and_cert_for_domains(self, domains, key_type="rsa"): class FakeCert: - def __init__(self, not_valid_after): - self.not_valid_after = not_valid_after + def __init__(self, not_valid_after_utc): + self.not_valid_after_utc = not_valid_after_utc def test_is_root_true_false(monkeypatch): @@ -365,9 +365,10 @@ def fake_verify(domains, api_key=None): def test_main_obtain(monkeypatch): called = {} - def fake_obtain(domains, api_key=None): + def fake_obtain(domains, api_key=None, keystore_path="/etc/ssl"): called["domains"] = domains called["api_key"] = api_key + called["keystore_path"] = keystore_path monkeypatch.setattr(cli, "obtain_certificate", fake_obtain) monkeypatch.setattr(cli, "_resolve_cloudflare_api_key", lambda: None) @@ -376,6 +377,25 @@ def fake_obtain(domains, api_key=None): cli.main() assert called["domains"] == ["example.com"] assert called["api_key"] is None + assert called["keystore_path"] == "/etc/ssl" + + +def test_main_obtain_custom_keystore(monkeypatch): + called = {} + + def fake_obtain(domains, api_key=None, keystore_path="/etc/ssl"): + called["domains"] = domains + called["api_key"] = api_key + called["keystore_path"] = keystore_path + + monkeypatch.setattr(cli, "obtain_certificate", fake_obtain) + monkeypatch.setattr(cli, "_resolve_cloudflare_api_key", lambda: None) + monkeypatch.setattr(cli.sys, "argv", ["certapi", "obtain", "--path", ".", "example.com"]) + + cli.main() + assert called["domains"] == ["example.com"] + assert called["api_key"] is None + assert called["keystore_path"] == "." def test_main_help(monkeypatch): diff --git a/tests/test_domain_batching_boulder_cases.py b/tests/test_domain_batching_boulder_cases.py new file mode 100644 index 0000000..33603fd --- /dev/null +++ b/tests/test_domain_batching_boulder_cases.py @@ -0,0 +1,87 @@ +from certapi.domain_batching import would_trigger + + +def _looks_like_recursive_on_demand_request_test_only(idents, blocked): + if not blocked: + return False + + blocked_set = {label.lower() for label in blocked} + for ident in idents or []: + if ident.get("type") != "dns": + continue + labels = [label for label in ident["value"].lower().split(".") if label] + if would_trigger(labels, blocked_set): + return True + return False + + +def test_looks_like_recursive_on_demand_request_boulder_cases(): + test_cases = [ + { + "name": "no blocks", + "idents": [{"type": "dns", "value": "example.com"}], + "blocked": None, + "want_err": False, + }, + { + "name": "no idents", + "idents": None, + "blocked": ["asdf"], + "want_err": False, + }, + { + "name": "no dns idents", + "idents": [{"type": "ip", "value": "1.2.3.4"}], + "blocked": ["asdf"], + "want_err": False, + }, + { + "name": "short idents", + "idents": [ + {"type": "dns", "value": "foo.example.com"}, + {"type": "dns", "value": "bar.example.com"}, + ], + "blocked": ["asdf"], + "want_err": False, + }, + { + "name": "long but not blocked ident", + "idents": [ + {"type": "dns", "value": "foo.example.com"}, + {"type": "dns", "value": "asdf.qwer.zxcv.asdf.example.com"}, + ], + "blocked": ["asdf"], + "want_err": False, + }, + { + "name": "two identical blocked labels early", + "idents": [ + {"type": "dns", "value": "foo.example.com"}, + {"type": "dns", "value": "asdf.asdf.qwer.zxcv.example.com"}, + ], + "blocked": ["asdf"], + "want_err": True, + }, + { + "name": "two identical blocked labels late", + "idents": [ + {"type": "dns", "value": "foo.example.com"}, + {"type": "dns", "value": "qwer.zxcv.asdf.asdf.example.com"}, + ], + "blocked": ["asdf"], + "want_err": True, + }, + { + "name": "three blocked labels", + "idents": [ + {"type": "dns", "value": "foo.example.com"}, + {"type": "dns", "value": "asdf.qwer.zxcv.asdf.example.com"}, + ], + "blocked": ["asdf", "qwer", "zxcv"], + "want_err": True, + }, + ] + + for tc in test_cases: + got = _looks_like_recursive_on_demand_request_test_only(tc["idents"], tc["blocked"]) + assert got is tc["want_err"], f"{tc['name']}: got={got} want={tc['want_err']}"