Skip to content

Commit 8af2b71

Browse files
Feat: Support trait with CubeCL (#1980)
1 parent c9e9054 commit 8af2b71

File tree

7 files changed

+341
-154
lines changed

7 files changed

+341
-154
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub(crate) mod signature;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
use quote::ToTokens;
2+
3+
use crate::tracker::VariableTracker;
4+
5+
pub fn expand_sig(
6+
sig: &syn::Signature,
7+
visibility: &syn::Visibility,
8+
mut variable_tracker: Option<&mut VariableTracker>,
9+
) -> proc_macro2::TokenStream {
10+
let mut inputs = quote::quote!();
11+
12+
for input in &sig.inputs {
13+
match input {
14+
syn::FnArg::Typed(pat) => {
15+
let ident = pat.pat.clone();
16+
17+
if let syn::Pat::Ident(ident) = ident.as_ref() {
18+
if let Some(vars) = &mut variable_tracker {
19+
vars.codegen_declare(ident.ident.to_string(), 0);
20+
}
21+
}
22+
23+
let ty = no_ref(pat.ty.as_ref());
24+
inputs.extend(quote::quote! {
25+
#ident: <#ty as burn_cube::frontend::CubeType>::ExpandType,
26+
});
27+
}
28+
_ => todo!("Only Typed inputs are supported"),
29+
}
30+
}
31+
32+
let mut output = quote::quote!();
33+
34+
match &sig.output {
35+
syn::ReturnType::Default => output.extend(quote::quote! { ()}),
36+
syn::ReturnType::Type(_, ty) => {
37+
let ty = no_ref(ty.as_ref());
38+
output.extend(quote::quote! {
39+
<#ty as burn_cube::frontend::CubeType>::ExpandType
40+
});
41+
}
42+
}
43+
44+
let ident = &sig.ident;
45+
let ident = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span());
46+
47+
let generics = sig.generics.clone().into_token_stream();
48+
49+
quote::quote! {
50+
/// Expanded Cube function
51+
#visibility fn #ident #generics (context: &mut burn_cube::frontend::CubeContext, #inputs) -> #output
52+
}
53+
}
54+
55+
pub fn no_ref(ty: &syn::Type) -> &syn::Type {
56+
match ty {
57+
syn::Type::Reference(val) => &val.elem,
58+
_ => ty,
59+
}
60+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
use crate::codegen_common::signature::expand_sig;
2+
3+
pub fn expand_trait_def(mut tr: syn::ItemTrait) -> proc_macro2::TokenStream {
4+
let mut expand_items = Vec::new();
5+
6+
for item in tr.items.iter() {
7+
match item {
8+
syn::TraitItem::Fn(func) => {
9+
let expand = expand_sig(&func.sig, &syn::Visibility::Inherited, None);
10+
expand_items.push(syn::parse_quote!(#expand;));
11+
}
12+
_ => continue,
13+
}
14+
}
15+
tr.items.append(&mut expand_items);
16+
17+
quote::quote! {
18+
#tr
19+
}
20+
}
21+
22+
pub fn expand_trait_impl(mut tr: syn::ItemImpl) -> proc_macro2::TokenStream {
23+
let mut expand_items = Vec::new();
24+
25+
for item in tr.items.iter() {
26+
match item {
27+
syn::ImplItem::Fn(func) => {
28+
let ident = &func.sig.ident;
29+
let ident = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span());
30+
let mut inputs = quote::quote!();
31+
32+
for input in &func.sig.inputs {
33+
match input {
34+
syn::FnArg::Typed(pat) => {
35+
let ident = pat.pat.clone();
36+
inputs.extend(quote::quote! {
37+
#ident,
38+
});
39+
}
40+
_ => todo!("Only Typed inputs are supported"),
41+
}
42+
}
43+
44+
let expand = expand_sig(&func.sig, &syn::Visibility::Inherited, None);
45+
46+
let tokens = if !tr.generics.params.is_empty() {
47+
let mut func = func.clone();
48+
for param in tr.generics.params.iter() {
49+
func.sig.generics.params.push(param.clone());
50+
}
51+
register_expand(&func, &ident, expand, inputs)
52+
} else {
53+
register_expand(func, &ident, expand, inputs)
54+
};
55+
56+
expand_items.push(syn::parse2(tokens).unwrap());
57+
}
58+
_ => continue,
59+
}
60+
}
61+
tr.items.append(&mut expand_items);
62+
63+
quote::quote! {
64+
#tr
65+
}
66+
}
67+
68+
fn register_expand(
69+
func: &syn::ImplItemFn,
70+
name: &syn::Ident,
71+
expand: proc_macro2::TokenStream,
72+
inputs: proc_macro2::TokenStream,
73+
) -> proc_macro2::TokenStream {
74+
let (func, func_expand) = if func.sig.generics.params.is_empty() {
75+
(
76+
quote::quote! { #func },
77+
quote::quote! {
78+
#name(context, #inputs)
79+
},
80+
)
81+
} else {
82+
let (_, gen, _) = &func.sig.generics.split_for_impl();
83+
(
84+
quote::quote! { #func },
85+
quote::quote! {
86+
#name::#gen(context, #inputs)
87+
},
88+
)
89+
};
90+
91+
quote::quote! (
92+
#expand {
93+
#[cube]
94+
#func
95+
#func_expand
96+
}
97+
)
98+
}

crates/burn-cube-macros/src/lib.rs

+29-69
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,18 @@ extern crate derive_new;
33

44
mod analyzer;
55
mod codegen_function;
6+
mod codegen_trait;
67
mod codegen_type;
78
mod tracker;
89

10+
pub(crate) mod codegen_common;
11+
912
use analyzer::VariableAnalyzer;
13+
use codegen_common::signature::expand_sig;
1014
use codegen_function::{codegen_launch, codegen_statement};
15+
use codegen_trait::{expand_trait_def, expand_trait_impl};
1116
use codegen_type::generate_cube_type;
1217
use proc_macro::TokenStream;
13-
use quote::ToTokens;
1418
use syn::{parse_macro_input, punctuated::Punctuated, token::Comma, Meta};
1519
use tracker::VariableTracker;
1620

@@ -38,20 +42,36 @@ pub fn module_derive_cube_type(input: TokenStream) -> TokenStream {
3842
generate_cube_type(&input, false)
3943
}
4044

45+
struct SupportedAttributes {
46+
mode: CubeMode,
47+
launch: bool,
48+
}
49+
4150
/// Derive macro for the module.
4251
#[proc_macro_attribute]
4352
pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream {
4453
let args = parse_macro_input!(attr with Punctuated::<Meta, syn::Token![,]>::parse_terminated);
45-
let (mode, launch) = parse_attributes(&args);
54+
let attrs = parse_attributes(&args);
55+
56+
let code: TokenStream = match syn::parse::<syn::Item>(tokens).unwrap() {
57+
syn::Item::Fn(func) => cube_fn(func, &attrs),
58+
syn::Item::Impl(item) => expand_trait_impl(item).into(),
59+
syn::Item::Trait(item) => expand_trait_def(item).into(),
60+
_ => panic!("Cube annotations only supported for functions"),
61+
};
4662

47-
let func: syn::ItemFn =
48-
syn::parse(tokens).expect("Cube annotations only supported for functions");
63+
match attrs.mode {
64+
CubeMode::Default => code,
65+
CubeMode::Debug => panic!("{code}"),
66+
}
67+
}
4968

69+
fn cube_fn(func: syn::ItemFn, attrs: &SupportedAttributes) -> TokenStream {
5070
let mut variable_tracker = VariableAnalyzer::create_tracker(&func);
5171

52-
let code: TokenStream = match codegen_cube(&func, &mut variable_tracker) {
72+
match codegen_cube(&func, &mut variable_tracker) {
5373
Ok(code) => {
54-
if launch {
74+
if attrs.launch {
5575
let launch = codegen_launch(&func.sig);
5676

5777
quote::quote! {
@@ -64,15 +84,10 @@ pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream {
6484
}
6585
}
6686
Err(err) => err.into(),
67-
};
68-
69-
match mode {
70-
CubeMode::Default => code,
71-
CubeMode::Debug => panic!("{code}"),
7287
}
7388
}
7489

75-
fn parse_attributes(args: &Punctuated<Meta, Comma>) -> (CubeMode, bool) {
90+
fn parse_attributes(args: &Punctuated<Meta, Comma>) -> SupportedAttributes {
7691
let mut mode = CubeMode::Default;
7792
let mut launch = false;
7893

@@ -98,15 +113,15 @@ fn parse_attributes(args: &Punctuated<Meta, Comma>) -> (CubeMode, bool) {
98113
}
99114
}
100115

101-
(mode, launch)
116+
SupportedAttributes { mode, launch }
102117
}
103118

104119
/// Generate the expanded version of a function marked with the cube macro
105120
fn codegen_cube(
106121
func: &syn::ItemFn,
107122
variable_tracker: &mut VariableTracker,
108123
) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
109-
let signature = expand_sig(&func.sig, &func.vis, variable_tracker);
124+
let signature = expand_sig(&func.sig, &func.vis, Some(variable_tracker));
110125
let mut body = quote::quote! {};
111126

112127
for statement in func.block.stmts.iter() {
@@ -145,58 +160,3 @@ fn codegen_cube(
145160
}
146161
})
147162
}
148-
149-
fn expand_sig(
150-
sig: &syn::Signature,
151-
visibility: &syn::Visibility,
152-
variable_tracker: &mut VariableTracker,
153-
) -> proc_macro2::TokenStream {
154-
let mut inputs = quote::quote!();
155-
156-
for input in &sig.inputs {
157-
match input {
158-
syn::FnArg::Typed(pat) => {
159-
let ident = pat.pat.clone();
160-
161-
if let syn::Pat::Ident(ident) = ident.as_ref() {
162-
variable_tracker.codegen_declare(ident.ident.to_string(), 0);
163-
}
164-
165-
let ty = no_ref(pat.ty.as_ref());
166-
inputs.extend(quote::quote! {
167-
#ident: <#ty as burn_cube::frontend::CubeType>::ExpandType,
168-
});
169-
}
170-
_ => todo!("Only Typed inputs are supported"),
171-
}
172-
}
173-
174-
let mut output = quote::quote!();
175-
176-
match &sig.output {
177-
syn::ReturnType::Default => output.extend(quote::quote! { ()}),
178-
syn::ReturnType::Type(_, ty) => {
179-
let ty = no_ref(ty.as_ref());
180-
output.extend(quote::quote! {
181-
<#ty as burn_cube::frontend::CubeType>::ExpandType
182-
});
183-
}
184-
}
185-
186-
let ident = &sig.ident;
187-
let ident = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span());
188-
189-
let generics = sig.generics.clone().into_token_stream();
190-
191-
quote::quote! {
192-
/// Expanded Cube function
193-
#visibility fn #ident #generics (context: &mut burn_cube::frontend::CubeContext, #inputs) -> #output
194-
}
195-
}
196-
197-
fn no_ref(ty: &syn::Type) -> &syn::Type {
198-
match ty {
199-
syn::Type::Reference(val) => &val.elem,
200-
_ => ty,
201-
}
202-
}

0 commit comments

Comments
 (0)