Skip to content

Commit ff371de

Browse files
authored
Format all type names (fixes #2324) (#2436)
* Format all type names (fixes #2324) * Fix references
1 parent bb9f5b1 commit ff371de

File tree

2 files changed

+20
-17
lines changed

2 files changed

+20
-17
lines changed

crates/burn-import/src/burn/graph.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
553553
input_names.iter().for_each(|input| {
554554
self.graph_input_types.push(
555555
inputs
556-
.get(&TensorType::format_name(input))
556+
.get(&Type::format_name(input))
557557
.unwrap_or_else(|| panic!("Input type not found for {input}"))
558558
.clone(),
559559
);
@@ -562,7 +562,7 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
562562
output_names.iter().for_each(|output| {
563563
self.graph_output_types.push(
564564
outputs
565-
.get(&TensorType::format_name(output))
565+
.get(&Type::format_name(output))
566566
.unwrap_or_else(|| panic!("Output type not found for {output}"))
567567
.clone(),
568568
);

crates/burn-import/src/burn/ty.rs

+18-15
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,16 @@ pub enum Type {
6363
}
6464

6565
impl Type {
66+
// This is used, because types might have number literal name, which cannot be
67+
// used as a variable name.
68+
pub fn format_name(name: &str) -> String {
69+
let name_is_number = name.bytes().all(|digit| digit.is_ascii_digit());
70+
if name_is_number {
71+
format!("_{}", name)
72+
} else {
73+
name.to_string()
74+
}
75+
}
6676
pub fn name(&self) -> &Ident {
6777
match self {
6878
Type::Tensor(tensor) => &tensor.name,
@@ -107,8 +117,10 @@ impl ScalarType {
107117
if name.as_ref().is_empty() {
108118
panic!("Scalar of Type {:?} was passed with empty name", kind);
109119
}
120+
121+
let formatted_name = Type::format_name(name.as_ref());
110122
Self {
111-
name: Ident::new(name.as_ref(), Span::call_site()),
123+
name: Ident::new(&formatted_name, Span::call_site()),
112124
kind,
113125
}
114126
}
@@ -150,8 +162,9 @@ impl ShapeType {
150162
if name.as_ref().is_empty() {
151163
panic!("Shape was passed with empty name");
152164
}
165+
let formatted_name = Type::format_name(name.as_ref());
153166
Self {
154-
name: Ident::new(name.as_ref(), Span::call_site()),
167+
name: Ident::new(&formatted_name, Span::call_site()),
155168
dim,
156169
}
157170
}
@@ -173,17 +186,6 @@ impl ShapeType {
173186
}
174187

175188
impl TensorType {
176-
// This is used, because Tensors might have number literal name, which cannot be
177-
// used as a variable name.
178-
pub fn format_name(name: &str) -> String {
179-
let name_is_number = name.bytes().all(|digit| digit.is_ascii_digit());
180-
if name_is_number {
181-
format!("_{}", name)
182-
} else {
183-
name.to_string()
184-
}
185-
}
186-
187189
pub fn new<S: AsRef<str>>(
188190
name: S,
189191
dim: usize,
@@ -196,7 +198,7 @@ impl TensorType {
196198
kind, shape
197199
);
198200
}
199-
let formatted_name = Self::format_name(name.as_ref());
201+
let formatted_name = Type::format_name(name.as_ref());
200202
assert_ne!(
201203
dim, 0,
202204
"Trying to create TensorType with dim = 0 - should be a Scalar instead!"
@@ -277,8 +279,9 @@ impl OtherType {
277279
tokens
278280
);
279281
}
282+
let formatted_name = Type::format_name(name.as_ref());
280283
Self {
281-
name: Ident::new(name.as_ref(), Span::call_site()),
284+
name: Ident::new(&formatted_name, Span::call_site()),
282285
ty: tokens,
283286
}
284287
}

0 commit comments

Comments
 (0)