Skip to content

Commit a7efc10

Browse files
authored
Replaced str with Path (#1919)
* replaced str with Path * minor change (Path to AsRef<Path>) * fixed clippy lint
1 parent 98a58c8 commit a7efc10

File tree

6 files changed

+90
-66
lines changed

6 files changed

+90
-66
lines changed

crates/burn-train/src/checkpoint/file.rs

+22-10
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::path::{Path, PathBuf};
2+
13
use super::{Checkpointer, CheckpointerError};
24
use burn_core::{
35
record::{FileRecorder, Record},
@@ -6,7 +8,7 @@ use burn_core::{
68

79
/// The file checkpointer.
810
pub struct FileCheckpointer<FR> {
9-
directory: String,
11+
directory: PathBuf,
1012
name: String,
1113
recorder: FR,
1214
}
@@ -19,17 +21,19 @@ impl<FR> FileCheckpointer<FR> {
1921
/// * `recorder` - The file recorder.
2022
/// * `directory` - The directory to save the checkpoints.
2123
/// * `name` - The name of the checkpoint.
22-
pub fn new(recorder: FR, directory: &str, name: &str) -> Self {
24+
pub fn new(recorder: FR, directory: impl AsRef<Path>, name: &str) -> Self {
25+
let directory = directory.as_ref();
2326
std::fs::create_dir_all(directory).ok();
2427

2528
Self {
26-
directory: directory.to_string(),
29+
directory: directory.to_path_buf(),
2730
name: name.to_string(),
2831
recorder,
2932
}
3033
}
31-
fn path_for_epoch(&self, epoch: usize) -> String {
32-
format!("{}/{}-{}", self.directory, self.name, epoch)
34+
35+
fn path_for_epoch(&self, epoch: usize) -> PathBuf {
36+
self.directory.join(format!("{}-{}", self.name, epoch))
3337
}
3438
}
3539

@@ -41,28 +45,36 @@ where
4145
{
4246
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
4347
let file_path = self.path_for_epoch(epoch);
44-
log::info!("Saving checkpoint {} to {}", epoch, file_path);
48+
log::info!("Saving checkpoint {} to {}", epoch, file_path.display());
4549

4650
self.recorder
47-
.record(record, file_path.into())
51+
.record(record, file_path)
4852
.map_err(CheckpointerError::RecorderError)?;
4953

5054
Ok(())
5155
}
5256

5357
fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError> {
5458
let file_path = self.path_for_epoch(epoch);
55-
log::info!("Restoring checkpoint {} from {}", epoch, file_path);
59+
log::info!(
60+
"Restoring checkpoint {} from {}",
61+
epoch,
62+
file_path.display()
63+
);
5664
let record = self
5765
.recorder
58-
.load(file_path.into(), device)
66+
.load(file_path, device)
5967
.map_err(CheckpointerError::RecorderError)?;
6068

6169
Ok(record)
6270
}
6371

6472
fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> {
65-
let file_to_remove = format!("{}.{}", self.path_for_epoch(epoch), FR::file_extension(),);
73+
let file_to_remove = format!(
74+
"{}.{}",
75+
self.path_for_epoch(epoch).display(),
76+
FR::file_extension(),
77+
);
6678

6779
if std::path::Path::new(&file_to_remove).exists() {
6880
log::info!("Removing checkpoint {}", file_to_remove);

crates/burn-train/src/learner/application_logger.rs

+10-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::path::Path;
1+
use std::path::{Path, PathBuf};
22
use tracing_core::{Level, LevelFilter};
33
use tracing_subscriber::filter::filter_fn;
44
use tracing_subscriber::prelude::*;
@@ -12,14 +12,14 @@ pub trait ApplicationLoggerInstaller {
1212

1313
/// This struct is used to install a local file application logger to output logs to a given file path.
1414
pub struct FileApplicationLoggerInstaller {
15-
path: String,
15+
path: PathBuf,
1616
}
1717

1818
impl FileApplicationLoggerInstaller {
1919
/// Create a new file application logger.
20-
pub fn new(path: &str) -> Self {
20+
pub fn new(path: impl AsRef<Path>) -> Self {
2121
Self {
22-
path: path.to_string(),
22+
path: path.as_ref().to_path_buf(),
2323
}
2424
}
2525
}
@@ -29,8 +29,9 @@ impl ApplicationLoggerInstaller for FileApplicationLoggerInstaller {
2929
let path = Path::new(&self.path);
3030
let writer = tracing_appender::rolling::never(
3131
path.parent().unwrap_or_else(|| Path::new(".")),
32-
path.file_name()
33-
.unwrap_or_else(|| panic!("The path '{}' to point to a file.", self.path)),
32+
path.file_name().unwrap_or_else(|| {
33+
panic!("The path '{}' to point to a file.", self.path.display())
34+
}),
3435
);
3536
let layer = tracing_subscriber::fmt::layer()
3637
.with_ansi(false)
@@ -51,13 +52,14 @@ impl ApplicationLoggerInstaller for FileApplicationLoggerInstaller {
5152
}
5253

5354
let hook = std::panic::take_hook();
54-
let file_path: String = self.path.to_owned();
55+
let file_path = self.path.to_owned();
5556

5657
std::panic::set_hook(Box::new(move |info| {
5758
log::error!("PANIC => {}", info.to_string());
5859
eprintln!(
5960
"=== PANIC ===\nA fatal error happened, you can check the experiment logs here => \
60-
'{file_path}'\n============="
61+
'{}'\n=============",
62+
file_path.display()
6163
);
6264
hook(info);
6365
}));

crates/burn-train/src/learner/builder.rs

+15-26
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::collections::HashSet;
2+
use std::path::{Path, PathBuf};
23
use std::rc::Rc;
34

45
use super::Learner;
@@ -45,7 +46,7 @@ where
4546
)>,
4647
num_epochs: usize,
4748
checkpoint: Option<usize>,
48-
directory: String,
49+
directory: PathBuf,
4950
grad_accumulation: Option<usize>,
5051
devices: Vec<B::Device>,
5152
renderer: Option<Box<dyn MetricsRenderer + 'static>>,
@@ -74,20 +75,22 @@ where
7475
/// # Arguments
7576
///
7677
/// * `directory` - The directory to save the checkpoints.
77-
pub fn new(directory: &str) -> Self {
78+
pub fn new(directory: impl AsRef<Path>) -> Self {
79+
let directory = directory.as_ref().to_path_buf();
80+
let experiment_log_file = directory.join("experiment.log");
7881
Self {
7982
num_epochs: 1,
8083
checkpoint: None,
8184
checkpointers: None,
82-
directory: directory.to_string(),
85+
directory,
8386
grad_accumulation: None,
8487
devices: vec![B::Device::default()],
8588
metrics: Metrics::default(),
8689
event_store: LogEventStore::default(),
8790
renderer: None,
8891
interrupter: TrainingInterrupter::new(),
8992
tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
90-
format!("{}/experiment.log", directory).as_str(),
93+
experiment_log_file,
9194
))),
9295
num_loggers: 0,
9396
checkpointer_strategy: Box::new(
@@ -256,21 +259,12 @@ where
256259
M::Record: 'static,
257260
S::Record: 'static,
258261
{
259-
let checkpointer_model = FileCheckpointer::new(
260-
recorder.clone(),
261-
format!("{}/checkpoint", self.directory).as_str(),
262-
"model",
263-
);
264-
let checkpointer_optimizer = FileCheckpointer::new(
265-
recorder.clone(),
266-
format!("{}/checkpoint", self.directory).as_str(),
267-
"optim",
268-
);
269-
let checkpointer_scheduler: FileCheckpointer<FR> = FileCheckpointer::new(
270-
recorder,
271-
format!("{}/checkpoint", self.directory).as_str(),
272-
"scheduler",
273-
);
262+
let checkpoint_dir = self.directory.join("checkpoint");
263+
let checkpointer_model = FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "model");
264+
let checkpointer_optimizer =
265+
FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "optim");
266+
let checkpointer_scheduler: FileCheckpointer<FR> =
267+
FileCheckpointer::new(recorder, &checkpoint_dir, "scheduler");
274268

275269
self.checkpointers = Some((
276270
AsyncCheckpointer::new(checkpointer_model),
@@ -325,17 +319,12 @@ where
325319
let renderer = self.renderer.unwrap_or_else(|| {
326320
Box::new(default_renderer(self.interrupter.clone(), self.checkpoint))
327321
});
328-
let directory = &self.directory;
329322

330323
if self.num_loggers == 0 {
331324
self.event_store
332-
.register_logger_train(FileMetricLogger::new(
333-
format!("{directory}/train").as_str(),
334-
));
325+
.register_logger_train(FileMetricLogger::new(self.directory.join("train")));
335326
self.event_store
336-
.register_logger_valid(FileMetricLogger::new(
337-
format!("{directory}/valid").as_str(),
338-
));
327+
.register_logger_valid(FileMetricLogger::new(self.directory.join("valid")));
339328
}
340329

341330
let event_store = Rc::new(EventStoreClient::new(self.event_store));

crates/burn-train/src/learner/summary.rs

+16-9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
use core::cmp::Ordering;
2-
use std::{fmt::Display, path::Path};
2+
use std::{
3+
fmt::Display,
4+
path::{Path, PathBuf},
5+
};
36

47
use crate::{
58
logger::FileMetricLogger,
@@ -73,16 +76,20 @@ impl LearnerSummary {
7376
///
7477
/// * `directory` - The directory containing the training artifacts (checkpoints and logs).
7578
/// * `metrics` - The list of metrics to collect for the summary.
76-
pub fn new<S: AsRef<str>>(directory: &str, metrics: &[S]) -> Result<Self, String> {
77-
let directory_path = Path::new(directory);
78-
if !directory_path.exists() {
79-
return Err(format!("Artifact directory does not exist at: {directory}"));
79+
pub fn new<S: AsRef<str>>(directory: impl AsRef<Path>, metrics: &[S]) -> Result<Self, String> {
80+
let directory = directory.as_ref();
81+
if !directory.exists() {
82+
return Err(format!(
83+
"Artifact directory does not exist at: {}",
84+
directory.display()
85+
));
8086
}
81-
let train_dir = directory_path.join("train");
82-
let valid_dir = directory_path.join("valid");
87+
let train_dir = directory.join("train");
88+
let valid_dir = directory.join("valid");
8389
if !train_dir.exists() & !valid_dir.exists() {
8490
return Err(format!(
85-
"No training or validation artifacts found at: {directory}"
91+
"No training or validation artifacts found at: {}",
92+
directory.display()
8693
));
8794
}
8895

@@ -219,7 +226,7 @@ impl Display for LearnerSummary {
219226
}
220227

221228
pub(crate) struct LearnerSummaryConfig {
222-
pub(crate) directory: String,
229+
pub(crate) directory: PathBuf,
223230
pub(crate) metrics: Vec<String>,
224231
}
225232

crates/burn-train/src/logger/file.rs

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use super::Logger;
2-
use std::{fs::File, io::Write};
2+
use std::{fs::File, io::Write, path::Path};
33

44
/// File logger.
55
pub struct FileLogger {
@@ -16,14 +16,21 @@ impl FileLogger {
1616
/// # Returns
1717
///
1818
/// The file logger.
19-
pub fn new(path: &str) -> Self {
19+
pub fn new(path: impl AsRef<Path>) -> Self {
20+
let path = path.as_ref();
2021
let mut options = std::fs::File::options();
2122
let file = options
2223
.write(true)
2324
.truncate(true)
2425
.create(true)
2526
.open(path)
26-
.unwrap_or_else(|err| panic!("Should be able to create the new file '{path}': {err}"));
27+
.unwrap_or_else(|err| {
28+
panic!(
29+
"Should be able to create the new file '{}': {}",
30+
path.display(),
31+
err
32+
)
33+
});
2734

2835
Self { file }
2936
}

crates/burn-train/src/logger/metric.rs

+17-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger};
22
use crate::metric::{MetricEntry, NumericEntry};
3-
use std::{collections::HashMap, fs};
3+
use std::{
4+
collections::HashMap,
5+
fs,
6+
path::{Path, PathBuf},
7+
};
48

59
const EPOCH_PREFIX: &str = "epoch-";
610

@@ -27,7 +31,7 @@ pub trait MetricLogger: Send {
2731
/// The file metric logger.
2832
pub struct FileMetricLogger {
2933
loggers: HashMap<String, AsyncLogger<String>>,
30-
directory: String,
34+
directory: PathBuf,
3135
epoch: usize,
3236
}
3337

@@ -41,10 +45,10 @@ impl FileMetricLogger {
4145
/// # Returns
4246
///
4347
/// The file metric logger.
44-
pub fn new(directory: &str) -> Self {
48+
pub fn new(directory: impl AsRef<Path>) -> Self {
4549
Self {
4650
loggers: HashMap::new(),
47-
directory: directory.to_string(),
51+
directory: directory.as_ref().to_path_buf(),
4852
epoch: 1,
4953
}
5054
}
@@ -76,15 +80,18 @@ impl FileMetricLogger {
7680
max_epoch
7781
}
7882

79-
fn epoch_directory(&self, epoch: usize) -> String {
80-
format!("{}/{}{}", self.directory, EPOCH_PREFIX, epoch)
83+
fn epoch_directory(&self, epoch: usize) -> PathBuf {
84+
let name = format!("{}{}", EPOCH_PREFIX, epoch);
85+
self.directory.join(name)
8186
}
82-
fn file_path(&self, name: &str, epoch: usize) -> String {
87+
88+
fn file_path(&self, name: &str, epoch: usize) -> PathBuf {
8389
let directory = self.epoch_directory(epoch);
8490
let name = name.replace(' ', "_");
85-
86-
format!("{directory}/{name}.log")
91+
let name = format!("{name}.log");
92+
directory.join(name)
8793
}
94+
8895
fn create_directory(&self, epoch: usize) {
8996
let directory = self.epoch_directory(epoch);
9097
std::fs::create_dir_all(directory).ok();
@@ -102,7 +109,7 @@ impl MetricLogger for FileMetricLogger {
102109
self.create_directory(self.epoch);
103110

104111
let file_path = self.file_path(key, self.epoch);
105-
let logger = FileLogger::new(&file_path);
112+
let logger = FileLogger::new(file_path);
106113
let logger = AsyncLogger::new(logger);
107114

108115
self.loggers.insert(key.clone(), logger);

0 commit comments

Comments
 (0)