Skip to content

Commit

Permalink
feat: support Nx.conv limited padding and stride (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
xabi authored Nov 21, 2023
1 parent fd21e30 commit b55a50b
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 26 deletions.
22 changes: 18 additions & 4 deletions lib/candlex/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -495,8 +495,6 @@ defmodule Candlex.Backend do
unsupported_option!(opts, :feature_group_size, 1)

# For now we assume:
# strides = opts[:strides] # [1, 1]
# padding = opts[:padding] # [{0, 0}, {0, 0}]
# input_dilation = opts[:input_dilation] # [1, 1]
# kernel_dilation = opts[:kernel_dilation] # [1, 1]

Expand Down Expand Up @@ -528,10 +526,26 @@ defmodule Candlex.Backend do
|> Native.to_type(to_candle_dtype(out_type))
|> unwrap!()

padding =
case opts[:padding] do
[{p, p}] -> p
[{p, p}, {p, p}] -> p
p -> raise("unsupported padding #{inspect(p)}")
end

stride =
case opts[:strides] do
[s] -> s
[s, s] -> s
s -> raise("unsupported strides #{inspect(s)}")
end

conv_opts = %Candlex.Native.ConvOpts{padding: padding, stride: stride, dilation: 1, groups: 1}

native_result =
case Nx.rank(shape) do
3 -> Native.conv1d(native_tensor, native_kernel)
4 -> Native.conv2d(native_tensor, native_kernel)
3 -> Native.conv1d(native_tensor, native_kernel, conv_opts)
4 -> Native.conv2d(native_tensor, native_kernel, conv_opts)
rank -> raise("unsupported conv for tensor of rank #{rank}, only 3 or 4 supported")
end

Expand Down
8 changes: 6 additions & 2 deletions lib/candlex/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ defmodule Candlex.Native do
def dtype(_tensor), do: error()
def t_shape(_tensor), do: error()
def concatenate(_tensors, _axis), do: error()
def conv1d(_tensor, _kernel), do: error()
def conv2d(_tensor, _kernel), do: error()
def conv1d(_tensor, _kernel, _opts), do: error()
def conv2d(_tensor, _kernel, _opts), do: error()
def slice_scatter(_tensor, _src, _dim, _start), do: error()
def pad_with_zeros(_tensor, _left, _right), do: error()
def clamp(_tensor, _min, _max), do: error()
Expand Down Expand Up @@ -136,3 +136,7 @@ defmodule Candlex.Native do

defp error(), do: :erlang.nif_error(:nif_not_loaded)
end

defmodule Candlex.Native.ConvOpts do
defstruct [:padding, :stride, :dilation, :groups]
end
4 changes: 2 additions & 2 deletions native/candlex/src/kernels.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#[rustfmt::skip]
pub const CUSTOM_BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/src/kernels//custom_binary.ptx"));
#[rustfmt::skip]
pub const CUSTOM_UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/src/kernels//custom_unary.ptx"));
#[rustfmt::skip]
pub const CUSTOM_BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/src/kernels//custom_binary.ptx"));
57 changes: 39 additions & 18 deletions native/candlex/src/tensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,35 +337,56 @@ pub fn concatenate(ex_tensors: Vec<ExTensor>, dim: usize) -> Result<ExTensor, Ca
Ok(ExTensor::new(Tensor::cat(&tensors[..], dim)?))
}

