Skip to content

Commit 6895087

Browse files
authored
[RUST] Add conv3d transpose Rust bindings (#11471)
* Add conv3d transpose Rust bindings * Fix typename * Add base
1 parent 4a769c1 commit 6895087

File tree

2 files changed

+22
-3
lines changed
  • include/tvm/relay/attrs
  • rust/tvm/src/ir/relay/attrs

2 files changed

+22
-3
lines changed

include/tvm/relay/attrs/nn.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -377,9 +377,9 @@ struct Conv3DTransposeAttrs : public tvm::AttrsNode<Conv3DTransposeAttrs> {
377377
Array<IndexExpr> output_padding;
378378
Array<IndexExpr> dilation;
379379
int groups;
380-
std::string data_layout;
381-
std::string kernel_layout;
382-
std::string out_layout;
380+
tvm::String data_layout;
381+
tvm::String kernel_layout;
382+
tvm::String out_layout;
383383
DataType out_dtype;
384384

385385
TVM_DECLARE_ATTRS(Conv3DTransposeAttrs, "relay.attrs.Conv3DTransposeAttrs") {

rust/tvm/src/ir/relay/attrs/nn.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,25 @@ pub struct Conv3DAttrsNode {
9494
pub out_dtype: DataType,
9595
}
9696

97+
#[repr(C)]
98+
#[derive(Object, Debug)]
99+
#[ref_name = "Conv3DTransposeAttrs"]
100+
#[type_key = "relay.attrs.Conv3DTransposeAttrs"]
101+
pub struct Conv3DTransposeAttrsNode {
102+
pub base: BaseAttrsNode,
103+
pub channels: IndexExpr,
104+
pub kernel_size: Array<IndexExpr>,
105+
pub strides: Array<IndexExpr>,
106+
pub padding: Array<IndexExpr>,
107+
pub output_padding: Array<IndexExpr>,
108+
pub dilation: Array<IndexExpr>,
109+
pub groups: i32,
110+
pub data_layout: TString,
111+
pub kernel_layout: TString,
112+
pub out_layout: TString,
113+
pub out_dtype: DataType,
114+
}
115+
97116
#[repr(C)]
98117
#[derive(Object, Debug)]
99118
#[ref_name = "BiasAddAttrs"]

0 commit comments

Comments
 (0)