Skip to content

Commit 630044e

Browse files
Fix/devices api (#990)
1 parent 3d6c738 commit 630044e

File tree

10 files changed

+42
-25
lines changed

10 files changed

+42
-25
lines changed

burn-core/src/module/base.rs

+8-3
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,14 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
8484
/// Type to save and load the module.
8585
type Record: Record;
8686

87-
/// Collects devices in the given vector and returns it with the devices found in the module
88-
/// structure without duplicates.
89-
fn devices(&self, devices: Devices<B>) -> Devices<B>;
87+
/// Return all the devices found in the underneath module tree added to the given vector
88+
/// without duplicates.
89+
fn collect_devices(&self, devices: Devices<B>) -> Devices<B>;
90+
91+
/// Return all the devices found in the underneath module tree without duplicates.
92+
fn devices(&self) -> Devices<B> {
93+
self.collect_devices(Devices::<B>::new())
94+
}
9095

9196
/// Fork the module and all of its sub-modules to the given device.
9297
///

burn-core/src/module/param/constant.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ macro_rules! constant {
7575
self
7676
}
7777

78-
fn devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
78+
fn collect_devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
7979
devices
8080
}
8181
};
@@ -147,7 +147,7 @@ impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
147147
self.to_device(device)
148148
}
149149

150-
fn devices(&self, mut devices: Devices<B>) -> Devices<B> {
150+
fn collect_devices(&self, mut devices: Devices<B>) -> Devices<B> {
151151
let device = self.device();
152152

153153
if !devices.contains(&device) {
@@ -195,7 +195,7 @@ impl<B: Backend> Module<B> for PhantomData<B> {
195195
self
196196
}
197197

198-
fn devices(&self, devices: Devices<B>) -> Devices<B> {
198+
fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
199199
devices
200200
}
201201
}

burn-core/src/module/param/primitive.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ where
3737
self.map(|module| module.fork(device))
3838
}
3939

40-
fn devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
40+
fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
4141
if let Some(module) = self.as_ref() {
42-
devices = module.devices(devices);
42+
devices = module.collect_devices(devices);
4343
}
4444

4545
devices
@@ -105,9 +105,9 @@ where
105105
self.into_iter().map(|module| module.fork(device)).collect()
106106
}
107107

108-
fn devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
108+
fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
109109
for module in self.iter() {
110-
devices = module.devices(devices);
110+
devices = module.collect_devices(devices);
111111
}
112112

