@@ -3,14 +3,18 @@ extern crate derive_new;
3
3
4
4
mod analyzer;
5
5
mod codegen_function;
6
+ mod codegen_trait;
6
7
mod codegen_type;
7
8
mod tracker;
8
9
10
+ pub ( crate ) mod codegen_common;
11
+
9
12
use analyzer:: VariableAnalyzer ;
13
+ use codegen_common:: signature:: expand_sig;
10
14
use codegen_function:: { codegen_launch, codegen_statement} ;
15
+ use codegen_trait:: { expand_trait_def, expand_trait_impl} ;
11
16
use codegen_type:: generate_cube_type;
12
17
use proc_macro:: TokenStream ;
13
- use quote:: ToTokens ;
14
18
use syn:: { parse_macro_input, punctuated:: Punctuated , token:: Comma , Meta } ;
15
19
use tracker:: VariableTracker ;
16
20
@@ -38,20 +42,36 @@ pub fn module_derive_cube_type(input: TokenStream) -> TokenStream {
38
42
generate_cube_type ( & input, false )
39
43
}
40
44
45
+ struct SupportedAttributes {
46
+ mode : CubeMode ,
47
+ launch : bool ,
48
+ }
49
+
41
50
/// Derive macro for the module.
42
51
#[ proc_macro_attribute]
43
52
pub fn cube ( attr : TokenStream , tokens : TokenStream ) -> TokenStream {
44
53
let args = parse_macro_input ! ( attr with Punctuated :: <Meta , syn:: Token ![ , ] >:: parse_terminated) ;
45
- let ( mode, launch) = parse_attributes ( & args) ;
54
+ let attrs = parse_attributes ( & args) ;
55
+
56
+ let code: TokenStream = match syn:: parse :: < syn:: Item > ( tokens) . unwrap ( ) {
57
+ syn:: Item :: Fn ( func) => cube_fn ( func, & attrs) ,
58
+ syn:: Item :: Impl ( item) => expand_trait_impl ( item) . into ( ) ,
59
+ syn:: Item :: Trait ( item) => expand_trait_def ( item) . into ( ) ,
60
+ _ => panic ! ( "Cube annotations only supported for functions" ) ,
61
+ } ;
46
62
47
- let func: syn:: ItemFn =
48
- syn:: parse ( tokens) . expect ( "Cube annotations only supported for functions" ) ;
63
+ match attrs. mode {
64
+ CubeMode :: Default => code,
65
+ CubeMode :: Debug => panic ! ( "{code}" ) ,
66
+ }
67
+ }
49
68
69
+ fn cube_fn ( func : syn:: ItemFn , attrs : & SupportedAttributes ) -> TokenStream {
50
70
let mut variable_tracker = VariableAnalyzer :: create_tracker ( & func) ;
51
71
52
- let code : TokenStream = match codegen_cube ( & func, & mut variable_tracker) {
72
+ match codegen_cube ( & func, & mut variable_tracker) {
53
73
Ok ( code) => {
54
- if launch {
74
+ if attrs . launch {
55
75
let launch = codegen_launch ( & func. sig ) ;
56
76
57
77
quote:: quote! {
@@ -64,15 +84,10 @@ pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream {
64
84
}
65
85
}
66
86
Err ( err) => err. into ( ) ,
67
- } ;
68
-
69
- match mode {
70
- CubeMode :: Default => code,
71
- CubeMode :: Debug => panic ! ( "{code}" ) ,
72
87
}
73
88
}
74
89
75
- fn parse_attributes ( args : & Punctuated < Meta , Comma > ) -> ( CubeMode , bool ) {
90
+ fn parse_attributes ( args : & Punctuated < Meta , Comma > ) -> SupportedAttributes {
76
91
let mut mode = CubeMode :: Default ;
77
92
let mut launch = false ;
78
93
@@ -98,15 +113,15 @@ fn parse_attributes(args: &Punctuated<Meta, Comma>) -> (CubeMode, bool) {
98
113
}
99
114
}
100
115
101
- ( mode, launch)
116
+ SupportedAttributes { mode, launch }
102
117
}
103
118
104
119
/// Generate the expanded version of a function marked with the cube macro
105
120
fn codegen_cube (
106
121
func : & syn:: ItemFn ,
107
122
variable_tracker : & mut VariableTracker ,
108
123
) -> Result < proc_macro2:: TokenStream , proc_macro2:: TokenStream > {
109
- let signature = expand_sig ( & func. sig , & func. vis , variable_tracker) ;
124
+ let signature = expand_sig ( & func. sig , & func. vis , Some ( variable_tracker) ) ;
110
125
let mut body = quote:: quote! { } ;
111
126
112
127
for statement in func. block . stmts . iter ( ) {
@@ -145,58 +160,3 @@ fn codegen_cube(
145
160
}
146
161
} )
147
162
}
148
-
149
- fn expand_sig (
150
- sig : & syn:: Signature ,
151
- visibility : & syn:: Visibility ,
152
- variable_tracker : & mut VariableTracker ,
153
- ) -> proc_macro2:: TokenStream {
154
- let mut inputs = quote:: quote!( ) ;
155
-
156
- for input in & sig. inputs {
157
- match input {
158
- syn:: FnArg :: Typed ( pat) => {
159
- let ident = pat. pat . clone ( ) ;
160
-
161
- if let syn:: Pat :: Ident ( ident) = ident. as_ref ( ) {
162
- variable_tracker. codegen_declare ( ident. ident . to_string ( ) , 0 ) ;
163
- }
164
-
165
- let ty = no_ref ( pat. ty . as_ref ( ) ) ;
166
- inputs. extend ( quote:: quote! {
167
- #ident: <#ty as burn_cube:: frontend:: CubeType >:: ExpandType ,
168
- } ) ;
169
- }
170
- _ => todo ! ( "Only Typed inputs are supported" ) ,
171
- }
172
- }
173
-
174
- let mut output = quote:: quote!( ) ;
175
-
176
- match & sig. output {
177
- syn:: ReturnType :: Default => output. extend ( quote:: quote! { ( ) } ) ,
178
- syn:: ReturnType :: Type ( _, ty) => {
179
- let ty = no_ref ( ty. as_ref ( ) ) ;
180
- output. extend ( quote:: quote! {
181
- <#ty as burn_cube:: frontend:: CubeType >:: ExpandType
182
- } ) ;
183
- }
184
- }
185
-
186
- let ident = & sig. ident ;
187
- let ident = syn:: Ident :: new ( format ! ( "{ident}_expand" ) . as_str ( ) , ident. span ( ) ) ;
188
-
189
- let generics = sig. generics . clone ( ) . into_token_stream ( ) ;
190
-
191
- quote:: quote! {
192
- /// Expanded Cube function
193
- #visibility fn #ident #generics ( context: & mut burn_cube:: frontend:: CubeContext , #inputs) -> #output
194
- }
195
- }
196
-
197
- fn no_ref ( ty : & syn:: Type ) -> & syn:: Type {
198
- match ty {
199
- syn:: Type :: Reference ( val) => & val. elem ,
200
- _ => ty,
201
- }
202
- }
0 commit comments