Skip to content

Commit

Permalink
CubeCL first iteration (#1756)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored May 15, 2024
1 parent e823338 commit 542790e
Show file tree
Hide file tree
Showing 50 changed files with 3,605 additions and 11 deletions.
20 changes: 20 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 27 additions & 0 deletions crates/burn-cube-macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
[package]
authors = [
"nathanielsimard <[email protected]>",
"louisfd <[email protected]",
]
categories = ["science"]
description = "TODO"
edition.workspace = true
keywords = []
license.workspace = true
name = "burn-cube-macros"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/burn-cube-macros"
version.workspace = true

[lib]
proc-macro = true

[features]
default = []
std = []

[dependencies]
proc-macro2 = { workspace = true }
quote = { workspace = true }
syn = { workspace = true }
derive-new = { workspace = true }
245 changes: 245 additions & 0 deletions crates/burn-cube-macros/src/analysis.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
use std::collections::HashMap;

use syn::{PathArguments, Stmt};

use crate::VariableKey;

#[derive(Debug)]
/// Information about a single variable's use in Cube code
/// Information about a single variable's use in Cube code
/// Useful to figure out when the generated variable will need cloning
pub(crate) struct VariableAnalysis {
num_used: usize,
loop_level_declared: usize,
}

impl VariableAnalysis {
pub fn should_clone(&mut self, loop_level: usize) -> bool {
if self.num_used > 1 {
self.num_used -= 1;
true
} else {
self.loop_level_declared < loop_level
}
}
}

#[derive(Debug)]
/// Information about all variables in the Cube code, transmitted to codegen
pub(crate) struct CodeAnalysis {
pub variable_analyses: HashMap<VariableKey, VariableAnalysis>,
}

#[derive(Debug, Default)]
/// Reads the Cube code and accumulates information, to generate a CodeAnalysis artefact
pub(crate) struct CodeAnalysisBuilder {
declarations: Vec<(VariableKey, usize)>,
var_uses: Vec<VariableKey>,
}

impl CodeAnalysis {
pub fn should_clone(&mut self, ident: &syn::Ident, loop_level: usize) -> bool {
let key: VariableKey = ident.into();
match self.variable_analyses.remove(&key) {
Some(mut var) => {
let should_clone = var.should_clone(loop_level);
self.variable_analyses.insert(key, var);
should_clone
}
None => panic!("Ident {ident} not part of analysis"),
}
}

pub fn create(func: &syn::ItemFn) -> CodeAnalysis {
let code_analysis_builder = CodeAnalysisBuilder::default();
code_analysis_builder.analyze(func)
}
}

impl CodeAnalysisBuilder {
fn analyze(mut self, func: &syn::ItemFn) -> CodeAnalysis {
// Build the vector of (Id, depth), using recursion
self.signature_declarations(&func.sig);
self.find_occurrences_in_stmts(&func.block.stmts, 0);

CodeAnalysis {
variable_analyses: self.to_map(),
}
}

fn to_map(&self) -> HashMap<VariableKey, VariableAnalysis> {
// Run through the vec and build hashmap, without recursion
let mut variable_analyses = HashMap::<VariableKey, VariableAnalysis>::new();
for declaration in self.declarations.iter() {
let id = declaration.0.clone();
let new_analysis = match variable_analyses.remove(&id) {
Some(_) => {
panic!("Analysis: Multiple variables with the same identifier is not supported")
}
None => VariableAnalysis {
num_used: 0,
loop_level_declared: declaration.1,
},
};

variable_analyses.insert(id, new_analysis);
}

for id in self.var_uses.iter() {
let prev_analysis = variable_analyses.remove(id).unwrap_or_else(|| {
panic!(
"Analysis: Variable {:?} should be declared before it's used",
id
)
});
let new_analysis = VariableAnalysis {
num_used: prev_analysis.num_used + 1,
loop_level_declared: prev_analysis.loop_level_declared,
};
variable_analyses.insert(id.clone(), new_analysis);
}

variable_analyses
}

fn signature_declarations(&mut self, sig: &syn::Signature) {
for input in &sig.inputs {
match input {
syn::FnArg::Typed(pat) => {
let ident = &*pat.pat;
match ident {
syn::Pat::Ident(pat_ident) => {
let id = &pat_ident.ident;
self.declarations.push((id.into(), 0));
}
_ => todo!("Analysis: unsupported ident {ident:?}"),
}
}
_ => todo!("Analysis: unsupported input {input:?}"),
}
}
}

fn find_occurrences_in_stmts(&mut self, stmts: &Vec<Stmt>, depth: usize) {
for stmt in stmts {
match stmt {
// Declaration
syn::Stmt::Local(local) => {
let id = match &local.pat {
syn::Pat::Ident(pat_ident) => Some(&pat_ident.ident),
syn::Pat::Type(pat_type) => Some(match &*pat_type.pat {
syn::Pat::Ident(pat_ident) => &pat_ident.ident,
_ => todo!("Analysis: unsupported typed path {:?}", pat_type.pat),
}),
syn::Pat::Wild(_) => None,
_ => todo!("Analysis: unsupported path {:?}", local.pat),
};
if let Some(id) = id {
self.declarations.push((id.into(), depth));
}
if let Some(local_init) = &local.init {
self.find_occurrences_in_expr(&local_init.expr, depth)
}
}
syn::Stmt::Expr(expr, _) => self.find_occurrences_in_expr(expr, depth),
_ => todo!("Analysis: unsupported stmt {stmt:?}"),
}
}
}

fn find_occurrences_in_expr(&mut self, expr: &syn::Expr, depth: usize) {
match expr {
syn::Expr::ForLoop(expr) => {
let depth = depth + 1;

// Declaration of iterator
if let syn::Pat::Ident(pat_ident) = &*expr.pat {
let id = &pat_ident.ident;
self.declarations.push((id.into(), depth));
}

self.find_occurrences_in_stmts(&expr.body.stmts, depth);
}
syn::Expr::While(expr) => {
let depth = depth + 1;

self.find_occurrences_in_expr(&expr.cond, depth);
self.find_occurrences_in_stmts(&expr.body.stmts, depth);
}
syn::Expr::Loop(expr) => {
let depth = depth + 1;

self.find_occurrences_in_stmts(&expr.body.stmts, depth);
}
syn::Expr::If(expr) => {
let depth = depth + 1;

self.find_occurrences_in_expr(&expr.cond, depth);
self.find_occurrences_in_stmts(&expr.then_branch.stmts, depth);
if let Some((_, expr)) = &expr.else_branch {
if let syn::Expr::Block(expr_block) = &**expr {
self.find_occurrences_in_stmts(&expr_block.block.stmts, depth);
} else {
todo!("Analysis: Only block else expr is supported")
}
}
}
syn::Expr::Assign(expr) => {
self.find_occurrences_in_expr(&expr.left, depth);
self.find_occurrences_in_expr(&expr.right, depth);
}
syn::Expr::Index(expr) => {
self.find_occurrences_in_expr(&expr.expr, depth);
self.find_occurrences_in_expr(&expr.index, depth);
}
syn::Expr::Path(expr) => {
let ident = expr
.path
.get_ident()
.expect("Analysis: only ident path are supported.");

// Use
self.var_uses.push(ident.into());
}
syn::Expr::Binary(expr) => {
self.find_occurrences_in_expr(&expr.left, depth);
self.find_occurrences_in_expr(&expr.right, depth);
}
syn::Expr::Lit(_) => {}
syn::Expr::Call(expr) => {
match &*expr.func {
syn::Expr::Path(expr_path) => {
if let Some(first_segment) = expr_path.path.segments.first() {
// Check if the path segment has generic arguments
if let PathArguments::AngleBracketed(arguments) =
&first_segment.arguments
{
// Extract the generic arguments
for arg in &arguments.args {
match arg {
syn::GenericArgument::Type(_)
| syn::GenericArgument::Constraint(_) => {}
_ => todo!("Analysis: Generic {:?} not supported", arg),
}
}
}
}
}
_ => todo!("Analysis: unsupported func expr {:?}", expr.func),
}
for arg in expr.args.iter() {
self.find_occurrences_in_expr(arg, depth);
}
}
syn::Expr::MethodCall(expr) => {
self.find_occurrences_in_expr(&expr.receiver, depth);
for arg in expr.args.iter() {
self.find_occurrences_in_expr(arg, depth);
}
}
syn::Expr::Break(_) => {}
syn::Expr::Paren(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
_ => todo!("Analysis: unsupported expr {expr:?}"),
}
}
}
89 changes: 89 additions & 0 deletions crates/burn-cube-macros/src/codegen/base.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
use proc_macro2::TokenStream;

use crate::analysis::CodeAnalysis;

use super::{
branch::{codegen_break, codegen_for_loop, codegen_if, codegen_loop, codegen_while_loop},
function::{codegen_call, codegen_closure, codegen_expr_method_call},
operation::codegen_binary,
variable::{codegen_assign, codegen_index, codegen_lit, codegen_local, codegen_path_rhs},
};

/// Codegen for a statement (generally one line)
/// Entry point of code generation
pub fn codegen_statement(
statement: &syn::Stmt,
loop_level: usize,
variable_analyses: &mut CodeAnalysis,
) -> TokenStream {
match statement {
syn::Stmt::Local(local) => codegen_local(local, loop_level, variable_analyses),
syn::Stmt::Expr(expr, semi) => {
let expr = codegen_expr(expr, loop_level, variable_analyses);
match semi {
Some(_semi) => quote::quote!(
#expr;
),
None => expr,
}
}
_ => todo!("Codegen: statement {statement:?} not supported"),
}
}

/// Codegen for a code block (a list of statements)
pub(crate) fn codegen_block(
block: &syn::Block,
loop_level: usize,
variable_analyses: &mut CodeAnalysis,
) -> TokenStream {
let mut statements = quote::quote!();

for statement in block.stmts.iter() {
statements.extend(codegen_statement(statement, loop_level, variable_analyses));
}

quote::quote! {
{
#statements
}
}
}

/// Codegen for an expression containing a block
pub(crate) fn codegen_expr_block(
block: &syn::ExprBlock,
loop_level: usize,
variable_analyses: &mut CodeAnalysis,
) -> TokenStream {
codegen_block(&block.block, loop_level, variable_analyses)
}

/// Codegen for expressions
/// There are many variants of expression, treated differently
pub(crate) fn codegen_expr(
expr: &syn::Expr,
loop_level: usize,
variable_analyses: &mut CodeAnalysis,
) -> TokenStream {
match expr {
syn::Expr::Binary(op) => codegen_binary(op, loop_level, variable_analyses),
syn::Expr::Path(path) => codegen_path_rhs(path, loop_level, variable_analyses),
syn::Expr::Call(call) => codegen_call(call, loop_level, variable_analyses),
syn::Expr::Lit(lit) => codegen_lit(lit),
syn::Expr::Closure(closure) => codegen_closure(closure, loop_level, variable_analyses),
syn::Expr::Block(block) => codegen_expr_block(block, loop_level, variable_analyses),
syn::Expr::Assign(assign) => codegen_assign(assign, loop_level, variable_analyses),
syn::Expr::ForLoop(for_loop) => codegen_for_loop(for_loop, loop_level, variable_analyses),
syn::Expr::While(while_loop) => {
codegen_while_loop(while_loop, loop_level, variable_analyses)
}
syn::Expr::Loop(loop_expr) => codegen_loop(loop_expr, loop_level, variable_analyses),
syn::Expr::Break(_) => codegen_break(),
syn::Expr::If(expr_if) => codegen_if(expr_if, loop_level, variable_analyses),
syn::Expr::MethodCall(call) => codegen_expr_method_call(call),
syn::Expr::Index(index) => codegen_index(index, loop_level, variable_analyses),
syn::Expr::Paren(paren) => codegen_expr(&paren.expr, loop_level, variable_analyses),
_ => panic!("Codegen: Unsupported {:?}", expr),
}
}
Loading

0 comments on commit 542790e

Please sign in to comment.