From 8b34b14e98581be640cdf82097386f61e0872819 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Mon, 15 Jun 2026 17:34:39 -0400 Subject: [PATCH] [TIRX][CUDA] Framework support for FA4, CLC intrinsics, and nvfp4 tcgen05 GEMM Batch of tirx CUDA backend framework updates: - FA4: env-driven ptxas register level and scheduler num_ctas support. - clusterlaunchcontrol (CLC) device intrinsics and a CLC-based tile scheduler. - Framework support for nvfp4 tcgen05 GEMM. - Scope-level operands for warp/wg/cta register elementwise ops. - LLVM codegen diagnostic for duplicate PrimFunc global symbols. - Default CUDA compilation to NVRTC. Robustness / tests: - Cross-CTA mbarrier arrive intrinsics: guard with `setp.ne.s32 p, %2, 0` instead of `setp.eq.u32 p, %2, 1` so any non-zero `int pred` is treated as true (C boolean semantics), matching the `int pred` signature. - Harden the NVRTC path to always define the vector-deprecation silencing macros, so device-code compilation does not depend on which CUDA header include chain is pulled in. - Wire tests/python/tirx into the unittest CI task. The suite targets Blackwell (sm_100a); a directory conftest gates it on a real sm_100a device so it skips cleanly on CPU nodes / pre-sm_100 GPUs (where ptxas/NVRTC would otherwise reject tcgen05 / cp.async `.async` / fp8) and runs in full where the hardware is present. - Add `gpu` markers and CUDA compute-capability skipifs across the tirx tests. Tests under tests/python/tirx pass locally on sm_100a (B200). Signed-off-by: spectrometerHBH --- python/tvm/backend/cuda/lang/pipeline.py | 11 +- .../tvm/backend/cuda/lang/tile_scheduler.py | 135 +++++++++++++- python/tvm/backend/cuda/op.py | 79 +++++++- .../backend/cuda/operator/intrinsics/sync.py | 100 ++++++++++- .../tile_primitive/copy_async/tcgen05_ldst.py | 35 ++-- .../tile_primitive/elementwise/reg.py | 67 +++++++ python/tvm/backend/cuda/script.py | 6 + python/tvm/support/nvcc.py | 76 ++++++-- .../tirx/script/builder/external_kernel.py | 2 +- src/backend/cuda/op/target_builtin.cc | 6 + src/target/llvm/codegen_llvm.cc | 17 ++ src/target/llvm/codegen_llvm.h | 3 + src/tirx/ir/layout/tile_slice.cc | 6 +- .../codegen/test_target_codegen_llvm.py | 39 ++++ .../python/tirx/codegen/test_codegen_cuda.py | 11 ++ .../tirx/codegen/test_codegen_nvshmem.py | 3 + tests/python/tirx/codegen/test_cuda_copy.py | 11 ++ .../tirx/codegen/test_cuda_cta_reduce.py | 13 ++ .../tirx/codegen/test_cuda_warp_reduce.py | 13 ++ tests/python/tirx/conftest.py | 40 +++++ .../tile_primitive/cuda/copy/test_fallback.py | 5 + .../cuda/copy/test_gmem_smem.py | 4 + .../tile_primitive/cuda/copy/test_reg.py | 5 + .../cuda/copy_async/test_ldgsts.py | 3 + .../cuda/copy_async/test_tmem.py | 7 + .../cuda/copy_async/test_tmem_16xnb.py | 144 +++++++++++++++ .../cuda/elementwise/test_binary.py | 13 ++ .../cuda/elementwise/test_fma.py | 15 ++ .../cuda/elementwise/test_unary.py | 168 +++++++++++++++++- .../cuda/gemm_async/test_gemm_async.py | 23 +++ .../permute_layout/test_permute_layout.py | 7 + .../cuda/reduction/test_reduction.py | 23 +++ tests/python/tirx/test_buffer_print.py | 4 + tests/python/tirx/test_control_flow.py | 8 + tests/python/tirx/test_layout.py | 35 ++++ tests/scripts/task_python_unittest.sh | 1 + 36 files changed, 1096 insertions(+), 42 deletions(-) create mode 100644 tests/python/tirx/conftest.py diff --git a/python/tvm/backend/cuda/lang/pipeline.py b/python/tvm/backend/cuda/lang/pipeline.py index ee86090398e9..40fd40c3fac6 100644 --- a/python/tvm/backend/cuda/lang/pipeline.py +++ b/python/tvm/backend/cuda/lang/pipeline.py @@ -110,7 +110,7 @@ def wait(self, stage, phase): T.ptx.mbarrier.try_wait(self.buf.ptr_to([stage]), phase ^ self.phase_offset) @T.inline - def arrive(self, stage, cta_id=None, pred=None): + def arrive(self, stage, cta_id=None, pred=None, count=None): # Default: local-CTA arrive — emits the simple # ``mbarrier.arrive.shared.b64`` form. To arrive on a remote # CTA's mbarrier in a cluster kernel, callers must pass @@ -119,11 +119,18 @@ def arrive(self, stage, cta_id=None, pred=None): # the cross-CTA path was both surprising (``bar.arrive(stage)`` # silently ``mapa`` ed across the cluster) and a per-call cost # of ~3 PTX ops on every single-CTA kernel. + # + # ``count`` (cross-CTA path only) emits the explicit arrival-count + # operand, i.e. ``mbarrier.arrive.shared::cluster.b64 _, [addr], count``. + # When ``None`` the implicit count-of-1 form is emitted. Passing + # ``count=1`` is semantically identical but spells the count explicitly. if cta_id is None: T.ptx.mbarrier.arrive(self.buf.ptr_to([stage])) else: actual_pred = True if pred is None else pred - T.ptx.mbarrier.arrive(self.buf.ptr_to([stage]), cta_id=cta_id, pred=actual_pred) + T.ptx.mbarrier.arrive( + self.buf.ptr_to([stage]), cta_id=cta_id, pred=actual_pred, count=count + ) def ptr_to(self, idx): return self.buf.ptr_to(idx) diff --git a/python/tvm/backend/cuda/lang/tile_scheduler.py b/python/tvm/backend/cuda/lang/tile_scheduler.py index 3fd27f25ee5f..c6154f2462f6 100644 --- a/python/tvm/backend/cuda/lang/tile_scheduler.py +++ b/python/tvm/backend/cuda/lang/tile_scheduler.py @@ -20,6 +20,7 @@ instances are automatically treated as meta values inside @T.prim_func. """ +from tvm.backend.cuda.lang.pipeline import Pipeline, PipelineState from tvm.script import tirx as T @@ -753,13 +754,20 @@ class FlashAttentionLPTScheduler(BaseTileScheduler): """ def __init__( - self, prefix: str, num_batches: int, num_heads: int, num_m_blocks: int, l2_swizzle: int + self, + prefix: str, + num_batches: int, + num_heads: int, + num_m_blocks: int, + l2_swizzle: int, + num_ctas: int | None = None, ): super().__init__(prefix) self._num_batches = num_batches self._num_heads = num_heads self._num_m_blocks = num_m_blocks self._l2_swizzle = l2_swizzle + self._num_ctas = num_ctas self._total_tasks = num_batches * num_heads * num_m_blocks # Derived constants for L2 swizzle @@ -807,10 +815,131 @@ def init(self, cta_id): @T.inline def next_tile(self): - """Advance to next tile by striding by num_ctas.""" - self.linear_idx = self._total_tasks + """Advance to the next tile. + + Single-tile mode (``num_ctas=None``, the default): each CTA owns one + task; terminate. Persistent mode (``num_ctas=N``): stride by N, like + :class:`FlashAttentionLinearScheduler`, while keeping the LPT + L2 + swizzle index mapping. + """ + if self._num_ctas is None: + self.linear_idx = self._total_tasks + else: + self.linear_idx = self.linear_idx + self._num_ctas + self.update_current_m_n_idx(self.linear_idx) # fmt: on def valid(self): """Check if there are more tiles to process.""" return self.linear_idx < self._total_tasks + + +class _CLCWorker(ClusterPersistentScheduler2D): + """Per-role CLC handle: IS-A ClusterPersistentScheduler2D (so m_idx / n_idx work as + usual) plus the role-local barrier phase and handshake. A coord-free role (e.g. an + MMA warp consuming whatever a loader staged) arms the loop with reset() not init(). + """ + + def __init__(self, clc, prefix): + super().__init__( + prefix, + num_m_tiles=clc._num_m_tiles, + num_n_tiles=clc._num_n_tiles, + num_clusters=clc._num_m_tiles * clc._num_n_tiles, + l2_group_size=clc._l2_group_size, + ) + self._clc = clc + self._sa = PipelineState(1, 0) + self._done = T.local_scalar("int32") + self._nxt = T.local_scalar("uint32") + + @T.inline + def reset(self): + self._done = 0 + + @T.inline + def init(self, cluster_id): + # Explicit base call: TVMScript's parser has no zero-arg super(). + ClusterPersistentScheduler2D.init(self, cluster_id) + self._done = 0 + + def valid(self): + return self._done == 0 + + @T.inline + def consume(self): + # Single-elected-thread scope: wait for the handle, decode, release the slot. + self._clc.sched_arr.full.wait(0, self._sa.phase) + self._sa.advance() + self._nxt = T.ptx.clc_query_cancel(T.address_of(self._clc.clc_handle[0])) + self._clc.sched_fin.empty.arrive(0, cta_id=0, pred=True) + + @T.inline + def consume_wg(self, wg_id, warp_id, lane_id): + # Warpgroup scope: all threads decode; one elected lane releases the slot. + self._clc.sched_arr.full.wait(0, self._sa.phase) + self._sa.advance() + self._nxt = T.ptx.clc_query_cancel(T.address_of(self._clc.clc_handle[0])) + T.cuda.warpgroup_sync(wg_id + 1) + if (warp_id == 0) & (lane_id == 0): + self._clc.sched_fin.empty.arrive(0, cta_id=0, pred=True) + + @T.inline + def advance_coords(self): + if self._nxt != 0xFFFFFFFF: + self.update_current_m_n_idx(self._nxt // self._clc._cta_group) + + @T.inline + def mark_done_if_drained(self): + if self._nxt == 0xFFFFFFFF: + self._done = 1 + + +@T.meta_class +class ClusterLaunchControlScheduler: + """Blackwell Cluster Launch Control (CLC) tile scheduler. + + A scheduler warp runs ``run_scheduler`` (issues ``try_cancel`` to steal the next + cluster); worker roles each take a ``worker()`` handle and pull the stolen tile + through the shared smem handshake. Owns the CLC smem: the 16B response handle, the + arrival barrier (handle ready), and the finished barrier (slot consumed; + ``finish_arrivals`` arrivals per round). Tile-coord mapping is delegated to + ``ClusterPersistentScheduler2D`` (group-major L2 ordering). + """ + + def __init__(self, pool, num_m_tiles, num_n_tiles, l2_group_size, cta_group, finish_arrivals): + self._num_m_tiles = num_m_tiles + self._num_n_tiles = num_n_tiles + self._l2_group_size = l2_group_size + self._cta_group = cta_group + self.sched_arr = Pipeline(pool, 1, full="tma", empty="mbar", init_empty=1) + self.sched_fin = Pipeline(pool, 1, full="mbar", empty="mbar", init_empty=finish_arrivals) + self.clc_handle = pool.alloc((4,), "uint32", align=16) + self._s_done = T.local_scalar("int32") + self._s_nxt = T.local_scalar("uint32") + + def worker(self, prefix): + return _CLCWorker(self, prefix) + + @T.inline + def run_scheduler(self, cbx): + # cta0 drives try_cancel; both CTAs expect_bytes + consume the handle so the + # finished-barrier count is met and the slot can be reissued. + if T.ptx.elect_sync(): + sa = PipelineState(1, 0) + sf = PipelineState(1, 1) + self._s_done = 0 + while self._s_done == 0: + if cbx == 0: + self.sched_fin.empty.wait(0, sf.phase) + sf.advance() + T.ptx.clc_try_cancel( + T.address_of(self.clc_handle[0]), T.address_of(self.sched_arr.full.buf[0]) + ) + self.sched_arr.full.arrive(0, 16) # expect_bytes for the 16B handle + self.sched_arr.full.wait(0, sa.phase) + sa.advance() + self._s_nxt = T.ptx.clc_query_cancel(T.address_of(self.clc_handle[0])) + self.sched_fin.empty.arrive(0, cta_id=0, pred=True) + if self._s_nxt == 0xFFFFFFFF: + self._s_done = 1 diff --git a/python/tvm/backend/cuda/op.py b/python/tvm/backend/cuda/op.py index e76d5fbe2452..9570e266623c 100644 --- a/python/tvm/backend/cuda/op.py +++ b/python/tvm/backend/cuda/op.py @@ -653,12 +653,12 @@ def ptx_mbarrier_init(bar, thread_count): return call_intrin("", "tirx.ptx_mbarrier_init", bar, thread_count) -def ptx_mbarrier_arrive(bar, cta_id=None, pred=None): +def ptx_mbarrier_arrive(bar, cta_id=None, pred=None, count=None): """TVM intrinsic to call mbarrier.arrive.shared::cta.b64 or @p mapa.shared::cluster.u32 - @p mbarrier.arrive.shared::cluster.b64 + @p mbarrier.arrive.shared::cluster.b64 [, count] Parameters ---------- @@ -670,11 +670,29 @@ def ptx_mbarrier_arrive(bar, cta_id=None, pred=None): pred : Optional[PrimExpr] The predicate to guard the operation. + + count : Optional[PrimExpr] + Explicit arrival count operand for the cross-CTA (cluster) form. When + ``None`` the implicit count-of-1 form is emitted; when given, emits + ``mbarrier.arrive.shared::cluster.b64 _, [addr], count``. """ if cta_id is None and pred is None: return call_intrin("", "tirx.ptx_mbarrier_arrive", bar) assert cta_id is not None and pred is not None - return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, pred) + if count is None: + return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, pred) + return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, pred, count) + + +def ptx_mbarrier_arrive_cluster_count(bar, cta_id, count): + """Cross-CTA ``mbarrier.arrive`` on CTA ``cta_id`` with an explicit count. + + Convenience for an already-elected thread: emits + ``@p mapa.shared::cluster.u32`` + ``@p mbarrier.arrive.shared::cluster.b64 _, + [addr], count`` with the guard defaulted to 1. + """ + return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, True, count) + def ptx_mbarrier_arrive_expect_tx(bar, byte_count, cta_id=None, pred=None): @@ -706,7 +724,11 @@ def ptx_mbarrier_arrive_expect_tx(bar, byte_count, cta_id=None, pred=None): """ if cta_id is None and pred is None: return call_intrin("", "tirx.ptx_mbarrier_arrive_expect_tx", bar, byte_count) - assert cta_id is not None and pred is not None + assert cta_id is not None + # Cross-CTA expect_tx from an already-elected thread: default the guard to 1 + # (the caller has elected a single lane), so callers can pass cta_id alone. + if pred is None: + pred = True return call_intrin("", "tirx.ptx_mbarrier_arrive_expect_tx", bar, byte_count, cta_id, pred) @@ -729,6 +751,23 @@ def ptx_mbarrier_try_wait(bar, phase): return call_intrin("", "tirx.ptx_mbarrier_try_wait", bar, phase) +def ptx_mbarrier_try_wait_acquire_cluster(bar, phase): + """``mbarrier.try_wait.parity.acquire.cluster`` retry loop. + + Cluster-scope acquire wait — used to wait on a barrier that a remote CTA in + the cluster arrives on (a group cluster wait). + + Parameters + ---------- + bar : Var + The pointer to barrier variable. + + phase : int + The phase of the barrier. + """ + return call_intrin("", "tirx.ptx_mbarrier_try_wait_acquire_cluster", bar, phase) + + def ptx_mbarrier_try_wait_once(bar, phase, ticks): """TVM intrinsic for one-shot non-blocking ``mbarrier.try_wait.parity``. @@ -1261,6 +1300,38 @@ def ptx_barrier_cluster_wait(acquire=False, aligned=True): return call_intrin("", "tirx.ptx_barrier_cluster_wait", acquire, aligned) +def ptx_clc_try_cancel(handle, mbar): + """TVM intrinsic to call clusterlaunchcontrol.try_cancel. + + Async-requests cancelling the next cluster's launch (work-stealing): writes the + 16B response handle to smem and signals ``mbar`` (complete_tx, multicast to both + cluster CTAs). + + Parameters + ---------- + handle : PrimExpr + Pointer to the 16B (uint4) smem response handle. + + mbar : PrimExpr + Pointer to the mbarrier signalled when the handle lands. + """ + return call_intrin("", "tirx.ptx_clc_try_cancel", handle, mbar) + + +def ptx_clc_query_cancel(handle): + """TVM intrinsic to call clusterlaunchcontrol.query_cancel. + + Decodes the response handle written by :func:`ptx_clc_try_cancel`. Returns the + cancelled cluster's first ``ctaid.x``, or ``0xFFFFFFFF`` when no work was stolen. + + Parameters + ---------- + handle : PrimExpr + Pointer to the 16B (uint4) smem response handle. + """ + return call_intrin("uint32", "tirx.ptx_clc_query_cancel", handle) + + def ptx_elect_sync(): """TVM intrinsic to call elect.sync""" return call_intrin("uint32", "tirx.ptx_elect_sync") diff --git a/python/tvm/backend/cuda/operator/intrinsics/sync.py b/python/tvm/backend/cuda/operator/intrinsics/sync.py index 0fcdb31a46f1..791d9cc981fc 100644 --- a/python/tvm/backend/cuda/operator/intrinsics/sync.py +++ b/python/tvm/backend/cuda/operator/intrinsics/sync.py @@ -168,6 +168,54 @@ def _ptx_barrier_cluster_wait(acquire, aligned): ) +# ============================================================================= +# clusterlaunchcontrol.try_cancel / query_cancel — Blackwell Cluster Launch +# Control (CLC) work-stealing, written from the PTX ISA spec (section +# "clusterlaunchcontrol", PTX ISA 8.6). try_cancel async-requests cancelling the +# next cluster's launch, writing a 16B response to smem + signalling mbar. query +# decodes the response: on success it extracts the cancelled cluster's first +# ctaid.x (via the get_first_ctaid::x form); a single uint32 is returned, with +# 0xFFFFFFFF as the "no work stolen" sentinel (a device helper returns one scalar). +# ============================================================================= +device_intrinsic( + "ptx_clc_try_cancel", + c_signature="(void* handle, void* mbar)", + body=( + " unsigned int addr = (unsigned int)__cvta_generic_to_shared(handle);\n" + " unsigned int bar = (unsigned int)__cvta_generic_to_shared(mbar);\n" + " asm volatile(\n" + ' "clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes"\n' + ' ".multicast::cluster::all.b128 [%0], [%1];\\n"\n' + ' :: "r"(addr), "r"(bar) : "memory");' + ), +) + + +device_intrinsic( + "ptx_clc_query_cancel", + c_signature="(void* handle)", + return_type="uint32_t", + tvm_return_type="uint32", + body=( + " unsigned int addr = (unsigned int)__cvta_generic_to_shared(handle);\n" + " unsigned int first_ctaid_x;\n" + " asm volatile(\n" + ' "{\\n"\n' + ' ".reg .pred canceled;\\n"\n' + ' ".reg .b128 response;\\n"\n' + ' "ld.shared.b128 response, [%1];\\n"\n' + ' "clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 canceled, response;\\n"\n' + ' "mov.u32 %0, 0xffffffff;\\n"\n' + ' "@canceled clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128"\n' + ' " %0, response;\\n"\n' + ' "}\\n"\n' + ' : "=r"(first_ctaid_x) : "r"(addr) : "memory");\n' + ' asm volatile("fence.proxy.async.shared::cta;\\n" ::: "memory");\n' + " return first_ctaid_x;" + ), +) + + # ============================================================================= # mbarrier.init.shared.b64 [addr], count ; — 1 form. # ============================================================================= @@ -208,7 +256,7 @@ def _ptx_barrier_cluster_wait(acquire, aligned): ' "{\\n"\n' ' ".reg .pred p;\\n"\n' ' ".reg .b32 remAddr32;\\n"\n' - ' "setp.eq.u32 p, %2, 1;\\n"\n' + ' "setp.ne.s32 p, %2, 0;\\n"\n' ' "@p mapa.shared::cluster.u32 remAddr32, %0, %1;\\n"\n' ' "@p mbarrier.arrive.shared::cluster.b64 _, [remAddr32];\\n"\n' ' "}\\n"\n' @@ -217,15 +265,38 @@ def _ptx_barrier_cluster_wait(acquire, aligned): ) +# Same cross-CTA arrive, but with an explicit arrival-count operand +# (``..., [remAddr32], count``). Matches the ``tma::cluster::arrive`` spelling. +device_intrinsic( + "_ptx_mbarrier_arrive_remote_count", + helper_name="tvm_builtin_ptx_mbarrier_arrive_remote_count", + c_signature="(void* barrier, int cta_id, int pred, int count)", + body=( + " unsigned int barrier_addr = __cvta_generic_to_shared(barrier);\n" + " asm volatile(\n" + ' "{\\n"\n' + ' ".reg .pred p;\\n"\n' + ' ".reg .b32 remAddr32;\\n"\n' + ' "setp.ne.s32 p, %2, 0;\\n"\n' + ' "@p mapa.shared::cluster.u32 remAddr32, %0, %1;\\n"\n' + ' "@p mbarrier.arrive.shared::cluster.b64 _, [remAddr32], %3;\\n"\n' + ' "}\\n"\n' + ' :: "r"(barrier_addr), "r"(cta_id), "r"(pred), "r"(count) : "memory");' + ), +) + + @register_codegen("ptx_mbarrier_arrive") def _codegen_mbarrier_arrive(*args): - """Dispatch by arg count: 1 -> local, 3 -> remote (cluster-mapped).""" + """Dispatch by arg count: 1 -> local, 3 -> remote, 4 -> remote+count.""" if len(args) == 1: result = CODEGEN_REGISTRY["tirx._ptx_mbarrier_arrive_local"](list(args)) elif len(args) == 3: result = CODEGEN_REGISTRY["tirx._ptx_mbarrier_arrive_remote"](list(args)) + elif len(args) == 4: + result = CODEGEN_REGISTRY["tirx._ptx_mbarrier_arrive_remote_count"](list(args)) else: - raise ValueError(f"ptx_mbarrier_arrive expects 1 or 3 args, got {len(args)}") + raise ValueError(f"ptx_mbarrier_arrive expects 1, 3, or 4 args, got {len(args)}") return result[0] if isinstance(result, tuple) else result @@ -252,7 +323,7 @@ def _codegen_mbarrier_arrive(*args): ' "{\\n"\n' ' ".reg .pred p;\\n"\n' ' ".reg .b32 remAddr32;\\n"\n' - ' "setp.eq.u32 p, %2, 1;\\n"\n' + ' "setp.ne.s32 p, %2, 0;\\n"\n' ' "@p mapa.shared::cluster.u32 remAddr32, %0, %1;\\n"\n' ' "@p mbarrier.arrive.expect_tx.shared::cluster.b64 _, [remAddr32], %3;\\n"\n' ' "}\\n"\n' @@ -303,6 +374,27 @@ def _codegen_mbarrier_arrive_expect_tx(*args): ) +# mbarrier.try_wait.parity.acquire.cluster — cluster-scope acquire wait used for +# cross-CTA barrier handshakes (e.g. the tmem-finished handoff). +device_intrinsic( + "ptx_mbarrier_try_wait_acquire_cluster", + c_signature="(void* barrier, int phase)", + body=( + " unsigned int barrier_addr_int = __cvta_generic_to_shared(barrier);\n" + " asm volatile(\n" + ' "{\\n"\n' + ' ".reg .pred P1;\\n"\n' + ' "LAB_WAIT_AC:\\n"\n' + ' "mbarrier.try_wait.parity.acquire.cluster.shared::cta.b64 P1, [%0], %1;\\n"\n' + ' "@P1 bra.uni DONE_AC;\\n"\n' + ' "bra.uni LAB_WAIT_AC;\\n"\n' + ' "DONE_AC:\\n"\n' + ' "}\\n"\n' + ' :: "r"(barrier_addr_int), "r"(phase) : "memory");' + ), +) + + # ============================================================================= # mbarrier.try_wait.parity — ONE-SHOT non-blocking variant. Returns true # if the requested parity has already been reached, false otherwise. diff --git a/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_ldst.py b/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_ldst.py index ffd5e18a3a5c..081ea5a772d3 100644 --- a/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_ldst.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_ldst.py @@ -369,20 +369,24 @@ def _emit_16xnb_path( tmem_st, tmem_extent = get_st_extent(tmem_region) local_st, local_extent = get_st_extent(local_region) - # Local slice must be the full (frag_rows, K_cols) view. + # Rows must span the full frag. The COLUMN extent may be a sub-multiple of + # the atom's full width ``width_elems`` — i.e. a per-chunk column slice of a + # wider frag (e.g. an epilogue that loads one big (128, MMA_N) frag in + # EPI_TILE-wide chunks). The atom layout maps consecutive columns to + # consecutive registers within each slab, so a column slice occupies a + # contiguous register window; we emit ``num_eff`` (the slice's atom rep) at + # the slab base + the column's register offset. When the slice IS the full + # atom (the common case), num_eff == num and reg offset == 0 (no change). assert analyzer.can_prove_equal(local_st[0], 0) assert analyzer.can_prove_equal(local_extent[0], frag_rows) - assert analyzer.can_prove_equal(local_extent[1], width_elems) - - # TMEM slice must start at row 0 and span ``frag_rows`` rows. For Layout - # F the buffer is already (64, W) so frag_rows=64 covers the full slice; - # for Layout D + frag_rows=64 the slice reads the *first* half-slab and - # the rest of the buffer's 128 rows is invisible to this atom. For - # Layout D + frag_rows=128 the slice covers all 128 physical lanes via - # two PTX issues (row=0 + row=16). assert analyzer.can_prove_equal(tmem_st[0], 0) assert analyzer.can_prove_equal(tmem_extent[0], frag_rows) - assert analyzer.can_prove_equal(tmem_extent[1], width_elems) + # local and tmem column slices must match and divide the atom's full width. + assert analyzer.can_prove_equal(local_extent[1], tmem_extent[1]) + slice_w = int(local_extent[1]) + assert width_elems % slice_w == 0, f"slice width {slice_w} must divide atom width {width_elems}" + num_eff = num * slice_w // width_elems + regs_eff = regs_per_thread_per_slab * slice_w // width_elems del tmem_rows # only used for the structural check above col_off = tmem_st[1] @@ -410,13 +414,18 @@ def impl(): # for the register-pointer arguments of the PTX builtin. local_storage = local_buf.view(per_thread_elems, layout=TileLayout(S[per_thread_elems])) local_32b = local_storage.view("uint32") - local_reg_base = local_col_off_elems // elem_per_32b + # Register offset of the column slice within each slab. The old + # ``local_col_off // elem_per_32b`` is only correct when the slice IS the + # full atom; in general consecutive columns advance registers at the rate + # (regs_per_thread_per_slab / width_elems). For a full-atom load the + # offset is 0 either way, so existing callers are unaffected. + local_reg_base = local_col_off_elems * regs_per_thread_per_slab // width_elems for slab in range(n_slabs): reg_base = slab * regs_per_thread_per_slab op( tmem_buf.allocated_addr[0], - *[local_32b[local_reg_base + reg_base + i] for i in range(regs_per_thread_per_slab)], # noqa: E501 - shape=shape, num=num, row=slab * 16, col=col_off_32b, + *[local_32b[local_reg_base + reg_base + i] for i in range(regs_eff)], + shape=shape, num=num_eff, row=slab * 16, col=col_off_32b, ) # fmt: on return impl diff --git a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/reg.py b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/reg.py index eddf9f3d8eac..64d77a21cf69 100644 --- a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/reg.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/reg.py @@ -45,8 +45,10 @@ from ..copy._common import _carve_tail, _verify_s_tail_contig from ..layout_utils import get_sublayout_from_region, layout_signature from ._common import ( + _TID_AXIS_FOR_SCOPE, _all_threads_active, _tensor_shape_of, + _thread_cnt, align_operands_to_anchor, buffer_regions, compute_dtype_of, @@ -67,6 +69,68 @@ def _validate_anchor_layout(anchor_br) -> tuple[bool, str | None]: return True, None +def _validate_scope_level_anchor(anchor_br, sctx: DispatchContext) -> tuple[bool, str | None]: + """For warp/warpgroup/cta scope, require dst to be scope-level: after + canonicalizing with the target its thread axes are the scope's intra-thread + axis (laneid/tid_in_wg/tx) and, sorted by stride, tile a complete ``T:1`` + chain over all ``T`` threads of the scope. Rejects thread-local ``.local()`` + views; thread scope is exempt. + """ + scope = sctx.scope_kind + if scope == "thread": + return True, None + expected_axis = _TID_AXIS_FOR_SCOPE.get(scope) + if expected_axis is None: + return True, None + expected_cnt = _thread_cnt(sctx) + + # Canonicalize the sliced anchor with the target so warp/lane axes fuse. + st, ext = get_st_extent(anchor_br) + sliced = get_sublayout_from_region(anchor_br.buffer.layout, anchor_br.buffer.shape, st, ext) + with sctx.target: + canon = sliced.canonicalize() if hasattr(sliced, "canonicalize") else sliced + shard = getattr(canon, "shard", None) + if shard is None: + return False, f"{scope}-scope op operand layout is not a TileLayout after slicing" + + thread_iters = [it for it in shard if it.axis.is_thread()] + if not thread_iters: + return ( + False, + f"{scope}-scope op needs a {scope}-level operand whose layout carries " + f"thread axes ({expected_axis} composing to {expected_cnt}:1); got a " + f"thread-local view with no thread axes — pass the {scope}-level tensor, " + f"not its `.local()` (per-thread) view", + ) + bad = sorted({it.axis.name for it in thread_iters if it.axis.name != expected_axis}) + if bad: + return ( + False, + f"{scope}-scope op operand carries thread axes {bad}; after " + f"canonicalization a {scope}-level layout must use only {expected_axis!r}", + ) + # Sorted by stride the thread iters must tile a complete chain 1, e0, + # e0*e1, ... up to the scope thread count — i.e. cover all T threads with + # no gap or overlap (extents alone would miss gaps/overlaps). + running = 1 + for it in sorted(thread_iters, key=lambda i: int(i.stride)): + stride, extent = int(it.stride), int(it.extent) + if stride != running: + return ( + False, + f"{scope}-scope op operand thread axes do not tile a complete " + f"{expected_cnt}:1 (sorted by stride: expected {running}, got {stride})", + ) + running *= extent + if running != expected_cnt: + return ( + False, + f"{scope}-scope op operand thread axes span {running} threads, not the " + f"full {expected_cnt} of the {scope}", + ) + return True, None + + def _check_layout_operands_agree(plan) -> tuple[bool, str | None]: """Replica sigs must match across non-trivial-layout operands. @@ -133,6 +197,9 @@ def check(op_call: TilePrimitiveCall, sctx: DispatchContext) -> tuple[bool, str ok3, reason3 = _validate_anchor_layout(anchor) if not ok3: return False, reason3 + ok_scope, reason_scope = _validate_scope_level_anchor(anchor, sctx) + if not ok_scope: + return False, reason_scope # Shape compat (NumPy-style broadcast): anchor's tensor shape is the # result shape; every operand must broadcast TO anchor. anchor_tshape = _tensor_shape_of(anchor.region) diff --git a/python/tvm/backend/cuda/script.py b/python/tvm/backend/cuda/script.py index a1148f9b67ee..a46aa7e7e472 100644 --- a/python/tvm/backend/cuda/script.py +++ b/python/tvm/backend/cuda/script.py @@ -53,6 +53,8 @@ def __init__(self): self.stmatrix = _op_wrapper(_cuda_op.ptx_stmatrix) self.setmaxnreg: Callable[..., Any] = _op_wrapper(_cuda_op.ptx_setmaxnreg) self.elect_sync: Callable[..., Any] = _op_wrapper(_cuda_op.ptx_elect_sync) + self.clc_try_cancel = _op_wrapper(_cuda_op.ptx_clc_try_cancel) + self.clc_query_cancel = _op_wrapper(_cuda_op.ptx_clc_query_cancel) self.fetch_register: Callable[..., Any] = _op_wrapper(_cuda_op.ptx_fetch_register) self.ld = _op_wrapper(_cuda_op.ptx_ld) self.ld_acquire = _op_wrapper(_cuda_op.ptx_ld_acquire) @@ -276,6 +278,9 @@ def __init__(self): self.init = _op_wrapper(_cuda_op.ptx_mbarrier_init) self.try_wait = _op_wrapper(_cuda_op.ptx_mbarrier_try_wait) self.try_wait_once = _op_wrapper(_cuda_op.ptx_mbarrier_try_wait_once) + self.try_wait_acquire_cluster = _op_wrapper( + _cuda_op.ptx_mbarrier_try_wait_acquire_cluster + ) self.arrive = MbarrierArriveNamespace() @@ -284,6 +289,7 @@ class MbarrierArriveNamespace: def __init__(self): self.expect_tx = _op_wrapper(_cuda_op.ptx_mbarrier_arrive_expect_tx) + self.cluster_count = _op_wrapper(_cuda_op.ptx_mbarrier_arrive_cluster_count) def __call__(self, *args, **kwds): return _op_wrapper(_cuda_op.ptx_mbarrier_arrive)(*args, **kwds) diff --git a/python/tvm/support/nvcc.py b/python/tvm/support/nvcc.py index ea5939fceffc..b421042fb30b 100644 --- a/python/tvm/support/nvcc.py +++ b/python/tvm/support/nvcc.py @@ -32,7 +32,7 @@ def compile_cuda( - code, target_format=None, arch=None, options=None, path_target=None, compiler="nvcc" + code, target_format=None, arch=None, options=None, path_target=None, compiler="nvrtc" ): """Compile CUDA code with NVCC or NVRTC. @@ -54,7 +54,7 @@ def compile_cuda( Output file. compiler : str, optional - Compiler backend: "nvcc" or "nvrtc". + Compiler backend: "nvrtc" (default) or "nvcc". This can be set by the TVM_CUDA_COMPILE_MODE environment variable. Returns @@ -191,7 +191,7 @@ def _compile_cuda_nvcc( "--expt-extended-lambda", "--use_fast_math", "--ptxas-options=-v", # printing out number of registers - "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", # printing out number of registers # noqa: E501 + f"--ptxas-options=--verbose,--register-usage-level={os.environ.get('TVM_CUDA_PTXAS_REG_LEVEL', '10')},--warn-on-local-memory-usage", # noqa: E501 ] major, _ = parse_compute_version(get_target_compute_version(Target.current(allow_none=True))) @@ -342,14 +342,23 @@ def _compile_cuda_nvrtc( line for line in code.splitlines() if line.strip() not in headers_to_strip ) - # NVRTC compiles device code and does not include the host-side cuda.h. - # CUtensorMap is a host-side structure, to reference and use it in device code, - # we must forward-declare it for NVRTC. + # NVRTC compiles device code and does not include the host-side cuda.h + # (it is guarded behind ``#ifndef __CUDACC_RTC__`` in generated code and is + # stripped above), so the complete ``CUtensorMap_st`` layout that cuda.h + # normally provides is missing. TMA kernels take ``CUtensorMap`` by value as + # ``__grid_constant__`` params, which requires the complete type. Define the + # ``CUtensorMap_st`` tag with cuda.h's layout (64-byte aligned, 128 bytes) + # plus the typedef alias. This is compatible with cccl's ````, + # which only forward-declares ``struct CUtensorMap_st;`` and re-typedefs the + # alias (a redundant typedef to the same type is legal in C++); defining the + # tag rather than ``struct CUtensorMap`` avoids the previous redefinition + # clash with that header. if "CUtensorMap" in code_filtered: code_filtered = ( - "struct __align__(128) CUtensorMap {\n" + "struct alignas(64) CUtensorMap_st {\n" " unsigned long long opaque[16];\n" - "};\n\n" + code_filtered + "};\n" + "typedef struct CUtensorMap_st CUtensorMap;\n\n" + code_filtered ) # Add standard type definitions and compatibility macros that NVRTC doesn't provide. @@ -371,6 +380,13 @@ def _compile_cuda_nvrtc( #define __volatile__ volatile #endif +// NVRTC does not pull in the host , so INFINITY is undefined. Provide it +// from libcu++ (same float +inf value nvcc's yields). +#include +#ifndef INFINITY +#define INFINITY (::cuda::std::numeric_limits::infinity()) +#endif + """ code_filtered = nvrtc_preamble + code_filtered @@ -406,6 +422,9 @@ def _compile_cuda_nvrtc( compile_opts = [ f"--gpu-architecture={arch}".encode(), b"-default-device", + # nvcc enables 128-bit integers by default on Linux; NVRTC requires the + # flag to be passed explicitly for kernels that use __int128_t. + b"--device-int128", ] if use_nvshmem: @@ -469,6 +488,21 @@ def _compile_cuda_nvrtc( ] ) + # Define the vector-deprecation silencing macros as no-ops for every NVRTC + # compile. These live in vector_types.h, which the fp4/fp6/fp8 headers use + # but do not include; depending on the include chain NVRTC pulls in, the + # macro can be left undefined and trigger a bogus "declaration has no storage + # class" error. Defining them empty is harmless (they only gate host-side + # deprecation warnings) and matches what the NVSHMEM path already did. + compile_opts.extend( + [ + b"-D__NV_SILENCE_DEPRECATION_BEGIN=", + b"-D__NV_SILENCE_DEPRECATION_END=", + b"-D__NV_SILENCE_HOST_DEPRECATION_BEGIN=", + b"-D__NV_SILENCE_HOST_DEPRECATION_END=", + ] + ) + compile_opts.extend( [ b"-U__CUDA_NO_HALF_OPERATORS__", @@ -481,6 +515,24 @@ def _compile_cuda_nvrtc( ] ) + # Mirror the nvcc path's ptxas options. register-usage-level drives ptxas + # register allocation / instruction scheduling and is perf-relevant (FA4 was + # tuned around it, hence the env-driven default); -v and + # --warn-on-local-memory-usage are diagnostic. NVRTC rejects -O3 and + # --register-usage-level as top-level flags but forwards them to its internal + # ptxas via --ptxas-options (ptxas already defaults to -O3). NB: unlike nvcc, + # NVRTC does not comma-split --ptxas-options, so each ptxas flag must be its + # own entry. The nvcc-only --expt-relaxed-constexpr / --expt-extended-lambda + # have no NVRTC equivalent and are intentionally not mirrored. + reg_level = os.environ.get("TVM_CUDA_PTXAS_REG_LEVEL", "10") + compile_opts.extend( + [ + b"--ptxas-options=-v", + f"--ptxas-options=--register-usage-level={reg_level}".encode(), + b"--ptxas-options=--warn-on-local-memory-usage", + ] + ) + # Add user-provided options, filtering out nvcc-specific flags that nvrtc doesn't support if options: nvcc_only_prefixes = ( @@ -802,7 +854,7 @@ def tvm_callback_cuda_compile(code): Compile CUDA code using the configured backend (nvcc or nvrtc). This callback is invoked by TVM's C++ backend during CUDA module compilation. - By default, uses nvcc to generate fatbin. The current target is fetched + By default, uses nvrtc to generate cubin. The current target is fetched inside the callback (via ``tvm.target.Target.current(allow_none=True)``) so the caller does not need to push/pop a target scope around the invocation. @@ -810,9 +862,9 @@ def tvm_callback_cuda_compile(code): Environment Variables --------------------- TVM_CUDA_COMPILE_MODE : str - Compiler backend: "nvcc" (default) or "nvrtc" - - "nvcc": Use nvcc subprocess, generates fatbin + Compiler backend: "nvrtc" (default) or "nvcc" - "nvrtc": Use NVRTC via cuda-bindings for faster JIT, generates cubin + - "nvcc": Use nvcc subprocess, generates fatbin TVM_KERNEL_DUMP : str If set, dump generated CUDA/intermediate files and append "-lineinfo" so profilers can correlate SASS back to the dumped source. @@ -830,7 +882,7 @@ def tvm_callback_cuda_compile(code): # The current Target is fetched inside compile_cuda via # tvm.target.Target.current(allow_none=True) when arch is unset; the # caller no longer needs to push/pop a target scope. - compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvcc").lower() + compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvrtc").lower() if compiler == "nvrtc": return compile_cuda(code, target_format="cubin", compiler="nvrtc") diff --git a/python/tvm/tirx/script/builder/external_kernel.py b/python/tvm/tirx/script/builder/external_kernel.py index c1f5d5871655..d56ed9ea0384 100644 --- a/python/tvm/tirx/script/builder/external_kernel.py +++ b/python/tvm/tirx/script/builder/external_kernel.py @@ -159,7 +159,7 @@ def compile_to_device_module( # pylint: disable=arguments-differ target_format = "cubin" if use_nvshmem else "ptx" output_path = f"{temp_dir}/{kernel_name}.{target_format}" - compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvcc") + compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvrtc") nvcc.compile_cuda( source_code, target_format=target_format, diff --git a/src/backend/cuda/op/target_builtin.cc b/src/backend/cuda/op/target_builtin.cc index 005fe5b32263..353c04b501ec 100644 --- a/src/backend/cuda/op/target_builtin.cc +++ b/src/backend/cuda/op/target_builtin.cc @@ -152,6 +152,9 @@ TIRX_DEFINE_BUILTIN_FUNC(ptx_mbarrier_arrive_expect_tx) TIRX_DEFINE_BUILTIN_FUNC(ptx_mbarrier_try_wait) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); +TIRX_DEFINE_BUILTIN_FUNC(ptx_mbarrier_try_wait_acquire_cluster) + .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); + TIRX_DEFINE_BUILTIN_FUNC(ptx_bar_arrive) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); @@ -497,6 +500,8 @@ const DeviceIntrinsicRegistration kDeviceIntrinsics[] = { TIRX_DEVICE_INTRIN_ALIAS(ptx_bar_sync, ptx, kOpaque), TIRX_DEVICE_INTRIN_ALIAS(ptx_barrier_cluster_arrive, ptx, kOpaque), TIRX_DEVICE_INTRIN_ALIAS(ptx_barrier_cluster_wait, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_clc_query_cancel, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_clc_try_cancel, ptx, kOpaque), TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async, ptx, kOpaque), TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_bulk, ptx, kOpaque), TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_bulk_commit_group, ptx, kOpaque), @@ -540,6 +545,7 @@ const DeviceIntrinsicRegistration kDeviceIntrinsics[] = { TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_init, ptx, kOpaque), TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_test_wait_parity, ptx, kOpaque), TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_try_wait, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_try_wait_acquire_cluster, ptx, kOpaque), TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_try_wait_once, ptx, kOpaque), TIRX_DEVICE_INTRIN_ALIAS(ptx_mma, ptx, kOpaque), TIRX_DEVICE_INTRIN_ALIAS(ptx_mma_legacy, ptx, kOpaque), diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 88a28ebccb5f..f32dcdde11fd 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -133,6 +133,8 @@ void CodeGenLLVM::Init(const std::string& module_name, LLVMTarget* llvm_target, builder_.reset(new IRBuilder(*ctx)); module_.reset(new llvm::Module(module_name, *ctx)); md_builder_.reset(new llvm::MDBuilder(*ctx)); + functions_.clear(); + function_symbol_owners_.clear(); // types t_void_ = llvm::Type::getVoidTy(*ctx); t_void_p_ = llvmGetPointerTo(llvm::Type::getInt8Ty(*ctx), GetGlobalAddressSpace()); @@ -260,6 +262,21 @@ llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, cons llvm::FunctionType::get(GetLLVMType(func->ret_type), param_types, false); auto [symbol_name, linkage_type] = GetLinkage(gvar, func); + if (auto it = function_symbol_owners_.find(symbol_name); it != function_symbol_owners_.end()) { + constexpr const char* kFFISymbolPrefix = "__tvm_ffi_"; + std::string user_symbol = symbol_name; + if (user_symbol.rfind(kFFISymbolPrefix, 0) == 0) { + user_symbol = user_symbol.substr(std::char_traits::length(kFFISymbolPrefix)); + } + TVM_FFI_THROW(InternalError) << "Duplicate PrimFunc global_symbol '" << user_symbol + << "' in LLVM codegen: IRModule keys '" << it->second + << "' and '" << gvar->name_hint + << "' both lower to the same exported symbol '" << symbol_name + << "'. " + << "Each exposed PrimFunc in one IRModule must have a unique " + "global_symbol."; + } + function_symbol_owners_[symbol_name] = gvar->name_hint; auto function = module_->getFunction(MakeStringRef(symbol_name)); if (function == nullptr) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 8526b3f642df..08396d596daa 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -547,6 +547,9 @@ class CodeGenLLVM : public ExprFunctor, // that function. std::unordered_map functions_; + // Map from the generated LLVM function symbol to the GlobalVar that owns it. + std::unordered_map function_symbol_owners_; + // Whether current function is restricted bool is_restricted_{true}; // The analyzer information diff --git a/src/tirx/ir/layout/tile_slice.cc b/src/tirx/ir/layout/tile_slice.cc index 3f4db4837964..ce1809ae9907 100644 --- a/src/tirx/ir/layout/tile_slice.cc +++ b/src/tirx/ir/layout/tile_slice.cc @@ -144,7 +144,11 @@ ffi::Optional SlicePerGroup(TileLayout layout, PrimExpr begin, PrimE ffi::Optional TileLayoutNode::Slice(const Array& shape, const Region& region) const { arith::Analyzer analyzer; - auto [grouped_layout, seps] = Group(ffi::GetRef(this), shape); + // Canonicalize the whole layout first so scope fusion (e.g. wid_in_wg+laneid + // -> tid_in_wg) runs globally; otherwise grouping can split sibling thread + // axes and SlicePerGroup's per-group fusion leaves an ill-formed mix. + TileLayout canon = this->Canonicalize().as().value(); + auto [grouped_layout, seps] = Group(canon, shape); std::vector new_shard; ffi::Map new_offset; for (size_t i = 0; i < seps.size() - 1; ++i) { diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index 7c093f9be27b..624d587b825f 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -30,6 +30,45 @@ from tvm.testing import env +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") +def test_duplicate_primfunc_global_symbol_diagnostic(): + @I.ir_module(s_tir=True) + class Module: + @T.prim_func(s_tir=True) + def first_unique_key(A: T.Buffer((1,), "float32")): + T.func_attr({"global_symbol": "dup_symbol", "tirx.noalias": True}) + A[0] = T.float32(1) + + @T.prim_func(s_tir=True) + def second_unique_key(A: T.Buffer((1,), "float32")): + T.func_attr({"global_symbol": "dup_symbol", "tirx.noalias": True}) + A[0] = T.float32(2) + + with pytest.raises( + tvm.error.InternalError, match="Duplicate PrimFunc global_symbol 'dup_symbol'" + ) as err: + tvm.compile(Module, target="llvm") + assert "first_unique_key" in str(err.value) + assert "second_unique_key" in str(err.value) + + +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") +def test_unique_primfunc_global_symbols_compile(): + @I.ir_module(s_tir=True) + class Module: + @T.prim_func(s_tir=True) + def first_unique_key(A: T.Buffer((1,), "float32")): + T.func_attr({"global_symbol": "dup_symbol_a", "tirx.noalias": True}) + A[0] = T.float32(1) + + @T.prim_func(s_tir=True) + def second_unique_key(A: T.Buffer((1,), "float32")): + T.func_attr({"global_symbol": "dup_symbol_b", "tirx.noalias": True}) + A[0] = T.float32(2) + + tvm.compile(Module, target="llvm") + + @pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_intrin(): @I.ir_module(s_tir=True) diff --git a/tests/python/tirx/codegen/test_codegen_cuda.py b/tests/python/tirx/codegen/test_codegen_cuda.py index f253d6d375c6..521a72f6d732 100644 --- a/tests/python/tirx/codegen/test_codegen_cuda.py +++ b/tests/python/tirx/codegen/test_codegen_cuda.py @@ -21,6 +21,7 @@ import tvm import tvm.testing from tvm.script import tirx as T +from tvm.testing import env DEV = tvm.device("cuda") @@ -118,6 +119,8 @@ def main(A: T.Buffer((1,), "uint64")): assert "*(void* *)" not in src +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cuda_atomic_add(): @T.prim_func def main(A: T.Buffer((1,), "int32"), B: T.Buffer((1,), "float32")): @@ -442,6 +445,8 @@ def main(A: T.Buffer((16, 16), "int32")): assert "tvm_builtin_cuda_atomic_cas" in src +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cuda_func_call(): def test_add_one(): add_one = """ @@ -497,6 +502,8 @@ def main(a: T.Buffer((16, 16), "int32")): test_print() +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_warp_shuffle_xor_sync(): # fmt: off @T.prim_func @@ -532,6 +539,8 @@ def func(A_ptr: T.handle): np.testing.assert_allclose(A.numpy(), A_ref) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("cp_size", [4, 8, 16]) @pytest.mark.parametrize("cache_hint", ["", "evict_last"]) @pytest.mark.parametrize("prefetch_size", [-1, 64, 128, 256]) @@ -575,6 +584,8 @@ def main(A: T.Buffer((N), "float16")): print(src) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("trans", [False, True]) @pytest.mark.parametrize("num", [1, 2, 4]) def test_ptx_ldmatrix(trans, num): diff --git a/tests/python/tirx/codegen/test_codegen_nvshmem.py b/tests/python/tirx/codegen/test_codegen_nvshmem.py index ff9f17170ddd..d3869077428e 100644 --- a/tests/python/tirx/codegen/test_codegen_nvshmem.py +++ b/tests/python/tirx/codegen/test_codegen_nvshmem.py @@ -28,6 +28,7 @@ from tvm.runtime import disco as di from tvm.script import tirx as T from tvm.support.popen_pool import PopenWorker +from tvm.testing import env NUM_WORKERS = 4 @@ -61,6 +62,8 @@ def create_nvshmem_array(sess, shape, dtype, init_data_fn=None, zero_out=True): return arr +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.skip(reason="nvshmem doesn't work with pytest") def test_codegen_nvshmem(): def _test_func(): diff --git a/tests/python/tirx/codegen/test_cuda_copy.py b/tests/python/tirx/codegen/test_cuda_copy.py index cb08f4247318..047eb1f12ca3 100644 --- a/tests/python/tirx/codegen/test_cuda_copy.py +++ b/tests/python/tirx/codegen/test_cuda_copy.py @@ -21,6 +21,7 @@ import tvm from tvm.script import tirx as T +from tvm.testing import env DEV = tvm.cuda(0) TARGET = tvm.target.Target("cuda") @@ -34,6 +35,8 @@ def _build_and_run(func, *np_args): return (*tuple(a.numpy() for a in rt_args), mod) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_copy_128b(): """copy_128b: copies 16 bytes (4 float32 elements) via uint4 load/store.""" @@ -63,6 +66,8 @@ def func(out_ptr: T.handle): assert "tvm_builtin_copy_128b" in mod.mod.imports[0].inspect_source() +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_copy_64b(): """copy_64b: copies 8 bytes (2 float32 elements) via uint2 load/store.""" @@ -92,6 +97,8 @@ def func(out_ptr: T.handle): assert "tvm_builtin_copy_64b" in mod.mod.imports[0].inspect_source() +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_copy_32b(): """copy_32b: copies 4 bytes (1 float32 element) via unsigned int load/store.""" @@ -121,6 +128,8 @@ def func(out_ptr: T.handle): assert "tvm_builtin_copy_32b" in mod.mod.imports[0].inspect_source() +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_copy_16b(): """copy_16b: copies 2 bytes (1 float16 element) via unsigned short load/store.""" @@ -150,6 +159,8 @@ def func(out_ptr: T.handle): assert "tvm_builtin_copy_16b" in mod.mod.imports[0].inspect_source() +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_copy_8b(): """copy_8b: copies 1 byte (1 uint8 element) via unsigned char load/store.""" diff --git a/tests/python/tirx/codegen/test_cuda_cta_reduce.py b/tests/python/tirx/codegen/test_cuda_cta_reduce.py index 51b8f1099a91..bf07da1b6798 100644 --- a/tests/python/tirx/codegen/test_cuda_cta_reduce.py +++ b/tests/python/tirx/codegen/test_cuda_cta_reduce.py @@ -21,6 +21,7 @@ import tvm from tvm.script import tirx as T +from tvm.testing import env DEV = tvm.cuda(0) TARGET = tvm.target.Target("cuda") @@ -35,6 +36,8 @@ def _build_and_run(func, n): return out.numpy(), mod +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cta_sum_4_warps(): """CTA sum with 4 warps (128 threads): all threads get the same sum.""" NUM_WARPS = 4 @@ -61,6 +64,8 @@ def func(out_ptr: T.handle): assert "cta_reduce_sum_4" in mod.mod.imports[0].inspect_source() +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cta_sum_8_warps(): """CTA sum with 8 warps (256 threads).""" NUM_WARPS = 8 @@ -86,6 +91,8 @@ def func(out_ptr: T.handle): np.testing.assert_allclose(result, np.full(N, expected)) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cta_max_4_warps(): """CTA max with 4 warps: all threads get the maximum value.""" NUM_WARPS = 4 @@ -110,6 +117,8 @@ def func(out_ptr: T.handle): np.testing.assert_allclose(result, np.full(N, float(N))) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cta_min_4_warps(): """CTA min with 4 warps: all threads get the minimum value.""" NUM_WARPS = 4 @@ -134,6 +143,8 @@ def func(out_ptr: T.handle): np.testing.assert_allclose(result, np.full(N, 1.0)) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cta_sum_1_warp(): """CTA sum with 1 warp: degenerates to a pure warp reduce.""" NUM_WARPS = 1 @@ -159,6 +170,8 @@ def func(out_ptr: T.handle): np.testing.assert_allclose(result, np.full(N, expected)) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("num_warps", [1, 2, 4, 8, 16]) def test_cta_sum_all_warp_counts(num_warps): """Parametric test: cta_sum with various warp counts.""" diff --git a/tests/python/tirx/codegen/test_cuda_warp_reduce.py b/tests/python/tirx/codegen/test_cuda_warp_reduce.py index df568a95e483..e5167a055c9a 100644 --- a/tests/python/tirx/codegen/test_cuda_warp_reduce.py +++ b/tests/python/tirx/codegen/test_cuda_warp_reduce.py @@ -21,6 +21,7 @@ import tvm from tvm.script import tirx as T +from tvm.testing import env DEV = tvm.cuda(0) TARGET = tvm.target.Target("cuda") @@ -35,6 +36,8 @@ def _build_and_run(func, n=32): return out.numpy(), mod +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_warp_sum_full(): """Full warp sum (width=32): each lane gets the sum of all 32 values.""" @@ -57,6 +60,8 @@ def func(out_ptr: T.handle): assert "warp_reduce_sum_32" in mod.mod.imports[0].inspect_source() +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_warp_sum_partial_8(): """Partial warp sum (width=8): 4 groups of 8 lanes, each group sums independently.""" @@ -85,6 +90,8 @@ def func(out_ptr: T.handle): np.testing.assert_allclose(result, expected) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_warp_max_partial_4(): """Partial warp max (width=4): 8 groups of 4 lanes.""" @@ -109,6 +116,8 @@ def func(out_ptr: T.handle): np.testing.assert_allclose(result, expected) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_warp_min_full(): """Full warp min (width=32).""" @@ -129,6 +138,8 @@ def func(out_ptr: T.handle): np.testing.assert_allclose(result, np.full(32, 1.0)) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_warp_sum_partial_2(): """Smallest partial warp sum (width=2): 16 pairs of adjacent lanes.""" @@ -155,6 +166,8 @@ def func(out_ptr: T.handle): np.testing.assert_allclose(result, expected) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("width", [2, 4, 8, 16, 32]) def test_warp_sum_all_widths(width): """Parametric test: warp_sum with every valid width.""" diff --git a/tests/python/tirx/conftest.py b/tests/python/tirx/conftest.py new file mode 100644 index 000000000000..fb8ba62f4f41 --- /dev/null +++ b/tests/python/tirx/conftest.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Suite-level hardware gate for the tirx tests. + +The tirx kernels and codegen paths target Blackwell (sm_100a) — they emit +PTX/SASS (tcgen05, tmem, cp.async ``.async`` modifiers, fp8 conversions, ...) +that ptxas/NVRTC reject for older targets, and many tests execute on the +device. Running the suite on a CPU-only node or a pre-sm_100 GPU therefore +fails at compile/run time rather than skipping. Gate the whole directory on a +real sm_100a device so it skips cleanly where the hardware is absent and runs +in full where it is present. +""" + +import pytest + +from tvm.testing import env + + +def pytest_collection_modifyitems(config, items): + if env.has_cuda_compute(10): + return + skip = pytest.mark.skip( + reason="tirx suite requires a CUDA compute capability 10.0 (sm_100a) device" + ) + for item in items: + item.add_marker(skip) diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py index 75faf61366fe..1824b41eae43 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py @@ -32,6 +32,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env # Force the fallback dispatch to register before any test compiles a kernel. # Without this import, in fresh pytest workers the `copy/fallback` variant @@ -128,6 +129,8 @@ def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: return kernel +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") @pytest.mark.parametrize( "scope,n_threads,shape,why", [ @@ -158,6 +161,8 @@ def test_fallback_round_trip(scope, n_threads, shape, why): np.testing.assert_array_equal(B.numpy(), A_np) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_fallback_thread_scope(): """``T.thread()`` — single thread, no gate. Either ``gmem_smem`` picks it up (n_elements % 1 == 0) or ``fallback`` does — both end up emitting diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py index dc5a46a751ec..c31ca79db918 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py @@ -103,6 +103,8 @@ def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: ] +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") @pytest.mark.parametrize( "scope,n_threads,shape", [pytest.param(*t, id=f"{t[0]}-{t[1]}-{'x'.join(map(str, t[2]))}") for t in TASKS], @@ -194,6 +196,8 @@ def test_gmem_smem_roundtrip(scope, n_threads, shape, dtype): ), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") @pytest.mark.parametrize( "dtype", ["int8", "float8_e4m3fn", "float8_e5m2", "float16", "bfloat16", "float32"] ) diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_reg.py b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_reg.py index 451622530318..26c4d5de9b18 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_reg.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_reg.py @@ -35,6 +35,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx.layout import S, TileLayout, laneid, tid_in_wg, tx @@ -228,6 +229,8 @@ def _expected(shape, dtype): return out +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") @pytest.mark.parametrize("non_r_scope", ["shared", "global"]) @pytest.mark.parametrize( "scope,n_threads,k", @@ -287,6 +290,8 @@ def test_reg_roundtrip(scope, n_threads, k, dtype, non_r_scope): ), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") @pytest.mark.parametrize( "dtype", ["int8", "float8_e4m3fn", "float8_e5m2", "float16", "bfloat16", "float32"] ) diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_ldgsts.py b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_ldgsts.py index b4d54d2b4109..96f92832532a 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_ldgsts.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_ldgsts.py @@ -24,6 +24,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx.layout import S, TileLayout @@ -65,6 +66,8 @@ ), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize( "dtype", ["int8", "float8_e4m3fn", "float8_e5m2", "float16", "bfloat16", "float32"] ) diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py index 0f910a43766d..55e32339c72d 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py @@ -24,10 +24,13 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx.layout import S, TCol, TileLayout, TLane from tvm.tirx.layout import tid_in_wg as axis_tid_in_wg +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize("dtype", ["float16", "float32"]) @pytest.mark.parametrize("width_32b", [4, 8, 16, 32]) def test_copy_tmem2reg_async(dtype, width_32b): @@ -132,6 +135,8 @@ def copy_async_test(A_ptr: T.handle, B_ptr: T.handle) -> None: # ---------------------------------------------------------------------------- +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize("dtype", ["uint8", "float16", "float32"]) @pytest.mark.parametrize("width_32b", [2, 4, 8, 16, 32, 64, 128]) @pytest.mark.parametrize("offset_32b", [0, 3, 10]) @@ -224,6 +229,8 @@ def copy_sync(A_ptr: T.handle, B_ptr: T.handle) -> None: np.testing.assert_allclose(B.numpy(), A_np) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize("dtype", ["float16", "float32"]) @pytest.mark.parametrize("width_32b", [4, 8, 16, 32]) @pytest.mark.parametrize("local_offset_32b", [0, 2, 4]) diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py index 420935946028..aac93c0252c7 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py @@ -43,6 +43,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx.layout import ( S, TCol, @@ -152,6 +153,8 @@ def _expected_reg_value_16b( # -------------------------------------------------------------------------- +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize("shape", list(_SHAPE_REPS)) @pytest.mark.parametrize("rep", [1, 2, 4, 8, 16, 32]) # subset; full reps below @pytest.mark.parametrize("dtype", ["float32"]) @@ -162,6 +165,8 @@ def test_tcgen05_ld_16xnb_load_fp32(shape, rep, dtype): _run_load_test(shape, rep, dtype) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize( "shape, rep", [ @@ -175,6 +180,8 @@ def test_tcgen05_ld_16xnb_load_fp32_large_rep(shape, rep): _run_load_test(shape, rep, "float32") +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize("shape", list(_SHAPE_REPS)) @pytest.mark.parametrize("rep", [1, 2, 4, 8, 16, 32]) @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @@ -201,6 +208,8 @@ def test_tcgen05_16xnb_roundtrip_16b(shape, rep, dtype): # We only need to spot-check that the dispatch fires correctly and the per- # thread reg ↔ TMEM mapping round-trips bit-exactly — the M=64 sweep above # already covers the (lane, reg) decomposition, so a sparse rep set suffices. +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize("shape", ["16x64b", "16x128b", "16x256b"]) @pytest.mark.parametrize("rep", [1, 2, 4]) @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @@ -214,6 +223,8 @@ def test_tcgen05_16xnb_roundtrip_16b_M128(shape, rep, dtype): # with the scatter-encoded TileLayout that ``tmem_datapath_layout("F", ...)`` # produces. ``.16x*b`` M=64 PTX has the matching scatter built in, so the # round-trip is bit-exact in the same way as Layout D + M=64. +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize("shape", ["16x64b", "16x128b", "16x256b"]) @pytest.mark.parametrize("rep", [1, 2, 4]) @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @@ -639,6 +650,8 @@ def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: # -------------------------------------------------------------------------- +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize("shape", list(_SHAPE_REPS)) @pytest.mark.parametrize("rep", [1, 4, 16]) @pytest.mark.parametrize("dtype", ["float32"]) @@ -853,5 +866,136 @@ def kernel(A_ptr: T.handle) -> None: ) +# -------------------------------------------------------------------------- +# Test 3: column-slice loads of a wider frag +# +# An epilogue may allocate one wide ``(128, K)`` frag and load it from TMEM in +# EPI_TILE-wide column chunks (``frag[:, c:c+w]``) so all loads are in flight +# before a single ``wait.ld``. The ``.16x*b`` dispatch must emit each slice as +# its own atom (``num_eff`` derived from the slice width) at the correct +# per-slab register offset. We verify this is *bit-exact identical* to one +# full-width load of the same frag — which the sweeps above already validate +# against the layout-derived expectation. M=128 here exercises the 2-slab path +# (the slice's two slabs live ``regs_per_thread_per_slab`` apart, not adjacent). +# -------------------------------------------------------------------------- + + +def _run_sliced_vs_full_load(shape, full_rep, n_chunks): + dtype = "float32" + K_cols_fp32 = _COL_FACTOR_FP32[shape] * full_rep + assert K_cols_fp32 % n_chunks == 0 + chunk_elem = K_cols_fp32 // n_chunks # fp32: elem == fp32 col + frag_rows = 128 # M=128 => 2 slabs + per_thread_elems = _REGS_FACTOR[shape] * full_rep * 2 # *2 for the second slab + + tmem_col_width_32b = max(32, _next_pow2(K_cols_fp32)) + stage_width_elem = tmem_col_width_32b + CHUNK_FP32 = 128 + n_stage = tmem_col_width_32b // CHUNK_FP32 if tmem_col_width_32b > CHUNK_FP32 else 1 + stage_w = tmem_col_width_32b if n_stage == 1 else CHUNK_FP32 + VEC_LEN = 4 # 128-bit / fp32 + + atom_view = tcgen05_atom_layout(shape, (frag_rows, K_cols_fp32), dtype) + stage_view = TileLayout(S[(128, stage_w) : (1 @ axis_tid_in_wg, 1)]) + + @T.prim_func + def kernel(A_ptr: T.handle, Bf_ptr: T.handle, Bs_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (128, stage_width_elem), dtype) + Bf = T.match_buffer(Bf_ptr, (128, per_thread_elems), dtype) # full-load dump + Bs = T.match_buffer(Bs_ptr, (128, per_thread_elems), dtype) # sliced-load dump + A_flat = A.view(-1) + + T.device_entry() + warp_id = T.warp_id([4]) + T.cta_id([2]) + wg_id = T.warpgroup_id([1]) + T.warp_id_in_wg([4]) + T.lane_id([32]) + tid_in_wg = T.thread_id([128]) + + tmem_addr = T.alloc_shared([1], "uint32") + if wg_id == 0: + if warp_id == 0: + T.ptx.tcgen05.alloc(T.address_of(tmem_addr), n_cols=tmem_col_width_32b, cta_group=1) + T.tvm_storage_sync("shared") + tmem = T.decl_buffer( + (128, stage_width_elem), + dtype, + scope="tmem", + allocated_addr=tmem_addr[0], + layout=TileLayout(S[(128, stage_width_elem) : (1 @ TLane, 1 @ TCol)]), + ) + # Stage A -> TMEM via the standard .32x32b path. + stage_reg = T.alloc_local((stage_w,), dtype) + stage_local = stage_reg.view(128, stage_w, layout=stage_view) + for ci in range(n_stage): + coff = ci * stage_w + for i in range(stage_w // VEC_LEN): + g = T.meta_var(tid_in_wg * stage_width_elem + coff + i * VEC_LEN) + Tx.copy(stage_reg[i * VEC_LEN : i * VEC_LEN + VEC_LEN], A_flat[g : g + VEC_LEN]) + T.cuda.cta_sync() + Tx.wg.copy_async(tmem[:, coff : coff + stage_w], stage_local[:, :]) + T.ptx.tcgen05.wait.st() + T.cuda.cta_sync() + + # (a) one full-width load + ff = T.alloc_local((per_thread_elems,), dtype) + ffl = ff.view(frag_rows, K_cols_fp32, layout=atom_view) + Tx.wg.copy_async(ffl[:, :], tmem[0:frag_rows, 0:K_cols_fp32]) + T.ptx.tcgen05.wait.ld() + T.cuda.cta_sync() + for i in range(per_thread_elems): + Bf[tid_in_wg, i] = ff[i] + + # (b) the same frag loaded in n_chunks column slices + sf = T.alloc_local((per_thread_elems,), dtype) + sfl = sf.view(frag_rows, K_cols_fp32, layout=atom_view) + for ck in range(n_chunks): + lo = T.meta_var(ck * chunk_elem) + Tx.wg.copy_async( + sfl[:, lo : lo + chunk_elem], tmem[0:frag_rows, lo : lo + chunk_elem] + ) + T.ptx.tcgen05.wait.ld() + T.cuda.cta_sync() + for i in range(per_thread_elems): + Bs[tid_in_wg, i] = sf[i] + + if warp_id == 0: + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) + T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=tmem_col_width_32b, cta_group=1) + + target = tvm.target.Target("cuda") + with target: + mod = tvm.IRModule({"main": kernel}) + mod = tvm.compile(mod, target=target, tir_pipeline="tirx") + A_np = tvm.testing.generate_random_array(dtype, (128, stage_width_elem)) + Bf_np = np.zeros((128, per_thread_elems), dtype=dtype) + Bs_np = np.zeros((128, per_thread_elems), dtype=dtype) + DEV = tvm.cuda(0) + A = tvm.runtime.tensor(A_np, DEV) + Bf = tvm.runtime.tensor(Bf_np, DEV) + Bs = tvm.runtime.tensor(Bs_np, DEV) + mod(A, Bf, Bs) + # Sliced load must reproduce the full-width load bit-for-bit. + np.testing.assert_array_equal(Bs.numpy().view(np.uint32), Bf.numpy().view(np.uint32)) + + +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") +@pytest.mark.parametrize( + "full_rep, n_chunks", + [ + (32, 8), # 16x256b.x32 (256 fp32 cols) loaded in 8 chunks of 32 cols (nvfp4 EPI_TILE=32) + (32, 16), # ...in 16 chunks of 16 cols (nvfp4 EPI_TILE=16) + (32, 4), # ...in 4 chunks of 64 cols + (16, 8), # 16x256b.x16 (128 fp32 cols) in 8 chunks of 16 cols + (16, 2), # ...in 2 chunks of 64 cols + ], +) +def test_tcgen05_ld_16x256b_sliced_matches_full_M128(full_rep, n_chunks): + """Per-chunk column-slice load of a wide M=128 frag == full-width load.""" + _run_sliced_vs_full_load("16x256b", full_rep, n_chunks) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_binary.py b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_binary.py index 1ce0d34ea6e0..8d39ba355633 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_binary.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_binary.py @@ -23,6 +23,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx.layout import S, TileLayout, wg_local_layout @@ -67,6 +68,8 @@ ), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["add", "sub", "mul", "fdiv"]) @pytest.mark.parametrize("operands_type", ["region_region", "region_const", "const_region"]) @pytest.mark.parametrize("dtype", ["float16"]) @@ -223,6 +226,8 @@ def bad_kernel() -> None: tvm.compile(mod, target=target, tir_pipeline="tirx") +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("exec_scope", ["warp", "warpgroup"]) @pytest.mark.parametrize("op_type", ["add", "mul"]) def test_binary_op_shared_subcta_scope(exec_scope, op_type): @@ -276,6 +281,8 @@ def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(A_ref, A.numpy(), atol=1e-3) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("exec_scope", ["cta", "warpgroup", "warp"]) @pytest.mark.parametrize("rhs_kind", ["region", "broadcast", "const"]) @pytest.mark.parametrize("op_type", ["add", "sub", "mul", "fdiv"]) @@ -392,6 +399,8 @@ def kernel(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: ), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("storage_scope", ["shared", "local"]) @pytest.mark.parametrize("exec_scope", ["cta", "thread"]) @pytest.mark.parametrize("op_type", ["add", "sub", "mul", "fdiv"]) @@ -495,6 +504,8 @@ def get_prim_func(): tvm.testing.assert_allclose(A_ref, A.numpy(), atol=atol) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["add", "sub", "mul"]) def test_binary_op_packed_f32x2_auto_dispatch(op_type): target = tvm.target.Target("cuda") @@ -568,6 +579,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(A_ref, A.numpy(), atol=1e-3) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_name", ["add", "sub", "mul"]) def test_binary_op_warpgroup_wg_local_layout(op_name): dtype = "float32" diff --git a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_fma.py b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_fma.py index aa0f5ced8f58..02352638e4d6 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_fma.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_fma.py @@ -26,6 +26,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx.layout import S, TileLayout, wg_local_layout @@ -41,6 +42,8 @@ def _get_sm_version(): # --------------------------------------------------------------------------- # FMA op: scalar scale + scalar bias # --------------------------------------------------------------------------- +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_fma_scalar_scalar(): sm = _get_sm_version() if sm < 100: @@ -78,6 +81,8 @@ def test_func(A_ptr: T.handle) -> None: # --------------------------------------------------------------------------- # FMA op: buffer scale + scalar bias (Horner pattern) # --------------------------------------------------------------------------- +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_fma_buffer_scale_scalar_bias(): sm = _get_sm_version() if sm < 100: @@ -119,6 +124,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: # --------------------------------------------------------------------------- # Binary op with scalar broadcast (PrimExpr scalar, e.g. BufferLoad) # --------------------------------------------------------------------------- +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_mul_scalar_broadcast(): sm = _get_sm_version() if sm < 100: @@ -158,6 +165,8 @@ def test_func(A_ptr: T.handle, S_ptr: T.handle) -> None: # --------------------------------------------------------------------------- # Binary add with rounding mode # --------------------------------------------------------------------------- +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_add_rounding_mode(): sm = _get_sm_version() if sm < 100: @@ -199,6 +208,8 @@ def test_func(A_ptr: T.handle) -> None: # --------------------------------------------------------------------------- # FMA op: layout=None local buffer (no TileLayout) # --------------------------------------------------------------------------- +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_fma_no_layout(): sm = _get_sm_version() if sm < 100: @@ -238,6 +249,8 @@ def test_func(A_ptr: T.handle) -> None: # --------------------------------------------------------------------------- # Binary sub with rounding mode (buffer-buffer) # --------------------------------------------------------------------------- +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_sub_buffer_buffer_rounding(): sm = _get_sm_version() if sm < 100: @@ -278,6 +291,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(expected, A_dev.numpy(), atol=1e-6) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_fma_warpgroup_wg_local_layout(): rows, cols = 128, 8 dtype = "float32" diff --git a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py index c20df63bebf0..fb70b3754123 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py @@ -23,6 +23,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx.cuda.operator.tile_primitive.layout_utils import ( cast_layout_supported_for_local as _cast_layout_supported_for_local, ) @@ -54,6 +55,8 @@ ), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["zero", "sqrt"]) @pytest.mark.parametrize( "src_dtype,dst_dtype", [("float16", "float16"), ("float32", "float16"), ("float32", "bfloat16")] @@ -145,6 +148,8 @@ def get_ref(A_np): tvm.testing.assert_allclose(B_ref, B.numpy(), atol=1e-2, rtol=1e-2) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("exec_scope", ["warp", "warpgroup"]) def test_unary_op_shared_subcta_scope(exec_scope): dtype = "float16" @@ -209,6 +214,8 @@ def unary_op_subcta(A_ptr: T.handle) -> None: ), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["sqrt", "exp"]) @pytest.mark.parametrize("bias_type", ["const", "region"]) @pytest.mark.parametrize( @@ -432,6 +439,8 @@ def get_ref(A_np, bias_np): ), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["reciprocal", "exp", "exp2"]) @pytest.mark.parametrize( "src_dtype,dst_dtype", [("float16", "float16"), ("float32", "float16"), ("float32", "bfloat16")] @@ -554,6 +563,8 @@ def test_unary(A_ptr: T.handle, B_ptr: T.handle) -> None: ), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["sqrt", "exp"]) @pytest.mark.parametrize("bias_type", ["const", "region"]) @pytest.mark.parametrize( @@ -682,6 +693,8 @@ def get_ref(A_np, bias_np): tvm.testing.assert_allclose(B_ref, B.numpy(), atol=atol) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("shape", [(128, 8), (128, 4, 16), (128, 5, 5)]) @pytest.mark.parametrize("op_type", ["fill"]) @pytest.mark.parametrize("exec_scope", ["thread", "cta"]) @@ -740,6 +753,8 @@ def test_unary_cta(A_ptr: T.handle) -> None: tvm.testing.assert_allclose(A.numpy(), np.full(shape, value.value), atol=1e-2) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["zero", "sqrt", "reciprocal", "exp", "silu"]) @pytest.mark.parametrize("dtype", ["float16"]) def test_unary_op_local_thread_wise(op_type, dtype): @@ -791,6 +806,8 @@ def kernel(A_ptr: T.handle) -> None: tvm.testing.assert_allclose(A_ref, A.numpy(), atol=1e-2, rtol=1e-2) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("shape", [(8,), (16, 16), (5, 5)]) @pytest.mark.parametrize("A_dtype", ["float16", "float32"]) @pytest.mark.parametrize("B_dtype", ["float16", "float32"]) @@ -831,6 +848,8 @@ def test_cast(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(B.numpy(), B_ref, atol=1e-2) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("A_dtype,B_dtype", [("float32", "float16"), ("float32", "bfloat16")]) def test_cast_warpgroup_local_view(A_dtype, B_dtype): """T.cast in warpgroup scope with offset (tid_in_wg + layout offset). Covers offset/tid_in_wg/warpgroup scope.""" # noqa: E501 @@ -884,6 +903,8 @@ def test_cast(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(B.numpy(), B_ref, atol=1e-2) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("A_dtype,B_dtype", [("float32", "float16"), ("float32", "bfloat16")]) def test_cast_warpgroup_src_layout_to_flat_uses_vec2_intrinsic(A_dtype, B_dtype): """Regression: GEMM-epilogue cast pattern must emit the packed vec2 cuda intrinsic. @@ -944,6 +965,8 @@ def test_cast(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(B.numpy(), B_ref, atol=1e-2) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("A_dtype,B_dtype", [("float32", "float16"), ("float32", "bfloat16")]) def test_cast_cta_local_view(A_dtype, B_dtype): """T.cast with view+layout in CTA scope (128 threads, register->register).""" @@ -988,6 +1011,8 @@ def test_cast(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(B.numpy(), B_ref, atol=1e-2) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("A_dtype,B_dtype", [("float32", "float16"), ("float32", "bfloat16")]) @pytest.mark.parametrize("slice_start,slice_end", [(0, 4), (2, 6), (4, 8)]) def test_cast_local_view_sliced(A_dtype, B_dtype, slice_start, slice_end): @@ -1087,6 +1112,8 @@ def test_cast_layout_partition_and_validation(): check(part) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("slice_start,slice_end", [(0, 2), (2, 4)]) def test_cast_mixed_axes_and_subregion(slice_start, slice_end): """Test cast with mixed axes and subregion.""" @@ -1095,7 +1122,7 @@ def test_cast_mixed_axes_and_subregion(slice_start, slice_end): LOCAL_LEN = 4 full_shape = (8, N_WARPS, 4, LOCAL_LEN) g_layout = TileLayout(S[full_shape]) - cast_layout = TileLayout(S[full_shape : (4 @ laneid, 2 @ warpid, 1 @ laneid, 1)]) + cast_layout = TileLayout(S[full_shape : (4 @ laneid, 1 @ warpid, 1 @ laneid, 1)]) A_ref = np.zeros(full_shape, dtype="float32") for j in range(full_shape[0]): @@ -1207,8 +1234,12 @@ def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: target = tvm.target.Target("cuda") with target: mod = tvm.IRModule({"main": kernel}) + # The mismatched dst also fails the scope-level check (thread axes don't + # span the full CTA), which fires first — either rejection is fine. with pytest.raises( - Exception, match="tile_local_valid|layout signature mismatch|thread part mismatch" + Exception, + match="tile_local_valid|layout signature mismatch|thread part mismatch" + "|do not tile a complete|not the full", ): tvm.compile(mod, target=target, tir_pipeline="tirx") @@ -1277,5 +1308,138 @@ def k(A_ptr: T.handle, B_ptr: T.handle) -> None: ), f"expected packed vec2 cast {intrinsic}; got:\n{src[:2000]}" +# ----------------------------------------------------------------------------- +# Scope-level operand check: a warp/wg/cta reg op needs a scope-level layout +# (thread axes spanning all the scope's threads), not a thread-local .local(). +# ----------------------------------------------------------------------------- +_SL_ROWS, _SL_COLS = 128, 8 + + +def _sl_compile(fn): + target = tvm.target.Target("cuda") + with target: + tvm.compile(tvm.IRModule({"main": fn}), target=target, tir_pipeline="tirx") + + +def test_cast_wg_rejects_thread_local_view(): + """Tx.wg.cast on a .local() (thread-axis-stripped) view is rejected.""" + + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) + B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) + T.device_entry() + _bx = T.cta_id([1]) + _wg = T.warpgroup_id([1]) + tid = T.thread_id_in_wg([_SL_ROWS]) + src = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float32", scope="local", layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)])) + dst = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float16", scope="local", layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)])) + src_row = src.local(_SL_COLS) + for i in T.serial(_SL_COLS): + src_row[i] = A[tid, i] + Tx.wg.cast(dst.local(), src.local()) + dst_row = dst.local(_SL_COLS) + for i in T.serial(_SL_COLS): + B[tid, i] = dst_row[i] + + with pytest.raises(Exception, match="thread-local view"): + _sl_compile(kernel) + + +def test_cast_cta_rejects_thread_local_view(): + """Tx.cta.cast on a .local() view is rejected (cta -> tx).""" + + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) + B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) + T.device_entry() + _bx = T.cta_id([1]) + tx_var = T.thread_id([_SL_ROWS]) + src = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float32", scope="local", layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tx, 1)])) + dst = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float16", scope="local", layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tx, 1)])) + src_row = src.local(_SL_COLS) + for i in T.serial(_SL_COLS): + src_row[i] = A[tx_var, i] + Tx.cta.cast(dst.local(), src.local()) + dst_row = dst.local(_SL_COLS) + for i in T.serial(_SL_COLS): + B[tx_var, i] = dst_row[i] + + with pytest.raises(Exception, match="thread-local view"): + _sl_compile(kernel) + + +def test_cast_wg_rejects_partial_thread_coverage(): + """A tid_in_wg layout covering only 64 of the 128 wg threads is rejected.""" + half = 64 + + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (half, _SL_COLS), "float32", layout=TileLayout(S[(half, _SL_COLS)])) + B = T.match_buffer(B_ptr, (half, _SL_COLS), "float16", layout=TileLayout(S[(half, _SL_COLS)])) + T.device_entry() + _bx = T.cta_id([1]) + _wg = T.warpgroup_id([1]) + tid = T.thread_id_in_wg([_SL_ROWS]) + src = T.alloc_buffer((half, _SL_COLS), "float32", scope="local", layout=TileLayout(S[(half, _SL_COLS) : (1 @ tid_in_wg, 1)])) + dst = T.alloc_buffer((half, _SL_COLS), "float16", scope="local", layout=TileLayout(S[(half, _SL_COLS) : (1 @ tid_in_wg, 1)])) + src_row = src.local(_SL_COLS) + for i in T.serial(_SL_COLS): + src_row[i] = A[tid, i] + Tx.wg.cast(dst, src) + dst_row = dst.local(_SL_COLS) + for i in T.serial(_SL_COLS): + B[tid, i] = dst_row[i] + + with pytest.raises(Exception, match="not the full 128"): + _sl_compile(kernel) + + +def test_cast_wg_accepts_wg_level_layout(): + """Tx.wg.cast on a wg-level (tid_in_wg-distributed) layout compiles.""" + + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) + B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) + T.device_entry() + _bx = T.cta_id([1]) + _wg = T.warpgroup_id([1]) + tid = T.thread_id_in_wg([_SL_ROWS]) + src = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float32", scope="local", layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)])) + dst = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float16", scope="local", layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)])) + src_row = src.local(_SL_COLS) + for i in T.serial(_SL_COLS): + src_row[i] = A[tid, i] + Tx.wg.cast(dst, src) + dst_row = dst.local(_SL_COLS) + for i in T.serial(_SL_COLS): + B[tid, i] = dst_row[i] + + _sl_compile(kernel) + + +def test_cast_thread_accepts_local_view(): + """thread scope is exempt: a thread-axis-free local tile still compiles.""" + + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) + B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) + T.device_entry() + _bx = T.cta_id([1]) + tx_var = T.thread_id([_SL_ROWS]) + src = T.alloc_buffer((_SL_COLS,), "float32", scope="local", layout=TileLayout(S[(_SL_COLS,)])) + dst = T.alloc_buffer((_SL_COLS,), "float16", scope="local", layout=TileLayout(S[(_SL_COLS,)])) + for i in T.serial(_SL_COLS): + src[i] = A[tx_var, i] + Tx.cast(dst, src) + for i in T.serial(_SL_COLS): + B[tx_var, i] = dst[i] + + _sl_compile(kernel) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py b/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py index e0a270e7091a..32ac00e39d5f 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py @@ -32,6 +32,7 @@ from tvm.ir.type import PointerType, PrimType from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx.cuda.operator.tile_primitive.gemm_async import sf_tmem_layout from tvm.tirx.cuda.operator.tile_primitive.tma_utils import ( mma_atom_layout, @@ -167,6 +168,8 @@ def pack_sf_fp8_uint32(sf_uint8, n_total=128): return packed +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize( "task", [ @@ -293,6 +296,8 @@ def gemm_async(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-3, rtol=1e-3) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") def test_gemm_tcgen05_cta_group_1_layout_f_m64(): """M=64 MMA with C operand allocated as Layout F (datapath="F"). @@ -405,6 +410,8 @@ def gemm_layout_f(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-2, rtol=1e-2) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize( "task", [ @@ -545,6 +552,8 @@ def gemm_async(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-3, rtol=1e-3) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") def test_gemm_tcgen05_cta_group_2_layout_b(): """Test cta_group=2 with Layout B (2x2 datapath, M=128 total, 64 per CTA). @@ -675,6 +684,8 @@ def gemm_async(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-3, rtol=1e-3) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes") @pytest.mark.parametrize( "task", @@ -864,6 +875,8 @@ def gemm_async_fn(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle, SFA_ptr: T. np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1.0, rtol=0.15) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes") @pytest.mark.parametrize( "task", @@ -1089,6 +1102,8 @@ def gemm_async_fn(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle, SFA_ptr: T. np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1.0, rtol=0.15) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes") def test_gemm_block_scaled_nvfp4_cta_group_1(): """Test block-scaled nvfp4 GEMM with cta_group=1. @@ -1258,6 +1273,8 @@ def gemm_async_fn(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle, SFA_ptr: T. np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1.0, rtol=0.15) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes") def test_gemm_block_scaled_nvfp4_cta_group_2(): """Test block-scaled nvfp4 GEMM with cta_group=2. @@ -1462,6 +1479,8 @@ def gemm_async_fn(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle, SFA_ptr: T. np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1.0, rtol=0.15) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes") def test_gemm_block_scaled_fp8_sf_id(): """Test sf_id auto-derivation from layout for fp8 block-scaled MMA. @@ -1681,6 +1700,8 @@ def per_block_quantize_fp8(mat, block_size=32): ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize( "task", [ @@ -1960,6 +1981,8 @@ def gemm_async(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-3, rtol=1e-3) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize("k_lo,k_hi", [(0, 16), (0, 32), (16, 32), (16, 48), (32, 64)]) def test_gemm_tcgen05_contiguous_kslice_partial_k(k_lo, k_hi): """A slice on the *contiguous* (K) axis of a swizzled gemm_async operand must diff --git a/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py b/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py index 67cc1e0bd6fa..0402719ba1e5 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py @@ -43,6 +43,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env # Helpers exposed by the dispatcher module for direct algorithm tests. from tvm.tirx.cuda.operator.tile_primitive.permute_layout.warp_xor_swizzle import ( @@ -167,6 +168,8 @@ def _compile_and_run(prim_func, np_inputs): return [t.numpy() for t in tensors], mod.mod.imports[0].inspect_source() +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @needs_cuda @pytest.mark.parametrize( "name, pipe, blk, dtype", @@ -231,6 +234,8 @@ def f(A: T.handle, B: T.handle): np.testing.assert_array_equal(B_flat, ref, err_msg=f"{name} stage {s}") +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @needs_cuda def test_identity_passes_through_as_copy(): """L_src == L_dst should still compile and produce a correct (identity) copy.""" @@ -255,6 +260,8 @@ def f(A: T.handle, B: T.handle): np.testing.assert_array_equal(B_out, A_np) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @needs_cuda @pytest.mark.parametrize("dtype", ["uint32", "int32", "float32"]) @pytest.mark.parametrize( diff --git a/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py b/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py index 0474ad2dc46a..9031aa4f487f 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py @@ -21,6 +21,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx.layout import R, S, TileLayout, laneid, wg_local_layout @@ -41,6 +42,8 @@ ((32, 32), (32,), (-1,), (1, 1), (2,), (5, 8), (5,)), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["sum", "max", "min"]) @pytest.mark.parametrize("dtype", ["float32", "float16"]) @pytest.mark.parametrize("accum", [False, True]) @@ -129,6 +132,8 @@ def test_reduction(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(ref, B.numpy()[tuple(reduce_slice_dst)], atol=atol) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("exec_scope", ["warp", "warpgroup", "thread"]) @pytest.mark.parametrize("op_type", ["sum", "max", "min"]) @pytest.mark.parametrize("accum", [False, True]) @@ -264,6 +269,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: ((2, 3, 4), (3, 4), (0,)), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["sum", "max", "min"]) @pytest.mark.parametrize("accum", [False, True]) def test_reduction_local_thread_wise(src_shape, dst_shape, axes, op_type, accum): @@ -367,6 +374,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: ((4, 8), (1, 8), (1,), False, None), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["sum", "max", "min"]) def test_reduction_local_view_basic(inner_dims, dst_dims, axes, accum, slice_end, op_type): """Test view-based local reduction with simple purely-local layouts.""" @@ -484,6 +493,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(ref, B.numpy(), atol=1e-5) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("n_groups, n_warps", [(1, 1), (1, 4), (2, 8)]) @pytest.mark.parametrize("op_type", ["sum", "max", "min"]) @pytest.mark.parametrize("dtype", ["float32", "float16"]) @@ -616,6 +627,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(B_ref, B.numpy(), atol=atol) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("reduction_len", [8, 16, 64, 128, 256, 7, 10, 15, 100]) @pytest.mark.parametrize("op_type", ["max", "min"]) @pytest.mark.parametrize("accum", [False, True]) @@ -685,6 +698,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(B_ref, B.numpy()[0], atol=1e-5) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("reduction_len", [8, 16, 64, 128, 256, 9, 17, 63, 65, 100]) @pytest.mark.parametrize("accum", [False, True]) def test_reduction_local_optimized_packed_add_sum(reduction_len, accum): @@ -746,6 +761,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(B_ref, B.numpy()[0], atol=1e-4) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["sum", "max"]) @pytest.mark.parametrize("dtype", ["float32", "float16"]) def test_reduction_op_warp_shuffle(op_type, dtype): @@ -807,6 +824,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(B_ref, B.numpy(), atol=atol) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["sum", "max"]) @pytest.mark.parametrize("dtype", ["float32", "float16"]) def test_reduction_op_warp_shuffle_multi_elem(op_type, dtype): @@ -875,6 +894,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(B_ref, B.numpy(), atol=atol) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_reduction_warp_shuffle_multi_warp_loop(): """Test intra-warp + cross-warp reduction via T.sum in a for loop with multiple warps. @@ -951,6 +972,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(B_ref, B_dev.numpy(), atol=1e-3) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_name", ["sum", "max"]) def test_reduction_warpgroup_wg_local_layout(op_name): rows, cols = 128, 16 diff --git a/tests/python/tirx/test_buffer_print.py b/tests/python/tirx/test_buffer_print.py index 211f4d390313..dbd0da8f849a 100644 --- a/tests/python/tirx/test_buffer_print.py +++ b/tests/python/tirx/test_buffer_print.py @@ -18,10 +18,12 @@ import re import numpy as np +import pytest import tvm import tvm.testing from tvm.script import tirx as T +from tvm.testing import env def generate_random_data(shape, dtype): @@ -181,6 +183,8 @@ def verify_cuda_code_string(func, expected_var_name, expected_string_literal): ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_print(): DEV = tvm.cuda() target = tvm.target.Target("cuda") diff --git a/tests/python/tirx/test_control_flow.py b/tests/python/tirx/test_control_flow.py index 1f905bd03cc9..9085c2b0213b 100644 --- a/tests/python/tirx/test_control_flow.py +++ b/tests/python/tirx/test_control_flow.py @@ -15,9 +15,11 @@ # specific language governing permissions and limitations # under the License. import numpy as np +import pytest import tvm from tvm.script import tirx as T +from tvm.testing import env def run_test_break_continue(func, shape, expected): @@ -32,6 +34,8 @@ def run_test_break_continue(func, shape, expected): np.testing.assert_allclose(arr.numpy(), expected) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_break_continue1(): # fmt: off @T.prim_func @@ -53,6 +57,8 @@ def func(A_ptr: T.handle): run_test_break_continue(func, (10,), expected) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_break_continue2(): # fmt: off @T.prim_func @@ -79,6 +85,8 @@ def func(A_ptr: T.handle): run_test_break_continue(func, (9,), expected) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_break_continue3(): # fmt: off @T.prim_func diff --git a/tests/python/tirx/test_layout.py b/tests/python/tirx/test_layout.py index e3711cb00cd2..0dcf212ce271 100644 --- a/tests/python/tirx/test_layout.py +++ b/tests/python/tirx/test_layout.py @@ -1733,5 +1733,40 @@ def test_slice_single_shard_skips_defensive_floormod(): # we just assert offset is non-empty and structurally sane (not None). +def test_slice_tcgen05_frag_layout_scope_consistent(): + """Slicing a wid_in_wg+laneid frag layout (tcgen05 16x256b) must stay + scope-consistent: the sliced result canonicalizes to a single tid_in_wg + chain over the full 128 threads (regression for the per-group-fusion bug). + """ + frag = TileLayout( + S[(4, 2, 2, 8, 4, 4, 2) : (1 @ wid_in_wg, 16, 2, 4 @ laneid, 4, 1 @ laneid, 1)] + ) + + def thread_chain(layout): + canon = layout.canonicalize() + names = {it.axis.name for it in canon.shard if it.axis.is_thread()} + titers = sorted( + ((int(it.stride), int(it.extent)) for it in canon.shard if it.axis.is_thread()), + ) + running = 1 + for stride, extent in titers: + assert stride == running, f"non-contiguous thread chain: {titers}" + running *= extent + return names, running + + with tvm.target.Target("cuda"): + # Full-region slice and a column sub-slice must both canonicalize to a + # single tid_in_wg chain covering all 128 warpgroup threads. + full = frag.slice([128, 32], [(0, 128), (0, 32)]) + names, total = thread_chain(full) + assert names == {"tid_in_wg"}, names + assert total == 128, total + + col = frag.slice([128, 32], [(0, 128), (16, 32)]) + names_c, total_c = thread_chain(col) + assert names_c == {"tid_in_wg"}, names_c + assert total_c == 128, total_c + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/scripts/task_python_unittest.sh b/tests/scripts/task_python_unittest.sh index ec052281ad11..15bb51bdf73d 100755 --- a/tests/scripts/task_python_unittest.sh +++ b/tests/scripts/task_python_unittest.sh @@ -55,6 +55,7 @@ TEST_FILES=( "tirx-analysis" "tirx-base" "tirx-transform" + "tirx" "tvmscript" "relax" )