diff --git a/lib/candlex/backend.ex b/lib/candlex/backend.ex index f7d1c31..acf1d88 100644 --- a/lib/candlex/backend.ex +++ b/lib/candlex/backend.ex @@ -495,8 +495,6 @@ defmodule Candlex.Backend do unsupported_option!(opts, :feature_group_size, 1) # For now we assume: - # strides = opts[:strides] # [1, 1] - # padding = opts[:padding] # [{0, 0}, {0, 0}] # input_dilation = opts[:input_dilation] # [1, 1] # kernel_dilation = opts[:kernel_dilation] # [1, 1] @@ -528,10 +526,26 @@ defmodule Candlex.Backend do |> Native.to_type(to_candle_dtype(out_type)) |> unwrap!() + padding = + case opts[:padding] do + [{p, p}] -> p + [{p, p}, {p, p}] -> p + p -> raise("unsupported padding #{inspect(p)}") + end + + stride = + case opts[:strides] do + [s] -> s + [s, s] -> s + s -> raise("unsupported strides #{inspect(s)}") + end + + conv_opts = %Candlex.Native.ConvOpts{padding: padding, stride: stride, dilation: 1, groups: 1} + native_result = case Nx.rank(shape) do - 3 -> Native.conv1d(native_tensor, native_kernel) - 4 -> Native.conv2d(native_tensor, native_kernel) + 3 -> Native.conv1d(native_tensor, native_kernel, conv_opts) + 4 -> Native.conv2d(native_tensor, native_kernel, conv_opts) rank -> raise("unsupported conv for tensor of rank #{rank}, only 3 or 4 supported") end diff --git a/lib/candlex/native.ex b/lib/candlex/native.ex index 427d39b..f0aefa6 100644 --- a/lib/candlex/native.ex +++ b/lib/candlex/native.ex @@ -49,8 +49,8 @@ defmodule Candlex.Native do def dtype(_tensor), do: error() def t_shape(_tensor), do: error() def concatenate(_tensors, _axis), do: error() - def conv1d(_tensor, _kernel), do: error() - def conv2d(_tensor, _kernel), do: error() + def conv1d(_tensor, _kernel, _opts), do: error() + def conv2d(_tensor, _kernel, _opts), do: error() def slice_scatter(_tensor, _src, _dim, _start), do: error() def pad_with_zeros(_tensor, _left, _right), do: error() def clamp(_tensor, _min, _max), do: error() @@ -136,3 +136,7 @@ defmodule Candlex.Native do defp error(), do: :erlang.nif_error(:nif_not_loaded) end + +defmodule Candlex.Native.ConvOpts do + defstruct [:padding, :stride, :dilation, :groups] +end diff --git a/native/candlex/src/kernels.rs b/native/candlex/src/kernels.rs index 13317b3..c6627ef 100644 --- a/native/candlex/src/kernels.rs +++ b/native/candlex/src/kernels.rs @@ -1,4 +1,4 @@ #[rustfmt::skip] -pub const CUSTOM_BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/src/kernels//custom_binary.ptx")); -#[rustfmt::skip] pub const CUSTOM_UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/src/kernels//custom_unary.ptx")); +#[rustfmt::skip] +pub const CUSTOM_BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/src/kernels//custom_binary.ptx")); diff --git a/native/candlex/src/tensors.rs b/native/candlex/src/tensors.rs index 780046b..6e89df6 100644 --- a/native/candlex/src/tensors.rs +++ b/native/candlex/src/tensors.rs @@ -337,35 +337,56 @@ pub fn concatenate(ex_tensors: Vec, dim: usize) -> Result Self { + Self { + padding: 0, + stride: 1, + dilation: 1, + groups: 1, + } + } +} + #[rustler::nif(schedule = "DirtyCpu")] -pub fn conv1d(tensor: ExTensor, kernel: ExTensor) -> Result { - let padding = 0; - let stride = 1; - let dilation = 1; - let groups = 1; +pub fn conv1d( + tensor: ExTensor, + kernel: ExTensor, + options: Option, +) -> Result { + let opts = options.unwrap_or_default(); Ok(ExTensor::new(tensor.conv1d( kernel.deref(), - padding, - stride, - dilation, - groups, + opts.padding, + opts.stride, + opts.dilation, + opts.groups, )?)) } #[rustler::nif(schedule = "DirtyCpu")] -pub fn conv2d(tensor: ExTensor, kernel: ExTensor) -> Result { - let padding = 0; - let stride = 1; - let dilation = 1; - let groups = 1; +pub fn conv2d( + tensor: ExTensor, + kernel: ExTensor, + options: Option, +) -> Result { + let opts = options.unwrap_or_default(); Ok(ExTensor::new(tensor.conv2d( kernel.deref(), - padding, - stride, - dilation, - groups, + opts.padding, + opts.stride, + opts.dilation, + opts.groups, )?)) } diff --git a/test/candlex_test.exs b/test/candlex_test.exs index 9ab9d07..0488ce0 100644 --- a/test/candlex_test.exs +++ b/test/candlex_test.exs @@ -1478,6 +1478,97 @@ defmodule CandlexTest do ]) ) + Nx.iota({1, 1, 10, 10}) + |> Nx.conv(Nx.iota({1, 1, 3, 3}), strides: 3, padding: :same) + |> assert_equal( + t([ + [ + [ + [163.0, 313.0, 412.0, 301.0], + [945.0, 1374.0, 1482.0, 930.0], + [1755.0, 2454.0, 2562.0, 1560.0], + [1057.0, 1369.0, 1414.0, 779.0] + ] + ] + ]) + ) + + Nx.iota({1, 1, 10, 10}) + |> Nx.conv(Nx.iota({1, 1, 3, 3}), strides: 2, padding: :same) + |> assert_equal( + t([ + [ + [ + [163.0, 280.0, 346.0, 412.0, 478.0], + [675.0, 978.0, 1050.0, 1122.0, 1194.0], + [1215.0, 1698.0, 1770.0, 1842.0, 1914.0], + [1755.0, 2418.0, 2490.0, 2562.0, 2634.0], + [2295.0, 3138.0, 3210.0, 3282.0, 3354.0] + ] + ] + ]) + ) + + Nx.iota({1, 1, 10, 10}) + |> Nx.conv(Nx.iota({2, 1, 3, 3}), strides: 2, padding: :same) + |> assert_equal( + t([ + [ + [ + [163.0, 280.0, 346.0, 412.0, 478.0], + [675.0, 978.0, 1050.0, 1122.0, 1194.0], + [1215.0, 1698.0, 1770.0, 1842.0, 1914.0], + [1755.0, 2418.0, 2490.0, 2562.0, 2634.0], + [2295.0, 3138.0, 3210.0, 3282.0, 3354.0] + ], + [ + [361.0, 658.0, 832.0, 1006.0, 1180.0], + [1782.0, 2760.0, 2994.0, 3228.0, 3462.0], + [3402.0, 5.1e3, 5334.0, 5568.0, 5802.0], + [5022.0, 7440.0, 7674.0, 7908.0, 8142.0], + [6642.0, 9780.0, 10014.0, 10248.0, 10482.0] + ] + ] + ]) + ) + + Nx.iota({1, 1, 10, 10}) + |> Nx.conv(Nx.iota({2, 1, 3, 3}), strides: 4) + |> assert_equal( + t([ + [ + [ + [582.0, 726.0], + [2022.0, 2166.0] + ], + [ + [1473.0, 1941.0], + [6153.0, 6621.0] + ] + ] + ]) + ) + + Nx.iota({1, 1, 10, 10}) + |> Nx.conv(Nx.iota({2, 1, 3, 3}), + strides: 4, + output_permutation: [0, 3, 1, 2] + ) + |> assert_equal( + t([ + [ + [ + [582.0, 1473.0], + [726.0, 1941.0] + ], + [ + [2022.0, 6153.0], + [2166.0, 6621.0] + ] + ] + ]) + ) + # Nx.iota({1, 1, 3, 3}) # |> Nx.conv( # Nx.iota({4, 1, 2, 1}),