From b47c0889446d338c52c89f3c5ca0cd7efce39657 Mon Sep 17 00:00:00 2001 From: xabi Date: Fri, 24 Nov 2023 13:24:48 +0100 Subject: [PATCH] feat: window_max (#44) --------- Co-authored-by: Gonzalo <456459+grzuy@users.noreply.github.com> --- lib/candlex/backend.ex | 18 +++++++- lib/candlex/native.ex | 1 + native/candlex/src/lib.rs | 1 + native/candlex/src/tensors.rs | 9 ++++ test/candlex_test.exs | 82 +++++++++++++++++++++++++++++++++++ 5 files changed, 110 insertions(+), 1 deletion(-) diff --git a/lib/candlex/backend.ex b/lib/candlex/backend.ex index ed02b08..d97def7 100644 --- a/lib/candlex/backend.ex +++ b/lib/candlex/backend.ex @@ -488,6 +488,23 @@ defmodule Candlex.Backend do |> to_nx(out) end + @impl true + def window_max(%T{type: out_type} = out, tensor, {1, 1, dx, dy} = _window_dimensions, opts) do + strides = + case opts[:strides] do + [1, 1, sx, sy] -> {sx, sy} + s -> raise("unsupported strides #{inspect(s)}") + end + + tensor + |> from_nx() + |> Native.to_type(to_candle_dtype(out_type)) + |> unwrap!() + |> Native.max_pool2d({dx, dy}, strides) + |> unwrap!() + |> to_nx(out) + end + @impl true def conv(%T{type: out_type} = out, %T{shape: shape} = tensor, %T{} = kernel, opts) do # TODO: Support more opts @@ -933,7 +950,6 @@ defmodule Candlex.Backend do # TODO: Remove after nx 0.7 is released :random_uniform, :triangular_solve, - :window_max, :window_min, :window_product, :window_sum diff --git a/lib/candlex/native.ex b/lib/candlex/native.ex index f0aefa6..ad43f38 100644 --- a/lib/candlex/native.ex +++ b/lib/candlex/native.ex @@ -55,6 +55,7 @@ defmodule Candlex.Native do def pad_with_zeros(_tensor, _left, _right), do: error() def clamp(_tensor, _min, _max), do: error() def reverse(_tensor, _axes), do: error() + def max_pool2d(_tensor, _dims, _strides), do: error() for op <- [ :abs, diff --git a/native/candlex/src/lib.rs b/native/candlex/src/lib.rs index 98d47bc..6faa729 100644 --- a/native/candlex/src/lib.rs +++ b/native/candlex/src/lib.rs @@ -115,6 +115,7 @@ rustler::init! { tensors::left_shift, tensors::right_shift, tensors::to_device, + tensors::max_pool2d, devices::is_cuda_available ], load = load diff --git a/native/candlex/src/tensors.rs b/native/candlex/src/tensors.rs index 6e89df6..ed725bf 100644 --- a/native/candlex/src/tensors.rs +++ b/native/candlex/src/tensors.rs @@ -390,6 +390,15 @@ pub fn conv2d( )?)) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn max_pool2d( + tensor: ExTensor, + dims: (usize, usize), + strides: (usize, usize), +) -> Result { + Ok(ExTensor::new(tensor.max_pool2d_with_stride(dims, strides)?)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn divide(left: ExTensor, right: ExTensor) -> Result { Ok(ExTensor::new( diff --git a/test/candlex_test.exs b/test/candlex_test.exs index 0488ce0..89ab3d8 100644 --- a/test/candlex_test.exs +++ b/test/candlex_test.exs @@ -1600,6 +1600,88 @@ defmodule CandlexTest do # )) end + test "window_max" do + Nx.iota({2, 1, 4, 4}) + |> Nx.window_max({1, 1, 2, 2}, strides: [1, 1, 2, 1]) + |> assert_equal( + t([ + [ + [ + [5, 6, 7], + [13, 14, 15] + ] + ], + [ + [ + [21, 22, 23], + [29, 30, 31] + ] + ] + ]) + ) + + Nx.iota({2, 1, 4, 4}) + |> Nx.window_max({1, 1, 2, 2}, strides: [1, 1, 1, 2]) + |> assert_equal( + t([ + [ + [ + [5, 7], + [9, 11], + [13, 15] + ] + ], + [ + [ + [21, 23], + [25, 27], + [29, 31] + ] + ] + ]) + ) + + Nx.iota({2, 1, 4, 4}) + |> Nx.window_max({1, 1, 2, 1}, strides: [1, 1, 2, 2]) + |> assert_equal( + t([ + [ + [ + [4, 6], + [12, 14] + ] + ], + [ + [ + [20, 22], + [28, 30] + ] + ] + ]) + ) + + Nx.iota({2, 1, 4, 4}) + |> Nx.window_max({1, 1, 2, 1}) + |> assert_equal( + t([ + [ + [ + [4, 5, 6, 7], + [8, 9, 10, 11], + [12, 13, 14, 15] + ] + ], + [ + [ + [20, 21, 22, 23], + [24, 25, 26, 27], + [28, 29, 30, 31] + ] + ] + ]) + ) + end + test "reduce_max" do t(42) |> Nx.reduce_max()