diff --git a/docs/on_demand_checkpointing.md b/docs/on_demand_checkpointing.md index b21d71d5..b4576164 100644 --- a/docs/on_demand_checkpointing.md +++ b/docs/on_demand_checkpointing.md @@ -135,9 +135,13 @@ The trigger file is always at a fixed path. To trigger a checkpoint (e.g. via `kubectl exec` into the training pod): ```bash -touch /dev/shm/instructlab_checkpoint_requested +touch /dev/shm/checkpoint_requested ``` +The default filename is `checkpoint_requested`. To use a custom filename, +set the `CHECKPOINT_TRIGGER_FILENAME` environment variable before starting +training. + Workers check for the trigger file at each synchronization point in the training loop (multiple times per step). Once any rank on any node detects it, all ranks coordinate via `all_reduce` to save a checkpoint and exit. diff --git a/src/instructlab/training/on_demand_checkpoint.py b/src/instructlab/training/on_demand_checkpoint.py index 2bf0361a..3598c267 100644 --- a/src/instructlab/training/on_demand_checkpoint.py +++ b/src/instructlab/training/on_demand_checkpoint.py @@ -89,12 +89,12 @@ # 2. Shared between all containers in the same Kubernetes pod. # 3. Automatically cleaned up when the pod is destroyed. _TRIGGER_DIR = Path("/dev/shm") -_TRIGGER_FILENAME = "instructlab_checkpoint_requested" def _get_trigger_path() -> Path: """Return the path to the checkpoint trigger file.""" - return _TRIGGER_DIR / _TRIGGER_FILENAME + filename = os.environ.get("CHECKPOINT_TRIGGER_FILENAME", "checkpoint_requested") + return _TRIGGER_DIR / filename def write_trigger_file() -> Path: diff --git a/tests/unit/test_on_demand_checkpoint.py b/tests/unit/test_on_demand_checkpoint.py index f9db657e..584d547f 100644 --- a/tests/unit/test_on_demand_checkpoint.py +++ b/tests/unit/test_on_demand_checkpoint.py @@ -29,7 +29,13 @@ class TestGetTriggerPath: def test_returns_correct_name(self): path = _get_trigger_path() - assert path.name == "instructlab_checkpoint_requested" + assert path.name == "checkpoint_requested" + assert str(path.parent) == "/dev/shm" + + def test_respects_env_override(self, monkeypatch): + monkeypatch.setenv("CHECKPOINT_TRIGGER_FILENAME", "my_custom_trigger") + path = _get_trigger_path() + assert path.name == "my_custom_trigger" assert str(path.parent) == "/dev/shm" @@ -43,7 +49,7 @@ def test_creates_file(self, tmp_path): def test_returns_correct_path(self, tmp_path): with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path): path = write_trigger_file() - assert path == tmp_path / "instructlab_checkpoint_requested" + assert path == tmp_path / "checkpoint_requested" class TestTriggerFileExists: