diff --git a/plugins/nemo-safe-synthesizer/src/nemo_safe_synthesizer_plugin/cli.py b/plugins/nemo-safe-synthesizer/src/nemo_safe_synthesizer_plugin/cli.py index 6211d16ed1..dff9348079 100644 --- a/plugins/nemo-safe-synthesizer/src/nemo_safe_synthesizer_plugin/cli.py +++ b/plugins/nemo-safe-synthesizer/src/nemo_safe_synthesizer_plugin/cli.py @@ -5,6 +5,7 @@ from __future__ import annotations +import os import subprocess from pathlib import Path from typing import ClassVar @@ -14,6 +15,9 @@ from nemo_safe_synthesizer_plugin.config import config from nemo_safe_synthesizer_plugin.runtime import runtime_info, runtime_task_command, setup_runtime +NEMO_DEPLOYMENT_TYPE_ENVVAR = "NEMO_DEPLOYMENT_TYPE" +NMP_DEPLOYMENT_TYPE = "nmp" + class SafeSynthesizerCLI(NemoCLI): """CLI extensions for host-local Safe Synthesizer development.""" @@ -100,7 +104,9 @@ def run_local_command( typer.echo(str(e), err=True) raise typer.Exit(1) from e - result = subprocess.run(command, check=False) + runtime_env = os.environ.copy() + runtime_env[NEMO_DEPLOYMENT_TYPE_ENVVAR] = NMP_DEPLOYMENT_TYPE + result = subprocess.run(command, check=False, env=runtime_env) if result.returncode != 0: raise typer.Exit(result.returncode) typer.echo(f"Wrote Safe Synthesizer results to {output_dir}") diff --git a/plugins/nemo-safe-synthesizer/tests/unit/test_cli.py b/plugins/nemo-safe-synthesizer/tests/unit/test_cli.py new file mode 100644 index 0000000000..1c65b626db --- /dev/null +++ b/plugins/nemo-safe-synthesizer/tests/unit/test_cli.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from subprocess import CompletedProcess + +from nemo_safe_synthesizer_plugin import cli +from nemo_safe_synthesizer_plugin.cli import NEMO_DEPLOYMENT_TYPE_ENVVAR, NMP_DEPLOYMENT_TYPE, SafeSynthesizerCLI +from typer.testing import CliRunner + + +def test_run_local_sets_nmp_deployment_type_for_runtime_subprocess(tmp_path, monkeypatch): + spec_file = tmp_path / "nss-job.json" + spec_file.write_text("{}", encoding="utf-8") + data_file = tmp_path / "input.csv" + data_file.write_text("name\nAda\n", encoding="utf-8") + output_dir = tmp_path / "nss-output" + captured = {} + + def fake_runtime_task_command(_config, args): + return ["runtime-python", *args] + + def fake_run(command, *, check=False, env=None): + captured["command"] = command + captured["check"] = check + captured["env"] = env + return CompletedProcess(command, 0) + + monkeypatch.setattr(cli, "runtime_task_command", fake_runtime_task_command) + monkeypatch.setattr(cli.subprocess, "run", fake_run) + + result = CliRunner().invoke( + SafeSynthesizerCLI().get_cli(), + [ + "run-local", + "--workspace", + "default", + "--spec-file", + str(spec_file), + "--data-source", + str(data_file), + "--output-dir", + str(output_dir), + ], + ) + + assert result.exit_code == 0, result.output + assert captured["check"] is False + assert captured["env"][NEMO_DEPLOYMENT_TYPE_ENVVAR] == NMP_DEPLOYMENT_TYPE