Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions python/tvm/backend/cuda/lang/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def wait(self, stage, phase):
T.ptx.mbarrier.try_wait(self.buf.ptr_to([stage]), phase ^ self.phase_offset)

@T.inline
def arrive(self, stage, cta_id=None, pred=None):
def arrive(self, stage, cta_id=None, pred=None, count=None):
# Default: local-CTA arrive — emits the simple
# ``mbarrier.arrive.shared.b64`` form. To arrive on a remote
# CTA's mbarrier in a cluster kernel, callers must pass
Expand All @@ -119,11 +119,18 @@ def arrive(self, stage, cta_id=None, pred=None):
# the cross-CTA path was both surprising (``bar.arrive(stage)``
# silently ``mapa`` ed across the cluster) and a per-call cost
# of ~3 PTX ops on every single-CTA kernel.
#
# ``count`` (cross-CTA path only) emits the explicit arrival-count
# operand, i.e. ``mbarrier.arrive.shared::cluster.b64 _, [addr], count``.
# When ``None`` the implicit count-of-1 form is emitted. Passing
# ``count=1`` is semantically identical but spells the count explicitly.
if cta_id is None:
T.ptx.mbarrier.arrive(self.buf.ptr_to([stage]))
else:
actual_pred = True if pred is None else pred
T.ptx.mbarrier.arrive(self.buf.ptr_to([stage]), cta_id=cta_id, pred=actual_pred)
T.ptx.mbarrier.arrive(
self.buf.ptr_to([stage]), cta_id=cta_id, pred=actual_pred, count=count
)

def ptr_to(self, idx):
return self.buf.ptr_to(idx)
Expand Down
135 changes: 132 additions & 3 deletions python/tvm/backend/cuda/lang/tile_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
instances are automatically treated as meta values inside @T.prim_func.
"""

from tvm.backend.cuda.lang.pipeline import Pipeline, PipelineState
from tvm.script import tirx as T


Expand Down Expand Up @@ -753,13 +754,20 @@ class FlashAttentionLPTScheduler(BaseTileScheduler):
"""

def __init__(
self, prefix: str, num_batches: int, num_heads: int, num_m_blocks: int, l2_swizzle: int
self,
prefix: str,
num_batches: int,
num_heads: int,
num_m_blocks: int,
l2_swizzle: int,
num_ctas: int | None = None,
):
super().__init__(prefix)
self._num_batches = num_batches
self._num_heads = num_heads
self._num_m_blocks = num_m_blocks
self._l2_swizzle = l2_swizzle
self._num_ctas = num_ctas
self._total_tasks = num_batches * num_heads * num_m_blocks

# Derived constants for L2 swizzle
Expand Down Expand Up @@ -807,10 +815,131 @@ def init(self, cta_id):

@T.inline
def next_tile(self):
"""Advance to next tile by striding by num_ctas."""
self.linear_idx = self._total_tasks
"""Advance to the next tile.

Single-tile mode (``num_ctas=None``, the default): each CTA owns one
task; terminate. Persistent mode (``num_ctas=N``): stride by N, like
:class:`FlashAttentionLinearScheduler`, while keeping the LPT + L2
swizzle index mapping.
"""
if self._num_ctas is None:
self.linear_idx = self._total_tasks
else:
self.linear_idx = self.linear_idx + self._num_ctas
self.update_current_m_n_idx(self.linear_idx)
# fmt: on

def valid(self):
"""Check if there are more tiles to process."""
return self.linear_idx < self._total_tasks


class _CLCWorker(ClusterPersistentScheduler2D):
"""Per-role CLC handle: IS-A ClusterPersistentScheduler2D (so m_idx / n_idx work as
usual) plus the role-local barrier phase and handshake. A coord-free role (e.g. an
MMA warp consuming whatever a loader staged) arms the loop with reset() not init().
"""

def __init__(self, clc, prefix):
super().__init__(
prefix,
num_m_tiles=clc._num_m_tiles,
num_n_tiles=clc._num_n_tiles,
num_clusters=clc._num_m_tiles * clc._num_n_tiles,
l2_group_size=clc._l2_group_size,
)
self._clc = clc
self._sa = PipelineState(1, 0)
self._done = T.local_scalar("int32")
self._nxt = T.local_scalar("uint32")

@T.inline
def reset(self):
self._done = 0

