Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions ci/tools/tests/test_pytest_run_parallel_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import sys
import threading
import types
from contextlib import contextmanager
from pathlib import Path

import pytest

_TEST_HELPERS_ROOT = Path(__file__).resolve().parents[3] / "cuda_python_test_helpers"
sys.path.insert(0, str(_TEST_HELPERS_ROOT))

from cuda_python_test_helpers.pytest_run_parallel import (
install_run_parallel_worker_context_patch,
mark_item_for_worker_context,
)


def _install_fake_pytest_run_parallel(monkeypatch):
package = types.ModuleType("pytest_run_parallel")
package.__path__ = []
plugin = types.ModuleType("pytest_run_parallel.plugin")

def wrap_function_parallel(fn, n_workers, n_iterations):
raise AssertionError("unpatched fake wrapper should not be called")

plugin.wrap_function_parallel = wrap_function_parallel
monkeypatch.setitem(sys.modules, "pytest_run_parallel", package)
monkeypatch.setitem(sys.modules, "pytest_run_parallel.plugin", plugin)
return plugin


@pytest.mark.agent_authored(model="gpt-5")
def test_install_run_parallel_worker_context_patch_is_idempotent(monkeypatch):
plugin = _install_fake_pytest_run_parallel(monkeypatch)

assert install_run_parallel_worker_context_patch() is True
patched = plugin.wrap_function_parallel
assert patched._cuda_python_patched_run_parallel_worker_context

assert install_run_parallel_worker_context_patch() is True
assert plugin.wrap_function_parallel is patched


@pytest.mark.agent_authored(model="gpt-5")
def test_patched_wrapper_runs_context_with_isolated_kwargs(monkeypatch):
plugin = _install_fake_pytest_run_parallel(monkeypatch)
install_run_parallel_worker_context_patch()

lock = threading.Lock()
context_events = []
calls = []

@contextmanager
def worker_context(*, thread_index, iteration_index, kwargs):
token = object()
kwargs["token"] = (thread_index, iteration_index, id(token))
kwargs["context_kwargs_id"] = id(kwargs)
with lock:
context_events.append(("enter", thread_index, iteration_index, id(kwargs)))
try:
yield
finally:
with lock:
context_events.append(("exit", thread_index, iteration_index, id(kwargs)))

def test_body(*, thread_index, iteration_index, token, context_kwargs_id, static_value):
with lock:
calls.append(
{
"thread_index": thread_index,
"iteration_index": iteration_index,
"token": token,
"context_kwargs_id": context_kwargs_id,
"static_value": static_value,
}
)

item = types.SimpleNamespace(obj=test_body)
assert mark_item_for_worker_context(item, worker_context) is True

wrapped = plugin.wrap_function_parallel(item.obj, n_workers=3, n_iterations=2)
wrapped(thread_index=-1, iteration_index=-1, static_value="fixture-value")

expected_pairs = {(thread_index, iteration_index) for thread_index in range(3) for iteration_index in range(2)}
actual_pairs = {(call["thread_index"], call["iteration_index"]) for call in calls}
assert actual_pairs == expected_pairs
assert {call["token"][:2] for call in calls} == expected_pairs
assert {call["static_value"] for call in calls} == {"fixture-value"}

kwargs_ids = {call["context_kwargs_id"] for call in calls}
assert len(kwargs_ids) == 6
assert len(context_events) == 12
assert {event[3] for event in context_events} == kwargs_ids


@pytest.mark.agent_authored(model="gpt-5")
def test_mark_item_for_worker_context_wraps_callables_without_attrs():
class CallableWithoutAttrs:
__slots__ = ("calls",)

def __init__(self):
self.calls = []

def __call__(self, **kwargs):
self.calls.append(kwargs)

@contextmanager
def worker_context(*, thread_index, iteration_index, kwargs):
kwargs["patched"] = True
yield

original = CallableWithoutAttrs()
item = types.SimpleNamespace(obj=original)

assert mark_item_for_worker_context(item, worker_context) is True
assert item.obj is not original
item.obj()
assert original.calls == [{}]
54 changes: 54 additions & 0 deletions cuda_bindings/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pathlib
import sys
from contextlib import contextmanager
from importlib.metadata import PackageNotFoundError, distribution

import pytest
Expand All @@ -25,6 +26,59 @@
sys.path.insert(0, test_helpers_root)


from cuda_python_test_helpers.pytest_run_parallel import (
install_run_parallel_worker_context_patch,
mark_item_for_worker_context,
)


def pytest_configure(config):
install_run_parallel_worker_context_patch()


