Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7ba9dec
new linear kernels
AnihilatorGun Jan 7, 2025
ff804f4
black isort
AnihilatorGun Jan 7, 2025
ec09d1e
configure LinBReLULinAdd, small fix, tests
AnihilatorGun Jan 9, 2025
eb5c60e
configure
AnihilatorGun Jan 10, 2025
9242b1c
configure
AnihilatorGun Jan 10, 2025
83fb977
linear restructurization
AnihilatorGun Jan 10, 2025
0ca3ca9
conv restructurization, small fix
AnihilatorGun Jan 10, 2025
7b321cb
style
AnihilatorGun Jan 10, 2025
861774e
reformat tests
AnihilatorGun Jan 10, 2025
bc7dff9
black + isort
AnihilatorGun Jan 10, 2025
e442641
test_linbrelulin_backward, small fixes
AnihilatorGun Jan 10, 2025
522dd64
black isort
AnihilatorGun Jan 10, 2025
971dba6
new strategy to compare weight gradients
AnihilatorGun Jan 10, 2025
2400682
black isort flake8
AnihilatorGun Jan 10, 2025
2bf5f32
style
AnihilatorGun Jan 10, 2025
dbf77e8
handle configurator's exceptions
AnihilatorGun Jan 10, 2025
54fae4a
style
AnihilatorGun Jan 10, 2025
46158c8
style
AnihilatorGun Jan 10, 2025
b3c0a1f
check if function can be configured
AnihilatorGun Jan 10, 2025
2f46b41
version
AnihilatorGun Jan 10, 2025
fd439b4
function repr
AnihilatorGun Jan 10, 2025
396b383
cleaner linear test
AnihilatorGun Jan 14, 2025
77a0157
ReLULinearBackward now produces fp32 weight gradient, tests
AnihilatorGun Jan 14, 2025
f0c3c17
dwconv test cleaner
AnihilatorGun Jan 14, 2025
e54b8ae
DWConvWGRAD produces FP32 gradient now, test_conv
AnihilatorGun Jan 14, 2025
0f8d03b
black
AnihilatorGun Jan 14, 2025
0680a52
DWConvWGRAD went back to FP16
AnihilatorGun Jan 15, 2025
5e635a7
ReLULinearBackward went back to FP16
AnihilatorGun Jan 15, 2025
d1436cc
LinbReLULinBackward now produces FP16 gradients, test fix
AnihilatorGun Jan 15, 2025
dc122f3
no underscored names, fix num_warps in linear functions
AnihilatorGun Jan 21, 2025
3b7865a
black
AnihilatorGun Jan 22, 2025
0b2f978
removed 64-channel configuration of LinBReLULinAdd to make consistant…
AnihilatorGun Jan 22, 2025
2e81bf1
fix
AnihilatorGun Jan 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
max-line-length = 120
per-file-ignores =
kerops/kernels/*: B007
__init__.py: F401
__init__.py: F401
kerops/utils.py: E731
2 changes: 1 addition & 1 deletion kerops/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.2'
__version__ = '0.0.3'
132 changes: 132 additions & 0 deletions kerops/kernels/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,135 @@ def _ReLULinearAddBackward(
input_ptr += in_channels * D_block

tl.store(weight_grad_ptr + weight_grad_offset, weight_grad)


@triton.jit
def _LinBReLULinAdd(
input_ptr,
weight_up_ptr,
weight_down_ptr,
bias_ptr,
add_ptr,
output_ptr,
numel_no_channels,
in_channels: tl.constexpr,
hidden_channels: tl.constexpr,
D_block: tl.constexpr,
_ILP: tl.constexpr,
):
pid = tl.program_id(0)
input_ptr += pid * _ILP * in_channels * D_block
add_ptr += pid * _ILP * in_channels * D_block
output_ptr += pid * _ILP * in_channels * D_block

in_channels_offset = tl.arange(0, in_channels)
hidden_channels_offset = tl.arange(0, hidden_channels)
d_offset = tl.arange(0, D_block)

offset = d_offset[:, None] * in_channels + in_channels_offset[None, :]
weight_up_offset = in_channels_offset[:, None] * hidden_channels + hidden_channels_offset[None, :]
weight_down_offset = hidden_channels_offset[:, None] * in_channels + in_channels_offset[None, :]

weight_up = tl.load(weight_up_ptr + weight_up_offset)
weight_down = tl.load(weight_down_ptr + weight_down_offset)
bias = tl.load(bias_ptr + hidden_channels_offset)[None]

for i in tl.static_range(0, _ILP):
mask = d_offset[:, None] < numel_no_channels - (pid * _ILP + i) * D_block

x = tl.load(input_ptr + offset, mask=mask) # , other=0)
add = tl.load(add_ptr + offset, mask=mask) # , other=0)

hidden = tl.dot(x, weight_up, out_dtype=tl.float32, allow_tf32=True).to(tl.float16) + bias
hidden = tl.maximum(hidden, 0.0).to(tl.float16)
output = tl.dot(hidden, weight_down, out_dtype=tl.float32, allow_tf32=True).to(tl.float16) + add

tl.store(output_ptr + offset, output, mask=mask)

input_ptr += in_channels * D_block
output_ptr += in_channels * D_block
add_ptr += in_channels * D_block


@triton.jit
def _LinBReLULinBackward(
input_ptr,
grad_ptr,
input_grad_ptr,
weight_up_ptr,
weight_down_ptr,
bias_ptr,
weight_up_grad_ptr,
weight_down_grad_ptr,
bias_grad_ptr,
numel_no_channels,
in_channels: tl.constexpr,
hidden_channels: tl.constexpr,
D_block: tl.constexpr,
_ILP: tl.constexpr,
):
pid = tl.program_id(0)

input_ptr += pid * _ILP * in_channels * D_block
grad_ptr += pid * _ILP * in_channels * D_block
input_grad_ptr += pid * _ILP * in_channels * D_block
weight_up_grad_ptr += pid * in_channels * hidden_channels
weight_down_grad_ptr += pid * in_channels * hidden_channels
bias_grad_ptr += pid * hidden_channels

in_channels_offset = tl.arange(0, in_channels)
hidden_channels_offset = tl.arange(0, hidden_channels)
d_offset = tl.arange(0, D_block)

offset = d_offset[:, None] * in_channels + in_channels_offset[None, :]
weight_up_offset = in_channels_offset[:, None] * hidden_channels + hidden_channels_offset[None, :]
weight_down_offset = hidden_channels_offset[:, None] * in_channels + in_channels_offset[None, :]

weight_up = tl.load(weight_up_ptr + weight_up_offset)
weight_down = tl.load(weight_down_ptr + weight_down_offset)
bias = tl.load(bias_ptr + hidden_channels_offset)[None]

weight_up_grad = tl.zeros([hidden_channels, in_channels], dtype=tl.float32)
weight_down_grad = tl.zeros([in_channels, hidden_channels], dtype=tl.float32)
bias_grad = tl.zeros([hidden_channels], dtype=tl.float32)

out_offset = in_channels_offset[:, None] + d_offset[None, :] * in_channels

weight_up_grad_offset = hidden_channels_offset[:, None] + in_channels_offset[None, :] * hidden_channels
weight_down_grad_offset = in_channels_offset[:, None] + hidden_channels_offset[None, :] * in_channels

for i in tl.static_range(0, _ILP):
mask = d_offset[:, None] < numel_no_channels - (pid * _ILP + i) * D_block
out_mask = d_offset[None, :] < numel_no_channels - (pid * _ILP + i) * D_block

input = tl.load(input_ptr + offset, mask=mask, other=0.0) # [D_block, in_channels]
grad = tl.load(grad_ptr + offset, mask=mask, other=0.0) # [D_block, in_channels]
gradT = tl.trans(grad) # [in_channels, D_block]

linup = (
tl.dot(input, weight_up, out_dtype=tl.float32, allow_tf32=True).to(tl.float16) + bias
) # [D_block, hidden_channels]
linup_relu = tl.maximum(linup, 0.0).to(tl.float16) # [D_block, hidden_channels]

weight_down_grad += tl.dot(
gradT, linup_relu, out_dtype=tl.float32, allow_tf32=True
) # [in_channels, hidden_channels]

linup_gradT = tl.trans(linup > 0) * tl.dot(weight_down, gradT, out_dtype=tl.float32, allow_tf32=True).to(
tl.float16
) # [hidden_channels, D_block]
weight_up_grad += tl.dot(
linup_gradT, input, out_dtype=tl.float32, allow_tf32=True
) # [hidden_channels, in_channels]
bias_grad += tl.sum(linup_gradT, axis=1)

input_gradT = tl.dot(weight_up, linup_gradT)
tl.store(input_grad_ptr + out_offset, input_gradT, mask=out_mask)

grad_ptr += in_channels * D_block
input_grad_ptr += in_channels * D_block
input_ptr += in_channels * D_block

tl.store(weight_up_grad_ptr + weight_up_grad_offset, weight_up_grad)
tl.store(weight_down_grad_ptr + weight_down_grad_offset, weight_down_grad)
tl.store(bias_grad_ptr + hidden_channels_offset, bias_grad)
16 changes: 8 additions & 8 deletions kerops/ops/addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from ..settings import ConfigurableArg, configure, get_l1_cache


@configure(_l1_cache_bytes=get_l1_cache, _num_warps=8)
def AddStats(x, y, inplace=False, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg):
@configure(l1_cache_bytes=get_l1_cache, num_warps=8)
def AddStats(x, y, inplace=False, *, l1_cache_bytes: ConfigurableArg, num_warps: ConfigurableArg):
num_channels = x.shape[1]
numel = x.numel()
assert x.shape == y.shape
Expand All @@ -19,7 +19,7 @@ def AddStats(x, y, inplace=False, *, _l1_cache_bytes: ConfigurableArg, _num_warp
assert x.is_contiguous(memory_format=torch.channels_last_3d)
assert y.is_contiguous(memory_format=torch.channels_last_3d)

MAX_SIZE = _l1_cache_bytes // x.element_size() # 32768 for fp16
MAX_SIZE = l1_cache_bytes // x.element_size() # 32768 for fp16
numel_no_channels = reduce(lambda x, y: x * y, [s if idx != 1 else 1 for idx, s in enumerate(x.shape)], 1)
other = min(MAX_SIZE // num_channels, numel_no_channels)
other = int(2 ** (floor(log2(other))))
Expand All @@ -44,14 +44,14 @@ def AddStats(x, y, inplace=False, *, _l1_cache_bytes: ConfigurableArg, _num_warp
BLOCK_SIZE=BLOCK_SIZE,
num_channels=num_channels,
block_other=other,
num_warps=_num_warps,
num_warps=num_warps,
)
return output, mean, sqmean


@configure(_l1_cache_bytes=get_l1_cache, _num_warps=8)
@configure(l1_cache_bytes=get_l1_cache, num_warps=8)
def AddStatsBackward(
add_grad, mean_grad, sqmean_grad, add_result, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg
add_grad, mean_grad, sqmean_grad, add_result, *, l1_cache_bytes: ConfigurableArg, num_warps: ConfigurableArg
):
num_channels = add_grad.shape[1]
numel = add_grad.numel()
Expand All @@ -63,7 +63,7 @@ def AddStatsBackward(
assert add_grad.is_contiguous(memory_format=torch.channels_last_3d)
assert add_result.is_contiguous(memory_format=torch.channels_last_3d)

MAX_SIZE = _l1_cache_bytes // add_grad.element_size() # 32768 for fp16
MAX_SIZE = l1_cache_bytes // add_grad.element_size() # 32768 for fp16
numel_no_channels = reduce(lambda x, y: x * y, [s if idx != 1 else 1 for idx, s in enumerate(add_grad.shape)], 1)
other = min(MAX_SIZE // num_channels, numel_no_channels)
other = int(2 ** (floor(log2(other))))
Expand All @@ -83,6 +83,6 @@ def AddStatsBackward(
BLOCK_SIZE=BLOCK_SIZE,
num_channels=num_channels,
block_other=other,
num_warps=_num_warps,
num_warps=num_warps,
)
return output_grad
20 changes: 10 additions & 10 deletions kerops/ops/avgpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@


@configure(
_l1_cache_bytes=get_l1_cache,
_num_warps=2,
l1_cache_bytes=get_l1_cache,
num_warps=2,
)
def AvgPoolCeilStats(x, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg):
def AvgPoolCeilStats(x, *, l1_cache_bytes: ConfigurableArg, num_warps: ConfigurableArg):
num_channels = x.shape[1]
input_d = x.shape[-1]
MAX_SIZE = _l1_cache_bytes // x.element_size() # 32768 for fp16
MAX_SIZE = l1_cache_bytes // x.element_size() # 32768 for fp16

assert input_d * num_channels <= MAX_SIZE
assert num_channels == next_power_of_2(num_channels)
Expand Down Expand Up @@ -55,23 +55,23 @@ def AvgPoolCeilStats(x, *, _l1_cache_bytes: ConfigurableArg, _num_warps: Configu
numel_no_channels_output=numel_no_channels_output,
num_channels=num_channels,
almost_half_d=almost_half_d,
num_warps=_num_warps,
num_warps=num_warps,
)
return output, mean, sqmean


@configure(_l1_cache_bytes=get_l1_cache, _num_warps=4)
@configure(l1_cache_bytes=get_l1_cache, num_warps=4)
def AvgPoolCeilStatsBackward(
inpgrad,
meangrad,
sqmeangrad,
output,
outgrad_shape,
*,
_l1_cache_bytes: ConfigurableArg,
_num_warps: ConfigurableArg,
l1_cache_bytes: ConfigurableArg,
num_warps: ConfigurableArg,
):
MAX_SIZE = _l1_cache_bytes // inpgrad.element_size() # 32768 for fp16
MAX_SIZE = l1_cache_bytes // inpgrad.element_size() # 32768 for fp16
bsize, num_channels, h_outgrad, w_outgrad, d_outgrad = outgrad_shape
d_inpgrad = inpgrad.shape[-1]

Expand Down Expand Up @@ -117,6 +117,6 @@ def AvgPoolCeilStatsBackward(
numel_no_channels_inpgrad=numel_no_channels_inpgrad,
num_channels=num_channels,
almost_half_d=almost_half_d,
num_warps=_num_warps,
num_warps=num_warps,
)
return outgrad
16 changes: 8 additions & 8 deletions kerops/ops/bnrelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from ..settings import ConfigurableArg, configure, get_l1_cache


@configure(_l1_cache_bytes=get_l1_cache, _num_warps=8)
def ApplyBNReLU(x, weight, bias, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg):
@configure(l1_cache_bytes=get_l1_cache, num_warps=8)
def ApplyBNReLU(x, weight, bias, *, l1_cache_bytes: ConfigurableArg, num_warps: ConfigurableArg):
num_channels = x.shape[1]
numel = x.numel()
assert x.ndim == 5
Expand All @@ -18,7 +18,7 @@ def ApplyBNReLU(x, weight, bias, *, _l1_cache_bytes: ConfigurableArg, _num_warps
assert x.is_contiguous(memory_format=torch.channels_last_3d)
assert weight.numel() == bias.numel() == num_channels

MAX_SIZE = _l1_cache_bytes // x.element_size() # 32768 for fp16
MAX_SIZE = l1_cache_bytes // x.element_size() # 32768 for fp16
numel_no_channels = reduce(lambda x, y: x * y, [s if idx != 1 else 1 for idx, s in enumerate(x.shape)], 1)
other = min(MAX_SIZE // num_channels, numel_no_channels)
other = int(2 ** (floor(log2(other))))
Expand All @@ -36,13 +36,13 @@ def ApplyBNReLU(x, weight, bias, *, _l1_cache_bytes: ConfigurableArg, _num_warps
BLOCK_SIZE=BLOCK_SIZE,
num_channels=num_channels,
block_other=other,
num_warps=_num_warps,
num_warps=num_warps,
)
return output


@configure(_l1_cache_bytes=get_l1_cache, _num_warps=8)
def ApplyBNReLUBackward(x, weight, bias, grad, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg):
@configure(l1_cache_bytes=get_l1_cache, num_warps=8)
def ApplyBNReLUBackward(x, weight, bias, grad, *, l1_cache_bytes: ConfigurableArg, num_warps: ConfigurableArg):
num_channels = x.shape[1]
numel = x.numel()
assert x.ndim == 5
Expand All @@ -54,7 +54,7 @@ def ApplyBNReLUBackward(x, weight, bias, grad, *, _l1_cache_bytes: ConfigurableA
assert grad.is_contiguous(memory_format=torch.channels_last_3d)
assert weight.numel() == bias.numel() == num_channels

MAX_SIZE = _l1_cache_bytes // x.element_size() # 32768 for fp16
MAX_SIZE = l1_cache_bytes // x.element_size() # 32768 for fp16
numel_no_channels = reduce(lambda x, y: x * y, [s if idx != 1 else 1 for idx, s in enumerate(x.shape)], 1)
other = min(MAX_SIZE // num_channels, numel_no_channels)
other = int(2 ** (floor(log2(other))))
Expand All @@ -77,6 +77,6 @@ def ApplyBNReLUBackward(x, weight, bias, grad, *, _l1_cache_bytes: ConfigurableA
BLOCK_SIZE=BLOCK_SIZE,
num_channels=num_channels,
block_other=other,
num_warps=_num_warps,
num_warps=num_warps,
)
return outgrad, weight_grad, bias_grad
Loading