@T.inline
def init(self, cluster_id):
# Explicit base call: TVMScript's parser has no zero-arg super().
ClusterPersistentScheduler2D.init(self, cluster_id)
self._done = 0

def valid(self):
return self._done == 0

@T.inline
def consume(self):
# Single-elected-thread scope: wait for the handle, decode, release the slot.
self._clc.sched_arr.full.wait(0, self._sa.phase)
self._sa.advance()
self._nxt = T.ptx.clc_query_cancel(T.address_of(self._clc.clc_handle[0]))
self._clc.sched_fin.empty.arrive(0, cta_id=0, pred=True)

@T.inline
def consume_wg(self, wg_id, warp_id, lane_id):
# Warpgroup scope: all threads decode; one elected lane releases the slot.
self._clc.sched_arr.full.wait(0, self._sa.phase)
self._sa.advance()
self._nxt = T.ptx.clc_query_cancel(T.address_of(self._clc.clc_handle[0]))
T.cuda.warpgroup_sync(wg_id + 1)
if (warp_id == 0) & (lane_id == 0):
self._clc.sched_fin.empty.arrive(0, cta_id=0, pred=True)

@T.inline
def advance_coords(self):
if self._nxt != 0xFFFFFFFF:
self.update_current_m_n_idx(self._nxt // self._clc._cta_group)

@T.inline
def mark_done_if_drained(self):
if self._nxt == 0xFFFFFFFF:
self._done = 1


@T.meta_class
class ClusterLaunchControlScheduler:
"""Blackwell Cluster Launch Control (CLC) tile scheduler.

A scheduler warp runs ``run_scheduler`` (issues ``try_cancel`` to steal the next
cluster); worker roles each take a ``worker()`` handle and pull the stolen tile
through the shared smem handshake. Owns the CLC smem: the 16B response handle, the
arrival barrier (handle ready), and the finished barrier (slot consumed;
``finish_arrivals`` arrivals per round). Tile-coord mapping is delegated to
``ClusterPersistentScheduler2D`` (group-major L2 ordering).
"""

def __init__(self, pool, num_m_tiles, num_n_tiles, l2_group_size, cta_group, finish_arrivals):
self._num_m_tiles = num_m_tiles
self._num_n_tiles = num_n_tiles
self._l2_group_size = l2_group_size
self._cta_group = cta_group
self.sched_arr = Pipeline(pool, 1, full="tma", empty="mbar", init_empty=1)
self.sched_fin = Pipeline(pool, 1, full="mbar", empty="mbar", init_empty=finish_arrivals)
self.clc_handle = pool.alloc((4,), "uint32", align=16)
self._s_done = T.local_scalar("int32")
self._s_nxt = T.local_scalar("uint32")

def worker(self, prefix):
return _CLCWorker(self, prefix)

@T.inline
def run_scheduler(self, cbx):
# cta0 drives try_cancel; both CTAs expect_bytes + consume the handle so the
# finished-barrier count is met and the slot can be reissued.
if T.ptx.elect_sync():
sa = PipelineState(1, 0)
sf = PipelineState(1, 1)
self._s_done = 0
while self._s_done == 0:
if cbx == 0:
self.sched_fin.empty.wait(0, sf.phase)
sf.advance()
T.ptx.clc_try_cancel(
T.address_of(self.clc_handle[0]), T.address_of(self.sched_arr.full.buf[0])
)
self.sched_arr.full.arrive(0, 16) # expect_bytes for the 16B handle
self.sched_arr.full.wait(0, sa.phase)
sa.advance()
self._s_nxt = T.ptx.clc_query_cancel(T.address_of(self.clc_handle[0]))
self.sched_fin.empty.arrive(0, cta_id=0, pred=True)
if self._s_nxt == 0xFFFFFFFF:
self._s_done = 1
79 changes: 75 additions & 4 deletions python/tvm/backend/cuda/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,12 +653,12 @@ def ptx_mbarrier_init(bar, thread_count):
return call_intrin("", "tirx.ptx_mbarrier_init", bar, thread_count)


def ptx_mbarrier_arrive(bar, cta_id=None, pred=None):
def ptx_mbarrier_arrive(bar, cta_id=None, pred=None, count=None):
"""TVM intrinsic to call
mbarrier.arrive.shared::cta.b64
or
@p mapa.shared::cluster.u32
@p mbarrier.arrive.shared::cluster.b64
@p mbarrier.arrive.shared::cluster.b64 [, count]

Parameters
----------
Expand All @@ -670,11 +670,29 @@ def ptx_mbarrier_arrive(bar, cta_id=None, pred=None):

pred : Optional[PrimExpr]
The predicate to guard the operation.

count : Optional[PrimExpr]
Explicit arrival count operand for the cross-CTA (cluster) form. When
``None`` the implicit count-of-1 form is emitted; when given, emits
``mbarrier.arrive.shared::cluster.b64 _, [addr], count``.
"""
if cta_id is None and pred is None:
return call_intrin("", "tirx.ptx_mbarrier_arrive", bar)
assert cta_id is not None and pred is not None
return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, pred)
if count is None:
return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, pred)
return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, pred, count)


