Skip to content

Commit

Permalink
Format changes
Browse files Browse the repository at this point in the history
  • Loading branch information
John Staib Matilla committed Jan 13, 2025
1 parent 91a7054 commit f10f24d
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 14 deletions.
2 changes: 1 addition & 1 deletion modyn/config/examples/modyn_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ selector:
local_storage_directory: "/tmp/local_storage"
local_storage_max_samples_in_file: 1000000
cleanup_storage_directories_after_shutdown: true
ignore_existing_trigger_samples: false
ignore_existing_trigger_samples: true

trainer_server:
hostname: "trainer_server"
Expand Down
2 changes: 1 addition & 1 deletion modyn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .dlrm.dlrm import DLRM # noqa: F401
from .dummy.dummy import Dummy # noqa: F401
from .fmownet.fmownet import FmowNet # noqa: F401
from .gpt2.gpt2 import gpt2 # noqa: F401
from .gpt2.gpt2 import Gpt2 # noqa: F401
from .resnet18.resnet18 import ResNet18 # noqa: F401
from .resnet50.resnet50 import ResNet50 # noqa: F401
from .resnet152.resnet152 import ResNet152 # noqa: F401
Expand Down
6 changes: 3 additions & 3 deletions modyn/models/gpt2/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@
# as GPT2_Lora


class gpt2:
class Gpt2:
# pylint: disable-next=unused-argument
def __init__(self, hparams: Any, device: str, amp: bool) -> None:
self.model = gpt2Modyn(hparams)
self.model = Gpt2Modyn(hparams)
self.model.to(device)


# the following class is adapted from
# torchvision https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py


class gpt2Modyn(CoresetSupportingModule):
class Gpt2Modyn(CoresetSupportingModule):
def __init__(self, hparams: Any) -> None:
super().__init__()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class FileWrapper {
virtual std::vector<unsigned char> get_sample(uint64_t index) = 0;
virtual std::vector<std::vector<unsigned char>> get_samples(uint64_t start, uint64_t end) = 0;
virtual std::vector<std::vector<unsigned char>> get_samples_from_indices(const std::vector<uint64_t>& indices,
bool include_labels) = 0;
bool include_labels = true) = 0;
virtual void validate_file_extension() = 0;
virtual void delete_samples(const std::vector<uint64_t>& indices) = 0;
virtual void set_file_path(const std::string& path) = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class SingleSampleFileWrapper : public FileWrapper {
std::vector<unsigned char> get_sample(uint64_t index) override;
std::vector<std::vector<unsigned char>> get_samples(uint64_t start, uint64_t end) override;
std::vector<std::vector<unsigned char>> get_samples_from_indices(const std::vector<uint64_t>& indices,
bool include_labels) override;
bool include_labels = true) override;
void validate_file_extension() override;
void delete_samples(const std::vector<uint64_t>& indices) override;
void set_file_path(const std::string& path) override { file_path_ = path; }
Expand Down
15 changes: 8 additions & 7 deletions modyn/storage/include/internal/grpc/storage_service_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service {
// return millions of files
const std::vector<int64_t> file_ids = get_file_ids(session, dataset_id, start_timestamp, end_timestamp);
session.close();

if (file_ids.empty()) {
SPDLOG_INFO("No files found for dataset {} with start_timestamp = {} and end_timestamp = {}", dataset_id,
start_timestamp, end_timestamp);
Expand Down Expand Up @@ -646,8 +646,8 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service {
sample_labels.begin() + static_cast<int64_t>(sample_idx));
}
{
const std::lock_guard<std::mutex> lock(writer_mutex);
writer->Write(response);
const std::lock_guard<std::mutex> lock(writer_mutex);
writer->Write(response);
}
current_file_id = sample_fileid;
current_file_path = "",
Expand All @@ -672,7 +672,7 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service {
// Send leftovers
const std::vector<uint64_t> file_indexes(sample_indices.begin() + static_cast<int64_t>(current_file_start_idx),
sample_indices.end());

const std::vector<std::vector<unsigned char>> data =
file_wrapper->get_samples_from_indices(file_indexes, include_labels);

Expand All @@ -693,9 +693,10 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service {
response.mutable_labels()->Assign(sample_labels.begin() + static_cast<int64_t>(current_file_start_idx),
sample_labels.end());
}
{
const std::lock_guard<std::mutex> lock(writer_mutex);
writer->Write(response);}
{
const std::lock_guard<std::mutex> lock(writer_mutex);
writer->Write(response);
}
} catch (const std::exception& e) {
SPDLOG_ERROR("Error in send_sample_data_for_keys_and_file: {}", e.what());
SPDLOG_ERROR("Propagating error up the call chain to handle gRPC calls.");
Expand Down

0 comments on commit f10f24d

Please sign in to comment.