-
Notifications
You must be signed in to change notification settings - Fork 489
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
50 changed files
with
3,605 additions
and
11 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
[package] | ||
authors = [ | ||
"nathanielsimard <[email protected]>", | ||
"louisfd <[email protected]", | ||
] | ||
categories = ["science"] | ||
description = "TODO" | ||
edition.workspace = true | ||
keywords = [] | ||
license.workspace = true | ||
name = "burn-cube-macros" | ||
readme.workspace = true | ||
repository = "https://github.com/tracel-ai/burn/tree/main/burn-cube-macros" | ||
version.workspace = true | ||
|
||
[lib] | ||
proc-macro = true | ||
|
||
[features] | ||
default = [] | ||
std = [] | ||
|
||
[dependencies] | ||
proc-macro2 = { workspace = true } | ||
quote = { workspace = true } | ||
syn = { workspace = true } | ||
derive-new = { workspace = true } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,245 @@ | ||
use std::collections::HashMap; | ||
|
||
use syn::{PathArguments, Stmt}; | ||
|
||
use crate::VariableKey; | ||
|
||
#[derive(Debug)] | ||
/// Information about a single variable's use in Cube code | ||
/// Information about a single variable's use in Cube code | ||
/// Useful to figure out when the generated variable will need cloning | ||
pub(crate) struct VariableAnalysis { | ||
num_used: usize, | ||
loop_level_declared: usize, | ||
} | ||
|
||
impl VariableAnalysis { | ||
pub fn should_clone(&mut self, loop_level: usize) -> bool { | ||
if self.num_used > 1 { | ||
self.num_used -= 1; | ||
true | ||
} else { | ||
self.loop_level_declared < loop_level | ||
} | ||
} | ||
} | ||
|
||
#[derive(Debug)] | ||
/// Information about all variables in the Cube code, transmitted to codegen | ||
pub(crate) struct CodeAnalysis { | ||
pub variable_analyses: HashMap<VariableKey, VariableAnalysis>, | ||
} | ||
|
||
#[derive(Debug, Default)] | ||
/// Reads the Cube code and accumulates information, to generate a CodeAnalysis artefact | ||
pub(crate) struct CodeAnalysisBuilder { | ||
declarations: Vec<(VariableKey, usize)>, | ||
var_uses: Vec<VariableKey>, | ||
} | ||
|
||
impl CodeAnalysis { | ||
pub fn should_clone(&mut self, ident: &syn::Ident, loop_level: usize) -> bool { | ||
let key: VariableKey = ident.into(); | ||
match self.variable_analyses.remove(&key) { | ||
Some(mut var) => { | ||
let should_clone = var.should_clone(loop_level); | ||
self.variable_analyses.insert(key, var); | ||
should_clone | ||
} | ||
None => panic!("Ident {ident} not part of analysis"), | ||
} | ||
} | ||
|
||
pub fn create(func: &syn::ItemFn) -> CodeAnalysis { | ||
let code_analysis_builder = CodeAnalysisBuilder::default(); | ||
code_analysis_builder.analyze(func) | ||
} | ||
} | ||
|
||
impl CodeAnalysisBuilder { | ||
fn analyze(mut self, func: &syn::ItemFn) -> CodeAnalysis { | ||
// Build the vector of (Id, depth), using recursion | ||
self.signature_declarations(&func.sig); | ||
self.find_occurrences_in_stmts(&func.block.stmts, 0); | ||
|
||
CodeAnalysis { | ||
variable_analyses: self.to_map(), | ||
} | ||
} | ||
|
||
fn to_map(&self) -> HashMap<VariableKey, VariableAnalysis> { | ||
// Run through the vec and build hashmap, without recursion | ||
let mut variable_analyses = HashMap::<VariableKey, VariableAnalysis>::new(); | ||
for declaration in self.declarations.iter() { | ||
let id = declaration.0.clone(); | ||
let new_analysis = match variable_analyses.remove(&id) { | ||
Some(_) => { | ||
panic!("Analysis: Multiple variables with the same identifier is not supported") | ||
} | ||
None => VariableAnalysis { | ||
num_used: 0, | ||
loop_level_declared: declaration.1, | ||
}, | ||
}; | ||
|
||
variable_analyses.insert(id, new_analysis); | ||
} | ||
|
||
for id in self.var_uses.iter() { | ||
let prev_analysis = variable_analyses.remove(id).unwrap_or_else(|| { | ||
panic!( | ||
"Analysis: Variable {:?} should be declared before it's used", | ||
id | ||
) | ||
}); | ||
let new_analysis = VariableAnalysis { | ||
num_used: prev_analysis.num_used + 1, | ||
loop_level_declared: prev_analysis.loop_level_declared, | ||
}; | ||
variable_analyses.insert(id.clone(), new_analysis); | ||
} | ||
|
||
variable_analyses | ||
} | ||
|
||
fn signature_declarations(&mut self, sig: &syn::Signature) { | ||
for input in &sig.inputs { | ||
match input { | ||
syn::FnArg::Typed(pat) => { | ||
let ident = &*pat.pat; | ||
match ident { | ||
syn::Pat::Ident(pat_ident) => { | ||
let id = &pat_ident.ident; | ||
self.declarations.push((id.into(), 0)); | ||
} | ||
_ => todo!("Analysis: unsupported ident {ident:?}"), | ||
} | ||
} | ||
_ => todo!("Analysis: unsupported input {input:?}"), | ||
} | ||
} | ||
} | ||
|
||
fn find_occurrences_in_stmts(&mut self, stmts: &Vec<Stmt>, depth: usize) { | ||
for stmt in stmts { | ||
match stmt { | ||
// Declaration | ||
syn::Stmt::Local(local) => { | ||
let id = match &local.pat { | ||
syn::Pat::Ident(pat_ident) => Some(&pat_ident.ident), | ||
syn::Pat::Type(pat_type) => Some(match &*pat_type.pat { | ||
syn::Pat::Ident(pat_ident) => &pat_ident.ident, | ||
_ => todo!("Analysis: unsupported typed path {:?}", pat_type.pat), | ||
}), | ||
syn::Pat::Wild(_) => None, | ||
_ => todo!("Analysis: unsupported path {:?}", local.pat), | ||
}; | ||
if let Some(id) = id { | ||
self.declarations.push((id.into(), depth)); | ||
} | ||
if let Some(local_init) = &local.init { | ||
self.find_occurrences_in_expr(&local_init.expr, depth) | ||
} | ||
} | ||
syn::Stmt::Expr(expr, _) => self.find_occurrences_in_expr(expr, depth), | ||
_ => todo!("Analysis: unsupported stmt {stmt:?}"), | ||
} | ||
} | ||
} | ||
|
||
fn find_occurrences_in_expr(&mut self, expr: &syn::Expr, depth: usize) { | ||
match expr { | ||
syn::Expr::ForLoop(expr) => { | ||
let depth = depth + 1; | ||
|
||
// Declaration of iterator | ||
if let syn::Pat::Ident(pat_ident) = &*expr.pat { | ||
let id = &pat_ident.ident; | ||
self.declarations.push((id.into(), depth)); | ||
} | ||
|
||
self.find_occurrences_in_stmts(&expr.body.stmts, depth); | ||
} | ||
syn::Expr::While(expr) => { | ||
let depth = depth + 1; | ||
|
||
self.find_occurrences_in_expr(&expr.cond, depth); | ||
self.find_occurrences_in_stmts(&expr.body.stmts, depth); | ||
} | ||
syn::Expr::Loop(expr) => { | ||
let depth = depth + 1; | ||
|
||
self.find_occurrences_in_stmts(&expr.body.stmts, depth); | ||
} | ||
syn::Expr::If(expr) => { | ||
let depth = depth + 1; | ||
|
||
self.find_occurrences_in_expr(&expr.cond, depth); | ||
self.find_occurrences_in_stmts(&expr.then_branch.stmts, depth); | ||
if let Some((_, expr)) = &expr.else_branch { | ||
if let syn::Expr::Block(expr_block) = &**expr { | ||
self.find_occurrences_in_stmts(&expr_block.block.stmts, depth); | ||
} else { | ||
todo!("Analysis: Only block else expr is supported") | ||
} | ||
} | ||
} | ||
syn::Expr::Assign(expr) => { | ||
self.find_occurrences_in_expr(&expr.left, depth); | ||
self.find_occurrences_in_expr(&expr.right, depth); | ||
} | ||
syn::Expr::Index(expr) => { | ||
self.find_occurrences_in_expr(&expr.expr, depth); | ||
self.find_occurrences_in_expr(&expr.index, depth); | ||
} | ||
syn::Expr::Path(expr) => { | ||
let ident = expr | ||
.path | ||
.get_ident() | ||
.expect("Analysis: only ident path are supported."); | ||
|
||
// Use | ||
self.var_uses.push(ident.into()); | ||
} | ||
syn::Expr::Binary(expr) => { | ||
self.find_occurrences_in_expr(&expr.left, depth); | ||
self.find_occurrences_in_expr(&expr.right, depth); | ||
} | ||
syn::Expr::Lit(_) => {} | ||
syn::Expr::Call(expr) => { | ||
match &*expr.func { | ||
syn::Expr::Path(expr_path) => { | ||
if let Some(first_segment) = expr_path.path.segments.first() { | ||
// Check if the path segment has generic arguments | ||
if let PathArguments::AngleBracketed(arguments) = | ||
&first_segment.arguments | ||
{ | ||
// Extract the generic arguments | ||
for arg in &arguments.args { | ||
match arg { | ||
syn::GenericArgument::Type(_) | ||
| syn::GenericArgument::Constraint(_) => {} | ||
_ => todo!("Analysis: Generic {:?} not supported", arg), | ||
} | ||
} | ||
} | ||
} | ||
} | ||
_ => todo!("Analysis: unsupported func expr {:?}", expr.func), | ||
} | ||
for arg in expr.args.iter() { | ||
self.find_occurrences_in_expr(arg, depth); | ||
} | ||
} | ||
syn::Expr::MethodCall(expr) => { | ||
self.find_occurrences_in_expr(&expr.receiver, depth); | ||
for arg in expr.args.iter() { | ||
self.find_occurrences_in_expr(arg, depth); | ||
} | ||
} | ||
syn::Expr::Break(_) => {} | ||
syn::Expr::Paren(expr) => self.find_occurrences_in_expr(&expr.expr, depth), | ||
_ => todo!("Analysis: unsupported expr {expr:?}"), | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
use proc_macro2::TokenStream; | ||
|
||
use crate::analysis::CodeAnalysis; | ||
|
||
use super::{ | ||
branch::{codegen_break, codegen_for_loop, codegen_if, codegen_loop, codegen_while_loop}, | ||
function::{codegen_call, codegen_closure, codegen_expr_method_call}, | ||
operation::codegen_binary, | ||
variable::{codegen_assign, codegen_index, codegen_lit, codegen_local, codegen_path_rhs}, | ||
}; | ||
|
||
/// Codegen for a statement (generally one line) | ||
/// Entry point of code generation | ||
pub fn codegen_statement( | ||
statement: &syn::Stmt, | ||
loop_level: usize, | ||
variable_analyses: &mut CodeAnalysis, | ||
) -> TokenStream { | ||
match statement { | ||
syn::Stmt::Local(local) => codegen_local(local, loop_level, variable_analyses), | ||
syn::Stmt::Expr(expr, semi) => { | ||
let expr = codegen_expr(expr, loop_level, variable_analyses); | ||
match semi { | ||
Some(_semi) => quote::quote!( | ||
#expr; | ||
), | ||
None => expr, | ||
} | ||
} | ||
_ => todo!("Codegen: statement {statement:?} not supported"), | ||
} | ||
} | ||
|
||
/// Codegen for a code block (a list of statements) | ||
pub(crate) fn codegen_block( | ||
block: &syn::Block, | ||
loop_level: usize, | ||
variable_analyses: &mut CodeAnalysis, | ||
) -> TokenStream { | ||
let mut statements = quote::quote!(); | ||
|
||
for statement in block.stmts.iter() { | ||
statements.extend(codegen_statement(statement, loop_level, variable_analyses)); | ||
} | ||
|
||
quote::quote! { | ||
{ | ||
#statements | ||
} | ||
} | ||
} | ||
|
||
/// Codegen for an expression containing a block | ||
pub(crate) fn codegen_expr_block( | ||
block: &syn::ExprBlock, | ||
loop_level: usize, | ||
variable_analyses: &mut CodeAnalysis, | ||
) -> TokenStream { | ||
codegen_block(&block.block, loop_level, variable_analyses) | ||
} | ||
|
||
/// Codegen for expressions | ||
/// There are many variants of expression, treated differently | ||
pub(crate) fn codegen_expr( | ||
expr: &syn::Expr, | ||
loop_level: usize, | ||
variable_analyses: &mut CodeAnalysis, | ||
) -> TokenStream { | ||
match expr { | ||
syn::Expr::Binary(op) => codegen_binary(op, loop_level, variable_analyses), | ||
syn::Expr::Path(path) => codegen_path_rhs(path, loop_level, variable_analyses), | ||
syn::Expr::Call(call) => codegen_call(call, loop_level, variable_analyses), | ||
syn::Expr::Lit(lit) => codegen_lit(lit), | ||
syn::Expr::Closure(closure) => codegen_closure(closure, loop_level, variable_analyses), | ||
syn::Expr::Block(block) => codegen_expr_block(block, loop_level, variable_analyses), | ||
syn::Expr::Assign(assign) => codegen_assign(assign, loop_level, variable_analyses), | ||
syn::Expr::ForLoop(for_loop) => codegen_for_loop(for_loop, loop_level, variable_analyses), | ||
syn::Expr::While(while_loop) => { | ||
codegen_while_loop(while_loop, loop_level, variable_analyses) | ||
} | ||
syn::Expr::Loop(loop_expr) => codegen_loop(loop_expr, loop_level, variable_analyses), | ||
syn::Expr::Break(_) => codegen_break(), | ||
syn::Expr::If(expr_if) => codegen_if(expr_if, loop_level, variable_analyses), | ||
syn::Expr::MethodCall(call) => codegen_expr_method_call(call), | ||
syn::Expr::Index(index) => codegen_index(index, loop_level, variable_analyses), | ||
syn::Expr::Paren(paren) => codegen_expr(&paren.expr, loop_level, variable_analyses), | ||
_ => panic!("Codegen: Unsupported {:?}", expr), | ||
} | ||
} |
Oops, something went wrong.