def ptx_mbarrier_arrive_cluster_count(bar, cta_id, count):
"""Cross-CTA ``mbarrier.arrive`` on CTA ``cta_id`` with an explicit count.

Convenience for an already-elected thread: emits
``@p mapa.shared::cluster.u32`` + ``@p mbarrier.arrive.shared::cluster.b64 _,
[addr], count`` with the guard defaulted to 1.
"""
return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, True, count)



def ptx_mbarrier_arrive_expect_tx(bar, byte_count, cta_id=None, pred=None):
Expand Down Expand Up @@ -706,7 +724,11 @@ def ptx_mbarrier_arrive_expect_tx(bar, byte_count, cta_id=None, pred=None):
"""
if cta_id is None and pred is None:
return call_intrin("", "tirx.ptx_mbarrier_arrive_expect_tx", bar, byte_count)
assert cta_id is not None and pred is not None
assert cta_id is not None
# Cross-CTA expect_tx from an already-elected thread: default the guard to 1
# (the caller has elected a single lane), so callers can pass cta_id alone.
if pred is None:
pred = True
return call_intrin("", "tirx.ptx_mbarrier_arrive_expect_tx", bar, byte_count, cta_id, pred)


Expand All @@ -729,6 +751,23 @@ def ptx_mbarrier_try_wait(bar, phase):
return call_intrin("", "tirx.ptx_mbarrier_try_wait", bar, phase)


def ptx_mbarrier_try_wait_acquire_cluster(bar, phase):
"""``mbarrier.try_wait.parity.acquire.cluster`` retry loop.

Cluster-scope acquire wait — used to wait on a barrier that a remote CTA in
the cluster arrives on (a group cluster wait).

Parameters
----------
bar : Var
The pointer to barrier variable.

phase : int
The phase of the barrier.
"""
return call_intrin("", "tirx.ptx_mbarrier_try_wait_acquire_cluster", bar, phase)


def ptx_mbarrier_try_wait_once(bar, phase, ticks):
"""TVM intrinsic for one-shot non-blocking ``mbarrier.try_wait.parity``.

Expand Down Expand Up @@ -1261,6 +1300,38 @@ def ptx_barrier_cluster_wait(acquire=False, aligned=True):
return call_intrin("", "tirx.ptx_barrier_cluster_wait", acquire, aligned)


def ptx_clc_try_cancel(handle, mbar):
"""TVM intrinsic to call clusterlaunchcontrol.try_cancel.

Async-requests cancelling the next cluster's launch (work-stealing): writes the
16B response handle to smem and signals ``mbar`` (complete_tx, multicast to both
cluster CTAs).

Parameters
----------
handle : PrimExpr
Pointer to the 16B (uint4) smem response handle.

mbar : PrimExpr
Pointer to the mbarrier signalled when the handle lands.
"""
return call_intrin("", "tirx.ptx_clc_try_cancel", handle, mbar)


def ptx_clc_query_cancel(handle):
"""TVM intrinsic to call clusterlaunchcontrol.query_cancel.

Decodes the response handle written by :func:`ptx_clc_try_cancel`. Returns the
cancelled cluster's first ``ctaid.x``, or ``0xFFFFFFFF`` when no work was stolen.

Parameters
----------
handle : PrimExpr
Pointer to the 16B (uint4) smem response handle.
"""
return call_intrin("uint32", "tirx.ptx_clc_query_cancel", handle)


def ptx_elect_sync():
"""TVM intrinsic to call elect.sync"""
return call_intrin("uint32", "tirx.ptx_elect_sync")
Expand Down
Loading
Loading