Skip to content

fix: correct MPS execution on apple silicon#145

Open
yanghan234 wants to merge 7 commits intomainfrom
hanyang/fix-mps-device-handling
Open

fix: correct MPS execution on apple silicon#145
yanghan234 wants to merge 7 commits intomainfrom
hanyang/fix-mps-device-handling

Conversation

@yanghan234
Copy link
Copy Markdown
Collaborator

Summary

  • register SphericalBasisLayer.coef as a buffer so it moves with the model on MPS
  • precompute graph-derived indexing values in batch_to_dict() and move the input dict to the target device
    explicitly
  • remove MPS device-to-host synchronization hotspots in the M3GNet forward path and stress path

@yanghan234 yanghan234 force-pushed the hanyang/fix-mps-device-handling branch from 450d7cf to 191cf36 Compare April 7, 2026 12:20
yanghan234 and others added 6 commits April 7, 2026 13:27
- Use batch_to_dict -> move_to_device pattern in deprecated get_properties
  method, consistent with predict/fit paths
- Create index_map on the same device as num_triple_ij to avoid
  CPU/CUDA tensor mismatch in three_body_edge_map computation

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Merge the device transfer logic into batch_to_dict via its existing
device parameter (now defaulting to None). This ensures every call
site gets device placement automatically and removes the risk of
forgetting a separate move_to_device call.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add conftest.py with a device fixture that auto-detects available
torch devices. Tests using the fixture run on all available backends.
A --device flag allows restricting to a single device.

Converted test_batch_relax.py from unittest to pytest style to use
the device fixture. Verified passing on both cpu and mps.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- M3Gnet.forward now uses .get() with fallback computation for
  precomputed keys (total_num_atoms, bond_index_bias, etc.), so
  callers constructing input dicts directly won't KeyError.
- batch_to_dict creates index_map on CPU (moved to device at the end),
  avoiding intermediate device mismatches.
- Remove unused pytest import in test_batch_relax.py.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Prevents device mismatch if graph_batch tensors are already on
a non-CPU device when batch_to_dict is called.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant