Skip to content

Commit

Permalink
Integrate Triton up to [88c704e](https://github.com/openai/triton/com…
Browse files Browse the repository at this point in the history
  • Loading branch information
Moerafaat authored and Google-ML-Automation committed Dec 23, 2024
1 parent c3419d6 commit 859cc39
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
10 changes: 6 additions & 4 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,9 @@ def get_or_create_triton_kernel(
specialization_attr = backend.get_attrs_descriptor(fn.params[:len(args_for_specialization_attr)], args_for_specialization_attr) # pylint: disable=protected-access
constants = dict(metaparams)
constants.update({k: None for _, k, v in scalar_args if v is None})
constants.update({fn.arg_names[i]: 1 for i in specialization_attr.equal_to_1})
constants.update({fn.arg_names[i]: 1 for (i,) in specialization_attr.equal_to_1})
for constant in constants:
signature[constant] = "constexpr"

# Cache key should contain any parameter that can affect the compiler output.
cache_key = (
Expand Down Expand Up @@ -413,7 +415,7 @@ def get_or_create_triton_kernel(
fn,
specialization=tc.ASTSource(
fn,
constants=constants,
constexprs=constants,
signature=signature,
attrs=specialization_attr,
),
Expand All @@ -429,7 +431,7 @@ def get_or_create_triton_kernel(
fn,
specialization=tc.ASTSource(
fn,
constants=constants,
constexprs=constants,
signature=signature,
attrs=specialization_attr,
),
Expand Down Expand Up @@ -634,7 +636,7 @@ def prune_configs(configs, named_args, **kwargs):
16 if (i in specialization_attr.divisibility_16) else 0,
)
)
elif i not in specialization_attr.equal_to_1:
elif (i,) not in specialization_attr.equal_to_1:
kernel_params.append(
triton_kernel_call_lib.create_scalar_parameter(arg, dtype)
)
Expand Down
4 changes: 2 additions & 2 deletions tests/triton_call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,10 +564,10 @@ def test_specialization(self):
# Pointers are assumed to divide by 16, as do `M`, `N`, `stride_{bk,cm}`.
# However, we've marked `a_ptr`, `M`, `stride_bk`, and `c_ptr` as "do not
# specialize", leaving `b_ptr`, `N`, and `stride_cm`.
self.assertEqual(specialization.attrs.divisibility_16, [1, 3, 9])
self.assertEqual(specialization.attrs.divisibility_16, [(1,), (3,), (9,)])
# `stride_{ak,bn,cn}` equal 1, but we've marked `stride_ak` as "do not
# specialize" leaving `stride_{bn,cn}`.
self.assertEqual(specialization.attrs.equal_to_1, [8, 10])
self.assertEqual(specialization.attrs.equal_to_1, [(8,), (10,)])


if __name__ == "__main__":
Expand Down

0 comments on commit 859cc39

Please sign in to comment.