Skip to content

fix: cast initial_basis to long in 07_iterative_nqs_dci scripts (fixes #40)#41

Open
thc1006 wants to merge 1 commit intoQuantumNoLab:mainfrom
thc1006:fix/iter-nqs-dci-initial-basis-dtype
Open

fix: cast initial_basis to long in 07_iterative_nqs_dci scripts (fixes #40)#41
thc1006 wants to merge 1 commit intoQuantumNoLab:mainfrom
thc1006:fix/iter-nqs-dci-initial-basis-dtype

Conversation

@thc1006
Copy link
Copy Markdown
Member

@thc1006 thc1006 commented Apr 7, 2026

Summary

Minimal surgical fix for #40: add .long() cast to the basis tensor before passing it as initial_basis in the two affected 07_iterative_nqs_dci scripts.

Closes #40.

Affected files (2)

  • experiments/pipelines/07_iterative_nqs_dci/iter_nqs_dci_sqd.py
  • experiments/pipelines/07_iterative_nqs_dci/iter_nqs_dci_krylov_classical.py

Diff summary

Each file gets one extra .long() call + a 4-line explanatory comment:

-    basis = pipeline.extract_and_select_basis()
+    # `extract_and_select_basis()` returns float32 (inherited from the
+    # NF trainer's accumulated_basis). `run_hi_nqs_sqd(initial_basis=...)`
+    # requires integer/bool dtype for the binary occupation vectors, so
+    # cast before passing.
+    basis = pipeline.extract_and_select_basis().long()

Total: 10 insertions, 2 deletions across 2 files. No library changes, no API changes.

Why this scope

See issue #40 for the full root-cause analysis. Short version: extract_and_select_basis() returns float32 (cloned from NF trainer's accumulated_basis), but the downstream run_hi_nqs_sqd/run_hi_nqs_skqd strictly validate integer/bool dtype. 008's sister scripts work because expand_basis_via_connections() happens to cast internally; 007 doesn't have that intermediate step.

Two cleaner long-term architectural fixes (cast at the library source, or relax the validator) are noted in the issue as out-of-scope. This PR is the minimum unblocking patch.

Test plan

Manual end-to-end verification on H2/CPU:

  • Before fix: both scripts raise ValueError: initial_basis must be integer or bool dtype ... got torch.float32
  • After fix:
    • iter_nqs_dci_sqd.py h2 --device cpuError 0.0 mHa, 11.88 s
    • iter_nqs_dci_krylov_classical.py h2 --device cpuError 0.0 mHa, 9.32 s
  • ruff check experiments/pipelines/07_iterative_nqs_dci/ — clean
  • ruff format --check — clean
  • (Reviewer) Consider whether to also open a follow-up issue/PR for the root-cause library fix (cast in extract_and_select_basis or lenient validator)

Relationship to PR #39

PR #39 (refactor/pipeline-catalog) renames 07_iterative_nqs_dci/007_iterative_nqs_dci/ via git mv. This PR modifies file content (not paths) in 07_iterative_nqs_dci/. The two PRs touch disjoint aspects of the same files, so git's rename detection should handle the rebase cleanly regardless of merge order.

Recommended merge order:

  1. Merge this PR first (standalone, off main, no deps)
  2. PR refactor: pipeline catalog with 3-digit prefix + 010-013 method-as-pipeline entries #39 rebases on main (picks up this fix at the old path, then renames folder)
  3. OR merge PR refactor: pipeline catalog with 3-digit prefix + 010-013 method-as-pipeline entries #39 first — then this PR rebases (git detects rename, applies fix at the new path)

Either order works.

`FlowGuidedKrylovPipeline.extract_and_select_basis()` returns a float32
tensor (cloned from the NF trainer's `accumulated_basis`, which lives in
float for gradient tracking). But `run_hi_nqs_sqd(initial_basis=...)` and
`run_hi_nqs_skqd(initial_basis=...)` strictly validate that
`initial_basis` is integer/bool dtype (binary occupation vectors) and
raise ValueError on float.

This caused both `iter_nqs_dci_sqd.py` and `iter_nqs_dci_krylov_classical.py`
to fail end-to-end with:

    ValueError: initial_basis must be integer or bool dtype
    (binary occupations), got torch.float32

at the `run_hi_nqs_sqd`/`run_hi_nqs_skqd` call site.

Interestingly, the parallel scripts in group 08 (iter_nqs_dci_pt2_*) work
because their basis is passed through `expand_basis_via_connections()`
first, which happens to cast internally. The 007 scripts do not have that
intermediate step, so they hit the bug raw.

Fix: add `.long()` cast to the `basis` result in the two affected scripts,
with a short comment explaining why. Both scripts now reach chemical
accuracy on H2:

    07_iterative_nqs_dci/iter_nqs_dci_sqd.py           h2 cpu 11.9s  0.0 mHa
    07_iterative_nqs_dci/iter_nqs_dci_krylov_classical h2 cpu  9.3s  0.0 mHa

Scope:
- Surgical script-level fix (2 files, 2-line changes each including
  comment). Does not change library-level API contracts.
- The 3rd script in the group (`iter_nqs_dci_krylov_quantum.py`) does not
  use `initial_basis` at all — unaffected.
- Related architectural issue (float32 basis at API boundary with int-only
  validator) could eventually be fixed by either (a) making
  `extract_and_select_basis` return long, or (b) making
  `validate_initial_basis` lenient for float {0,1} values. Out of scope
  for this minimal fix.

Discovered via end-to-end smoke test on `refactor/pipeline-catalog` branch
(PR QuantumNoLab#39). Verified that the failure is pre-existing on main before writing
this fix.
Copilot AI review requested due to automatic review settings April 7, 2026 17:32
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes an end-to-end runtime failure in the 07_iterative_nqs_dci experiment scripts by ensuring the NF/DCI-selected basis is cast to an integer dtype before being passed as initial_basis to HI+NQS runners that strictly validate integer/bool occupations.

Changes:

  • Cast pipeline.extract_and_select_basis() output to torch.long in iter_nqs_dci_sqd.py before calling run_hi_nqs_sqd(initial_basis=...).
  • Cast pipeline.extract_and_select_basis() output to torch.long in iter_nqs_dci_krylov_classical.py before calling run_hi_nqs_skqd(initial_basis=...).
  • Add short inline comments documenting the dtype mismatch and why the cast is required.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
experiments/pipelines/07_iterative_nqs_dci/iter_nqs_dci_sqd.py Cast extracted basis to long so run_hi_nqs_sqd(initial_basis=...) passes dtype validation.
experiments/pipelines/07_iterative_nqs_dci/iter_nqs_dci_krylov_classical.py Cast extracted basis to long so run_hi_nqs_skqd(initial_basis=...) passes dtype validation.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated no new comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

bug: 07_iterative_nqs_dci/iter_nqs_dci_{sqd,krylov_classical}.py fail with initial_basis dtype ValueError

2 participants