Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions src/ntops/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,61 @@
bitwise_not,
bitwise_or,
bmm,
cast,
ceil,
clamp,
conv2d,
cos,
div,
dropout,
elu,
eq,
exp,
floor,
ge,
gelu,
gelu_backward,
gt,
hardswish,
hardtanh,
isinf,
isnan,
layer_norm,
le,
leaky_relu,
log,
log_softmax,
lt,
max_pool2d,
mish,
mm,
mul,
ne,
neg,
pow,
prelu,
reciprocal,
relu,
relu6,
relu_backward,
rms_norm,
rotary_position_embedding,
rsqrt,
scaled_dot_product_attention,
selu,
sigmoid,
sigmoid_backward,
silu,
sin,
softmax,
softplus,
softsign,
sqrt,
sub,
tanh,
tanh_backward,
tanhshrink,
where,
)

__all__ = [
Expand All @@ -50,36 +73,59 @@
"bitwise_not",
"bitwise_or",
"bmm",
"cast",
"ceil",
"clamp",
"conv2d",
"cos",
"div",
"dropout",
"elu",
"eq",
"exp",
"floor",
"ge",
"gelu",
"gelu_backward",
"gt",
"hardswish",
"hardtanh",
"isinf",
"isnan",
"layer_norm",
"le",
"leaky_relu",
"log",
"log_softmax",
"lt",
"max_pool2d",
"mish",
"mm",
"mul",
"ne",
"neg",
"pow",
"prelu",
"reciprocal",
"relu",
"relu6",
"relu_backward",
"rms_norm",
"rotary_position_embedding",
"rsqrt",
"scaled_dot_product_attention",
"selu",
"sigmoid",
"sigmoid_backward",
"silu",
"sin",
"softmax",
"softplus",
"softsign",
"sqrt",
"sub",
"tanh",
"tanh_backward",
"tanhshrink",
"where",
]
21 changes: 21 additions & 0 deletions src/ntops/kernels/cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import functools

import ninetoothed
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, output):
output = input # noqa: F841


def premake(ndim, input_dtype=None, output_dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=input_dtype),
Tensor(ndim, dtype=output_dtype),
)

return arrangement_, application, tensors
18 changes: 18 additions & 0 deletions src/ntops/kernels/ceil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, output):
output = ntl.cast(ntl.ceil(ntl.cast(input, ntl.float32)), input.dtype) # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))

return arrangement_, application, tensors
25 changes: 25 additions & 0 deletions src/ntops/kernels/elu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, alpha, output):
input_f32 = ntl.cast(input, ntl.float32)
result = ntl.where(input >= 0, input, ntl.cast(alpha * (ntl.exp(input_f32) - 1), input.dtype))
output = result # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(0, dtype=ninetoothed.float64),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
18 changes: 18 additions & 0 deletions src/ntops/kernels/floor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, output):
output = ntl.cast(ntl.floor(ntl.cast(input, ntl.float32)), input.dtype) # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))

return arrangement_, application, tensors
25 changes: 25 additions & 0 deletions src/ntops/kernels/gelu_backward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(grad_output, input, grad_input):
input_f32 = ntl.cast(input, ntl.float32)
cdf = 0.5 * (1.0 + ntl.erf(input_f32 * 0.7071067811865476))
pdf = ntl.exp(-0.5 * input_f32 * input_f32) * 0.3989422804014327
grad_input = grad_output * ntl.cast(cdf + input_f32 * pdf, grad_output.dtype) # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
18 changes: 18 additions & 0 deletions src/ntops/kernels/hardswish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, output):
output = input * ntl.clamp(input + 3.0, 0.0, 6.0) / 6.0 # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))

return arrangement_, application, tensors
24 changes: 24 additions & 0 deletions src/ntops/kernels/hardtanh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, min_val, max_val, output):
output = ntl.clamp(input, min_val, max_val) # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(0, dtype=ninetoothed.float64),
Tensor(0, dtype=ninetoothed.float64),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
23 changes: 23 additions & 0 deletions src/ntops/kernels/leaky_relu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, negative_slope, output):
output = ntl.where(input >= 0, input, negative_slope * input) # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(0, dtype=ninetoothed.float64),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
18 changes: 18 additions & 0 deletions src/ntops/kernels/log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, output):
output = ntl.log(ntl.cast(input, ntl.float32)) # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))

return arrangement_, application, tensors
44 changes: 44 additions & 0 deletions src/ntops/kernels/log_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.reduction import arrangement


def _exp(x, dtype):
exp_dtype = dtype if dtype != ntl.float16 else ntl.float32
return ntl.cast(ntl.exp(ntl.cast(x, exp_dtype)), dtype)


def application(input, output):
dtype = output.dtype.dtype
prev_max = ntl.cast(float("-inf"), dtype)
denominator = ntl.cast(0, dtype)

for i in range(input.shape[0]):
input_i = ntl.cast(input[i], dtype)
curr_max = ntl.cast(ntl.maximum(prev_max, ntl.max(input_i)), dtype)
input_max_diff_exp = _exp(input_i - curr_max, dtype)
prev_curr_max_diff_exp = _exp(prev_max - curr_max, dtype)
denominator = denominator * prev_curr_max_diff_exp + ntl.sum(input_max_diff_exp)
prev_max = curr_max

log_dtype = dtype if dtype != ntl.float16 else ntl.float32

for i in range(input.shape[0]):
log_denominator = ntl.log(ntl.cast(denominator, log_dtype))
output[i] = ntl.cast(ntl.cast(input[i], log_dtype) - ntl.cast(prev_max, log_dtype) - log_denominator, dtype)


def premake(ndim, dim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size)

tensors = (
Tensor(
ndim, dtype=dtype, other=float("-inf"), shape_options={"constexpr": True}
),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
24 changes: 24 additions & 0 deletions src/ntops/kernels/mish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, output):
input_f32 = ntl.cast(input, ntl.float32)
sp = ntl.log(1 + ntl.exp(input_f32))
exp_sp = ntl.exp(sp)
exp_neg_sp = ntl.exp(-sp)
tanh_sp = (exp_sp - exp_neg_sp) / (exp_sp + exp_neg_sp)
result = ntl.cast(input_f32 * tanh_sp, input.dtype)
output = result # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))

return arrangement_, application, tensors
Loading