1
1
pub mod burnbenchapp;
2
2
pub mod persistence;
3
3
4
+ /// Simple parse to retrieve additional argument passed to cargo bench command
5
+ /// We cannot use clap here as clap parser does not allow to have unknown arguments.
6
+ pub fn get_argument < ' a > ( args : & ' a [ String ] , arg_name : & ' a str ) -> Option < & ' a str > {
7
+ let mut i = 0 ;
8
+ while i < args. len ( ) {
9
+ match args[ i] . as_str ( ) {
10
+ arg if arg == arg_name && i + 1 < args. len ( ) => {
11
+ return Some ( & args[ i + 1 ] ) ;
12
+ }
13
+ _ => i += 1 ,
14
+ }
15
+ }
16
+ None
17
+ }
18
+
19
+ /// Specialized function to retrieve the sharing token
20
+ pub fn get_sharing_token ( args : & [ String ] ) -> Option < & str > {
21
+ get_argument ( args, "--sharing-token" )
22
+ }
23
+
24
+ /// Specialized function to retrieve the sharing URL
25
+ pub fn get_sharing_url ( args : & [ String ] ) -> Option < & str > {
26
+ get_argument ( args, "--sharing-url" )
27
+ }
28
+
4
29
#[ macro_export]
5
30
macro_rules! bench_on_backend {
6
31
( ) => {
32
+ use std:: env;
33
+ let args: Vec <String > = env:: args( ) . collect( ) ;
34
+ let url = backend_comparison:: get_sharing_url( & args) ;
35
+ let token = backend_comparison:: get_sharing_token( & args) ;
36
+
7
37
#[ cfg( feature = "wgpu" ) ]
8
38
{
9
39
use burn:: backend:: wgpu:: { AutoGraphicsApi , Wgpu , WgpuDevice } ;
10
40
11
- bench:: <Wgpu <AutoGraphicsApi , f32 , i32 >>( & WgpuDevice :: default ( ) ) ;
41
+ bench:: <Wgpu <AutoGraphicsApi , f32 , i32 >>( & WgpuDevice :: default ( ) , url , token ) ;
12
42
}
13
43
14
44
#[ cfg( feature = "tch-gpu" ) ]
@@ -19,15 +49,15 @@ macro_rules! bench_on_backend {
19
49
let device = LibTorchDevice :: Cuda ( 0 ) ;
20
50
#[ cfg( target_os = "macos" ) ]
21
51
let device = LibTorchDevice :: Mps ;
22
- bench:: <LibTorch >( & device) ;
52
+ bench:: <LibTorch >( & device, url , token ) ;
23
53
}
24
54
25
55
#[ cfg( feature = "tch-cpu" ) ]
26
56
{
27
57
use burn:: backend:: { libtorch:: LibTorchDevice , LibTorch } ;
28
58
29
59
let device = LibTorchDevice :: Cpu ;
30
- bench:: <LibTorch >( & device) ;
60
+ bench:: <LibTorch >( & device, url , token ) ;
31
61
}
32
62
33
63
#[ cfg( any(
@@ -41,7 +71,7 @@ macro_rules! bench_on_backend {
41
71
use burn:: backend:: NdArray ;
42
72
43
73
let device = NdArrayDevice :: Cpu ;
44
- bench:: <NdArray >( & device) ;
74
+ bench:: <NdArray >( & device, url , token ) ;
45
75
}
46
76
47
77
#[ cfg( feature = "candle-cpu" ) ]
@@ -50,7 +80,7 @@ macro_rules! bench_on_backend {
50
80
use burn:: backend:: Candle ;
51
81
52
82
let device = CandleDevice :: Cpu ;
53
- bench:: <Candle >( & device) ;
83
+ bench:: <Candle >( & device, url , token ) ;
54
84
}
55
85
56
86
#[ cfg( feature = "candle-cuda" ) ]
@@ -59,7 +89,7 @@ macro_rules! bench_on_backend {
59
89
use burn:: backend:: Candle ;
60
90
61
91
let device = CandleDevice :: Cuda ( 0 ) ;
62
- bench:: <Candle >( & device) ;
92
+ bench:: <Candle >( & device, url , token ) ;
63
93
}
64
94
65
95
#[ cfg( feature = "candle-metal" ) ]
@@ -68,7 +98,35 @@ macro_rules! bench_on_backend {
68
98
use burn:: backend:: Candle ;
69
99
70
100
let device = CandleDevice :: Metal ( 0 ) ;
71
- bench:: <Candle >( & device) ;
101
+ bench:: <Candle >( & device, url , token ) ;
72
102
}
73
103
} ;
74
104
}
105
+
106
+ #[ cfg( test) ]
107
+ mod tests {
108
+ use super :: * ;
109
+ use rstest:: rstest;
110
+
111
+ #[ rstest]
112
+ #[ case:: sharing_token_argument_with_value( & [ "--sharing-token" , "token123" ] , Some ( "token123" ) ) ]
113
+ #[ case:: sharing_token_argument_no_value( & [ "--sharing-token" ] , None ) ]
114
+ #[ case:: sharing_token_argument_with_additional_arguments( & [ "--other-arg" , "value" , "--sharing-token" , "token789" ] , Some ( "token789" ) ) ]
115
+ #[ case:: other_argument( & [ "--other-arg" , "value" ] , None ) ]
116
+ #[ case:: no_argument( & [ ] , None ) ]
117
+ fn test_get_sharing_token ( #[ case] args : & [ & str ] , #[ case] expected : Option < & str > ) {
118
+ let args = args. iter ( ) . map ( |s| s. to_string ( ) ) . collect :: < Vec < String > > ( ) ;
119
+ assert_eq ! ( get_sharing_token( & args) , expected) ;
120
+ }
121
+
122
+ #[ rstest]
123
+ #[ case:: sharing_url_argument_with_value( & [ "--sharing-url" , "url123" ] , Some ( "url123" ) ) ]
124
+ #[ case:: sharing_url_argument_no_value( & [ "--sharing-url" ] , None ) ]
125
+ #[ case:: sharing_url_argument_with_additional_arguments( & [ "--other-arg" , "value" , "--sharing-url" , "url789" ] , Some ( "url789" ) ) ]
126
+ #[ case:: other_argument( & [ "--other-arg" , "value" ] , None ) ]
127
+ #[ case:: no_argument( & [ ] , None ) ]
128
+ fn test_get_sharing_url ( #[ case] args : & [ & str ] , #[ case] expected : Option < & str > ) {
129
+ let args = args. iter ( ) . map ( |s| s. to_string ( ) ) . collect :: < Vec < String > > ( ) ;
130
+ assert_eq ! ( get_sharing_url( & args) , expected) ;
131
+ }
132
+ }
0 commit comments