113113
devices
@@ -134,9 +134,9 @@ where
134134
{
135135
type Record = [T::Record; N];
136136

137-
fn devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
137+
fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {
138138
for module in self.iter() {
139-
devices = module.devices(devices);
139+
devices = module.collect_devices(devices);
140140
}
141141

142142
devices

burn-core/src/module/param/running.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,10 @@ impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
9595
self.to_device(device) // Same thing here since no grad.
9696
}
9797

98-
fn devices(&self, mut devices: Vec<<B as Backend>::Device>) -> Vec<<B as Backend>::Device> {
98+
fn collect_devices(
99+
&self,
100+
mut devices: Vec<<B as Backend>::Device>,
101+
) -> Vec<<B as Backend>::Device> {
99102
let device = self.value.read().unwrap().device();
100103

101104
if !devices.contains(&device) {

burn-core/src/module/param/tensor.rs

+12-3
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,10 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
7575
})
7676
}
7777

78-
fn devices(&self, mut devices: Vec<<B as Backend>::Device>) -> Vec<<B as Backend>::Device> {
78+
fn collect_devices(
79+
&self,
80+
mut devices: Vec<<B as Backend>::Device>,
81+
) -> Vec<<B as Backend>::Device> {
7982
let device = self.device();
8083

8184
if !devices.contains(&device) {
@@ -122,7 +125,10 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {
122125
self.to_device(device) // Don't support autodiff.
123126
}
124127

125-
fn devices(&self, mut devices: Vec<<B as Backend>::Device>) -> Vec<<B as Backend>::Device> {
128+
fn collect_devices(
129+
&self,
130+
mut devices: Vec<<B as Backend>::Device>,
131+
) -> Vec<<B as Backend>::Device> {
126132
let device = self.device();
127133

128134
if !devices.contains(&device) {
@@ -169,7 +175,10 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {
169175
self.to_device(device) // Don't support autodiff.
170176
}
171177

172-
fn devices(&self, mut devices: Vec<<B as Backend>::Device>) -> Vec<<B as Backend>::Device> {
178+
fn collect_devices(
179+
&self,
180+
mut devices: Vec<<B as Backend>::Device>,
181+
) -> Vec<<B as Backend>::Device> {
173182
let device = self.device();
174183

175184
if !devices.contains(&device) {

burn-derive/src/module/base.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> TokenStream {
2929
let num_params_fn = generator.gen_num_params();
3030
let visit = generator.gen_visit();
3131
let map_mut = generator.gen_map();
32-
let devices = generator.gen_devices();
32+
let collect_devices = generator.gen_collect_devices();
3333
let to_device = generator.gen_to_device();
3434
let fork = generator.gen_fork();
3535
let valid_fn = generator.gen_valid();
@@ -54,7 +54,7 @@ pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> TokenStream {
5454
#visit
5555
#map_mut
5656

57-
#devices
57+
#collect_devices
5858
#to_device
5959
#fork
6060
}

burn-derive/src/module/codegen.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use proc_macro2::TokenStream;
44
pub(crate) trait ModuleCodegen {
55
fn gen_num_params(&self) -> TokenStream;
66
fn gen_visit(&self) -> TokenStream;
7-
fn gen_devices(&self) -> TokenStream;
7+
fn gen_collect_devices(&self) -> TokenStream;
88
fn gen_to_device(&self) -> TokenStream;
99
fn gen_fork(&self) -> TokenStream;
1010
fn gen_map(&self) -> TokenStream;

burn-derive/src/module/codegen_struct.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,15 @@ impl ModuleCodegen for StructModuleCodegen {
3939
}
4040
}
4141

42-
fn gen_devices(&self) -> TokenStream {
42+
fn gen_collect_devices(&self) -> TokenStream {
4343
let body = self.gen_fields_fn(|name| {
4444
quote! {
45-
let devices = burn::module::Module::<B>::devices(&self.#name, devices);
45+
let devices = burn::module::Module::<B>::collect_devices(&self.#name, devices);
4646
}
4747
});
4848

4949
quote! {
50-
fn devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
50+
fn collect_devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
5151
#body
5252

5353
devices

examples/text-classification/src/model.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ impl<B: Backend> TextClassificationModel<B> {
8888
pub fn forward(&self, item: TextClassificationTrainingBatch<B>) -> ClassificationOutput<B> {
8989
// Get batch and sequence length, and the device
9090
let [batch_size, seq_length] = item.tokens.dims();
91-
let device = &self.embedding_token.devices(Vec::new())[0];
91+
let device = &self.embedding_token.devices()[0];
9292

9393
// Move tensors to the correct device
9494
let tokens = item.tokens.to_device(device);
@@ -128,7 +128,7 @@ impl<B: Backend> TextClassificationModel<B> {
128128
pub fn infer(&self, item: TextClassificationInferenceBatch<B>) -> Tensor<B, 2> {
129129
// Get batch and sequence length, and the device
130130
let [batch_size, seq_length] = item.tokens.dims();
131-
let device = &self.embedding_token.devices(Vec::new())[0];
131+
let device = &self.embedding_token.devices()[0];
132132

133133
// Move tensors to the correct device
134134
let tokens = item.tokens.to_device(device);

examples/text-generation/src/model.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ impl<B: Backend> TextGenerationModel<B> {
5858
item: TrainingTextGenerationBatch<B>,
5959
) -> ClassificationOutput<B> {
6060
let [batch_size, seq_length] = item.tokens_inputs.dims();
61-
let device = &self.devices(Vec::new())[0];
61+
let device = &self.devices()[0];
6262

6363
let inputs = item.tokens_inputs.to_device(device);
6464
let targets = item.targets.to_device(device);

0 commit comments

Comments
 (0)