1
+ use std:: path:: { Path , PathBuf } ;
2
+
1
3
use super :: { Checkpointer , CheckpointerError } ;
2
4
use burn_core:: {
3
5
record:: { FileRecorder , Record } ,
@@ -6,7 +8,7 @@ use burn_core::{
6
8
7
9
/// The file checkpointer.
8
10
pub struct FileCheckpointer < FR > {
9
- directory : String ,
11
+ directory : PathBuf ,
10
12
name : String ,
11
13
recorder : FR ,
12
14
}
@@ -19,17 +21,19 @@ impl<FR> FileCheckpointer<FR> {
19
21
/// * `recorder` - The file recorder.
20
22
/// * `directory` - The directory to save the checkpoints.
21
23
/// * `name` - The name of the checkpoint.
22
- pub fn new ( recorder : FR , directory : & str , name : & str ) -> Self {
24
+ pub fn new ( recorder : FR , directory : impl AsRef < Path > , name : & str ) -> Self {
25
+ let directory = directory. as_ref ( ) ;
23
26
std:: fs:: create_dir_all ( directory) . ok ( ) ;
24
27
25
28
Self {
26
- directory : directory. to_string ( ) ,
29
+ directory : directory. to_path_buf ( ) ,
27
30
name : name. to_string ( ) ,
28
31
recorder,
29
32
}
30
33
}
31
- fn path_for_epoch ( & self , epoch : usize ) -> String {
32
- format ! ( "{}/{}-{}" , self . directory, self . name, epoch)
34
+
35
+ fn path_for_epoch ( & self , epoch : usize ) -> PathBuf {
36
+ self . directory . join ( format ! ( "{}-{}" , self . name, epoch) )
33
37
}
34
38
}
35
39
@@ -41,28 +45,36 @@ where
41
45
{
42
46
fn save ( & self , epoch : usize , record : R ) -> Result < ( ) , CheckpointerError > {
43
47
let file_path = self . path_for_epoch ( epoch) ;
44
- log:: info!( "Saving checkpoint {} to {}" , epoch, file_path) ;
48
+ log:: info!( "Saving checkpoint {} to {}" , epoch, file_path. display ( ) ) ;
45
49
46
50
self . recorder
47
- . record ( record, file_path. into ( ) )
51
+ . record ( record, file_path)
48
52
. map_err ( CheckpointerError :: RecorderError ) ?;
49
53
50
54
Ok ( ( ) )
51
55
}
52
56
53
57
fn restore ( & self , epoch : usize , device : & B :: Device ) -> Result < R , CheckpointerError > {
54
58
let file_path = self . path_for_epoch ( epoch) ;
55
- log:: info!( "Restoring checkpoint {} from {}" , epoch, file_path) ;
59
+ log:: info!(
60
+ "Restoring checkpoint {} from {}" ,
61
+ epoch,
62
+ file_path. display( )
63
+ ) ;
56
64
let record = self
57
65
. recorder
58
- . load ( file_path. into ( ) , device)
66
+ . load ( file_path, device)
59
67
. map_err ( CheckpointerError :: RecorderError ) ?;
60
68
61
69
Ok ( record)
62
70
}
63
71
64
72
fn delete ( & self , epoch : usize ) -> Result < ( ) , CheckpointerError > {
65
- let file_to_remove = format ! ( "{}.{}" , self . path_for_epoch( epoch) , FR :: file_extension( ) , ) ;
73
+ let file_to_remove = format ! (
74
+ "{}.{}" ,
75
+ self . path_for_epoch( epoch) . display( ) ,
76
+ FR :: file_extension( ) ,
77
+ ) ;
66
78
67
79
if std:: path:: Path :: new ( & file_to_remove) . exists ( ) {
68
80
log:: info!( "Removing checkpoint {}" , file_to_remove) ;
0 commit comments