Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions cli_preflight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Magic-prompt pre-flight check, isolated so it can be tested without torch
or the ideogram4 model code.

Lives at the repo root (next to run_inference.py) so the helper is
importable by both the CLI and the tests.
"""

from __future__ import annotations

import argparse
import os


def _resolve_magic_prompt_key(explicit: str | None) -> str | None:
"""Resolve the magic-prompt API key.

Precedence:
1. ``explicit`` (the ``--magic-prompt-key`` value).
2. ``$MAGIC_PROMPT_API_KEY``.
3. ``$IDEOGRAM_API_KEY``.
Returns the first non-empty match, or None if all are missing.
"""
if explicit:
return explicit
env_key = os.environ.get("MAGIC_PROMPT_API_KEY")
if env_key:
return env_key
return os.environ.get("IDEOGRAM_API_KEY") or None


def check_magic_prompt_key(args: argparse.Namespace) -> bool:
"""Return True if the pre-flight check passes.

The pre-flight check requires that either ``--magic-prompt`` is disabled
or that an API key is resolvable (via ``--magic-prompt-key`` or the
``MAGIC_PROMPT_API_KEY`` / ``IDEOGRAM_API_KEY`` environment variables).

When the check fails the caller is expected to call ``parser.error(...)``
with a user-facing message, so this function is side-effect-free.
"""
if not getattr(args, "magic_prompt", False):
return True
return _resolve_magic_prompt_key(getattr(args, "magic_prompt_key", None)) is not None
29 changes: 20 additions & 9 deletions run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
moderate_prompt,
)

from cli_preflight import check_magic_prompt_key

QUANTIZATION_REPOS = {
"nf4": "ideogram-ai/ideogram-4-nf4",
"fp8": "ideogram-ai/ideogram-4-fp8",
Expand All @@ -45,7 +47,7 @@ def _print_flags(label: str, flags: list[tuple[str, float]]) -> None:
print(f" {name}: {score:.3f}", file=sys.stderr)


def main() -> None:
def build_argparser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", required=True)
parser.add_argument("--output", default="output.png")
Expand Down Expand Up @@ -140,8 +142,23 @@ def main() -> None:
"create a Visual Content Moderation project."
),
)
return parser


def main() -> None:
parser = build_argparser()
args = parser.parse_args()

# Pre-flight: if magic-prompt is on but no key is set (via flag or env),
# fail fast before any network or model-load work. Saves a real Hive call
# and any partial model load in the common "I forgot the key" case.
if not check_magic_prompt_key(args):
parser.error(
"--magic-prompt is on but no API key was set. Provide one via "
"--magic-prompt-key, $MAGIC_PROMPT_API_KEY, or $IDEOGRAM_API_KEY; or "
"pass --no-magic-prompt to disable expansion."
)
Comment on lines +155 to +160

if args.hive_text_key:
flags = moderate_prompt(args.prompt, args.hive_text_key)
if flags:
Expand All @@ -158,14 +175,8 @@ def main() -> None:

prompt = args.prompt
if args.magic_prompt:
if not args.magic_prompt_key:
print(
"ERROR: magic prompt is enabled but no API key was found. Set "
"MAGIC_PROMPT_API_KEY, pass --magic-prompt-key, or disable expansion "
"with --no-magic-prompt.",
file=sys.stderr,
)
sys.exit(2)
# The pre-flight check above guarantees args.magic_prompt_key is set
# when args.magic_prompt is true.
aspect_ratio = aspect_ratio_from_size(args.width, args.height)
magic = MAGIC_PROMPTS[args.magic_prompt_model](api_key=args.magic_prompt_key) # type: ignore[call-arg]
Comment on lines +178 to 181
print(
Expand Down
42 changes: 42 additions & 0 deletions tests/test_cli_preflight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Tests for the magic-prompt pre-flight check (no torch, no ideogram4)."""

from __future__ import annotations

import argparse

import pytest

from cli_preflight import check_magic_prompt_key, _resolve_magic_prompt_key


def test_check_returns_true_when_magic_prompt_disabled() -> None:
args = argparse.Namespace(magic_prompt=False, magic_prompt_key=None)
assert check_magic_prompt_key(args) is True


def test_check_returns_true_when_flag_key_set() -> None:
args = argparse.Namespace(magic_prompt=True, magic_prompt_key="sk-abc")
assert check_magic_prompt_key(args) is True


def test_check_returns_false_when_no_key_anywhere() -> None:
args = argparse.Namespace(magic_prompt=True, magic_prompt_key=None)
with pytest.MonkeyPatch.context() as mp:
mp.delenv("MAGIC_PROMPT_API_KEY", raising=False)
mp.delenv("IDEOGRAM_API_KEY", raising=False)
assert check_magic_prompt_key(args) is False
Comment on lines +22 to +27


def test_resolve_prefers_explicit_then_magic_then_ideogram(monkeypatch) -> None:
monkeypatch.delenv("MAGIC_PROMPT_API_KEY", raising=False)
monkeypatch.delenv("IDEOGRAM_API_KEY", raising=False)
assert _resolve_magic_prompt_key("sk-explicit") == "sk-explicit"
monkeypatch.setenv("MAGIC_PROMPT_API_KEY", "sk-magic")
assert _resolve_magic_prompt_key(None) == "sk-magic"
monkeypatch.setenv("IDEOGRAM_API_KEY", "sk-ideo")
monkeypatch.delenv("MAGIC_PROMPT_API_KEY")
assert _resolve_magic_prompt_key(None) == "sk-ideo" # empty explicit falls through to env
# But an empty env var is treated as missing
monkeypatch.setenv("MAGIC_PROMPT_API_KEY", "")
monkeypatch.delenv("IDEOGRAM_API_KEY")
assert _resolve_magic_prompt_key(None) is None