Skip to content

Commit

Permalink
Fix use of Box instead of Arc, closes #5
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Nov 27, 2023
1 parent 273eef4 commit 8b516d4
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions candle-lora-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 {
continue;
}
let typname = &ty.path.segments.first().unwrap().ident;
if is_ident(typname, "Box") {
if is_ident(typname, "Arc") {
if let syn::PathArguments::AngleBracketed(bracketed) =
&ty.path.segments.first().as_ref().unwrap().arguments
{
Expand Down Expand Up @@ -241,7 +241,7 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 {
}
continue;
}
if !is_ident(&segments[0].ident, "Box") {
if !is_ident(&segments[0].ident, "Arc") {
continue;
}
if let syn::PathArguments::AngleBracketed(bracketed) =
Expand Down Expand Up @@ -334,7 +334,7 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 {
if !linear_fields.is_empty() {
quote_into::quote_into!(linear_stream_assign += [#{
for (name, n) in linear_fields.iter() {
linear_stream_assign.extend(quote::quote!((self.#name = ::std::boxed::Box::new(new_layers.linear.get(#n).unwrap().clone())),))
linear_stream_assign.extend(quote::quote!((self.#name = ::std::sync::Arc::new(new_layers.linear.get(#n).unwrap().clone())),))
}
}];);
}
Expand All @@ -345,7 +345,7 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 {
for (name, n) in linear_fields.iter() {
linear_merge_stream_assign.extend(quote::quote!(({
(new_layers.linear.get_mut(#n).unwrap().clone()).merge_weights().expect("Merge failed for linear.");
self.#name = ::std::boxed::Box::new(new_layers.linear.get(#n).unwrap().clone())
self.#name = ::std::sync::Arc::new(new_layers.linear.get(#n).unwrap().clone())
}),))
}
}];);
Expand All @@ -355,7 +355,7 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 {
if !conv1d_fields.is_empty() {
quote_into::quote_into!(conv1d_stream_assign += [#{
for (name, n) in conv1d_fields.iter() {
conv1d_stream_assign.extend(quote::quote!((self.#name = ::std::boxed::Box::new(new_layers.conv1d.get(#n).unwrap().clone())),))
conv1d_stream_assign.extend(quote::quote!((self.#name = ::std::sync::Arc::new(new_layers.conv1d.get(#n).unwrap().clone())),))
}
}];);
}
Expand All @@ -366,7 +366,7 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 {
for (name, n) in conv1d_fields.iter() {
conv1d_merge_stream_assign.extend(quote::quote!(({
(new_layers.conv1d.get_mut(#n).unwrap().clone()).merge_weights().expect("Merge failed for conv1d.");
self.#name = ::std::boxed::Box::new(new_layers.conv1d.get(#n).unwrap().clone())
self.#name = ::std::sync::Arc::new(new_layers.conv1d.get(#n).unwrap().clone())
}),))
}
}];);
Expand All @@ -376,7 +376,7 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 {
if !conv2d_fields.is_empty() {
quote_into::quote_into!(conv2d_stream_assign += [#{
for (name, n) in conv2d_fields.iter() {
conv2d_stream_assign.extend(quote::quote!((self.#name = ::std::boxed::Box::new(new_layers.conv2d.get(#n).unwrap().clone())),))
conv2d_stream_assign.extend(quote::quote!((self.#name = ::std::sync::Arc::new(new_layers.conv2d.get(#n).unwrap().clone())),))
}
}];);
}
Expand All @@ -387,7 +387,7 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 {
for (name, n) in conv2d_fields.iter() {
conv2d_merge_stream_assign.extend(quote::quote!(({
(new_layers.conv2d.get_mut(#n).unwrap().clone()).merge_weights().expect("Merge failed for conv2d.");
self.#name = ::std::boxed::Box::new(new_layers.conv2d.get(#n).unwrap().clone())
self.#name = ::std::sync::Arc::new(new_layers.conv2d.get(#n).unwrap().clone())
}),))
}
}];);
Expand All @@ -397,7 +397,7 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 {
if !embed_fields.is_empty() {
quote_into::quote_into!(embed_stream_assign += [#{
for (name, n) in embed_fields.iter() {
embed_stream_assign.extend(quote::quote!((self.#name = ::std::boxed::Box::new(new_layers.embed.get(#n).unwrap().clone())),))
embed_stream_assign.extend(quote::quote!((self.#name = ::std::sync::Arc::new(new_layers.embed.get(#n).unwrap().clone())),))
}
}];);
}
Expand All @@ -408,7 +408,7 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 {
for (name, n) in embed_fields.iter() {
embed_merge_stream_assign.extend(quote::quote!(({
(new_layers.embed.get_mut(#n).unwrap().clone()).merge_weights().expect("Merge failed for embed.");
self.#name = ::std::boxed::Box::new(new_layers.embed.get(#n).unwrap().clone())
self.#name = ::std::sync::Arc::new(new_layers.embed.get(#n).unwrap().clone())
}),))
}
}];);
Expand Down Expand Up @@ -454,7 +454,7 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 {
if !linear_option1_fields.is_empty() {
quote_into::quote_into!(linear_option1_stream_assign += [#{
for (name, n) in linear_option1_fields.iter() {
linear_option1_stream_assign.extend(quote::quote!((self.#name = ::std::option::Option::Some(::std::boxed::Box::new(new_layers.linear.get(#n).unwrap().clone()))),))
linear_option1_stream_assign.extend(quote::quote!((self.#name = ::std::option::Option::Some(::std::sync::Arc::new(new_layers.linear.get(#n).unwrap().clone()))),))
}
}];);
}
Expand All @@ -465,7 +465,7 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 {
for (name, n) in linear_option1_fields.iter() {
linear_merge_option1_stream_assign.extend(quote::quote!(({
(new_layers.linear.get_mut(#n).unwrap().clone()).merge_weights().expect("Merge failed for option linear.");
self.#name = ::std::option::Option::Some(::std::boxed::Box::new(new_layers.linear.get(#n).unwrap().clone()))
self.#name = ::std::option::Option::Some(::std::sync::Arc::new(new_layers.linear.get(#n).unwrap().clone()))
}),))
}
}];);
Expand All @@ -475,7 +475,7 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 {
if !conv1d_option1_fields.is_empty() {
quote_into::quote_into!(conv1d_option1_stream_assign += [#{
for (name, n) in conv1d_option1_fields.iter() {
conv1d_option1_stream_assign.extend(quote::quote!((self.#name = ::std::option::Option::Some(::std::boxed::Box::new(new_layers.conv1d.get(#n).unwrap().clone()))),))
conv1d_option1_stream_assign.extend(quote::quote!((self.#name = ::std::option::Option::Some(::std::sync::Arc::new(new_layers.conv1d.get(#n).unwrap().clone()))),))
}
}];);
}
Expand All @@ -486,7 +486,7 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 {
for (name, n) in conv1d_option1_fields.iter() {
conv1d_merge_option1_stream_assign.extend(quote::quote!(({
(new_layers.conv1d.get_mut(#n).unwrap().clone()).merge_weights().expect("Merge failed for option conv1d.");
self.#name = ::std::option::Option::Some(::std::boxed::Box::new(new_layers.conv1d.get(#n).unwrap().clone()))
self.#name = ::std::option::Option::Some(::std::sync::Arc::new(new_layers.conv1d.get(#n).unwrap().clone()))
}),))
}
}];);
Expand All @@ -496,7 +496,7 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 {
if !conv2d_option1_fields.is_empty() {
quote_into::quote_into!(conv2d_option1_stream_assign += [#{
for (name, n) in conv2d_option1_fields.iter() {
conv2d_option1_stream_assign.extend(quote::quote!((self.#name = ::std::option::Option::Some(::std::boxed::Box::new(new_layers.conv2d.get(#n).unwrap().clone()))),))
conv2d_option1_stream_assign.extend(quote::quote!((self.#name = ::std::option::Option::Some(::std::sync::Arc::new(new_layers.conv2d.get(#n).unwrap().clone()))),))
}
}];);
}
Expand All @@ -507,7 +507,7 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 {
for (name, n) in conv2d_option1_fields.iter() {
conv2d_merge_option1_stream_assign.extend(quote::quote!(({
(new_layers.conv2d.get_mut(#n).unwrap().clone()).merge_weights().expect("Merge failed for option conv2d.");
self.#name = ::std::option::Option::Some(::std::boxed::Box::new(new_layers.conv2d.get(#n).unwrap().clone()))
self.#name = ::std::option::Option::Some(::std::sync::Arc::new(new_layers.conv2d.get(#n).unwrap().clone()))
}),))
}
}];);
Expand All @@ -517,7 +517,7 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 {
if !embed_option1_fields.is_empty() {
quote_into::quote_into!(embed_option1_stream_assign += [#{
for (name, n) in embed_option1_fields.iter() {
embed_option1_stream_assign.extend(quote::quote!((self.#name = ::std::option::Option::Some(::std::boxed::Box::new(new_layers.embed.get(#n).unwrap().clone()))),))
embed_option1_stream_assign.extend(quote::quote!((self.#name = ::std::option::Option::Some(::std::sync::Arc::new(new_layers.embed.get(#n).unwrap().clone()))),))
}
}];);
}
Expand All @@ -528,7 +528,7 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 {
for (name, n) in embed_option1_fields.iter() {
embed_merge_option1_stream_assign.extend(quote::quote!(({
(new_layers.embed.get_mut(#n).unwrap().clone()).merge_weights().expect("Merge failed for option embed.");
self.#name = ::std::option::Option::Some(::std::boxed::Box::new(new_layers.embed.get(#n).unwrap().clone()))
self.#name = ::std::option::Option::Some(::std::sync::Arc::new(new_layers.embed.get(#n).unwrap().clone()))
}),))
}
}];);
Expand Down

0 comments on commit 8b516d4

Please sign in to comment.