diff --git a/lib/candlex/backend.ex b/lib/candlex/backend.ex index acf1d88..ed02b08 100644 --- a/lib/candlex/backend.ex +++ b/lib/candlex/backend.ex @@ -557,13 +557,13 @@ defmodule Candlex.Backend do @impl true def dot( - %T{type: _out_type} = out, - %T{shape: left_shape, type: _left_type} = left, - [left_axis] = _left_axes, - [] = _left_batched_axes, - %T{shape: right_shape, type: _right_type} = right, - [0] = _right_axes, - [] = _right_batched_axes + %T{} = out, + %T{shape: left_shape} = left, + [left_axis], + [], + %T{shape: right_shape} = right, + [0], + [] ) when tuple_size(left_shape) >= 1 and tuple_size(right_shape) == 1 and left_axis == tuple_size(left_shape) - 1 do @@ -576,13 +576,13 @@ defmodule Candlex.Backend do end def dot( - %T{type: _out_type} = out, - %T{shape: left_shape, type: _left_type} = left, - [1] = _left_axes, - [] = _left_batched_axes, - %T{shape: right_shape, type: _right_type} = right, - [0] = _right_axes, - [] = _right_batched_axes + %T{} = out, + %T{shape: left_shape} = left, + [1], + [], + %T{shape: right_shape} = right, + [0], + [] ) when tuple_size(left_shape) == 2 and tuple_size(right_shape) == 2 do {left, right} = maybe_upcast(left, right)