#[derive(NifStruct)]
#[module = "Candlex.Native.ConvOpts"]
pub struct ConvOpts {
padding: usize,
stride: usize,
dilation: usize,
groups: usize,
}
impl Default for ConvOpts {
fn default() -> Self {
Self {
padding: 0,
stride: 1,
dilation: 1,
groups: 1,
}
}
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn conv1d(tensor: ExTensor, kernel: ExTensor) -> Result<ExTensor, CandlexError> {
let padding = 0;
let stride = 1;
let dilation = 1;
let groups = 1;
pub fn conv1d(
tensor: ExTensor,
kernel: ExTensor,
options: Option<ConvOpts>,
) -> Result<ExTensor, CandlexError> {
let opts = options.unwrap_or_default();

Ok(ExTensor::new(tensor.conv1d(
kernel.deref(),
padding,
stride,
dilation,
groups,
opts.padding,
opts.stride,
opts.dilation,
opts.groups,
)?))
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn conv2d(tensor: ExTensor, kernel: ExTensor) -> Result<ExTensor, CandlexError> {
let padding = 0;
let stride = 1;
let dilation = 1;
let groups = 1;
pub fn conv2d(
tensor: ExTensor,
kernel: ExTensor,
options: Option<ConvOpts>,
) -> Result<ExTensor, CandlexError> {
let opts = options.unwrap_or_default();

Ok(ExTensor::new(tensor.conv2d(
kernel.deref(),
padding,
stride,
dilation,
groups,
opts.padding,
opts.stride,
opts.dilation,
opts.groups,
)?))
}

Expand Down
91 changes: 91 additions & 0 deletions test/candlex_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,97 @@ defmodule CandlexTest do
])
)

Nx.iota({1, 1, 10, 10})
|> Nx.conv(Nx.iota({1, 1, 3, 3}), strides: 3, padding: :same)
|> assert_equal(
t([
[
[
[163.0, 313.0, 412.0, 301.0],
[945.0, 1374.0, 1482.0, 930.0],
[1755.0, 2454.0, 2562.0, 1560.0],
[1057.0, 1369.0, 1414.0, 779.0]
]
]
])
)

Nx.iota({1, 1, 10, 10})
|> Nx.conv(Nx.iota({1, 1, 3, 3}), strides: 2, padding: :same)
|> assert_equal(
t([
[
[
[163.0, 280.0, 346.0, 412.0, 478.0],
[675.0, 978.0, 1050.0, 1122.0, 1194.0],
[1215.0, 1698.0, 1770.0, 1842.0, 1914.0],
[1755.0, 2418.0, 2490.0, 2562.0, 2634.0],
[2295.0, 3138.0, 3210.0, 3282.0, 3354.0]
]
]
])
)

Nx.iota({1, 1, 10, 10})
|> Nx.conv(Nx.iota({2, 1, 3, 3}), strides: 2, padding: :same)
|> assert_equal(
t([
[
[
[163.0, 280.0, 346.0, 412.0, 478.0],
[675.0, 978.0, 1050.0, 1122.0, 1194.0],
[1215.0, 1698.0, 1770.0, 1842.0, 1914.0],
[1755.0, 2418.0, 2490.0, 2562.0, 2634.0],
[2295.0, 3138.0, 3210.0, 3282.0, 3354.0]
],
[
[361.0, 658.0, 832.0, 1006.0, 1180.0],
[1782.0, 2760.0, 2994.0, 3228.0, 3462.0],
[3402.0, 5.1e3, 5334.0, 5568.0, 5802.0],
[5022.0, 7440.0, 7674.0, 7908.0, 8142.0],
[6642.0, 9780.0, 10014.0, 10248.0, 10482.0]
]
]
])
)

Nx.iota({1, 1, 10, 10})
|> Nx.conv(Nx.iota({2, 1, 3, 3}), strides: 4)
|> assert_equal(
t([
[
[
[582.0, 726.0],
[2022.0, 2166.0]
],
[
[1473.0, 1941.0],
[6153.0, 6621.0]
]
]
])
)

Nx.iota({1, 1, 10, 10})
|> Nx.conv(Nx.iota({2, 1, 3, 3}),
strides: 4,
output_permutation: [0, 3, 1, 2]
)
|> assert_equal(
t([
[
[
[582.0, 1473.0],
[726.0, 1941.0]
],
[
[2022.0, 6153.0],
[2166.0, 6621.0]
]
]
])
)

# Nx.iota({1, 1, 3, 3})
# |> Nx.conv(
# Nx.iota({4, 1, 2, 1}),
Expand Down

0 comments on commit b55a50b

Please sign in to comment.