Skip to content

Commit 847d185

Browse files
Mikhail Balakhnodrahnr
authored andcommitted
Do not pass batch_size to cudnnGetRNNParamsSize().
1 parent b43fdf9 commit 847d185

File tree

8 files changed

+14
-13
lines changed

8 files changed

+14
-13
lines changed

coaster-nn/src/frameworks/cuda/mod.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -844,14 +844,17 @@ where
844844
fn generate_rnn_weight_description(
845845
&self,
846846
rnn_config: &Self::CRNN,
847-
batch_size: i32,
848847
input_size: i32,
849848
) -> Result<Vec<usize>, Error> {
850849
let cudnn_framework = self.framework().cudnn();
851850
let data_type = <T as DataTypeInfo>::cudnn_data_type();
852851

853-
// MiniBatch, LayerSize, 1
854-
let dim_x = vec![batch_size, input_size, 1];
852+
// According to cuDNN API reference and examples, xDesc should have a
853+
// least 3 dimensions with batch_size being the first. However, weights
854+
// size does not depend on batch size and we'd like to avoid having to
855+
// specify batch size in advance (as it can change during execution).
856+
// So we use batch_size = 1 as it appers to work well.
857+
let dim_x = vec![1, input_size, 1];
855858
let stride_x = vec![dim_x[2] * dim_x[1], dim_x[2], 1];
856859

857860
// dummy desc to get the param size

coaster-nn/src/frameworks/native/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,6 @@ where
890890
fn generate_rnn_weight_description(
891891
&self,
892892
rnn_config: &Self::CRNN,
893-
batch_size: i32,
894893
input_size: i32,
895894
) -> Result<Vec<usize>, Error> {
896895
// This will end up being the tensor descriptor for the weights associated with the RNN pass

coaster-nn/src/plugin.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,6 @@ pub trait Rnn<F>: NN<F> {
335335
fn generate_rnn_weight_description(
336336
&self,
337337
rnn_config: &Self::CRNN,
338-
batch_size: i32,
339338
input_size: i32,
340339
) -> Result<Vec<usize>, crate::co::error::Error>;
341340

coaster-nn/src/tests/rnn.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ where
5757
.unwrap();
5858

5959
let filter_dimensions = backend
60-
.generate_rnn_weight_description(&rnn_config, BATCH_SIZE as i32, INPUT_SIZE as i32)
60+
.generate_rnn_weight_description(&rnn_config, INPUT_SIZE as i32)
6161
.unwrap();
6262

6363
let w = uniformly_random_tensor::<T, F>(

juice-examples/mackey-glass-rnn-regression/README.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,14 @@ Rustflags must be set to link natively to `cuda.lib` and `cudnn.h` in the patter
3939
A generated version of Mackey-Glass is packaged with Juice, and packaged in a way suitable for RNN networks.
4040

4141
```bash
42+
cd juice-examples/mackey-glass-rnn-regression
4243
# Train a RNN Network (*nix)
43-
./target/release/example-rnn-regression train --file=SavedRNNNetwork.juice --learningRate=0.01 --batchSize=40
44+
../../target/release/example-rnn-regression train --learning-rate=0.01 --batch-size=40 SavedRNNNetwork.juice
4445
# Train a RNN Network (Windows)
45-
.\target\release\example-rnn-regression.exe train --file=SavedRNNNetwork.juice --learningRate=0.01 --batchSize=40
46+
..\..\target\release\example-rnn-regression.exe train --learning-rate=0.01 --batch-size=40 SavedRNNNetwork.juice
4647

4748
# Test the RNN Network (*nix)
48-
../target/release/example-rnn-regression test --file=SavedRNNNetwork.juice
49+
../../target/release/example-rnn-regression test --batch-size=40 SavedRNNNetwork.juice
4950
# Test the RNN Network (Windows)
50-
cd ../target/release/ && example-rnn-regression.exe test --file=SavedRNNNetwork.juice
51+
..\..\target\release\example-rnn-regression.exe test --batch-size=40 SavedRNNNetwork.juice
5152
```
2.76 KB
Binary file not shown.

juice/src/layers/common/rnn.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ impl<B: IBackend + conn::Rnn<f32>> ILayer<B> for Rnn<B> {
137137
.unwrap();
138138

139139
let filter_dimensions: TensorDesc = backend
140-
.generate_rnn_weight_description(&config, batch_size as i32, input_size as i32)
140+
.generate_rnn_weight_description(&config, input_size as i32)
141141
.unwrap();
142142

143143
// weights
@@ -492,7 +492,6 @@ mod tests {
492492
let filter_dimensions = <Backend<Cuda> as conn::Rnn<f32>>::generate_rnn_weight_description(
493493
&backend,
494494
&config,
495-
BATCH_SIZE as i32,
496495
INPUT_SIZE as i32,
497496
)
498497
.unwrap();

rcudnn/cudnn/src/api/rnn.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ impl API {
166166
) -> Result<::libc::size_t, Error> {
167167
let mut size: ::libc::size_t = 0;
168168
let size_ptr: *mut ::libc::size_t = &mut size;
169-
match cudnnGetRNNParamsSize(handle, rnn_desc,x_desc, size_ptr, data_type) {
169+
match cudnnGetRNNParamsSize(handle, rnn_desc, x_desc, size_ptr, data_type) {
170170
cudnnStatus_t::CUDNN_STATUS_SUCCESS => Ok(size),
171171
cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => Err(Error::BadParam("One of the following; rnnDesc is invalid, x_desc is invalid, x_desc isn't fully packed, dataType & tensor Description type don't match")),
172172
cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED => Err(Error::NotSupported("The data type used in `rnn_desc` is not supported for RNN.")),

0 commit comments

Comments
 (0)