笔者最近在工作中需要用到一些高性能计算的优化,于是准备着手系统性进行学习。有大佬建议先从triton学起,并且推荐了triton puzzles和triton的tutorial作为入门资料。以下是我练习triton puzzles时对一些解法的分析,记录一下作为心得。
https://github.com/SiriusNEO/Triton-Puzzles-Lite
@triton.jit
def add_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
off_x = tl.arange(0, B0)
x = tl.load(x_ptr + off_x)
z = x + 10.0 # Add the constant
tl.store(z_ptr + off_x, z)
@triton.jit
def add_mask2_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
off_x = tl.arange(0, B0) + tl.program_id(0) * B0
mask = off_x
@triton.jit
def add_vec_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
i = tl.arange(0, B0)[:, None]
j = tl.arange(0, B1)[None, :]
x = tl.load(x_ptr + i)
y = tl.load(y_ptr + j)
z = x + y
tl.store(z_ptr + i * N1 + j, z)
@triton.jit
def add_vec_block_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
block_id_x = tl.program_id(0)
block_id_y = tl.program_id(1)
off_x = block_id_x * B0 + tl.arange(0, B0)
off_y = block_id_y * B1 + tl.arange(0, B1)
mask_x = off_x
@triton.jit
def mul_relu_block_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
block_id_x = tl.program_id(0)
block_id_y = tl.program_id(1)
off_x = block_id_x * B0 + tl.arange(0, B0)
off_y = block_id_y * B1 + tl.arange(0, B1)
mask_x = off_x 0, z, 0) # Apply ReLU
tl.store(z_ptr + off_x * N1 + off_y, z, mask=mask_x & mask_y)
@triton.jit
def mul_relu_block_back_kernel(x_ptr, y_ptr, dz_ptr, dx_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
block_id_x = tl.program_id(0)
block_id_y = tl.program_id(1)
off_x = block_id_x * B0 + tl.arange(0, B0)
off_y = block_id_y * B1 + tl.arange(0, B1)
mask_x = off_x
@triton.jit
def sum_kernel(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
pid = tl.program_id(0)
block_start = pid * B0
offsets = block_start + tl.arange(0, B0)
mask = offsets
本文由博客一文多发平台 OpenWrite 发布!
登录查看全部
参与评论
手机查看
返回顶部