Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/on_demand_checkpointing.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,13 @@ The trigger file is always at a fixed path. To trigger a checkpoint
(e.g. via `kubectl exec` into the training pod):

```bash
touch /dev/shm/instructlab_checkpoint_requested
touch /dev/shm/checkpoint_requested
```

The default filename is `checkpoint_requested`. To use a custom filename,
set the `CHECKPOINT_TRIGGER_FILENAME` environment variable before starting
training.

Workers check for the trigger file at each synchronization point in the
training loop (multiple times per step). Once any rank on any node detects
it, all ranks coordinate via `all_reduce` to save a checkpoint and exit.
Expand Down
4 changes: 2 additions & 2 deletions src/instructlab/training/on_demand_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@
# 2. Shared between all containers in the same Kubernetes pod.
# 3. Automatically cleaned up when the pod is destroyed.
_TRIGGER_DIR = Path("/dev/shm")
_TRIGGER_FILENAME = "instructlab_checkpoint_requested"


def _get_trigger_path() -> Path:
"""Return the path to the checkpoint trigger file."""
return _TRIGGER_DIR / _TRIGGER_FILENAME
filename = os.environ.get("CHECKPOINT_TRIGGER_FILENAME", "checkpoint_requested")
return _TRIGGER_DIR / filename


def write_trigger_file() -> Path:
Expand Down
10 changes: 8 additions & 2 deletions tests/unit/test_on_demand_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@
class TestGetTriggerPath:
def test_returns_correct_name(self):
path = _get_trigger_path()
assert path.name == "instructlab_checkpoint_requested"
assert path.name == "checkpoint_requested"
assert str(path.parent) == "/dev/shm"

def test_respects_env_override(self, monkeypatch):
monkeypatch.setenv("CHECKPOINT_TRIGGER_FILENAME", "my_custom_trigger")
path = _get_trigger_path()
assert path.name == "my_custom_trigger"
assert str(path.parent) == "/dev/shm"


Expand All @@ -43,7 +49,7 @@ def test_creates_file(self, tmp_path):
def test_returns_correct_path(self, tmp_path):
with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path):
path = write_trigger_file()
assert path == tmp_path / "instructlab_checkpoint_requested"
assert path == tmp_path / "checkpoint_requested"


class TestTriggerFileExists:
Expand Down
Loading