@contextmanager
def _thread_context():
(err,) = cuda.cuInit(0)
assert err == cuda.CUresult.CUDA_SUCCESS
err, device = cuda.cuDeviceGet(0)
assert err == cuda.CUresult.CUDA_SUCCESS
err, ctx = cuda.cuCtxCreate(None, 0, device)
assert err == cuda.CUresult.CUDA_SUCCESS
try:
yield device, ctx
finally:
(err,) = cuda.cuCtxDestroy(ctx)
assert err == cuda.CUresult.CUDA_SUCCESS


@contextmanager
def _cuda_bindings_worker_context(*, thread_index, iteration_index, kwargs):
with _thread_context() as (device, ctx):
if "device" in kwargs:
kwargs["device"] = device
if "ctx" in kwargs:
kwargs["ctx"] = ctx
yield


def _is_cudla_item(item):
nodeid = item.nodeid.replace("\\", "/")
return nodeid.startswith("tests/cudla/") or "cuda_bindings/tests/cudla/" in nodeid


def _item_needs_thread_ctx(item):
if _is_cudla_item(item):
return False
fixturenames = set(getattr(item, "fixturenames", ()))
return bool(fixturenames & {"device", "ctx", "driver", "cufile_env_json"})


def pytest_collection_modifyitems(config, items):
for item in items:
if _item_needs_thread_ctx(item):
mark_item_for_worker_context(item, _cuda_bindings_worker_context)


@pytest.fixture(scope="module")
def cuda_driver():
(err,) = cuda.cuInit(0)
Expand Down
84 changes: 72 additions & 12 deletions cuda_core/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,76 @@ def xfail_if_mempool_oom(err_or_exc, api_name=None, device=0):
sys.path.insert(0, test_helpers_root)


from cuda_python_test_helpers.pytest_run_parallel import (
install_run_parallel_worker_context_patch,
mark_item_for_worker_context,
)


def pytest_configure(config):
install_run_parallel_worker_context_patch()


@contextmanager
def _init_cuda_context():
# TODO: rename this to e.g. init_context
device = Device(0)
device.set_current()

# Set option to avoid spin-waiting on synchronization.
if int(os.environ.get("CUDA_CORE_TEST_BLOCKING_SYNC", 0)) != 0:
handle_return(
driver.cuDevicePrimaryCtxSetFlags(device.device_id, driver.CUctx_flags.CU_CTX_SCHED_BLOCKING_SYNC)
)

try:
yield device
finally:
_ = _device_unset_current()


@contextmanager
def _cuda_core_worker_context(*, thread_index, iteration_index, kwargs):
with _init_cuda_context() as device:
if "init_cuda" in kwargs:
kwargs["init_cuda"] = device
if "mempool_device" in kwargs:
kwargs["mempool_device"] = device
if "ipc_device" in kwargs:
kwargs["ipc_device"] = device
if "mempool_device_x2" in kwargs:
kwargs["mempool_device_x2"] = _mempool_device_impl(2)
if "mempool_device_x3" in kwargs:
kwargs["mempool_device_x3"] = _mempool_device_impl(3)
if "ipc_mempool_device_x2" in kwargs:
kwargs["ipc_mempool_device_x2"] = _require_ipc_mempool_devices(_mempool_device_impl(2))
yield


_CUDA_CONTEXT_FIXTURES = frozenset(
{
"init_cuda",
"ipc_device",
"ipc_memory_resource",
"mempool_device",
"mempool_device_x2",
"mempool_device_x3",
"ipc_mempool_device_x2",
"memory_resource_factory",
}
)


def _item_needs_thread_ctx(item):
return bool(_CUDA_CONTEXT_FIXTURES & set(getattr(item, "fixturenames", ())))


def pytest_collection_modifyitems(config, items):
for item in items:
if _item_needs_thread_ctx(item):
mark_item_for_worker_context(item, _cuda_core_worker_context)


def skip_if_pinned_memory_unsupported(device):
try:
if not device.properties.host_memory_pools_supported:
Expand Down Expand Up @@ -194,18 +264,8 @@ def session_setup():

@pytest.fixture
def init_cuda():
# TODO: rename this to e.g. init_context
device = Device(0)
device.set_current()

# Set option to avoid spin-waiting on synchronization.
if int(os.environ.get("CUDA_CORE_TEST_BLOCKING_SYNC", 0)) != 0:
handle_return(
driver.cuDevicePrimaryCtxSetFlags(device.device_id, driver.CUctx_flags.CU_CTX_SCHED_BLOCKING_SYNC)
)

yield device
_ = _device_unset_current()
with _init_cuda_context() as device:
yield device


def _device_unset_current() -> bool:
Expand Down
Loading
Loading