From 4920c0697bed93181e5e171a6a44bb3de063c94b Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Tue, 16 Jun 2026 00:53:19 -0400 Subject: [PATCH 1/3] [Docs] Modernize test-gating documentation Update the contributor guide and tvm.testing docstrings/comments to describe the current gating API (@pytest.mark.gpu + tvm.testing.env.has_*() + skipif, and pytest.importorskip for optional packages) instead of the removed Feature class and @tvm.testing.requires_*/uses_gpu decorators. --- docs/contribute/code_guide.rst | 11 +++++- docs/contribute/testing.rst | 62 +++++++++++++++++----------------- python/tvm/testing/env.py | 22 ++++++------ python/tvm/testing/plugin.py | 8 +++-- python/tvm/testing/utils.py | 58 +++++++++++++++---------------- 5 files changed, 87 insertions(+), 74 deletions(-) diff --git a/docs/contribute/code_guide.rst b/docs/contribute/code_guide.rst index fd40cec579cf..6419b7f9d77a 100644 --- a/docs/contribute/code_guide.rst +++ b/docs/contribute/code_guide.rst @@ -139,7 +139,16 @@ If you want your test to run over a variety of targets, use the :py:func:`tvm.te def test_mytest(target, dev): ... -will run ``test_mytest`` with ``target="llvm"``, ``target="cuda"``, and few others. This also ensures that your test is run on the correct hardware by the CI. If you only want to test against a couple targets use ``@tvm.testing.parametrize_targets("target_1", "target_2")``. If you want to test on a single target, use the associated decorator from :py:func:`tvm.testing`. For example, CUDA tests use the ``@tvm.testing.requires_cuda`` decorator. +will run ``test_mytest`` with ``target="llvm"``, ``target="cuda"``, and few others. This also ensures that your test is run on the correct hardware by the CI. If you only want to test against a couple targets use ``@tvm.testing.parametrize_targets("target_1", "target_2")``. If you want to test on a single target, gate the test on the corresponding capability probe instead of using a per-target decorator. Mark GPU tests with ``@pytest.mark.gpu`` so the CI can select them, and skip when the required feature is unavailable with ``@pytest.mark.skipif``. For example, CUDA tests use: + +.. code:: python + + @pytest.mark.gpu + @pytest.mark.skipif(not tvm.testing.env.has_cuda(), reason="need cuda") + def test_mycudatest(): + ... + +The ``tvm.testing.env`` module exposes a ``has_*()`` probe for each runtime and hardware feature (e.g. ``has_cuda()``, ``has_rocm()``, ``has_vulkan()``, ``has_llvm()``). To skip a test when an optional Python package is missing, use ``pytest.importorskip("package_name")``. Network Resources diff --git a/docs/contribute/testing.rst b/docs/contribute/testing.rst index c2f502503099..c5777bad38a0 100644 --- a/docs/contribute/testing.rst +++ b/docs/contribute/testing.rst @@ -111,9 +111,9 @@ parameters. For instance, there may be target-specific implementations that should be tested, where some targets have more than one implementation. These can be done by explicitly parametrizing over tuples of arguments, such as shown below. In these -cases, only the explicitly listed targets will run, but they will -still have the appropriate ``@tvm.testing.requires_RUNTIME`` mark -applied to them. +cases, only the explicitly listed targets will run, and each target is +automatically gated on whether it can run on the current machine (a GPU +target gets ``@pytest.mark.gpu`` plus a skip when no device is present). .. code-block:: python @@ -134,34 +134,34 @@ marks are as follows. - ``@pytest.mark.gpu`` - Tags a function as using GPU capabilities. This has no effect on its own, but can be paired with - command-line arguments ``-m gpu`` or ``-m 'not gpu'`` to restrict - which tests pytest will execute. This should not be called on its - own, but is part of other marks used in unit-tests. - -- ``@tvm.testing.uses_gpu`` - Applies ``@pytest.mark.gpu``. This - should be used to mark unit tests that may use the GPU, if one is - present. This decorator is only needed for tests that explicitly - loop over ``tvm.testing.enabled_targets()``, but that is no longer - the preferred style of writing unit tests (see below). When using - ``tvm.testing.parametrize_targets()``, this decorator is implicit - for GPU targets, and does not need to be explicitly applied. - -- ``@tvm.testing.requires_gpu`` - Applies ``@tvm.testing.uses_gpu``, - and additionally marks that the test should be skipped - (``@pytest.mark.skipif``) entirely if no GPU is present. - -- ``@tvm.testing.requires_RUNTIME`` - Several decorators - (e.g. ``@tvm.testing.requires_cuda``), each of which skips a test if - the specified runtime cannot be used. A runtime cannot be used if it - is disabled in the ``config.cmake``, or if a compatible device is - not present. For runtimes that use the GPU, this includes - ``@tvm.testing.requires_gpu``. - -When using parametrized targets, each test run is decorated with the -``@tvm.testing.requires_RUNTIME`` that corresponds to the target -being used. As a result, if a target is disabled in ``config.cmake`` -or does not have appropriate hardware to run, it will be explicitly -listed as skipped. + the command-line arguments ``-m gpu`` or ``-m 'not gpu'`` to restrict + which tests pytest will execute. Apply it to any test that needs a + GPU so that the CI runs it only on GPU nodes. + +- ``@pytest.mark.skipif(not tvm.testing.env.has_X(), reason=...)`` - + Skips a test when a required runtime or hardware feature is not + available. The :py:mod:`tvm.testing.env` module exposes one memoized + probe per capability (e.g. ``has_cuda()``, ``has_rocm()``, + ``has_vulkan()``, ``has_gpu()``, ``has_llvm()``), each of which + returns ``False`` when the runtime is disabled in ``config.cmake`` or + no compatible device is present. Pair it with ``@pytest.mark.gpu`` + for tests that use the GPU:: + + @pytest.mark.gpu + @pytest.mark.skipif(not tvm.testing.env.has_cuda(), reason="need cuda") + def test_cuda_vectorize_add(): + # Test code goes here + +- ``pytest.importorskip("package_name")`` - Skips a test (or the whole + module, when called at import time) if an optional Python package is + not installed. Use this instead of a ``skipif`` for package + dependencies. + +When using ``tvm.testing.parametrize_targets()``, each parametrized run +is gated automatically on whether its target can run on the current +machine. As a result, if a target is disabled in ``config.cmake`` or +does not have appropriate hardware to run, it will be explicitly listed +as skipped, and GPU targets are tagged with ``@pytest.mark.gpu`` for you. There also exists a ``tvm.testing.enabled_targets()`` that returns all targets that are enabled and runnable on the current machine, diff --git a/python/tvm/testing/env.py b/python/tvm/testing/env.py index 6b39b5f2f674..e70105a95a10 100644 --- a/python/tvm/testing/env.py +++ b/python/tvm/testing/env.py @@ -115,9 +115,9 @@ def _device_exists(kind: str, index: int = 0) -> bool: def _build_flag_enabled(flag: str) -> bool: """Return whether an optional build flag (e.g. ``USE_CUTLASS``) is on. - Mirrors the historical ``Feature`` check: a flag counts as enabled - unless it is explicitly disabled, so library flags carrying a path - still register as present. + A flag counts as enabled unless it is explicitly disabled, so library + flags carrying a path (rather than a boolean) still register as present. + Callers gate on this via ``@pytest.mark.skipif(not has_cutlass(), ...)``. """ try: value = tvm.support.libinfo().get(flag, "OFF") @@ -130,8 +130,8 @@ def _build_flag_enabled(flag: str) -> bool: def _target_enabled(kind: str) -> bool: """True if ``kind`` is selected by ``TVM_TEST_TARGETS`` (or the default set). - Restores the historical ``target_kind_enabled`` opt-out, so CI can exclude a - flaky backend (e.g. opencl) via ``TVM_TEST_TARGETS`` and have its tests skip + Honors the ``TVM_TEST_TARGETS`` opt-out, so CI can exclude a flaky + backend (e.g. opencl) via ``TVM_TEST_TARGETS`` and have its tests skip even when a device is physically present. """ try: @@ -343,8 +343,9 @@ def _nvcc_version() -> tuple: def has_nvcc_version(major: int, minor: int = 0, release: int = 0) -> bool: """True if a CUDA device is present and nvcc is at least ``(major, minor, release)``. - Implies :func:`has_cuda`, matching the historical ``requires_nvcc_version`` - decorator which also required the CUDA runtime. + Returns False when no CUDA device is present, so it implies :func:`has_cuda`. + Gate a test with ``@pytest.mark.skipif(not env.has_nvcc_version(11, 4), + reason="need nvcc >= 11.4")`` (add ``@pytest.mark.gpu`` for GPU selection). """ return has_cuda() and _nvcc_version() >= (major, minor, release) @@ -389,9 +390,10 @@ def has_matrixcore() -> bool: def has_cudagraph() -> bool: """True if a CUDA device is present and the toolkit supports CUDA Graphs. - Implies :func:`has_cuda`, matching the historical ``requires_cudagraph`` - decorator (``parent_features="cuda"``): ``nvcc.have_cudagraph()`` only - checks the toolkit version, so the device guard must be explicit. + Implies :func:`has_cuda`: ``nvcc.have_cudagraph()`` only checks the + toolkit version, so the device guard must be explicit. Gate a test with + ``@pytest.mark.skipif(not tvm.testing.env.has_cudagraph(), reason=...)`` + (add ``@pytest.mark.gpu`` for CI selection). """ try: from tvm.support import nvcc # pylint: disable=import-outside-toplevel diff --git a/python/tvm/testing/plugin.py b/python/tvm/testing/plugin.py index bba2da6aee0d..91aeb6374f34 100644 --- a/python/tvm/testing/plugin.py +++ b/python/tvm/testing/plugin.py @@ -210,9 +210,11 @@ def update_parametrize_target_arg( raise TypeError(msg) from err if "target" in metafunc.fixturenames: - # Update any explicit use of @pytest.mark.parmaetrize to - # parametrize over targets. This adds the appropriate - # @tvm.testing.requires_* markers for each target. + # Update any explicit use of @pytest.mark.parametrize to + # parametrize over targets. This attaches the appropriate + # per-target gating markers (pytest.mark.gpu for GPU-family + # targets, plus a pytest.mark.skipif guarded by the relevant + # tvm.testing.env.has_*() probe) via _target_to_requirement. for mark in metafunc.definition.iter_markers("parametrize"): update_parametrize_target_arg(mark, *mark.args, **mark.kwargs) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index c90e610af4d6..5e9820864340 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -29,38 +29,38 @@ Testing Markers *************** -We use pytest markers to specify the requirements of test functions. Currently -there is a single distinction that matters for our testing environment: does -the test require a gpu. For tests that require just a gpu or just a cpu, we -have the decorator :py:func:`requires_gpu` that enables the test when a gpu is -available. To avoid running tests that don't require a gpu on gpu nodes, this -decorator also sets the pytest marker `gpu` so we can use select the gpu subset -of tests (using `pytest -m gpu`). - -Unfortunately, many tests are written like this: +We use pytest markers to specify the requirements of test functions. +Currently there is a single distinction that matters for our testing +environment: does the test require a gpu. Tests that require a gpu are +tagged with the ``gpu`` pytest marker -- the only registered marker (see +the ``markers`` entry in ``pyproject.toml``). This lets us select the +gpu subset of tests with ``pytest -m gpu`` (and exclude them on cpu-only +nodes with ``pytest -m "not gpu"``). + +The ``gpu`` marker only controls which testing node a test runs on; it +does not check whether the required hardware or libraries are actually +present. To gate a test on a specific capability, combine the marker +with a ``skipif`` that consults the memoized environment probes in +:py:mod:`tvm.testing.env`: .. code-block:: python - def test_something(): - for target in all_targets(): - do_something() - -The test uses both gpu and cpu targets, so the test needs to be run on both cpu -and gpu nodes. But we still want to only run the cpu targets on the cpu testing -node. The solution is to mark these tests with the gpu marker so they will be -run on the gpu nodes. But we also modify all_targets (renamed to -enabled_targets) so that it only returns gpu targets on gpu nodes and cpu -targets on cpu nodes (using an environment variable). - -Instead of using the all_targets function, future tests that would like to -test against a variety of targets should use the -:py:func:`tvm.testing.parametrize_targets` functionality. This allows us -greater control over which targets are run on which testing nodes. - -If in the future we want to add a new type of testing node (for example -fpgas), we need to add a new marker in `tests/python/pytest.ini` and a new -function in this module. Then targets using this node should be added to the -`TVM_TEST_TARGETS` environment variable in the CI. + @pytest.mark.gpu + @pytest.mark.skipif(not tvm.testing.env.has_cuda(), reason="need cuda") + def test_cuda_vectorize_add(): + ... + +There is one ``has_*`` (or ``is_*``) probe per capability -- for example +:py:func:`tvm.testing.env.has_gpu`, :py:func:`tvm.testing.env.has_cuda`, +and :py:func:`tvm.testing.env.has_vulkan`. For optional Python packages, +prefer ``pytest.importorskip("pkg_name")`` instead of a ``skipif``. + +To run a test against a variety of targets, use +:py:func:`tvm.testing.parametrize_targets`; it parametrizes the test over +the enabled targets and applies the appropriate ``gpu`` tag and skip +conditions per target automatically. The set of enabled targets is +controlled by the ``TVM_TEST_TARGETS`` environment variable, so the CI +can run different targets on different testing nodes. """ From 0c1ed7fc71cf41796e18737767593edfb46be346 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 16 Jun 2026 07:30:38 -0400 Subject: [PATCH 2/3] Update python/tvm/testing/env.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- python/tvm/testing/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/testing/env.py b/python/tvm/testing/env.py index e70105a95a10..502de47218f4 100644 --- a/python/tvm/testing/env.py +++ b/python/tvm/testing/env.py @@ -344,7 +344,7 @@ def has_nvcc_version(major: int, minor: int = 0, release: int = 0) -> bool: """True if a CUDA device is present and nvcc is at least ``(major, minor, release)``. Returns False when no CUDA device is present, so it implies :func:`has_cuda`. - Gate a test with ``@pytest.mark.skipif(not env.has_nvcc_version(11, 4), + Gate a test with ``@pytest.mark.skipif(not tvm.testing.env.has_nvcc_version(11, 4), reason="need nvcc >= 11.4")`` (add ``@pytest.mark.gpu`` for GPU selection). """ return has_cuda() and _nvcc_version() >= (major, minor, release) From 417a64561d3227a6f81537f17a6951968661deb6 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 16 Jun 2026 07:30:44 -0400 Subject: [PATCH 3/3] Update python/tvm/testing/env.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- python/tvm/testing/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/testing/env.py b/python/tvm/testing/env.py index 502de47218f4..0c9b48e5c16a 100644 --- a/python/tvm/testing/env.py +++ b/python/tvm/testing/env.py @@ -117,7 +117,7 @@ def _build_flag_enabled(flag: str) -> bool: A flag counts as enabled unless it is explicitly disabled, so library flags carrying a path (rather than a boolean) still register as present. - Callers gate on this via ``@pytest.mark.skipif(not has_cutlass(), ...)``. + Callers gate on this via ``@pytest.mark.skipif(not tvm.testing.env.has_cutlass(), ...)``. """ try: value = tvm.support.libinfo().get(flag, "OFF")