diff --git a/modyn/common/grpc/grpc_helpers.py b/modyn/common/grpc/grpc_helpers.py index f115d0c22..5002cfc97 100644 --- a/modyn/common/grpc/grpc_helpers.py +++ b/modyn/common/grpc/grpc_helpers.py @@ -251,6 +251,10 @@ def prepare_start_training_request( enable_accurate_gpu_measurements=training_config.enable_accurate_gpu_measurements, record_loss_every=training_config.record_loss_every, drop_last_batch=training_config.drop_last_batch, + generative=training_config.generative, + grad_norm=training_config.grad_norm if training_config.grad_norm != 0.0 else None, + lora=training_config.lora, + kadapter=training_config.kadapter, ) def start_training( diff --git a/modyn/config/examples/modyn_config.yaml b/modyn/config/examples/modyn_config.yaml index 42ec1588e..822e86c90 100644 --- a/modyn/config/examples/modyn_config.yaml +++ b/modyn/config/examples/modyn_config.yaml @@ -6,8 +6,8 @@ project: storage: port: "50051" hostname: "storage" - sample_batch_size: 2000000 - sample_dbinsertion_batchsize: 1000000 + sample_batch_size: 50000000 + sample_dbinsertion_batchsize: 10000000 insertion_threads: 8 retrieval_threads: 8 sample_table_unlogged: true @@ -22,10 +22,11 @@ storage: filesystem_wrapper_type: "LocalFilesystemWrapper", file_wrapper_type: "SingleSampleFileWrapper", file_wrapper_config: - { file_extension: ".png", label_file_extension: ".label" }, + { file_extension: ".png", label_file_extension: ".label",has_labels: true }, ignore_last_timestamp: false, file_watcher_interval: 5, selector_batch_size: 128, + }, # ----------------------------------- CRITEO ----------------------------------- # { @@ -41,10 +42,12 @@ storage: record_size: 160, label_size: 4, file_extension: ".bin", + has_labels: true, }, ignore_last_timestamp: false, file_watcher_interval: 5, selector_batch_size: 2000000, + }, # ---------------------------------- YEARBOOK ---------------------------------- # { @@ -60,10 +63,12 @@ storage: record_size: 12292, label_size: 4, file_extension: ".bin", + has_labels: true, }, ignore_last_timestamp: false, file_watcher_interval: 5, selector_batch_size: 256, + }, { name: "yearbook_train", @@ -78,6 +83,7 @@ storage: record_size: 12292, label_size: 4, file_extension: ".bin", + has_labels: true, }, ignore_last_timestamp: false, file_watcher_interval: 5, @@ -96,6 +102,8 @@ storage: record_size: 12292, label_size: 4, file_extension: ".bin", + has_labels: true, + }, ignore_last_timestamp: false, file_watcher_interval: 5, @@ -110,7 +118,7 @@ storage: filesystem_wrapper_type: "LocalFilesystemWrapper", file_wrapper_type: "SingleSampleFileWrapper", file_wrapper_config: - { file_extension: ".png", label_file_extension: ".label" }, + { file_extension: ".png", label_file_extension: ".label",has_labels: true, }, ignore_last_timestamp: false, file_watcher_interval: 5, selector_batch_size: 1024, @@ -127,6 +135,7 @@ storage: file_extension: ".csv", separator: "\t", #tsv best option here since headlines contain commas and semicolons label_index: 1, + has_labels: true, }, ignore_last_timestamp: false, file_watcher_interval: 5, @@ -144,6 +153,7 @@ storage: file_extension: ".csv", separator: "\t", #tsv best option here since headlines contain commas and semicolons label_index: 1, + has_labels: true, }, ignore_last_timestamp: false, file_watcher_interval: 5, @@ -160,10 +170,12 @@ storage: file_extension: ".csv", separator: "\t", #tsv best option here since headlines contain commas and semicolons label_index: 1, + has_labels: true, }, ignore_last_timestamp: false, file_watcher_interval: 5, selector_batch_size: 4096, + }, # ------------------------------------ ARXIV ----------------------------------- # { @@ -177,10 +189,12 @@ storage: file_extension: ".csv", separator: "\t", #tsv best option here since sentences contain commas and semicolons label_index: 1, + has_labels: true }, ignore_last_timestamp: false, file_watcher_interval: 5, selector_batch_size: 4096, + }, { name: "arxiv_test", @@ -193,10 +207,12 @@ storage: file_extension: ".csv", separator: "\t", #tsv best option here since sentences contain commas and semicolons label_index: 1, + has_labels: true, }, ignore_last_timestamp: false, file_watcher_interval: 5, selector_batch_size: 4096, + }, # -------------------------------- ARXIV KAGGLE -------------------------------- # { @@ -210,10 +226,12 @@ storage: file_extension: ".csv", separator: "\t", #tsv best option here since sentences contain commas and semicolons label_index: 1, + has_labels: true, }, ignore_last_timestamp: false, file_watcher_interval: 5, selector_batch_size: 4096, + }, { name: "arxiv_kaggle_test", @@ -226,10 +244,12 @@ storage: file_extension: ".csv", separator: "\t", #tsv best option here since sentences contain commas and semicolons label_index: 1, + has_labels: true, }, ignore_last_timestamp: false, file_watcher_interval: 5, selector_batch_size: 4096, + }, # ------------------------------------ CLOC ------------------------------------ # { @@ -240,10 +260,31 @@ storage: filesystem_wrapper_type: "LocalFilesystemWrapper", file_wrapper_type: "SingleSampleFileWrapper", file_wrapper_config: - { file_extension: ".jpg", label_file_extension: ".label" }, + { file_extension: ".jpg", label_file_extension: ".label", + has_labels: true, + }, ignore_last_timestamp: false, file_watcher_interval: 999999999, selector_batch_size: 100000, + + }, + # ------------------------------------ Wikipedia_21_08 ------------------------------------ # + { + name: "Wikipedia_21_08", + description: "Wikipedia text dump from August 2021", + version: "0.0.1", + base_path: "/datasets/wikipedia_21_08", + filesystem_wrapper_type: "LocalFilesystemWrapper", + file_wrapper_type: "CsvFileWrapper", + file_wrapper_config: { + file_extension: ".csv", + separator: "\t", #tsv best option here since sentences contain commas and semicolons + has_labels: false, + }, + ignore_last_timestamp: false, + file_watcher_interval: 5, + selector_batch_size: 5000000, + }, ] database: @@ -278,7 +319,7 @@ selector: local_storage_directory: "/tmp/local_storage" local_storage_max_samples_in_file: 1000000 cleanup_storage_directories_after_shutdown: true - ignore_existing_trigger_samples: false + ignore_existing_trigger_samples: true trainer_server: hostname: "trainer_server" diff --git a/modyn/config/schema/pipeline/evaluation/config.py b/modyn/config/schema/pipeline/evaluation/config.py index 967407877..0dc2bd6cc 100644 --- a/modyn/config/schema/pipeline/evaluation/config.py +++ b/modyn/config/schema/pipeline/evaluation/config.py @@ -5,6 +5,7 @@ from pydantic import Field, field_validator from modyn.config.schema.base_model import ModynBaseModel +from modyn.config.schema.pipeline.training.config import LrSchedulerConfig, OptimizationCriterion, OptimizerConfig from ..data import DataConfig from .handler import EvalHandlerConfig @@ -20,6 +21,10 @@ class EvalDataConfig(DataConfig): description="All metrics used to evaluate the model on the given dataset.", min_length=1, ) + light_tuning: bool = Field(False, description="Whether to perform a light tuning.") + tuning_config: TuningConfig = Field( + None, description="Configuration for tuning parameters. Unnecessary if light_tuning is false" + ) class ResultWriter(ModynBaseModel): @@ -34,6 +39,38 @@ class ResultWriter(ModynBaseModel): - tensorboard: output the evaluation to dedicated tensorboard files.""" +class TuningConfig(ModynBaseModel): + epochs: int = Field(1, description="Number of epochs for tuning.", ge=1) + num_samples_to_pass: list[int] | None = Field(None, description="Number of samples to pass per epoch.") + + batch_size: int = Field(1, description="Batch size used in tuning.", ge=1) + dataloader_workers: int = Field(1, description="Number of workers for data loading.", ge=1) + drop_last_batch: bool = Field(True, description="Whether to drop the last batch if smaller than batch size.") + shuffle: bool = Field(True, description="Whether data is shuffled during tuning.") + enable_accurate_gpu_measurements: bool = Field(False, description="Enable precise GPU measurement during tuning.") + + amp: bool = Field(False, description="Whether automatic mixed precision is enabled.") + lr_scheduler: LrSchedulerConfig | None = Field(None, description="Learning rate scheduler configuration.") + device: str = Field( + "cpu", + description="The device the model should be put on.", + pattern=r"^(cpu|cuda:\d+)$", + ) + + seed: int | None = Field(None, description="Random seed for reproducibility, if provided.") + optimizers: list[OptimizerConfig] = Field( + description="An array of the optimizers for the training", + min_length=1, + ) + optimization_criterion: OptimizationCriterion = Field( + description="Configuration for the optimization criterion that we optimize", + ) + datasets: EvalDataConfig = Field( + description="Dataset used for light tuning", + min_length=1, + ) + + class EvaluationConfig(ModynBaseModel): handlers: list[EvalHandlerConfig] = Field( description="An array of all evaluation handlers that should be used to evaluate the model.", diff --git a/modyn/config/schema/pipeline/evaluation/metrics.py b/modyn/config/schema/pipeline/evaluation/metrics.py index 6924fb756..c581d9148 100644 --- a/modyn/config/schema/pipeline/evaluation/metrics.py +++ b/modyn/config/schema/pipeline/evaluation/metrics.py @@ -76,7 +76,42 @@ class RocAucMetricConfig(_BaseMetricConfig): name: Literal["RocAuc"] = Field("RocAuc") -MetricConfig = Annotated[AccuracyMetricConfig | F1ScoreMetricConfig | RocAucMetricConfig, Field(discriminator="name")] +class PerplexityMetricConfig(_BaseMetricConfig): + name: Literal["Perplexity"] = Field("Perplexity") + + +class GlueScoreMetricConfig(_BaseMetricConfig): + name: Literal["GLUEScore"] = Field("GLUEScore") + + +class TwikiF1MetricConfig(_BaseMetricConfig): + name: Literal["TwikiF1Score"] = Field("TwikiF1Score") + + +class PerplexityWithLightTuningMetricConfig(_BaseMetricConfig): + name: Literal["PerplexityWithLightTuning"] = Field("PerplexityWithLightTuning") + + +class BleuMetricConfig(_BaseMetricConfig): + name: Literal["Bleuscore"] = Field("Bleuscore") + + +class RougeMetricConfig(_BaseMetricConfig): + name: Literal["RougeScore"] = Field("RougeScore") + + +MetricConfig = Annotated[ + AccuracyMetricConfig + | F1ScoreMetricConfig + | RocAucMetricConfig + | PerplexityMetricConfig + | GlueScoreMetricConfig + | TwikiF1MetricConfig + | PerplexityWithLightTuningMetricConfig + | BleuMetricConfig + | RougeMetricConfig, + Field(discriminator="name"), +] class _MetricWrapper(BaseModel): diff --git a/modyn/config/schema/pipeline/training/config.py b/modyn/config/schema/pipeline/training/config.py index b3f673553..5e18ac989 100644 --- a/modyn/config/schema/pipeline/training/config.py +++ b/modyn/config/schema/pipeline/training/config.py @@ -7,7 +7,7 @@ from modyn.config.schema.base_model import ModynBaseModel -OptimizerSource = Literal["PyTorch", "APEX"] +OptimizerSource = Literal["PyTorch", "APEX", "HuggingFace"] class OptimizerParamGroup(ModynBaseModel): @@ -119,6 +119,21 @@ class TrainingConfig(ModynBaseModel): "we start with random weights. If initial_model is 'pretrained', cannot be False." ) ) + generative: bool = Field( + False, + description=( + "If True then, then the training pipeline goes into the generative branch, data is sampled without expecting labels." + ), + ) + lora: bool = Field( + False, + description=("Applies Lora layers to the model"), + ) + kadapter: bool = Field( + False, + description=("Applies kadapter layers to the model"), + ) + seed: int | None = Field( None, description=( @@ -154,6 +169,10 @@ class TrainingConfig(ModynBaseModel): None, description="Configuration for the torch.cuda.amp.GradScaler. Effective only when amp is enabled.", ) + grad_norm: int = Field( + default=0, + description="Clips the gradients normed over this value, if its 0 it will not be used.", + ) # [Additional validation] diff --git a/modyn/config/schema/system/config.py b/modyn/config/schema/system/config.py index 881f0fb01..d3188893a 100644 --- a/modyn/config/schema/system/config.py +++ b/modyn/config/schema/system/config.py @@ -58,9 +58,10 @@ class DatasetCsvFileWrapperConfig(_DatasetBaseFileWrapperConfig): quoted_linebreaks: bool = Field(True, description="Whether linebreaks are quoted in CSV files.") label_index: int = Field( + -1, description=( "Column index of the label. For columns 'width, 'height, 'age', 'label' you should set label_index to 3." - ) + ), ) ignore_first_line: bool = Field( False, description="If the first line is the table header, you can skip it setting this parameter to True." @@ -73,6 +74,7 @@ class DatasetCsvFileWrapperConfig(_DatasetBaseFileWrapperConfig): "rows are the same size and that the 'label' column exists." ), ) + has_labels: bool = Field(True, description="Describes whether the dataset contains a label field or not") class DatasetBinaryFileWrapperConfig(_DatasetBaseFileWrapperConfig): @@ -83,12 +85,14 @@ class DatasetBinaryFileWrapperConfig(_DatasetBaseFileWrapperConfig): ) record_size: int = Field(description="The size of each full record in bytes (label + features).") label_size: int = Field(description="The size of the label field in bytes for a binary file wrapper.") + has_labels: bool = Field(True, description="Describes whether the dataset contains a label field or not") class DatasetPngFileWrapperConfig(_DatasetBaseFileWrapperConfig): """Represents a png dataset file used by modyn.""" label_file_extension: str = Field(description="The label file extension of the dataset", pattern=r"^\..*$") + has_labels: bool = Field(True, description="Describes whether the dataset contains a label field or not") DatasetFileWrapperConfig = Union[ # noqa: UP007 diff --git a/modyn/evaluator/internal/dataset/evaluation_dataset.py b/modyn/evaluator/internal/dataset/evaluation_dataset.py index 29cf294ff..b3f374815 100644 --- a/modyn/evaluator/internal/dataset/evaluation_dataset.py +++ b/modyn/evaluator/internal/dataset/evaluation_dataset.py @@ -38,6 +38,7 @@ def __init__( tokenizer: str | None = None, start_timestamp: int | None = None, end_timestamp: int | None = None, + generative: bool = False, ): self._evaluation_id = evaluation_id self._dataset_id = dataset_id @@ -52,7 +53,7 @@ def __init__( self._bytes_parser_function: Callable | None = None self._start_timestamp = start_timestamp self._end_timestamp = end_timestamp - + self._generative = generative # tokenizer for NLP tasks self._tokenizer = None self._tokenizer_name = tokenizer @@ -156,7 +157,7 @@ def _get_keys_from_storage(self, worker_id: int, total_workers: int) -> Iterable def _get_data_from_storage( self, keys: list[int], worker_id: int | None = None - ) -> Iterable[list[tuple[int, bytes, int]]]: + ) -> Iterable[list[tuple[int, bytes, int | None]]]: processed_keys: set[int] | list[int] = [] has_failed = False for attempt in Retrying( @@ -170,7 +171,10 @@ def _get_data_from_storage( if not has_failed: assert isinstance(processed_keys, list) processed_keys.extend(response.keys) - yield list(zip(response.keys, response.samples, response.labels)) + if not self._generative: + yield list(zip(response.keys, response.samples, response.labels)) + else: + yield list(zip(response.keys, response.samples, [None] * len(processed_keys))) else: assert isinstance(processed_keys, set) new_keys: list[int] = [key for key in response.keys if key not in processed_keys] @@ -179,11 +183,17 @@ def _get_data_from_storage( for key, sample in zip(response.keys, response.samples) if key not in processed_keys ] - new_labels: list[int] = [ - label for key, label in zip(response.keys, response.labels) if key not in processed_keys - ] - processed_keys.update(keys) - yield list(zip(new_keys, new_samples, new_labels)) + if not self.generative: + new_labels: list[int] = [ + label + for key, label in zip(response.keys, response.labels) + if key not in processed_keys + ] + processed_keys.update(keys) + yield list(zip(new_keys, new_samples, new_labels)) + else: + processed_keys.update(keys) + yield list(zip(new_keys, new_samples, [None] * len(new_keys))) except grpc.RpcError as e: # We catch and reraise to log and reconnect has_failed = True diff --git a/modyn/evaluator/internal/grpc/generated/evaluator_pb2.py b/modyn/evaluator/internal/grpc/generated/evaluator_pb2.py index 7e11a1772..1e38880fb 100644 --- a/modyn/evaluator/internal/grpc/generated/evaluator_pb2.py +++ b/modyn/evaluator/internal/grpc/generated/evaluator_pb2.py @@ -1,21 +1,24 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE # source: evaluator.proto -# Protobuf Python Version: 5.26.1 +# Protobuf Python Version: 5.27.2 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion(_runtime_version.Domain.PUBLIC, 5, 27, 2, "", "evaluator.proto") # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x0f\x65valuator.proto\x12\x0fmodyn.evaluator"t\n\x12\x45valuationInterval\x12\x1c\n\x0fstart_timestamp\x18\x01 \x01(\x03H\x00\x88\x01\x01\x12\x1a\n\rend_timestamp\x18\x02 \x01(\x03H\x01\x88\x01\x01\x42\x12\n\x10_start_timestampB\x10\n\x0e_end_timestamp"}\n\x0b\x44\x61tasetInfo\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x17\n\x0fnum_dataloaders\x18\x02 \x01(\x05\x12\x41\n\x14\x65valuation_intervals\x18\x03 \x03(\x0b\x32#.modyn.evaluator.EvaluationInterval"\x1d\n\x0cPythonString\x12\r\n\x05value\x18\x01 \x01(\t"\x1b\n\nJsonString\x12\r\n\x05value\x18\x01 \x01(\t"\xfa\x02\n\x14\x45valuateModelRequest\x12\x10\n\x08model_id\x18\x01 \x01(\x05\x12\x32\n\x0c\x64\x61taset_info\x18\x02 \x01(\x0b\x32\x1c.modyn.evaluator.DatasetInfo\x12\x0e\n\x06\x64\x65vice\x18\x03 \x01(\t\x12\x12\n\nbatch_size\x18\x04 \x01(\x05\x12,\n\x07metrics\x18\x05 \x03(\x0b\x32\x1b.modyn.evaluator.JsonString\x12\x16\n\x0etransform_list\x18\x06 \x03(\t\x12\x33\n\x0c\x62ytes_parser\x18\x07 \x01(\x0b\x32\x1d.modyn.evaluator.PythonString\x12\x38\n\x11label_transformer\x18\x08 \x01(\x0b\x32\x1d.modyn.evaluator.PythonString\x12\x35\n\ttokenizer\x18\t \x01(\x0b\x32\x1d.modyn.evaluator.PythonStringH\x00\x88\x01\x01\x42\x0c\n\n_tokenizer"|\n\x1d\x45valuateModelIntervalResponse\x12\x14\n\x0c\x64\x61taset_size\x18\x01 \x01(\x03\x12\x45\n\x13\x65val_aborted_reason\x18\x02 \x01(\x0e\x32(.modyn.evaluator.EvaluationAbortedReason"\x96\x01\n\x15\x45valuateModelResponse\x12\x1a\n\x12\x65valuation_started\x18\x01 \x01(\x08\x12\x15\n\revaluation_id\x18\x02 \x01(\x05\x12J\n\x12interval_responses\x18\x03 \x03(\x0b\x32..modyn.evaluator.EvaluateModelIntervalResponse"0\n\x17\x45valuationStatusRequest\x12\x15\n\revaluation_id\x18\x01 \x01(\x05"c\n\x18\x45valuationStatusResponse\x12\r\n\x05valid\x18\x01 \x01(\x08\x12\x12\n\nis_running\x18\x02 \x01(\x08\x12\x16\n\texception\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x0c\n\n_exception"4\n\x12SingleMetricResult\x12\x0e\n\x06metric\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x01(\x02"n\n\x16\x45valuationIntervalData\x12\x16\n\x0einterval_index\x18\x01 \x01(\x05\x12<\n\x0f\x65valuation_data\x18\x02 \x03(\x0b\x32#.modyn.evaluator.SingleMetricResult"0\n\x17\x45valuationResultRequest\x12\x15\n\revaluation_id\x18\x01 \x01(\x05"2\n\x18\x45valuationCleanupRequest\x12\x16\n\x0e\x65valuation_ids\x18\x01 \x03(\x05"n\n\x18\x45valuationResultResponse\x12\r\n\x05valid\x18\x01 \x01(\x08\x12\x43\n\x12\x65valuation_results\x18\x02 \x03(\x0b\x32\'.modyn.evaluator.EvaluationIntervalData".\n\x19\x45valuationCleanupResponse\x12\x11\n\tsucceeded\x18\x01 \x03(\x05*\xcb\x01\n\x17\x45valuationAbortedReason\x12\x0f\n\x0bNOT_ABORTED\x10\x00\x12\x1f\n\x1bMODEL_NOT_EXIST_IN_METADATA\x10\x01\x12\x18\n\x14MODEL_IMPORT_FAILURE\x10\x02\x12\x1e\n\x1aMODEL_NOT_EXIST_IN_STORAGE\x10\x03\x12\x15\n\x11\x44\x41TASET_NOT_FOUND\x10\x04\x12\x11\n\rEMPTY_DATASET\x10\x05\x12\x1a\n\x16\x44OWNLOAD_MODEL_FAILURE\x10\x06\x32\xbe\x03\n\tEvaluator\x12\x61\n\x0e\x65valuate_model\x12%.modyn.evaluator.EvaluateModelRequest\x1a&.modyn.evaluator.EvaluateModelResponse"\x00\x12n\n\x15get_evaluation_status\x12(.modyn.evaluator.EvaluationStatusRequest\x1a).modyn.evaluator.EvaluationStatusResponse"\x00\x12n\n\x15get_evaluation_result\x12(.modyn.evaluator.EvaluationResultRequest\x1a).modyn.evaluator.EvaluationResultResponse"\x00\x12n\n\x13\x63leanup_evaluations\x12).modyn.evaluator.EvaluationCleanupRequest\x1a*.modyn.evaluator.EvaluationCleanupResponse"\x00\x62\x06proto3' + b'\n\x0f\x65valuator.proto\x12\x0fmodyn.evaluator"t\n\x12\x45valuationInterval\x12\x1c\n\x0fstart_timestamp\x18\x01 \x01(\x03H\x00\x88\x01\x01\x12\x1a\n\rend_timestamp\x18\x02 \x01(\x03H\x01\x88\x01\x01\x42\x12\n\x10_start_timestampB\x10\n\x0e_end_timestamp"}\n\x0b\x44\x61tasetInfo\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x17\n\x0fnum_dataloaders\x18\x02 \x01(\x05\x12\x41\n\x14\x65valuation_intervals\x18\x03 \x03(\x0b\x32#.modyn.evaluator.EvaluationInterval"\x1d\n\x0cPythonString\x12\r\n\x05value\x18\x01 \x01(\t"\x1b\n\nJsonString\x12\r\n\x05value\x18\x01 \x01(\t"\xdb\x03\n\x14\x45valuateModelRequest\x12\x10\n\x08model_id\x18\x01 \x01(\x05\x12\x32\n\x0c\x64\x61taset_info\x18\x02 \x01(\x0b\x32\x1c.modyn.evaluator.DatasetInfo\x12\x0e\n\x06\x64\x65vice\x18\x03 \x01(\t\x12\x12\n\nbatch_size\x18\x04 \x01(\x05\x12,\n\x07metrics\x18\x05 \x03(\x0b\x32\x1b.modyn.evaluator.JsonString\x12\x16\n\x0etransform_list\x18\x06 \x03(\t\x12\x33\n\x0c\x62ytes_parser\x18\x07 \x01(\x0b\x32\x1d.modyn.evaluator.PythonString\x12\x38\n\x11label_transformer\x18\x08 \x01(\x0b\x32\x1d.modyn.evaluator.PythonString\x12\x35\n\ttokenizer\x18\t \x01(\x0b\x32\x1d.modyn.evaluator.PythonStringH\x00\x88\x01\x01\x12\x14\n\x0clight_tuning\x18\n \x01(\x08\x12\x37\n\rtuning_config\x18\x0b \x01(\x0b\x32\x1b.modyn.evaluator.JsonStringH\x01\x88\x01\x01\x42\x0c\n\n_tokenizerB\x10\n\x0e_tuning_config"|\n\x1d\x45valuateModelIntervalResponse\x12\x14\n\x0c\x64\x61taset_size\x18\x01 \x01(\x03\x12\x45\n\x13\x65val_aborted_reason\x18\x02 \x01(\x0e\x32(.modyn.evaluator.EvaluationAbortedReason"\x96\x01\n\x15\x45valuateModelResponse\x12\x1a\n\x12\x65valuation_started\x18\x01 \x01(\x08\x12\x15\n\revaluation_id\x18\x02 \x01(\x05\x12J\n\x12interval_responses\x18\x03 \x03(\x0b\x32..modyn.evaluator.EvaluateModelIntervalResponse"0\n\x17\x45valuationStatusRequest\x12\x15\n\revaluation_id\x18\x01 \x01(\x05"c\n\x18\x45valuationStatusResponse\x12\r\n\x05valid\x18\x01 \x01(\x08\x12\x12\n\nis_running\x18\x02 \x01(\x08\x12\x16\n\texception\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x0c\n\n_exception"4\n\x12SingleMetricResult\x12\x0e\n\x06metric\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x01(\x02"n\n\x16\x45valuationIntervalData\x12\x16\n\x0einterval_index\x18\x01 \x01(\x05\x12<\n\x0f\x65valuation_data\x18\x02 \x03(\x0b\x32#.modyn.evaluator.SingleMetricResult"0\n\x17\x45valuationResultRequest\x12\x15\n\revaluation_id\x18\x01 \x01(\x05"2\n\x18\x45valuationCleanupRequest\x12\x16\n\x0e\x65valuation_ids\x18\x01 \x03(\x05"n\n\x18\x45valuationResultResponse\x12\r\n\x05valid\x18\x01 \x01(\x08\x12\x43\n\x12\x65valuation_results\x18\x02 \x03(\x0b\x32\'.modyn.evaluator.EvaluationIntervalData".\n\x19\x45valuationCleanupResponse\x12\x11\n\tsucceeded\x18\x01 \x03(\x05*\xcb\x01\n\x17\x45valuationAbortedReason\x12\x0f\n\x0bNOT_ABORTED\x10\x00\x12\x1f\n\x1bMODEL_NOT_EXIST_IN_METADATA\x10\x01\x12\x18\n\x14MODEL_IMPORT_FAILURE\x10\x02\x12\x1e\n\x1aMODEL_NOT_EXIST_IN_STORAGE\x10\x03\x12\x15\n\x11\x44\x41TASET_NOT_FOUND\x10\x04\x12\x11\n\rEMPTY_DATASET\x10\x05\x12\x1a\n\x16\x44OWNLOAD_MODEL_FAILURE\x10\x06\x32\xbe\x03\n\tEvaluator\x12\x61\n\x0e\x65valuate_model\x12%.modyn.evaluator.EvaluateModelRequest\x1a&.modyn.evaluator.EvaluateModelResponse"\x00\x12n\n\x15get_evaluation_status\x12(.modyn.evaluator.EvaluationStatusRequest\x1a).modyn.evaluator.EvaluationStatusResponse"\x00\x12n\n\x15get_evaluation_result\x12(.modyn.evaluator.EvaluationResultRequest\x1a).modyn.evaluator.EvaluationResultResponse"\x00\x12n\n\x13\x63leanup_evaluations\x12).modyn.evaluator.EvaluationCleanupRequest\x1a*.modyn.evaluator.EvaluationCleanupResponse"\x00\x62\x06proto3' ) _globals = globals() @@ -23,8 +26,8 @@ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "evaluator_pb2", _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals["_EVALUATIONABORTEDREASON"]._serialized_start = 1581 - _globals["_EVALUATIONABORTEDREASON"]._serialized_end = 1784 + _globals["_EVALUATIONABORTEDREASON"]._serialized_start = 1678 + _globals["_EVALUATIONABORTEDREASON"]._serialized_end = 1881 _globals["_EVALUATIONINTERVAL"]._serialized_start = 36 _globals["_EVALUATIONINTERVAL"]._serialized_end = 152 _globals["_DATASETINFO"]._serialized_start = 154 @@ -34,27 +37,27 @@ _globals["_JSONSTRING"]._serialized_start = 312 _globals["_JSONSTRING"]._serialized_end = 339 _globals["_EVALUATEMODELREQUEST"]._serialized_start = 342 - _globals["_EVALUATEMODELREQUEST"]._serialized_end = 720 - _globals["_EVALUATEMODELINTERVALRESPONSE"]._serialized_start = 722 - _globals["_EVALUATEMODELINTERVALRESPONSE"]._serialized_end = 846 - _globals["_EVALUATEMODELRESPONSE"]._serialized_start = 849 - _globals["_EVALUATEMODELRESPONSE"]._serialized_end = 999 - _globals["_EVALUATIONSTATUSREQUEST"]._serialized_start = 1001 - _globals["_EVALUATIONSTATUSREQUEST"]._serialized_end = 1049 - _globals["_EVALUATIONSTATUSRESPONSE"]._serialized_start = 1051 - _globals["_EVALUATIONSTATUSRESPONSE"]._serialized_end = 1150 - _globals["_SINGLEMETRICRESULT"]._serialized_start = 1152 - _globals["_SINGLEMETRICRESULT"]._serialized_end = 1204 - _globals["_EVALUATIONINTERVALDATA"]._serialized_start = 1206 - _globals["_EVALUATIONINTERVALDATA"]._serialized_end = 1316 - _globals["_EVALUATIONRESULTREQUEST"]._serialized_start = 1318 - _globals["_EVALUATIONRESULTREQUEST"]._serialized_end = 1366 - _globals["_EVALUATIONCLEANUPREQUEST"]._serialized_start = 1368 - _globals["_EVALUATIONCLEANUPREQUEST"]._serialized_end = 1418 - _globals["_EVALUATIONRESULTRESPONSE"]._serialized_start = 1420 - _globals["_EVALUATIONRESULTRESPONSE"]._serialized_end = 1530 - _globals["_EVALUATIONCLEANUPRESPONSE"]._serialized_start = 1532 - _globals["_EVALUATIONCLEANUPRESPONSE"]._serialized_end = 1578 - _globals["_EVALUATOR"]._serialized_start = 1787 - _globals["_EVALUATOR"]._serialized_end = 2233 + _globals["_EVALUATEMODELREQUEST"]._serialized_end = 817 + _globals["_EVALUATEMODELINTERVALRESPONSE"]._serialized_start = 819 + _globals["_EVALUATEMODELINTERVALRESPONSE"]._serialized_end = 943 + _globals["_EVALUATEMODELRESPONSE"]._serialized_start = 946 + _globals["_EVALUATEMODELRESPONSE"]._serialized_end = 1096 + _globals["_EVALUATIONSTATUSREQUEST"]._serialized_start = 1098 + _globals["_EVALUATIONSTATUSREQUEST"]._serialized_end = 1146 + _globals["_EVALUATIONSTATUSRESPONSE"]._serialized_start = 1148 + _globals["_EVALUATIONSTATUSRESPONSE"]._serialized_end = 1247 + _globals["_SINGLEMETRICRESULT"]._serialized_start = 1249 + _globals["_SINGLEMETRICRESULT"]._serialized_end = 1301 + _globals["_EVALUATIONINTERVALDATA"]._serialized_start = 1303 + _globals["_EVALUATIONINTERVALDATA"]._serialized_end = 1413 + _globals["_EVALUATIONRESULTREQUEST"]._serialized_start = 1415 + _globals["_EVALUATIONRESULTREQUEST"]._serialized_end = 1463 + _globals["_EVALUATIONCLEANUPREQUEST"]._serialized_start = 1465 + _globals["_EVALUATIONCLEANUPREQUEST"]._serialized_end = 1515 + _globals["_EVALUATIONRESULTRESPONSE"]._serialized_start = 1517 + _globals["_EVALUATIONRESULTRESPONSE"]._serialized_end = 1627 + _globals["_EVALUATIONCLEANUPRESPONSE"]._serialized_start = 1629 + _globals["_EVALUATIONCLEANUPRESPONSE"]._serialized_end = 1675 + _globals["_EVALUATOR"]._serialized_start = 1884 + _globals["_EVALUATOR"]._serialized_end = 2330 # @@protoc_insertion_point(module_scope) diff --git a/modyn/evaluator/internal/grpc/generated/evaluator_pb2.pyi b/modyn/evaluator/internal/grpc/generated/evaluator_pb2.pyi index 0cfb1aa5c..4058d379b 100644 --- a/modyn/evaluator/internal/grpc/generated/evaluator_pb2.pyi +++ b/modyn/evaluator/internal/grpc/generated/evaluator_pb2.pyi @@ -174,9 +174,12 @@ class EvaluateModelRequest(google.protobuf.message.Message): BYTES_PARSER_FIELD_NUMBER: builtins.int LABEL_TRANSFORMER_FIELD_NUMBER: builtins.int TOKENIZER_FIELD_NUMBER: builtins.int + LIGHT_TUNING_FIELD_NUMBER: builtins.int + TUNING_CONFIG_FIELD_NUMBER: builtins.int model_id: builtins.int device: builtins.str batch_size: builtins.int + light_tuning: builtins.bool @property def dataset_info(self) -> global___DatasetInfo: ... @property @@ -189,6 +192,8 @@ class EvaluateModelRequest(google.protobuf.message.Message): def label_transformer(self) -> global___PythonString: ... @property def tokenizer(self) -> global___PythonString: ... + @property + def tuning_config(self) -> global___JsonString: ... def __init__( self, *, @@ -201,12 +206,16 @@ class EvaluateModelRequest(google.protobuf.message.Message): bytes_parser: global___PythonString | None = ..., label_transformer: global___PythonString | None = ..., tokenizer: global___PythonString | None = ..., + light_tuning: builtins.bool = ..., + tuning_config: global___JsonString | None = ..., ) -> None: ... def HasField( self, field_name: typing.Literal[ "_tokenizer", b"_tokenizer", + "_tuning_config", + b"_tuning_config", "bytes_parser", b"bytes_parser", "dataset_info", @@ -215,6 +224,8 @@ class EvaluateModelRequest(google.protobuf.message.Message): b"label_transformer", "tokenizer", b"tokenizer", + "tuning_config", + b"tuning_config", ], ) -> builtins.bool: ... def ClearField( @@ -222,6 +233,8 @@ class EvaluateModelRequest(google.protobuf.message.Message): field_name: typing.Literal[ "_tokenizer", b"_tokenizer", + "_tuning_config", + b"_tuning_config", "batch_size", b"batch_size", "bytes_parser", @@ -232,6 +245,8 @@ class EvaluateModelRequest(google.protobuf.message.Message): b"device", "label_transformer", b"label_transformer", + "light_tuning", + b"light_tuning", "metrics", b"metrics", "model_id", @@ -240,11 +255,18 @@ class EvaluateModelRequest(google.protobuf.message.Message): b"tokenizer", "transform_list", b"transform_list", + "tuning_config", + b"tuning_config", ], ) -> None: ... + @typing.overload def WhichOneof( self, oneof_group: typing.Literal["_tokenizer", b"_tokenizer"] ) -> typing.Literal["tokenizer"] | None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing.Literal["_tuning_config", b"_tuning_config"] + ) -> typing.Literal["tuning_config"] | None: ... global___EvaluateModelRequest = EvaluateModelRequest @@ -277,7 +299,10 @@ class EvaluateModelResponse(google.protobuf.message.Message): EVALUATION_ID_FIELD_NUMBER: builtins.int INTERVAL_RESPONSES_FIELD_NUMBER: builtins.int evaluation_started: builtins.bool - """only when all interval evaluations failed, this field will be set to false""" + """only when all interval evaluations failed, this field will be set to false + it is a field of convenience for the client to decide whether to wait for the evaluation completion. + the client can always check the interval_responses + """ evaluation_id: builtins.int @property def interval_responses( @@ -378,7 +403,14 @@ class EvaluationIntervalData(google.protobuf.message.Message): INTERVAL_INDEX_FIELD_NUMBER: builtins.int EVALUATION_DATA_FIELD_NUMBER: builtins.int interval_index: builtins.int - """multiple metrics are required on on evaluation on one interval""" + """Since not every interval evaluation from EvaluateModelRequest may be successful, + the EvaluationIntervalData contained in the EvaluationResultResponse must explicitly specify what interval this + evaluation data corresponds to. The interval_index is the index of the interval in the list + Datainfo.evaluation_intervals in the EvaluateModelRequest. + For example if Datainfo.evaluation_intervals have 3 intervals, [interval1, interval2, interval3], + and interval2 fails. Then the EvaluationResultResponse will have 2 EvaluationIntervalData, one with interval_index + 0 (which corresponds to interval1) and the other with interval_index 2 (which corresponds to interval3). + """ @property def evaluation_data( self, diff --git a/modyn/evaluator/internal/grpc/generated/evaluator_pb2_grpc.py b/modyn/evaluator/internal/grpc/generated/evaluator_pb2_grpc.py index fe8811e4b..0e5f0d277 100644 --- a/modyn/evaluator/internal/grpc/generated/evaluator_pb2_grpc.py +++ b/modyn/evaluator/internal/grpc/generated/evaluator_pb2_grpc.py @@ -1,15 +1,13 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" +import grpc import warnings -import grpc import modyn.evaluator.internal.grpc.generated.evaluator_pb2 as evaluator__pb2 -GRPC_GENERATED_VERSION = "1.63.0" +GRPC_GENERATED_VERSION = "1.67.1" GRPC_VERSION = grpc.__version__ -EXPECTED_ERROR_RELEASE = "1.65.0" -SCHEDULED_RELEASE_DATE = "June 25, 2024" _version_not_supported = False try: @@ -20,15 +18,12 @@ _version_not_supported = True if _version_not_supported: - warnings.warn( + raise RuntimeError( f"The grpc package installed is at version {GRPC_VERSION}," + f" but the generated code in evaluator_pb2_grpc.py depends on" + f" grpcio>={GRPC_GENERATED_VERSION}." + f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}" + f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}." - + f" This warning will become an error in {EXPECTED_ERROR_RELEASE}," - + f" scheduled for release on {SCHEDULED_RELEASE_DATE}.", - RuntimeWarning, ) @@ -120,6 +115,7 @@ def add_EvaluatorServicer_to_server(servicer, server): } generic_handler = grpc.method_handlers_generic_handler("modyn.evaluator.Evaluator", rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers("modyn.evaluator.Evaluator", rpc_method_handlers) # This class is part of an EXPERIMENTAL API. diff --git a/modyn/evaluator/internal/metrics/accuracy.py b/modyn/evaluator/internal/metrics/accuracy.py index 5c881d8df..cd62e01cc 100644 --- a/modyn/evaluator/internal/metrics/accuracy.py +++ b/modyn/evaluator/internal/metrics/accuracy.py @@ -12,6 +12,7 @@ def __init__(self, config: AccuracyMetricConfig) -> None: self.samples_seen = 0 self.total_correct = 0 + # pylint: disable=unused-argument def _batch_evaluated_callback(self, y_true: torch.Tensor, y_pred: torch.Tensor, batch_size: int) -> None: if self.config.topn == 1: labeled_correctly = torch.sum(torch.eq(y_pred, y_true)).item() diff --git a/modyn/evaluator/internal/metrics/glue.py b/modyn/evaluator/internal/metrics/glue.py new file mode 100644 index 000000000..0ed583b2c --- /dev/null +++ b/modyn/evaluator/internal/metrics/glue.py @@ -0,0 +1,27 @@ +import numpy as np +import torch + +from modyn.config.schema.pipeline import GlueScoreMetricConfig +from modyn.evaluator.internal.metrics.abstract_holistic_metric import AbstractHolisticMetric + + +class GLUEScore(AbstractHolisticMetric): + """GLUE Score metric implementation.""" + + def __init__(self, config: GlueScoreMetricConfig) -> None: + super().__init__(config) + self.results: list[float] = [] + + # pylint: disable=unused-argument + def _dataset_evaluated_callback(self, y_true: torch.Tensor, y_pred: torch.Tensor, num_samples: int) -> None: + """Stores individual task scores to be averaged later.""" + self.results.append(torch.mean(y_pred).item()) # Example: replace with actual GLUE metric computation + + def get_evaluation_result(self) -> float: + if not self.results: + self.warning("No GLUE scores computed.") + return 0.0 + return np.mean(self.results) # type: ignore + + def get_name(self) -> str: + return "GLUE Score" diff --git a/modyn/evaluator/internal/metrics/perplexity.py b/modyn/evaluator/internal/metrics/perplexity.py new file mode 100644 index 000000000..d44f7a939 --- /dev/null +++ b/modyn/evaluator/internal/metrics/perplexity.py @@ -0,0 +1,29 @@ +import numpy as np +import torch + +from modyn.config.schema.pipeline import PerplexityMetricConfig +from modyn.evaluator.internal.metrics.abstract_decomposable_metric import AbstractDecomposableMetric + + +class Perplexity(AbstractDecomposableMetric): + """Standard Perplexity metric implementation.""" + + def __init__(self, config: PerplexityMetricConfig) -> None: + super().__init__(config) + self.total_loss = 0.0 + self.total_tokens = 0 + + def _batch_evaluated_callback(self, y_true: torch.Tensor, y_pred: torch.Tensor, batch_size: int) -> None: + loss_fn = torch.nn.CrossEntropyLoss(reduction="sum") + loss = loss_fn(y_pred.view(-1, y_pred.size(-1)), y_true.view(-1)) + self.total_loss += loss.item() + self.total_tokens += y_true.numel() + + def get_evaluation_result(self) -> float: + if self.total_tokens == 0: + self.warning("Did not see any samples.") + return float("inf") + return np.exp(self.total_loss / self.total_tokens) + + def get_name(self) -> str: + return "Perplexity" diff --git a/modyn/evaluator/internal/metrics/reuge_bleu.py b/modyn/evaluator/internal/metrics/reuge_bleu.py new file mode 100644 index 000000000..78218a8db --- /dev/null +++ b/modyn/evaluator/internal/metrics/reuge_bleu.py @@ -0,0 +1,65 @@ +import numpy as np +import torch +from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu +from rouge_score import rouge_scorer + +from modyn.config.schema.pipeline import BleuMetricConfig, RougeMetricConfig +from modyn.evaluator.internal.metrics.abstract_holistic_metric import AbstractHolisticMetric + + +class BLEUScore(AbstractHolisticMetric): + """BLEU Score metric implementation for text generation evaluation.""" + + def __init__(self, config: BleuMetricConfig) -> None: + super().__init__(config) + self.bleu_scores: list[float] = [] + + def _dataset_evaluated_callback(self, y_true: torch.Tensor, y_pred: torch.Tensor, num_samples: int) -> None: + """Computes BLEU score per sample and stores results.""" + chencherry = SmoothingFunction() + for ref, hyp in zip(y_true, y_pred): + reference = [ref.tolist()] # BLEU expects a list of references + hypothesis = hyp.tolist() + score = sentence_bleu(reference, hypothesis, smoothing_function=chencherry.method1) + self.bleu_scores.append(score) + + def get_evaluation_result(self) -> float: + """Returns the average BLEU score across all evaluated samples.""" + if not self.bleu_scores: + self.warning("No BLEU scores computed.") + return 0.0 + return float(np.mean(self.bleu_scores)) # Explicit cast to float + + def get_name(self) -> str: + return "BLEU Score" + + +class ROUGEScore(AbstractHolisticMetric): + """ROUGE Score metric implementation for text evaluation.""" + + def __init__(self, config: RougeMetricConfig) -> None: + super().__init__(config) + self.scores: dict[str, list[float]] = {"rouge-1": [], "rouge-2": [], "rouge-l": []} + + def _dataset_evaluated_callback(self, y_true: torch.Tensor, y_pred: torch.Tensor, num_samples: int) -> None: + """Computes ROUGE scores for predictions.""" + scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True) + for ref, hyp in zip(y_true, y_pred): + scores = scorer.score(ref.tolist(), hyp.tolist()) + self.scores["rouge-1"].append(scores["rouge1"].fmeasure) + self.scores["rouge-2"].append(scores["rouge2"].fmeasure) + self.scores["rouge-l"].append(scores["rougeL"].fmeasure) + + def get_evaluation_results(self) -> dict[str, float]: # New method name to avoid superclass type conflict + """Returns the averaged ROUGE scores.""" + if not self.scores["rouge-1"]: + self.warning("No ROUGE scores computed.") + return {"rouge-1": 0.0, "rouge-2": 0.0, "rouge-l": 0.0} + return { + "rouge-1": float(np.mean(self.scores["rouge-1"])), + "rouge-2": float(np.mean(self.scores["rouge-2"])), + "rouge-l": float(np.mean(self.scores["rouge-l"])), + } + + def get_name(self) -> str: + return "ROUGE Score" diff --git a/modyn/evaluator/internal/metrics/twikif1score.py b/modyn/evaluator/internal/metrics/twikif1score.py new file mode 100644 index 000000000..920185d0f --- /dev/null +++ b/modyn/evaluator/internal/metrics/twikif1score.py @@ -0,0 +1,33 @@ +import torch + +from modyn.config.schema.pipeline import TwikiF1MetricConfig +from modyn.evaluator.internal.metrics.abstract_decomposable_metric import AbstractDecomposableMetric + + +class TwikiF1Score(AbstractDecomposableMetric): + """TWIKI-Probes F1 Score implementation.""" + + def __init__(self, config: TwikiF1MetricConfig) -> None: + super().__init__(config) + self.true_positives = 0 + self.false_positives = 0 + self.false_negatives = 0 + + # pylint: disable=unused-argument + def _batch_evaluated_callback(self, y_true: torch.Tensor, y_pred: torch.Tensor, batch_size: int) -> None: + correct = torch.eq(y_true, y_pred) + self.true_positives += correct.sum().item() + self.false_positives += (~correct).sum().item() + self.false_negatives += (~correct).sum().item() + + def get_evaluation_result(self) -> float: + if self.true_positives == 0: + return 0.0 + precision = self.true_positives / (self.true_positives + self.false_positives) + recall = self.true_positives / (self.true_positives + self.false_negatives) + if precision + recall == 0: + return 0.0 + return 2 * (precision * recall) / (precision + recall) + + def get_name(self) -> str: + return "TWIKI-Probes F1 Score" diff --git a/modyn/evaluator/internal/pytorch_evaluator.py b/modyn/evaluator/internal/pytorch_evaluator.py index b3ad49ce8..61c80b393 100644 --- a/modyn/evaluator/internal/pytorch_evaluator.py +++ b/modyn/evaluator/internal/pytorch_evaluator.py @@ -9,7 +9,8 @@ from modyn.evaluator.internal.core_evaluation import perform_evaluation, setup_metrics from modyn.evaluator.internal.dataset.evaluation_dataset import EvaluationDataset -from modyn.evaluator.internal.utils import EvaluationInfo +from modyn.evaluator.internal.pytorch_lighttuner import PytorchTuner +from modyn.evaluator.internal.utils import EvaluationInfo, TuningInfo from modyn.utils import LABEL_TRANSFORMER_FUNC_NAME, deserialize_function @@ -108,6 +109,15 @@ def _single_interval_evaluate(self, dataloader: torch.utils.data.DataLoader, int f"Queue size = {self._metric_result_queue.qsize()}" ) + def _light_tune(self, tuning_info: TuningInfo) -> None: + tuner = PytorchTuner( + tuning_info=tuning_info, + logger=self.logger, + model=self._model.model, + storage_address=self._eval_info.storage_address, + ) + tuner.train() + def evaluate(self) -> None: for idx, interval_idx in enumerate(self._eval_info.not_failed_interval_ids): self._info(f"Evaluating interval {idx + 1}/{len(self._eval_info.not_failed_interval_ids)} ({interval_idx})") @@ -124,6 +134,7 @@ def evaluate( log_path: pathlib.Path, exception_queue: mp.Queue, metric_result_queue: mp.Queue, + light_tuning_info: TuningInfo | None = None, # Dictionary to pass tuning parameters ) -> None: logging.basicConfig( level=logging.DEBUG, @@ -136,7 +147,20 @@ def evaluate( try: evaluator = PytorchEvaluator(evaluation_info, logger, metric_result_queue) - evaluator.evaluate() + + # Perform light tuning before evaluation if enabled + if evaluation_info.light_tuning: + logger.info("Performing light tuning before evaluation.") + light_tuning_info = evaluation_info.tuning_info + # Ensure light_tuning_info is valid + if not isinstance(light_tuning_info, TuningInfo): + raise ValueError("light_tuning_info must be a dictionary with tuning parameters.") + + evaluator._light_tune(light_tuning_info) # Pass tuning info + + logger.info("Light tuning completed.") + + evaluator.evaluate() # Run evaluation after tuning logger.info("Evaluator returned.") except Exception: # pylint: disable=broad-except exception_msg = traceback.format_exc() diff --git a/modyn/evaluator/internal/pytorch_lighttuner.py b/modyn/evaluator/internal/pytorch_lighttuner.py new file mode 100644 index 000000000..da208881e --- /dev/null +++ b/modyn/evaluator/internal/pytorch_lighttuner.py @@ -0,0 +1,426 @@ +# pylint: disable=no-name-in-module +from __future__ import annotations + +import copy +import glob +import io +import json +import logging +import os +import pathlib +import shutil +import tempfile +from typing import Any + +import torch +import transformers + +from modyn.common.benchmark.stopwatch import Stopwatch +from modyn.evaluator.internal.dataset.evaluation_dataset import EvaluationDataset +from modyn.evaluator.internal.utils.tuning_info import TuningInfo +from modyn.models.modular_adapters.modular_adapters import apply_kadapter, apply_lora +from modyn.trainer_server.internal.trainer.gpu_measurement import GPUMeasurement +from modyn.utils import ( + LABEL_TRANSFORMER_FUNC_NAME, + deserialize_function, + dynamic_module_import, + package_available_and_can_be_imported, + seed_everything, +) + + +class PytorchTuner: + # pylint: disable=too-many-instance-attributes, too-many-locals, too-many-branches, too-many-statements + + def __init__(self, tuning_info: TuningInfo, logger: logging.Logger, model: Any, storage_address: Any) -> None: + self.logger = logger + self.pipeline_id = tuning_info.pipeline_id + self._evaluation_id = tuning_info.evaluation_id + self._info("Initializing Pytorch Tuner") + self.generative = tuning_info.generative + self._grad_norm = 0.5 # remember add this to training infotuning_info.grad_norm + self._lora = False + self._kadapter = False + self._light_tuning_steps = tuning_info.steps + if tuning_info.seed is not None: + self._seed_trainer_server(tuning_info.seed) + self._info("Everything seeded") + self._storage_address = storage_address + # setup model and optimizer + self._model = model + self._setup_optimizers(tuning_info) + self._info("Model and optimizer created.") + + self._scaler = torch.cuda.amp.GradScaler(enabled=tuning_info.amp, **tuning_info.grad_scaler_configuration) + self._info("Grad scaler created.") + if self._lora: + apply_lora(self._model) + if self._kadapter: + apply_kadapter(self._model) + + criterion_func = getattr(torch.nn, tuning_info.torch_criterion) + self._criterion = criterion_func(**tuning_info.criterion_dict) + + self._batch_size = tuning_info.batch_size + self._num_dataloaders = tuning_info.num_dataloaders + + self._label_transformer_function = deserialize_function( + tuning_info.label_transformer, LABEL_TRANSFORMER_FUNC_NAME + ) + + self._device = tuning_info.device + self._device_type = "cuda" if "cuda" in self._device else "cpu" + self._amp = tuning_info.amp + + self._measure_gpu_ops = tuning_info.enable_accurate_gpu_measurements + + self.epochs_per_trigger = tuning_info.epochs + + self._drop_last_batch = tuning_info.drop_last_batch + self._dataset_log_path = pathlib.Path(tempfile.mkdtemp(prefix=f"pl{self.pipeline_id}")) + self._log_file_path = tuning_info.log_file_path + if self._log_file_path is not None: + assert isinstance(self._log_file_path, pathlib.Path) + self._log_file_path.unlink(missing_ok=True) + else: + logger.warn("Log file path is None.") + + self._log: dict[str, Any] = {} + + self._num_samples = 0 + + self._expected_num_batches = -1 + self._expected_num_epochs = -1 + + self._step_lr_every: str | None = None + self._setup_lr_scheduler(tuning_info) + + self._info("LR scheduler created.") + + # setup dataloaders + self._info("Setting up data loaders.") + + self._tuning_info = tuning_info + + # ---------------------------------------------------------------------------------------------------------------- # + # Core training pipeline orchestration # + # ---------------------------------------------------------------------------------------------------------------- # + def _prepare_dataloader( + self, + tuning_info: TuningInfo, + ) -> torch.utils.data.DataLoader: + dataset = EvaluationDataset( + tuning_info.dataset_id, + tuning_info.bytes_parser, + tuning_info.transform_list, + self._storage_address, + tuning_info.evaluation_id, + tuning_info.tokenizer, + ) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=tuning_info.batch_size, + num_workers=tuning_info.num_dataloaders, + timeout=60 if tuning_info.num_dataloaders > 0 else 0, + ) + + return dataloader + + def train(self) -> None: + """Performs light tuning for a few steps before evaluation.""" + self._info(f"Process {os.getpid()} starts light tuning") + _train_dataloader = self._prepare_dataloader(self._tuning_info) + stopw = Stopwatch() + model_device = self._model.model.device + self._model.model.to(self._device) + stopw.start("TotalLightTuning") + self._model.model.train() + + for step, batch in enumerate(_train_dataloader): + if step >= self._light_tuning_steps: + break # Stop after defined steps + + stopw.start("FetchBatch", resume=True) + sample_ids, target, data = self.preprocess_batch(batch, stopw) + stopw.stop("FetchBatch") + for _, optimizer in self._optimizers.items(): + optimizer.zero_grad() + + # Forward pass + with torch.autocast(self._device_type, enabled=self._amp): + stopw.start("Forward", resume=True) + + if self.generative: + output = self._model.model(data) + output = output[..., :-1, :] # Ignore last token prediction + target = data[..., 1:, 0] # Shift target labels + + output = output.reshape(-1, output.size(-1)) + target = target.reshape(-1) + + target[target == 50256] = -100 # Mask padding tokens for GPT-style models + + else: + output = self._model.model(data, sample_ids=sample_ids) + + stopw.stop("Forward") + + # Compute loss + stopw.start("Loss", resume=True) + loss = self._criterion(output, target) + + stopw.stop("Loss") + + # Backward pass and optimizer step + stopw.start("Backward", resume=True) + self._scaler.scale(loss).backward() + if self._grad_norm is not None: + torch.nn.utils.clip_grad_norm_(self._model.model.parameters(), max_norm=self._grad_norm) + stopw.stop("Backward") + + stopw.start("OptimizerStep", resume=True) + for _, optimizer in self._optimizers.items(): + self._scaler.step(optimizer) + self._scaler.update() + self._step_lr_if_necessary(True) + stopw.stop("OptimizerStep") + + # Log loss + + stopw.stop("TotalLightTuning") + self._info(f"Light tuning complete! Total time: {stopw.measurements.get('TotalLightTuning', 0)} seconds") + self._model.model.to(model_device) + + # ---------------------------------------------------------------------------------------------------------------- # + # Training stages # + # ---------------------------------------------------------------------------------------------------------------- # + + def preprocess_batch( + self, batch: tuple, stopw: Stopwatch | None = None + ) -> tuple[list, torch.Tensor, torch.Tensor | dict]: + if stopw is None: + stopw = Stopwatch() + + stopw.start("PreprocSampleIDs", resume=True) + sample_ids = batch[0] + if isinstance(sample_ids, torch.Tensor): + sample_ids = sample_ids.tolist() + elif isinstance(sample_ids, tuple): + sample_ids = list(sample_ids) + assert isinstance(sample_ids, list), "Cannot parse result from DataLoader" + stopw.stop("PreprocSampleIDs") + if self.generative: + target = None + else: + stopw.start("LabelTransform", resume=True) + if self._label_transformer_function is not None: + target = self._label_transformer_function(batch[2]) + else: + target = batch[2] + stopw.stop("LabelTransform") + + with GPUMeasurement(self._measure_gpu_ops, "MoveLabelToGPU", self._device, stopw, resume=True): + target = target.to(self._device) + + with GPUMeasurement(self._measure_gpu_ops, "MoveDataToGPU", self._device, stopw, resume=True): + data: torch.Tensor | dict + if isinstance(batch[1], torch.Tensor): + data = batch[1].to(self._device) + elif isinstance(batch[1], dict): + data: dict[str, torch.Tensor] = {} # type: ignore[no-redef] + for name, tensor in batch[1].items(): + data[name] = tensor.to(self._device) + else: + raise ValueError( + "The format of the data provided is not supported in modyn. " + "Please use either torch tensors or dict[str, torch.Tensor]" + ) + return sample_ids, target, data + + def _step_lr_if_necessary(self, is_batch: bool) -> None: + if self._lr_scheduler is None: + return + assert self._step_lr_every is not None # for mypy + + if is_batch and self._step_lr_every == "batch": + self._lr_scheduler.step() + + if not is_batch and self._step_lr_every == "epoch": + self._lr_scheduler.step() + + # ------------------------------------------------------ IO ------------------------------------------------------ # + + def save_state(self, destination: pathlib.Path | io.BytesIO, iteration: int | None = None) -> None: + dict_to_save = {} + dict_to_save["model"] = self._model.state_dict() + for optimizer_name, optimizer in self._optimizers.items(): + dict_to_save[f"optimizer-{optimizer_name}"] = optimizer.state_dict() + + if iteration is not None: + dict_to_save["iteration"] = iteration + print(destination) + torch.save(dict_to_save, destination) + + def _setup_optimizers(self, tuning_info: TuningInfo) -> None: + self._optimizers = {} + for name, optimizer_config in tuning_info.torch_optimizers_configuration.items(): + if optimizer_config["source"] == "PyTorch": + optimizer_func = getattr(torch.optim, optimizer_config["algorithm"]) + elif optimizer_config["source"] == "APEX": + if package_available_and_can_be_imported("apex"): + import apex # pylint: disable=import-outside-toplevel, import-error + + optimizer_func = getattr(apex.optimizers, optimizer_config["algorithm"]) + else: + raise ValueError("Apex Optimizer defined, but apex is not available in the system") + elif optimizer_config["source"] == "HuggingFace": + optimizer_func = getattr(transformers, optimizer_config["algorithm"]) + else: + raise ValueError( + f"Unsupported optimizer from {optimizer_config['source']}. PyTorch and APEX are supported" + ) + optimizer_config_list = [] + for param_group in optimizer_config["param_groups"]: + module = param_group["module"] + + if optimizer_config["algorithm"] == "Adafactor": # Check if optimizer is Adafactor + no_decay = ["bias", "LayerNorm.weight"] + + # Create separate parameter group dictionaries + param_group_no_decay = copy.deepcopy(param_group["config"]) + param_group_decay = copy.deepcopy(param_group["config"]) + + param_group_decay["params"] = [ + p + for n, p in eval(f"self._model.{module}.named_parameters()") # pylint: disable=eval-used + if p.requires_grad and not any(m in n for m in no_decay) + ] + param_group_decay["weight_decay"] = 0.01 + optimizer_config_list.append(param_group_decay) + + param_group_no_decay["params"] = [ + p + for n, p in eval(f"self._model.{module}.named_parameters()") # pylint: disable=eval-used + if p.requires_grad and any(m in n for m in no_decay) + ] + param_group_no_decay["weight_decay"] = 0.0 + optimizer_config_list.append(param_group_no_decay) + + else: + param_group["config"]["params"] = [ + p + for p in eval(f"self._model.{module}.parameters()") # pylint: disable=eval-used + if p.requires_grad + ] + + optimizer_config_list.append(param_group["config"]) + self._optimizers[name] = optimizer_func(optimizer_config_list) + + def _update_lr_config_dict(self, lr_scheduler_config: dict[str, Any]) -> dict[str, Any]: + for key, value in lr_scheduler_config.items(): + if isinstance(value, dict): + self._update_lr_config_dict(value) + elif value == "MODYN_NUM_BATCHES": + lr_scheduler_config[key] = self._expected_num_batches + elif value == "MODYN_NUM_EPOCHS": + lr_scheduler_config[key] = self._expected_num_epochs + + return lr_scheduler_config + + def _setup_lr_scheduler(self, tuning_info: TuningInfo) -> None: + self._lr_scheduler = None + if tuning_info.lr_scheduler: + self._step_lr_every = tuning_info.lr_scheduler["step_every"] + + config_dict = self._update_lr_config_dict(tuning_info.lr_scheduler["config"]) + + if tuning_info.lr_scheduler["source"] == "Custom": + lr_scheduler_module = dynamic_module_import("modyn.trainer_server.custom_lr_schedulers") + custom_lr_scheduler = getattr(lr_scheduler_module, tuning_info.lr_scheduler["name"]) + optimizers = [self._optimizers[opt] for opt in tuning_info.lr_scheduler["optimizers"]] + self._lr_scheduler = custom_lr_scheduler(optimizers, config_dict) + elif tuning_info.lr_scheduler["source"] == "PyTorch": + torch_lr_scheduler = getattr(torch.optim.lr_scheduler, tuning_info.lr_scheduler["name"]) + if len(tuning_info.lr_scheduler["optimizers"]) > 1: + self._warning("Provided a LR scheduler from PyTorch, but multiple optimizers") + self._lr_scheduler = torch_lr_scheduler( + self._optimizers[tuning_info.lr_scheduler["optimizers"][0]], + **config_dict, + ) + else: + raise ValueError( + f"Unsupported LR scheduler of source {tuning_info.lr_scheduler['source']}." + "PyTorch and Custom are supported" + ) + + def _seed_trainer_server(self, seed: int) -> None: + if not (0 <= seed <= 100 and isinstance(seed, int)): + raise ValueError("The seed must be an integer in the range [0,100]") + # seed the trainer server + seed_everything(seed) + + # ---------------------------------------------------- Logging --------------------------------------------------- # + + def _info(self, msg: str) -> None: + self.logger.info(f"[Training {self._evaluation_id}][PL {self.pipeline_id}] {msg}") + + def _warning(self, msg: str) -> None: + self.logger.warning(f"[Training {self._evaluation_id}][PL {self.pipeline_id}] {msg}") + + def _error(self, msg: str) -> None: + self.logger.error(f"[Training {self._evaluation_id}][PL {self.pipeline_id}] {msg}") + + def _load_dataset_log(self) -> None: + worker_log = {} + for filename in glob.glob(str(self._dataset_log_path / "*.log")): + filepath = pathlib.Path(filename) + key = filepath.stem + + with open(self._dataset_log_path / filename, encoding="utf-8") as logfile: + worker_log[key] = json.load(logfile) + + self._log["dataset_worker_log"] = worker_log + + try: + if self._dataset_log_path.exists(): + shutil.rmtree(self._dataset_log_path) + except OSError as exp: + self._error("Error while deleting OnlineDataset logging directory.") + self._error(str(exp)) + + # -------------------------------------------------- Assertions -------------------------------------------------- # + + @staticmethod + def _assert_data_size( + expected_size: int, data: torch.Tensor | dict[Any, torch.Tensor], sample_ids: list, target: torch.Tensor + ) -> None: + assert ( + all(tensor.shape[0] == expected_size for tensor in data.values()) + if isinstance(data, dict) + else data.shape[0] == expected_size + ), ( + f"expected size: {expected_size}, actual size: " + + f"{data.shape[0] if isinstance(data, torch.Tensor) else 'n/a'}" + ) + assert len(sample_ids) == expected_size, f"expected size: {expected_size}, actual size: {len(sample_ids)}" + assert target.shape[0] == expected_size, f"expected size: {expected_size}, actual size: {target.shape[0]}" + + def _assert_training_size(self, epoch: int, trained_batches: int) -> None: + if self._lr_scheduler is not None: + assert self._expected_num_epochs == epoch + 1, ( + f"Something went wrong! We expected {self._expected_num_epochs}, but trained for {epoch + 1} epochs!" + + "\nWe fail since we trained using a LR scheduler that might depend on this." + ) + assert self._expected_num_batches == trained_batches, ( + f"Something went wrong! We expected to train on {self._expected_num_batches}," + + f" but trained for {trained_batches} batches!" + + "\nWe fail since we trained using a LR scheduler that might depend on this." + ) + else: + if self._expected_num_epochs != epoch + 1 or self._expected_num_batches != trained_batches: + self._error( + "Inconsistent expected batches. Not failing since no lr scheduler was used.\n" + + f" We expected {self._expected_num_epochs}, but trained for {epoch + 1} epochs!\n" + + f"We expected to train on {self._expected_num_batches}," + + f" but trained for {trained_batches} batches!" + ) diff --git a/modyn/evaluator/internal/utils/__init__.py b/modyn/evaluator/internal/utils/__init__.py index 38eda5259..796041e79 100644 --- a/modyn/evaluator/internal/utils/__init__.py +++ b/modyn/evaluator/internal/utils/__init__.py @@ -9,6 +9,7 @@ from .evaluation_info import EvaluationInfo # noqa: F401 from .evaluation_process_info import EvaluationProcessInfo # noqa: F401 from .evaluator_messages import EvaluatorMessages # noqa: F401 +from .tuning_info import TuningInfo # noqa: F401 files = os.listdir(os.path.dirname(__file__)) files.remove("__init__.py") diff --git a/modyn/evaluator/internal/utils/evaluation_info.py b/modyn/evaluator/internal/utils/evaluation_info.py index dd17cb0aa..747f88e27 100644 --- a/modyn/evaluator/internal/utils/evaluation_info.py +++ b/modyn/evaluator/internal/utils/evaluation_info.py @@ -53,3 +53,5 @@ def __init__( self.evaluation_id = evaluation_id self.storage_address = storage_address self.model_path = model_path + self.light_tuning = request.light_tuning + self.tuning_info = json.loads(request.tuning_config) if request.HasField("tuning_config") else None diff --git a/modyn/evaluator/internal/utils/tuning_info.py b/modyn/evaluator/internal/utils/tuning_info.py new file mode 100644 index 000000000..0282c0a2b --- /dev/null +++ b/modyn/evaluator/internal/utils/tuning_info.py @@ -0,0 +1,55 @@ +import json +import logging +import pathlib +from typing import Any + +# pylint: disable=no-name-in-module +# from modyn.trainer_server.internal.grpc.generated.trainer_server_pb2 import StartTrainingtuning_info + +logger = logging.getLogger(__name__) + + +class TuningInfo: + # pylint: disable=too-many-instance-attributes + + def __init__( + self, + tuning_info: Any, + evaluation_id: int, + offline_dataset_path: str, + log_file_path: pathlib.Path, + ) -> None: + self.pipeline_id = tuning_info.pipeline_id + + self.evaluation_id = evaluation_id + self.device = tuning_info.device + self.dataset_id = tuning_info.data_info.dataset_id + self.num_dataloaders = tuning_info.data_info.num_dataloaders + self.epochs = tuning_info.epochs + self.num_samples_to_pass = tuning_info.num_samples_to_pass + + self.torch_optimizers_configuration = json.loads(tuning_info.torch_optimizers_configuration.value) + self.criterion_dict = json.loads(tuning_info.criterion_parameters.value) + self.grad_scaler_configuration = json.loads(tuning_info.grad_scaler_configuration.value) + + self.transform_list = list(tuning_info.transform_list) + self.bytes_parser = tuning_info.bytes_parser.value + self.label_transformer = tuning_info.label_transformer.value + + self.log_file_path = log_file_path + self.shuffle = tuning_info.shuffle + self.enable_accurate_gpu_measurements = tuning_info.enable_accurate_gpu_measurements + self.generative = tuning_info.generative + self.steps = tuning_info.steps + self.batch_size = tuning_info.batch_size + self.drop_last_batch = tuning_info.drop_last_batch + self.torch_criterion = tuning_info.torch_criterion + self.amp = tuning_info.amp + + self.lr_scheduler = json.loads(tuning_info.lr_scheduler.value) + + self.record_loss_every = tuning_info.record_loss_every + self.seed: int | None = tuning_info.seed if tuning_info.seed is not None else None + self.tokenizer: str | None = tuning_info.tokenizer.value if tuning_info.tokenizer is not None else None + + self.offline_dataset_path = offline_dataset_path diff --git a/modyn/models/__init__.py b/modyn/models/__init__.py index 8d92f68e8..97eb6c933 100644 --- a/modyn/models/__init__.py +++ b/modyn/models/__init__.py @@ -6,6 +6,9 @@ from .dlrm.dlrm import DLRM # noqa: F401 from .dummy.dummy import Dummy # noqa: F401 from .fmownet.fmownet import FmowNet # noqa: F401 +from .gpt2.gpt2 import Gpt2 # noqa: F401 +from .modular_adapters import modular_adapters # noqa: F401 +from .modular_adapters.modular_adapters import apply_kadapter, apply_lora # noqa: F401 from .resnet18.resnet18 import ResNet18 # noqa: F401 from .resnet50.resnet50 import ResNet50 # noqa: F401 from .resnet152.resnet152 import ResNet152 # noqa: F401 diff --git a/modyn/models/gpt2/_init_.py b/modyn/models/gpt2/_init_.py new file mode 100644 index 000000000..f5987b52c --- /dev/null +++ b/modyn/models/gpt2/_init_.py @@ -0,0 +1,5 @@ +import os + +files = os.listdir(os.path.dirname(__file__)) +files.remove("__init__.py") +__all__ = [f[:-3] for f in files if f.endswith(".py")] diff --git a/modyn/models/gpt2/gpt2.py b/modyn/models/gpt2/gpt2.py new file mode 100644 index 000000000..198a9c198 --- /dev/null +++ b/modyn/models/gpt2/gpt2.py @@ -0,0 +1,103 @@ +from typing import Any + +import torch +from torch import nn +from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel, GPT2Model + +from modyn.models.coreset_methods_support import CoresetSupportingModule + + +class Gpt2: + # pylint: disable-next=unused-argument + def __init__(self, hparams: Any, device: str, amp: bool) -> None: + self.model = Gpt2Modyn(hparams) + self.model.to(device) + + +""" +Adapted from an example implementation of a GPT-2 model. +This implementation uses the GPT-2 tokenizer from Hugging Face's Transformers library: +https://huggingface.co/docs/transformers/model_doc/gpt2 +""" + + +class Gpt2Modyn(CoresetSupportingModule): + def __init__(self, hparams: Any) -> None: + super().__init__() + + # Use hparams to decide the GPT-2 version + model_name = hparams.model_name_or_path if hasattr(hparams, "model_name_or_path") else "gpt2-large" + + # Assert that the model name is valid + valid_model_names = {"gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"} + assert model_name in valid_model_names, f"Invalid model name: {model_name}. Must be one of {valid_model_names}." + self.tokenizer = AutoTokenizer.from_pretrained("gpt2-large") + self.tokenizer.pad_token = self.tokenizer.eos_token + # Load the specified GPT-2 model + + self.model = GPT2LMHeadModel.from_pretrained(model_name) + self.config = GPT2Config.from_pretrained(model_name) + self.transformer = GPT2Model(self.config) + + def forward(self, data: torch.Tensor, labels: torch.Tensor = None) -> torch.Tensor: + """Forward method for text generation or language modeling tasks. + + Args: + - data (torch.Tensor): Tensor of shape (batch_size, seq_len, 2), where + the last dimension contains token IDs and attention masks. + - labels (torch.Tensor, optional): Tensor of labels for language modeling tasks. + + Returns: + - output: The output logits or loss from the GPT-2 model. + """ + # Split input into token IDs and attention masks + input_ids = data[:, :, 0] + + attention_mask = data[:, :, 1] + # Forward pass through GPT-2 + + output = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + + return output + + def get_last_layer(self) -> nn.Module: + """Retrieve the last layer (lm_head) of the model. + + Returns: + The final linear layer of the GPT-2 model. + """ + return self.model.lm_head + + def freeze_params(self) -> None: + for par in self.model.parameters(): + par.requires_grad = False + for par in self.transformer.parameters(): + par.requires_grad = False + + def unfreeze_params(self) -> None: + for par in self.model.parameters(): + par.requires_grad = True + for par in self.transformer.parameters(): + par.requires_grad = True + + def generate( + self, + input_ids: torch.tensor, + max_length: int = 50, + temperature: float = 1.0, + top_k: int = 50, + top_p: float = 0.95, + num_return_sequences: int = 1, + ) -> list: + # Generate output sequences + outputs = self.model.generate( + input_ids=input_ids, + max_length=max_length, + temperature=temperature, + top_k=top_k, + top_p=top_p, + num_return_sequences=num_return_sequences, + pad_token_id=self.tokenizer.eos_token_id, + ) + + return [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs] diff --git a/modyn/models/gpt2/others.py b/modyn/models/gpt2/others.py new file mode 100644 index 000000000..d722b35c6 --- /dev/null +++ b/modyn/models/gpt2/others.py @@ -0,0 +1,125 @@ +import random +from typing import Any + +import torch +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + RobertaForMaskedLM, + T5ForConditionalGeneration, +) + +from modyn.models.coreset_methods_support import CoresetSupportingModule + + +class Roberta: + """Adapted from an example implementation of a RoBERTa model for masked language modeling.""" + + def __init__(self, hparams: Any, device: str, amp: bool) -> None: + self.model = RobertaModyn(hparams) + self.model.to(device) + + +class RobertaModyn(CoresetSupportingModule): + def __init__(self, hparams: Any) -> None: + super().__init__() + self.model = RobertaForMaskedLM.from_pretrained("roberta-large") + self.tokenizer = AutoTokenizer.from_pretrained("roberta-large") + self.mask_probability = 0.15 + self.freeze_params() + + def freeze_params(self) -> None: + for par in self.model.parameters(): + par.requires_grad = False + + def unfreeze_params(self) -> None: + for par in self.model.parameters(): + par.requires_grad = True + + def mask_input(self, input_ids: torch.Tensor) -> torch.Tensor: + masked_input_ids = input_ids.clone() + for i in range(masked_input_ids.size(1)): + if random.random() < self.mask_probability and masked_input_ids[0, i] not in [ + self.tokenizer.cls_token_id, + self.tokenizer.sep_token_id, + ]: + masked_input_ids[0, i] = self.tokenizer.mask_token_id + return masked_input_ids + + def forward(self, data: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + input_ids = data[:, :, 0] + attention_mask = data[:, :, 1] + masked_input_ids = self.mask_input(input_ids) + output = self.model(input_ids=masked_input_ids, attention_mask=attention_mask, labels=labels) + return output.logits + + +class T5: + """Adapted from an example implementation of a T5 model for sequence-to-sequence tasks.""" + + def __init__(self, hparams: Any, device: str, amp: bool) -> None: + self.model = T5Modyn(hparams) + self.model.to(device) + + +class T5Modyn(CoresetSupportingModule): + def __init__(self, hparams: Any) -> None: + super().__init__() + self.model = T5ForConditionalGeneration.from_pretrained("t5-large") + self.tokenizer = AutoTokenizer.from_pretrained("t5-large") + self.freeze_params() + + def freeze_params(self) -> None: + for par in self.model.parameters(): + par.requires_grad = False + + def unfreeze_params(self) -> None: + for par in self.model.parameters(): + par.requires_grad = True + + def forward(self, data: torch.Tensor, labels: torch.Tensor = None) -> torch.Tensor: + input_ids = data[:, :, 0] + attention_mask = data[:, :, 1] + decoder_input_ids = labels if labels is not None else torch.full_like(input_ids, self.tokenizer.pad_token_id) + output = self.model(input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids) + return output.logits + + def generate(self, prompt: str, max_length: int = 50) -> str: + inputs = self.tokenizer(prompt, return_tensors="pt", padding=True).to(self.device) + output = self.model.generate(**inputs, max_length=max_length) + return self.tokenizer.decode(output[0], skip_special_tokens=True) + + +class LLaMA: + """Adapted from an example implementation of a LLaMA model for causal language modeling.""" + + def __init__(self, hparams: Any, device: str, amp: bool) -> None: + self.model = LLaMAModyn(hparams) + self.model.to(device) + + +class LLaMAModyn(CoresetSupportingModule): + def __init__(self, hparams: Any) -> None: + super().__init__() + self.model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-1B") + self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-1B") + self.freeze_params() + + def freeze_params(self) -> None: + for par in self.model.parameters(): + par.requires_grad = False + + def unfreeze_params(self) -> None: + for par in self.model.parameters(): + par.requires_grad = True + + def forward(self, data: torch.Tensor, labels: torch.Tensor = None) -> torch.Tensor: + input_ids = data[:, :, 0] + attention_mask = data[:, :, 1] + output = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + return output.logits + + def generate(self, prompt: str, max_length: int = 50) -> str: + inputs = self.tokenizer(prompt, return_tensors="pt", padding=True).to(self.device) + output = self.model.generate(**inputs, max_length=max_length) + return self.tokenizer.decode(output[0], skip_special_tokens=True) diff --git a/modyn/trainer_server/internal/grpc/generated/__init__.py b/modyn/models/modular_adapters/__init__.py similarity index 100% rename from modyn/trainer_server/internal/grpc/generated/__init__.py rename to modyn/models/modular_adapters/__init__.py diff --git a/modyn/models/modular_adapters/modular_adapters.py b/modyn/models/modular_adapters/modular_adapters.py new file mode 100644 index 000000000..99de54e13 --- /dev/null +++ b/modyn/models/modular_adapters/modular_adapters.py @@ -0,0 +1,170 @@ +import types +from typing import Any + +import torch +from peft import LoraConfig, TaskType, get_peft_model +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers import GPT2Config +from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions + +# ============================================================================= +# Custom Adapter Model (KAdapter) defined manually +# ============================================================================= +# You must also have GPT2Block imported from your GPT-2 implementation. +# For this example, we assume it’s available as follows: +from transformers.models.gpt2.modeling_gpt2 import GPT2Block + + +class AdapterModel(nn.Module): + def __init__(self, pretrained_config: GPT2Config | None = None) -> None: + self.config = pretrained_config + if self.config is None: + self.config = GPT2Config.from_pretrained("gpt2-large") + super().__init__() + + self.embed_dim = self.config.hidden_size + # Define which layers to pull hidden states from + self.adapter_list = [1, 11] # For example, use layers 1 and 11 + self.adapter_num = len(self.adapter_list) + self.layer_norm = nn.LayerNorm(self.embed_dim, eps=self.config.layer_norm_epsilon) + # Create an adapter (here using GPT2Block) for each designated layer + self.adapter = nn.ModuleList([GPT2Block(self.config) for _ in range(self.adapter_num)]) + + def forward(self, pretrained_model_outputs: CausalLMOutputWithCrossAttentions) -> torch.Tensor: + # Assume pretrained_model_outputs is a ModelOutput with hidden_states and [0]=final hidden state. + sequence_output = pretrained_model_outputs[0] + hidden_states = pretrained_model_outputs.hidden_states + # Determine device from sequence_output + device = sequence_output.device + hidden_states_last = torch.zeros(sequence_output.size(), device=device) + + for i, adapter_module in enumerate(self.adapter): + # Get hidden state from the designated layer + pretrained_hidden_state = hidden_states[self.adapter_list[i]] + # Fuse with previously adapter-processed output + + fusion_state = pretrained_hidden_state + hidden_states_last + hidden_states_last = adapter_module(fusion_state)[0] + + scale_factor = 0.1 + # Fuse adapter output (after normalization/scaling) with the final hidden state + outputs = (scale_factor * self.layer_norm(hidden_states_last)) + sequence_output + return outputs + + +# ============================================================================= +# LoRA layer & apply_lora remains unchanged +# ============================================================================= + + +def apply_lora( + model: nn.Module, target_modules: list[str] | None = None, adapter_dim: int = 16, adapter_alpha: int = 32 +) -> nn.Module: + # Use default target modules for GPT-2 if not provided. + if target_modules is None: + target_modules = ["c_attn", "c_proj"] + + # Create a LoRA configuration. + # Here, `r` corresponds to the adapter (low-rank) dimension. + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=adapter_dim, + lora_alpha=adapter_alpha, + lora_dropout=0.0, + target_modules=target_modules, + ) + + # Wrap the model with the LoRA adapter. + # This call freezes the base parameters and adds trainable LoRA parameters. + model.model = get_peft_model(model.model, lora_config) + return model + + +def apply_kadapter( + model: nn.Module, +) -> nn.Module: + # Freeze base parameters + if hasattr(model, "freeze_params"): + model.freeze_params() + else: + for param in model.parameters(): + param.requires_grad = False + + # Ensure the model has a 'config' attribute + + # (Optional) Print some config details for debugging + + # Attach the custom adapter; AdconfigapterModel should be defined to use model.config.hidden_size, etc. + model.kadapter = AdapterModel(model.config) + + for param in model.kadapter.parameters(): + param.requires_grad = True + + # Define a new forward that integrates the adapter output. + def forward_with_adapter( + self: nn.Module, + data: torch.Tensor, + past_key_values: Any | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + head_mask: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool = True, + return_dict: bool | None = None, + **kwargs: Any, + ) -> Any: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + input_ids = data[:, :, 0] + + attention_mask = data[:, :, 1] + model_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + + # If hidden states are returned, pass them through the adapter. + if model_outputs.hidden_states is not None: + hidden_states = self.kadapter(model_outputs) + else: + hidden_states = model_outputs[0] + + lm_logits = self.model.lm_head(hidden_states) + loss = None + if labels is not None: + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss = CrossEntropyLoss()(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + if not return_dict: + output = (lm_logits,) + model_outputs[1:] + return ((loss,) + output) if loss is not None else output + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=model_outputs.past_key_values, + hidden_states=model_outputs.hidden_states, + attentions=model_outputs.attentions, + cross_attentions=model_outputs.cross_attentions, + ) + + # Replace the model's forward with the new forward + model.forward = types.MethodType(forward_with_adapter, model) + return model diff --git a/modyn/models/tokenizers/__init__.py b/modyn/models/tokenizers/__init__.py index 93e8d9c8f..c8e98fd02 100644 --- a/modyn/models/tokenizers/__init__.py +++ b/modyn/models/tokenizers/__init__.py @@ -1,8 +1,10 @@ -"""Bert Tokenizer for NLP tasks.""" +"""Tokenizer for NLP tasks.""" import os from .distill_bert_tokenizer import DistilBertTokenizerTransform # noqa: F401 +from .gpt2_tokenizer import GPT2TokenizerTransform # noqa: F401 +from .hf_tokenizer import HFTokenizerTransform # noqa: F401 files = os.listdir(os.path.dirname(__file__)) files.remove("__init__.py") diff --git a/modyn/models/tokenizers/distill_bert_tokenizer.py b/modyn/models/tokenizers/distill_bert_tokenizer.py index ece976992..50bdc3c99 100644 --- a/modyn/models/tokenizers/distill_bert_tokenizer.py +++ b/modyn/models/tokenizers/distill_bert_tokenizer.py @@ -1,24 +1,14 @@ -import torch from transformers import DistilBertTokenizer +from .hf_tokenizer import HFTokenizerTransform -class DistilBertTokenizerTransform: - """ - Adapted from WildTime's initialize_distilbert_transform - Here you can find the original implementation: - https://github.com/huaxiuyao/Wild-Time/blob/main/wildtime/data/utils.py - """ - def __init__(self, max_token_length: int = 300) -> None: - self.max_token_length = max_token_length - self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") - - def __call__(self, sample: str) -> torch.Tensor: - # make the class Callable to use it as Torch Transform - tokens = self.tokenizer( - sample, padding="max_length", truncation=True, max_length=self.max_token_length, return_tensors="pt" - ) - # create a tensor whose first dimension is the input_ids and the second is the attention_mask - data = torch.stack((tokens["input_ids"], tokens["attention_mask"]), dim=2) - data = torch.squeeze(data, dim=0) # First shape dim is always 1, since the input is just one string - return data +class DistilBertTokenizerTransform(HFTokenizerTransform): + def __init__(self, max_token_length: int = 300): + """ + Adapted from WildTime's initialize_distilbert_transform + Here you can find the original implementation: + https://github.com/huaxiuyao/Wild-Time/blob/main/wildtime/data/utils.py + """ + tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") + super().__init__(tokenizer, max_token_length) diff --git a/modyn/models/tokenizers/gpt2_tokenizer.py b/modyn/models/tokenizers/gpt2_tokenizer.py new file mode 100644 index 000000000..5aab2a64f --- /dev/null +++ b/modyn/models/tokenizers/gpt2_tokenizer.py @@ -0,0 +1,17 @@ +from transformers import GPT2Tokenizer + +from .hf_tokenizer import HFTokenizerTransform + + +class GPT2TokenizerTransform(HFTokenizerTransform): + def __init__(self, max_token_length: int = 512): + """Adapted from an example implementation of a GPT-2 tokenizer. + + This implementation uses the GPT-2 tokenizer from Hugging Face's + Transformers library: + https://huggingface.co/docs/transformers/model_doc/gpt2 + """ + tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large") + tokenizer.pad_token = tokenizer.eos_token # Set pad token to eos token to avoid padding errors + tokenizer.padding_side = "right" + super().__init__(tokenizer, max_token_length) diff --git a/modyn/models/tokenizers/hf_tokenizer.py b/modyn/models/tokenizers/hf_tokenizer.py new file mode 100644 index 000000000..620d35585 --- /dev/null +++ b/modyn/models/tokenizers/hf_tokenizer.py @@ -0,0 +1,31 @@ +import torch +from transformers import PreTrainedTokenizer + + +class HFTokenizerTransform: + def __init__(self, tokenizer: PreTrainedTokenizer, max_token_length: int) -> None: + """Parent class for tokenizers based on HuggingFace's Transformers. + + Args: + tokenizer: Preloaded tokenizer object. + max_token_length: Maximum length for tokenization. + """ + self.max_token_length = max_token_length + self.tokenizer = tokenizer + + def __call__(self, sample: str) -> torch.Tensor: + """ + Tokenize the input sample and return a tensor with input_ids and attention_mask. + Args: + sample: Input string to tokenize. + Returns: + A torch.Tensor with shape (max_token_length, 2), where: + - dim 0 is the token sequence length. + - dim 1 contains input_ids and attention_mask. + """ + tokens = self.tokenizer( + sample, padding="max_length", truncation=True, max_length=self.max_token_length, return_tensors="pt" + ) + data = torch.stack((tokens["input_ids"], tokens["attention_mask"]), dim=2) + data = torch.squeeze(data, dim=0) + return data diff --git a/modyn/protos/evaluator.proto b/modyn/protos/evaluator.proto index 426d78de5..951e6e8e2 100644 --- a/modyn/protos/evaluator.proto +++ b/modyn/protos/evaluator.proto @@ -44,6 +44,8 @@ message EvaluateModelRequest { PythonString bytes_parser = 7; PythonString label_transformer = 8; optional PythonString tokenizer = 9; + bool light_tuning = 10; + optional JsonString tuning_config = 11; } message EvaluateModelIntervalResponse { diff --git a/modyn/protos/storage.proto b/modyn/protos/storage.proto index 4b6fc9dde..429ee730d 100644 --- a/modyn/protos/storage.proto +++ b/modyn/protos/storage.proto @@ -4,6 +4,7 @@ package modyn.storage; service Storage { rpc Get(GetRequest) returns (stream GetResponse) {} + rpc GetNL(GetRequest) returns (stream GetResponse) {} rpc GetNewDataSince(GetNewDataSinceRequest) returns (stream GetNewDataSinceResponse) {} rpc GetDataInInterval(GetDataInIntervalRequest) @@ -33,6 +34,7 @@ message GetResponse { repeated int64 labels = 3; } + // https://github.com/grpc/grpc/issues/15937 message GetCurrentTimestampRequest {} diff --git a/modyn/protos/trainer_server.proto b/modyn/protos/trainer_server.proto index 0c7343c70..99ce97c90 100644 --- a/modyn/protos/trainer_server.proto +++ b/modyn/protos/trainer_server.proto @@ -59,6 +59,10 @@ message StartTrainingRequest { bool enable_accurate_gpu_measurements = 25; int64 record_loss_every = 26; bool drop_last_batch = 27; + bool generative = 28; + optional float grad_norm = 29; + bool lora = 30; + bool kadapter = 31; } message StartTrainingResponse { diff --git a/modyn/selector/internal/selector_strategies/presampling_strategies/original_data_strategy.py b/modyn/selector/internal/selector_strategies/presampling_strategies/original_data_strategy.py new file mode 100644 index 000000000..88cd815de --- /dev/null +++ b/modyn/selector/internal/selector_strategies/presampling_strategies/original_data_strategy.py @@ -0,0 +1,70 @@ +from sqlalchemy import Select, asc, func, select + +from modyn.config.schema.pipeline import PresamplingConfig +from modyn.metadata_database.models import SelectorStateMetadata +from modyn.selector.internal.selector_strategies.presampling_strategies.abstract_presampling_strategy import ( + AbstractPresamplingStrategy, +) +from modyn.selector.internal.storage_backend.abstract_storage_backend import AbstractStorageBackend + + +class OriginalSetPresamplingStrategy(AbstractPresamplingStrategy): + def __init__( + self, + presampling_config: PresamplingConfig, + modyn_config: dict, + pipeline_id: int, + storage_backend: AbstractStorageBackend, + ): + super().__init__(presampling_config, modyn_config, pipeline_id, storage_backend) + self.requires_trigger_dataset_size = True + self.first_trigger_ratio = presampling_config.get("first_trigger_ratio", 0.5) + + def get_presampling_query( + self, + next_trigger_id: int, + tail_triggers: int | None, + limit: int | None, + trigger_dataset_size: int | None, + ) -> Select: + """ + - In `trigger_id = 0`: Select **ALL** available data. + - In later triggers (`next_trigger_id > 0`): Select: + 1. A random fraction (`first_trigger_ratio`) of samples from `trigger_id = 0`. + 2. All samples from the current trigger (`next_trigger_id`). + """ + assert trigger_dataset_size is not None + assert trigger_dataset_size >= 0 + + if next_trigger_id == 0: + stmt = ( + select(SelectorStateMetadata.sample_key) + .filter(SelectorStateMetadata.pipeline_id == self.pipeline_id) + .order_by(asc(SelectorStateMetadata.timestamp)) + ) + return stmt + + first_trigger_subq = ( + select(SelectorStateMetadata.sample_key) + .filter( + SelectorStateMetadata.pipeline_id == self.pipeline_id, + SelectorStateMetadata.seen_in_trigger_id == 0, + ) + .order_by(func.random()) # pylint: disable=E1102 + .limit(int(trigger_dataset_size * self.first_trigger_ratio)) + ) + + current_trigger_subq = select(SelectorStateMetadata.sample_key).filter( + SelectorStateMetadata.pipeline_id == self.pipeline_id, + SelectorStateMetadata.seen_in_trigger_id == next_trigger_id, # 🔹 Select ALL from the current trigger + ) + + stmt = ( + select(SelectorStateMetadata.sample_key) + .filter( + SelectorStateMetadata.pipeline_id == self.pipeline_id, + SelectorStateMetadata.sample_key.in_(first_trigger_subq.union(current_trigger_subq)), + ) + .order_by(asc(SelectorStateMetadata.timestamp)) + ) + return stmt diff --git a/modyn/storage/include/internal/file_wrapper/binary_file_wrapper.hpp b/modyn/storage/include/internal/file_wrapper/binary_file_wrapper.hpp index 235442a91..efc3dd451 100644 --- a/modyn/storage/include/internal/file_wrapper/binary_file_wrapper.hpp +++ b/modyn/storage/include/internal/file_wrapper/binary_file_wrapper.hpp @@ -17,10 +17,17 @@ class BinaryFileWrapper : public FileWrapper { : FileWrapper(path, fw_config, std::move(filesystem_wrapper)) { ASSERT(filesystem_wrapper_ != nullptr, "Filesystem wrapper cannot be null."); ASSERT(fw_config["record_size"], "record_size must be specified in the file wrapper config."); - ASSERT(fw_config["label_size"], "label_size be specified in the file wrapper config."); record_size_ = fw_config["record_size"].as(); - label_size_ = fw_config["label_size"].as(); + if (!file_wrapper_config_["has_labels"] || file_wrapper_config_["has_labels"].as()) { + has_labels_ = true; + ASSERT(fw_config["label_size"], "label_size be specified in the file wrapper config."); + label_size_ = fw_config["label_size"].as(); + } else { + has_labels_ = false; + label_size_ = 0; // No labels exist + } + sample_size_ = record_size_ - label_size_; validate_file_extension(); file_size_ = filesystem_wrapper_->get_file_size(path); @@ -36,8 +43,9 @@ class BinaryFileWrapper : public FileWrapper { int64_t get_label(uint64_t index) override; std::vector get_all_labels() override; std::vector get_sample(uint64_t index) override; - std::vector> get_samples(uint64_t start, uint64_t end) override; - std::vector> get_samples_from_indices(const std::vector& indices) override; + std::vector> get_samples(uint64_t start, uint64_t end, bool include_labels) override; + std::vector> get_samples_from_indices(const std::vector& indices, + bool include_labels) override; void validate_file_extension() override; void delete_samples(const std::vector& indices) override; void set_file_path(const std::string& path) override; @@ -63,7 +71,7 @@ class BinaryFileWrapper : public FileWrapper { uint64_t sample_size_; bool little_endian_; std::shared_ptr stream_; - + bool has_labels_; friend class BinaryFileWrapperTest; // let gtest access private members }; } // namespace modyn::storage diff --git a/modyn/storage/include/internal/file_wrapper/csv_file_wrapper.hpp b/modyn/storage/include/internal/file_wrapper/csv_file_wrapper.hpp index 9e95c2e41..4bff74db6 100644 --- a/modyn/storage/include/internal/file_wrapper/csv_file_wrapper.hpp +++ b/modyn/storage/include/internal/file_wrapper/csv_file_wrapper.hpp @@ -15,8 +15,15 @@ class CsvFileWrapper : public FileWrapper { CsvFileWrapper(const std::string& path, const YAML::Node& fw_config, std::shared_ptr filesystem_wrapper) : FileWrapper{path, fw_config, std::move(filesystem_wrapper)} { - ASSERT(file_wrapper_config_["label_index"], "Please specify the index of the column that contains the label."); - label_index_ = file_wrapper_config_["label_index"].as(); + if (!file_wrapper_config_["has_labels"] || file_wrapper_config_["has_labels"].as()) { + has_labels_ = true; + ASSERT(file_wrapper_config_["label_index"].as() != -1, + "Please specify the index of the column that contains the label."); + label_index_ = file_wrapper_config_["label_index"].as(); + } else { + has_labels_ = false; + label_index_ = -1; // No labels exist + } if (file_wrapper_config_["separator"]) { separator_ = file_wrapper_config_["separator"].as(); @@ -65,8 +72,9 @@ class CsvFileWrapper : public FileWrapper { int64_t get_label(uint64_t index) override; std::vector get_all_labels() override; std::vector get_sample(uint64_t index) override; - std::vector> get_samples(uint64_t start, uint64_t end) override; - std::vector> get_samples_from_indices(const std::vector& indices) override; + std::vector> get_samples(uint64_t start, uint64_t end, bool include_labels) override; + std::vector> get_samples_from_indices(const std::vector& indices, + bool include_labels) override; void validate_file_extension() override; void delete_samples(const std::vector& indices) override; void set_file_path(const std::string& path) override; @@ -79,5 +87,6 @@ class CsvFileWrapper : public FileWrapper { rapidcsv::Document doc_; rapidcsv::LabelParams label_params_; std::shared_ptr stream_; + bool has_labels_; }; } // namespace modyn::storage diff --git a/modyn/storage/include/internal/file_wrapper/file_wrapper.hpp b/modyn/storage/include/internal/file_wrapper/file_wrapper.hpp index 0c64cc759..7bfb7a5e5 100644 --- a/modyn/storage/include/internal/file_wrapper/file_wrapper.hpp +++ b/modyn/storage/include/internal/file_wrapper/file_wrapper.hpp @@ -20,8 +20,9 @@ class FileWrapper { virtual int64_t get_label(uint64_t index) = 0; virtual std::vector get_all_labels() = 0; virtual std::vector get_sample(uint64_t index) = 0; - virtual std::vector> get_samples(uint64_t start, uint64_t end) = 0; - virtual std::vector> get_samples_from_indices(const std::vector& indices) = 0; + virtual std::vector> get_samples(uint64_t start, uint64_t end, bool include_labels) = 0; + virtual std::vector> get_samples_from_indices(const std::vector& indices, + bool include_labels) = 0; virtual void validate_file_extension() = 0; virtual void delete_samples(const std::vector& indices) = 0; virtual void set_file_path(const std::string& path) = 0; diff --git a/modyn/storage/include/internal/file_wrapper/single_sample_file_wrapper.hpp b/modyn/storage/include/internal/file_wrapper/single_sample_file_wrapper.hpp index 5bf09fcbc..b14082723 100644 --- a/modyn/storage/include/internal/file_wrapper/single_sample_file_wrapper.hpp +++ b/modyn/storage/include/internal/file_wrapper/single_sample_file_wrapper.hpp @@ -17,8 +17,9 @@ class SingleSampleFileWrapper : public FileWrapper { int64_t get_label(uint64_t index) override; std::vector get_all_labels() override; std::vector get_sample(uint64_t index) override; - std::vector> get_samples(uint64_t start, uint64_t end) override; - std::vector> get_samples_from_indices(const std::vector& indices) override; + std::vector> get_samples(uint64_t start, uint64_t end, bool include_labels) override; + std::vector> get_samples_from_indices(const std::vector& indices, + bool include_labels) override; void validate_file_extension() override; void delete_samples(const std::vector& indices) override; void set_file_path(const std::string& path) override { file_path_ = path; } diff --git a/modyn/storage/include/internal/grpc/storage_service_impl.hpp b/modyn/storage/include/internal/grpc/storage_service_impl.hpp index 875bf2e53..a4aeb0aac 100644 --- a/modyn/storage/include/internal/grpc/storage_service_impl.hpp +++ b/modyn/storage/include/internal/grpc/storage_service_impl.hpp @@ -49,6 +49,7 @@ struct DatasetData { FilesystemWrapperType filesystem_wrapper_type = FilesystemWrapperType::INVALID_FSW; FileWrapperType file_wrapper_type = FileWrapperType::INVALID_FW; std::string file_wrapper_config; + bool has_labels = true; }; class StorageServiceImpl final : public modyn::storage::Storage::Service { @@ -102,7 +103,7 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { // Check if the dataset exists std::string dataset_name = request->dataset_id(); const DatasetData dataset_data = get_dataset_data(session, dataset_name); - + const bool include_labels = dataset_data.has_labels; SPDLOG_INFO(fmt::format("Received GetRequest for dataset {} (id = {}) with {} keys.", dataset_name, dataset_data.dataset_id, request->keys_size())); @@ -121,7 +122,8 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { request_keys.reserve(keys_size); std::copy(request->keys().begin(), request->keys().end(), std::back_inserter(request_keys)); - send_sample_data_from_keys(writer, request_keys, dataset_data); + // Call the appropriate function based on the include_labels flag + send_sample_data_from_keys(writer, request_keys, dataset_data, include_labels); // sqlite causes memory leaks otherwise if (session.get_backend_name() != "sqlite3" && session.is_connected()) { @@ -154,6 +156,7 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { request->dataset_id(), dataset_id, request_timestamp)); send_file_ids_and_labels(writer, dataset_id, request_timestamp); + } catch (const std::exception& e) { SPDLOG_ERROR("Error in GetNewDataSince: {}", e.what()); return {StatusCode::INTERNAL, fmt::format("Error in GetNewDataSince: {}", e.what())}; @@ -312,7 +315,7 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { template > void send_sample_data_from_keys(WriterT* writer, const std::vector& request_keys, - const DatasetData& dataset_data) { + const DatasetData& dataset_data, bool include_labels = true) { // Create mutex to protect the writer from concurrent writes as this is not supported by gRPC std::mutex writer_mutex; @@ -320,7 +323,8 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { const std::vector::const_iterator begin = request_keys.begin(); // NOLINT (modernize-use-auto) const std::vector::const_iterator end = request_keys.end(); // NOLINT (modernize-use-auto) - get_samples_and_send(begin, end, writer, &writer_mutex, &dataset_data, &config_, sample_batch_size_); + get_samples_and_send(begin, end, writer, &writer_mutex, &dataset_data, &config_, sample_batch_size_, + include_labels); } else { std::vector thread_exceptions(retrieval_threads_); @@ -332,18 +336,19 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { const std::vector::const_iterator begin = its_per_thread[thread_id].first; const std::vector::const_iterator end = its_per_thread[thread_id].second; - retrieval_threads_vector[thread_id] = std::thread([thread_id, begin, end, writer, &writer_mutex, &dataset_data, - &thread_exceptions, &exception_mutex, this]() { - try { - get_samples_and_send(begin, end, writer, &writer_mutex, &dataset_data, &config_, - sample_batch_size_); - } catch (const std::exception& e) { - const std::lock_guard lock(exception_mutex); - spdlog::error( - fmt::format("Error in thread {} started by send_sample_data_from_keys: {}", thread_id, e.what())); - thread_exceptions[thread_id] = std::current_exception(); - } - }); + retrieval_threads_vector[thread_id] = + std::thread([thread_id, begin, end, writer, &writer_mutex, &dataset_data, &thread_exceptions, + &exception_mutex, this, include_labels]() { + try { + get_samples_and_send(begin, end, writer, &writer_mutex, &dataset_data, &config_, + sample_batch_size_, include_labels); + } catch (const std::exception& e) { + const std::lock_guard lock(exception_mutex); + spdlog::error( + fmt::format("Error in thread {} started by send_sample_data_from_keys: {}", thread_id, e.what())); + thread_exceptions[thread_id] = std::current_exception(); + } + }); } for (uint64_t thread_id = 0; thread_id < retrieval_threads_; ++thread_id) { @@ -550,7 +555,8 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { template > static void send_sample_data_for_keys_and_file( // NOLINT(readability-function-cognitive-complexity) WriterT* writer, std::mutex& writer_mutex, const std::vector& sample_keys, - const DatasetData& dataset_data, soci::session& session, int64_t /*sample_batch_size*/) { + const DatasetData& dataset_data, soci::session& session, int64_t /*sample_batch_size*/, + bool include_labels = true) { // Note that we currently ignore the sample batch size here, under the assumption that users do not request more // keys than this try { @@ -564,6 +570,7 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { std::vector sample_labels(num_keys); std::vector sample_indices(num_keys); std::vector sample_fileids(num_keys); + const std::string sample_query = fmt::format( "SELECT label, sample_index, file_id FROM samples WHERE dataset_id = :dataset_id AND sample_id IN ({}) ORDER " "BY file_id", @@ -600,10 +607,10 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { throw modyn::utils::ModynException(fmt::format("Could not obtain full path of file id {} in dataset {}", current_file_id, dataset_data.dataset_id)); } + const YAML::Node file_wrapper_config_node = YAML::Load(dataset_data.file_wrapper_config); auto filesystem_wrapper = get_filesystem_wrapper(static_cast(dataset_data.filesystem_wrapper_type)); - auto file_wrapper = get_file_wrapper(current_file_path, static_cast(dataset_data.file_wrapper_type), file_wrapper_config_node, filesystem_wrapper); @@ -612,12 +619,11 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { const int64_t& sample_fileid = sample_fileids.at(sample_idx); if (sample_fileid != current_file_id) { - // 1. Prepare response const std::vector file_indexes( sample_indices.begin() + static_cast(current_file_start_idx), sample_indices.begin() + static_cast(sample_idx)); - std::vector> data = file_wrapper->get_samples_from_indices(file_indexes); - + std::vector> data = + file_wrapper->get_samples_from_indices(file_indexes, include_labels); // Protobuf expects the data as std::string... std::vector stringified_data; stringified_data.reserve(data.size()); @@ -631,16 +637,17 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { response.mutable_samples()->Assign(stringified_data.begin(), stringified_data.end()); response.mutable_keys()->Assign(sample_keys.begin() + static_cast(current_file_start_idx), sample_keys.begin() + static_cast(sample_idx)); - response.mutable_labels()->Assign(sample_labels.begin() + static_cast(current_file_start_idx), - sample_labels.begin() + static_cast(sample_idx)); - // 2. Send response + if (!include_labels) { + response.mutable_labels()->Clear(); // Return empty labels if include_labels is true + } else { + response.mutable_labels()->Assign(sample_labels.begin() + static_cast(current_file_start_idx), + sample_labels.begin() + static_cast(sample_idx)); + } { const std::lock_guard lock(writer_mutex); writer->Write(response); } - - // 3. Update state current_file_id = sample_fileid; current_file_path = "", session << "SELECT path FROM files WHERE file_id = :file_id AND dataset_id = :dataset_id", @@ -664,8 +671,10 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { // Send leftovers const std::vector file_indexes(sample_indices.begin() + static_cast(current_file_start_idx), sample_indices.end()); - const std::vector> data = file_wrapper->get_samples_from_indices(file_indexes); - // Protobuf expects the data as std::string... + + const std::vector> data = + file_wrapper->get_samples_from_indices(file_indexes, include_labels); + std::vector stringified_data; stringified_data.reserve(data.size()); for (const std::vector& char_vec : data) { @@ -676,9 +685,13 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { response.mutable_samples()->Assign(stringified_data.begin(), stringified_data.end()); response.mutable_keys()->Assign(sample_keys.begin() + static_cast(current_file_start_idx), sample_keys.end()); - response.mutable_labels()->Assign(sample_labels.begin() + static_cast(current_file_start_idx), - sample_labels.end()); + if (!include_labels) { + response.mutable_labels()->Clear(); // Return empty labels if include_labels is true + } else { + response.mutable_labels()->Assign(sample_labels.begin() + static_cast(current_file_start_idx), + sample_labels.end()); + } { const std::lock_guard lock(writer_mutex); writer->Write(response); @@ -694,15 +707,18 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { static void get_samples_and_send(const std::vector::const_iterator begin, const std::vector::const_iterator end, WriterT* writer, std::mutex* writer_mutex, const DatasetData* dataset_data, const YAML::Node* config, - int64_t sample_batch_size) { + int64_t sample_batch_size, bool include_labels = true) { if (begin >= end) { return; } const StorageDatabaseConnection storage_database_connection(*config); soci::session session = storage_database_connection.get_session(); const std::vector sample_keys(begin, end); + + // Call the appropriate function based on the include_labels flag send_sample_data_for_keys_and_file(writer, *writer_mutex, sample_keys, *dataset_data, session, - sample_batch_size); + sample_batch_size, include_labels); + session.close(); } diff --git a/modyn/storage/internal/grpc/generated/storage_pb2.py b/modyn/storage/internal/grpc/generated/storage_pb2.py index 8d59a8b1f..d6f610613 100644 --- a/modyn/storage/internal/grpc/generated/storage_pb2.py +++ b/modyn/storage/internal/grpc/generated/storage_pb2.py @@ -1,23 +1,21 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE # source: storage.proto -# Protobuf Python Version: 5.26.1 +# Protobuf Python Version: 5.27.2 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder - -# @@protoc_insertion_point(imports) +from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 _sym_db = _symbol_database.Default() -from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 - DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\rstorage.proto\x12\rmodyn.storage".\n\nGetRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x0c\n\x04keys\x18\x02 \x03(\x03"<\n\x0bGetResponse\x12\x0f\n\x07samples\x18\x01 \x03(\x0c\x12\x0c\n\x04keys\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03"\x1c\n\x1aGetCurrentTimestampRequest"?\n\x16GetNewDataSinceRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x11\n\ttimestamp\x18\x02 \x01(\x03"K\n\x17GetNewDataSinceResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\x12\x12\n\ntimestamps\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03"^\n\x18GetDataInIntervalRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x17\n\x0fstart_timestamp\x18\x02 \x01(\x03\x12\x15\n\rend_timestamp\x18\x03 \x01(\x03"M\n\x19GetDataInIntervalResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\x12\x12\n\ntimestamps\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03"\xb7\x01\n\x17GetDataPerWorkerRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x11\n\tworker_id\x18\x02 \x01(\x05\x12\x15\n\rtotal_workers\x18\x03 \x01(\x05\x12\x1c\n\x0fstart_timestamp\x18\x04 \x01(\x03H\x00\x88\x01\x01\x12\x1a\n\rend_timestamp\x18\x05 \x01(\x03H\x01\x88\x01\x01\x42\x12\n\x10_start_timestampB\x10\n\x0e_end_timestamp"(\n\x18GetDataPerWorkerResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03"\x8b\x01\n\x15GetDatasetSizeRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x1c\n\x0fstart_timestamp\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x1a\n\rend_timestamp\x18\x03 \x01(\x03H\x01\x88\x01\x01\x42\x12\n\x10_start_timestampB\x10\n\x0e_end_timestamp";\n\x16GetDatasetSizeResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x10\n\x08num_keys\x18\x02 \x01(\x03"-\n\x17\x44\x61tasetAvailableRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t"-\n\x18\x44\x61tasetAvailableResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08"\xff\x01\n\x19RegisterNewDatasetRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x1f\n\x17\x66ilesystem_wrapper_type\x18\x02 \x01(\t\x12\x19\n\x11\x66ile_wrapper_type\x18\x03 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\x12\x11\n\tbase_path\x18\x05 \x01(\t\x12\x0f\n\x07version\x18\x06 \x01(\t\x12\x1b\n\x13\x66ile_wrapper_config\x18\x07 \x01(\t\x12\x1d\n\x15ignore_last_timestamp\x18\x08 \x01(\x08\x12\x1d\n\x15\x66ile_watcher_interval\x18\t \x01(\x03"-\n\x1aRegisterNewDatasetResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08"0\n\x1bGetCurrentTimestampResponse\x12\x11\n\ttimestamp\x18\x01 \x01(\x03"(\n\x15\x44\x65leteDatasetResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08"5\n\x11\x44\x65leteDataRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x0c\n\x04keys\x18\x02 \x03(\x03"%\n\x12\x44\x65leteDataResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x32\xe2\x07\n\x07Storage\x12@\n\x03Get\x12\x19.modyn.storage.GetRequest\x1a\x1a.modyn.storage.GetResponse"\x00\x30\x01\x12\x64\n\x0fGetNewDataSince\x12%.modyn.storage.GetNewDataSinceRequest\x1a&.modyn.storage.GetNewDataSinceResponse"\x00\x30\x01\x12j\n\x11GetDataInInterval\x12\'.modyn.storage.GetDataInIntervalRequest\x1a(.modyn.storage.GetDataInIntervalResponse"\x00\x30\x01\x12g\n\x10GetDataPerWorker\x12&.modyn.storage.GetDataPerWorkerRequest\x1a\'.modyn.storage.GetDataPerWorkerResponse"\x00\x30\x01\x12_\n\x0eGetDatasetSize\x12$.modyn.storage.GetDatasetSizeRequest\x1a%.modyn.storage.GetDatasetSizeResponse"\x00\x12\x66\n\x11\x43heckAvailability\x12&.modyn.storage.DatasetAvailableRequest\x1a\'.modyn.storage.DatasetAvailableResponse"\x00\x12k\n\x12RegisterNewDataset\x12(.modyn.storage.RegisterNewDatasetRequest\x1a).modyn.storage.RegisterNewDatasetResponse"\x00\x12n\n\x13GetCurrentTimestamp\x12).modyn.storage.GetCurrentTimestampRequest\x1a*.modyn.storage.GetCurrentTimestampResponse"\x00\x12_\n\rDeleteDataset\x12&.modyn.storage.DatasetAvailableRequest\x1a$.modyn.storage.DeleteDatasetResponse"\x00\x12S\n\nDeleteData\x12 .modyn.storage.DeleteDataRequest\x1a!.modyn.storage.DeleteDataResponse"\x00\x62\x06proto3' + b'\n\rstorage.proto\x12\rmodyn.storage".\n\nGetRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x0c\n\x04keys\x18\x02 \x03(\x03"<\n\x0bGetResponse\x12\x0f\n\x07samples\x18\x01 \x03(\x0c\x12\x0c\n\x04keys\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03"\x1c\n\x1aGetCurrentTimestampRequest"?\n\x16GetNewDataSinceRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x11\n\ttimestamp\x18\x02 \x01(\x03"K\n\x17GetNewDataSinceResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\x12\x12\n\ntimestamps\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03"^\n\x18GetDataInIntervalRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x17\n\x0fstart_timestamp\x18\x02 \x01(\x03\x12\x15\n\rend_timestamp\x18\x03 \x01(\x03"M\n\x19GetDataInIntervalResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03\x12\x12\n\ntimestamps\x18\x02 \x03(\x03\x12\x0e\n\x06labels\x18\x03 \x03(\x03"\xb7\x01\n\x17GetDataPerWorkerRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x11\n\tworker_id\x18\x02 \x01(\x05\x12\x15\n\rtotal_workers\x18\x03 \x01(\x05\x12\x1c\n\x0fstart_timestamp\x18\x04 \x01(\x03H\x00\x88\x01\x01\x12\x1a\n\rend_timestamp\x18\x05 \x01(\x03H\x01\x88\x01\x01\x42\x12\n\x10_start_timestampB\x10\n\x0e_end_timestamp"(\n\x18GetDataPerWorkerResponse\x12\x0c\n\x04keys\x18\x01 \x03(\x03"\x8b\x01\n\x15GetDatasetSizeRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x1c\n\x0fstart_timestamp\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x1a\n\rend_timestamp\x18\x03 \x01(\x03H\x01\x88\x01\x01\x42\x12\n\x10_start_timestampB\x10\n\x0e_end_timestamp";\n\x16GetDatasetSizeResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x10\n\x08num_keys\x18\x02 \x01(\x03"-\n\x17\x44\x61tasetAvailableRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t"-\n\x18\x44\x61tasetAvailableResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08"\xff\x01\n\x19RegisterNewDatasetRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x1f\n\x17\x66ilesystem_wrapper_type\x18\x02 \x01(\t\x12\x19\n\x11\x66ile_wrapper_type\x18\x03 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\x12\x11\n\tbase_path\x18\x05 \x01(\t\x12\x0f\n\x07version\x18\x06 \x01(\t\x12\x1b\n\x13\x66ile_wrapper_config\x18\x07 \x01(\t\x12\x1d\n\x15ignore_last_timestamp\x18\x08 \x01(\x08\x12\x1d\n\x15\x66ile_watcher_interval\x18\t \x01(\x03"-\n\x1aRegisterNewDatasetResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08"0\n\x1bGetCurrentTimestampResponse\x12\x11\n\ttimestamp\x18\x01 \x01(\x03"(\n\x15\x44\x65leteDatasetResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08"5\n\x11\x44\x65leteDataRequest\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x0c\n\x04keys\x18\x02 \x03(\x03"%\n\x12\x44\x65leteDataResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x32\xa6\x08\n\x07Storage\x12@\n\x03Get\x12\x19.modyn.storage.GetRequest\x1a\x1a.modyn.storage.GetResponse"\x00\x30\x01\x12\x42\n\x05GetNL\x12\x19.modyn.storage.GetRequest\x1a\x1a.modyn.storage.GetResponse"\x00\x30\x01\x12\x64\n\x0fGetNewDataSince\x12%.modyn.storage.GetNewDataSinceRequest\x1a&.modyn.storage.GetNewDataSinceResponse"\x00\x30\x01\x12j\n\x11GetDataInInterval\x12\'.modyn.storage.GetDataInIntervalRequest\x1a(.modyn.storage.GetDataInIntervalResponse"\x00\x30\x01\x12g\n\x10GetDataPerWorker\x12&.modyn.storage.GetDataPerWorkerRequest\x1a\'.modyn.storage.GetDataPerWorkerResponse"\x00\x30\x01\x12_\n\x0eGetDatasetSize\x12$.modyn.storage.GetDatasetSizeRequest\x1a%.modyn.storage.GetDatasetSizeResponse"\x00\x12\x66\n\x11\x43heckAvailability\x12&.modyn.storage.DatasetAvailableRequest\x1a\'.modyn.storage.DatasetAvailableResponse"\x00\x12k\n\x12RegisterNewDataset\x12(.modyn.storage.RegisterNewDatasetRequest\x1a).modyn.storage.RegisterNewDatasetResponse"\x00\x12n\n\x13GetCurrentTimestamp\x12).modyn.storage.GetCurrentTimestampRequest\x1a*.modyn.storage.GetCurrentTimestampResponse"\x00\x12_\n\rDeleteDataset\x12&.modyn.storage.DatasetAvailableRequest\x1a$.modyn.storage.DeleteDatasetResponse"\x00\x12S\n\nDeleteData\x12 .modyn.storage.DeleteDataRequest\x1a!.modyn.storage.DeleteDataResponse"\x00\x62\x06proto3' ) _globals = globals() @@ -64,5 +62,5 @@ _globals["_DELETEDATARESPONSE"]._serialized_start = 1466 _globals["_DELETEDATARESPONSE"]._serialized_end = 1503 _globals["_STORAGE"]._serialized_start = 1506 - _globals["_STORAGE"]._serialized_end = 2500 + _globals["_STORAGE"]._serialized_end = 2568 # @@protoc_insertion_point(module_scope) diff --git a/modyn/storage/internal/grpc/generated/storage_pb2.pyi b/modyn/storage/internal/grpc/generated/storage_pb2.pyi index cf286ada1..379c0e3a5 100644 --- a/modyn/storage/internal/grpc/generated/storage_pb2.pyi +++ b/modyn/storage/internal/grpc/generated/storage_pb2.pyi @@ -8,11 +8,17 @@ import collections.abc import google.protobuf.descriptor import google.protobuf.internal.containers import google.protobuf.message +import sys import typing +if sys.version_info >= (3, 8): + import typing as typing_extensions +else: + import typing_extensions + DESCRIPTOR: google.protobuf.descriptor.FileDescriptor -@typing.final +@typing_extensions.final class GetRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -27,11 +33,13 @@ class GetRequest(google.protobuf.message.Message): dataset_id: builtins.str = ..., keys: collections.abc.Iterable[builtins.int] | None = ..., ) -> None: ... - def ClearField(self, field_name: typing.Literal["dataset_id", b"dataset_id", "keys", b"keys"]) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["dataset_id", b"dataset_id", "keys", b"keys"] + ) -> None: ... global___GetRequest = GetRequest -@typing.final +@typing_extensions.final class GetResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -52,12 +60,12 @@ class GetResponse(google.protobuf.message.Message): labels: collections.abc.Iterable[builtins.int] | None = ..., ) -> None: ... def ClearField( - self, field_name: typing.Literal["keys", b"keys", "labels", b"labels", "samples", b"samples"] + self, field_name: typing_extensions.Literal["keys", b"keys", "labels", b"labels", "samples", b"samples"] ) -> None: ... global___GetResponse = GetResponse -@typing.final +@typing_extensions.final class GetCurrentTimestampRequest(google.protobuf.message.Message): """https://github.com/grpc/grpc/issues/15937""" @@ -69,7 +77,7 @@ class GetCurrentTimestampRequest(google.protobuf.message.Message): global___GetCurrentTimestampRequest = GetCurrentTimestampRequest -@typing.final +@typing_extensions.final class GetNewDataSinceRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -84,12 +92,12 @@ class GetNewDataSinceRequest(google.protobuf.message.Message): timestamp: builtins.int = ..., ) -> None: ... def ClearField( - self, field_name: typing.Literal["dataset_id", b"dataset_id", "timestamp", b"timestamp"] + self, field_name: typing_extensions.Literal["dataset_id", b"dataset_id", "timestamp", b"timestamp"] ) -> None: ... global___GetNewDataSinceRequest = GetNewDataSinceRequest -@typing.final +@typing_extensions.final class GetNewDataSinceResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -110,12 +118,12 @@ class GetNewDataSinceResponse(google.protobuf.message.Message): labels: collections.abc.Iterable[builtins.int] | None = ..., ) -> None: ... def ClearField( - self, field_name: typing.Literal["keys", b"keys", "labels", b"labels", "timestamps", b"timestamps"] + self, field_name: typing_extensions.Literal["keys", b"keys", "labels", b"labels", "timestamps", b"timestamps"] ) -> None: ... global___GetNewDataSinceResponse = GetNewDataSinceResponse -@typing.final +@typing_extensions.final class GetDataInIntervalRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -134,14 +142,14 @@ class GetDataInIntervalRequest(google.protobuf.message.Message): ) -> None: ... def ClearField( self, - field_name: typing.Literal[ + field_name: typing_extensions.Literal[ "dataset_id", b"dataset_id", "end_timestamp", b"end_timestamp", "start_timestamp", b"start_timestamp" ], ) -> None: ... global___GetDataInIntervalRequest = GetDataInIntervalRequest -@typing.final +@typing_extensions.final class GetDataInIntervalResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -162,12 +170,12 @@ class GetDataInIntervalResponse(google.protobuf.message.Message): labels: collections.abc.Iterable[builtins.int] | None = ..., ) -> None: ... def ClearField( - self, field_name: typing.Literal["keys", b"keys", "labels", b"labels", "timestamps", b"timestamps"] + self, field_name: typing_extensions.Literal["keys", b"keys", "labels", b"labels", "timestamps", b"timestamps"] ) -> None: ... global___GetDataInIntervalResponse = GetDataInIntervalResponse -@typing.final +@typing_extensions.final class GetDataPerWorkerRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -195,7 +203,7 @@ class GetDataPerWorkerRequest(google.protobuf.message.Message): ) -> None: ... def HasField( self, - field_name: typing.Literal[ + field_name: typing_extensions.Literal[ "_end_timestamp", b"_end_timestamp", "_start_timestamp", @@ -208,7 +216,7 @@ class GetDataPerWorkerRequest(google.protobuf.message.Message): ) -> builtins.bool: ... def ClearField( self, - field_name: typing.Literal[ + field_name: typing_extensions.Literal[ "_end_timestamp", b"_end_timestamp", "_start_timestamp", @@ -227,16 +235,16 @@ class GetDataPerWorkerRequest(google.protobuf.message.Message): ) -> None: ... @typing.overload def WhichOneof( - self, oneof_group: typing.Literal["_end_timestamp", b"_end_timestamp"] - ) -> typing.Literal["end_timestamp"] | None: ... + self, oneof_group: typing_extensions.Literal["_end_timestamp", b"_end_timestamp"] + ) -> typing_extensions.Literal["end_timestamp"] | None: ... @typing.overload def WhichOneof( - self, oneof_group: typing.Literal["_start_timestamp", b"_start_timestamp"] - ) -> typing.Literal["start_timestamp"] | None: ... + self, oneof_group: typing_extensions.Literal["_start_timestamp", b"_start_timestamp"] + ) -> typing_extensions.Literal["start_timestamp"] | None: ... global___GetDataPerWorkerRequest = GetDataPerWorkerRequest -@typing.final +@typing_extensions.final class GetDataPerWorkerResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -248,11 +256,11 @@ class GetDataPerWorkerResponse(google.protobuf.message.Message): *, keys: collections.abc.Iterable[builtins.int] | None = ..., ) -> None: ... - def ClearField(self, field_name: typing.Literal["keys", b"keys"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["keys", b"keys"]) -> None: ... global___GetDataPerWorkerResponse = GetDataPerWorkerResponse -@typing.final +@typing_extensions.final class GetDatasetSizeRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -274,7 +282,7 @@ class GetDatasetSizeRequest(google.protobuf.message.Message): ) -> None: ... def HasField( self, - field_name: typing.Literal[ + field_name: typing_extensions.Literal[ "_end_timestamp", b"_end_timestamp", "_start_timestamp", @@ -287,7 +295,7 @@ class GetDatasetSizeRequest(google.protobuf.message.Message): ) -> builtins.bool: ... def ClearField( self, - field_name: typing.Literal[ + field_name: typing_extensions.Literal[ "_end_timestamp", b"_end_timestamp", "_start_timestamp", @@ -302,16 +310,16 @@ class GetDatasetSizeRequest(google.protobuf.message.Message): ) -> None: ... @typing.overload def WhichOneof( - self, oneof_group: typing.Literal["_end_timestamp", b"_end_timestamp"] - ) -> typing.Literal["end_timestamp"] | None: ... + self, oneof_group: typing_extensions.Literal["_end_timestamp", b"_end_timestamp"] + ) -> typing_extensions.Literal["end_timestamp"] | None: ... @typing.overload def WhichOneof( - self, oneof_group: typing.Literal["_start_timestamp", b"_start_timestamp"] - ) -> typing.Literal["start_timestamp"] | None: ... + self, oneof_group: typing_extensions.Literal["_start_timestamp", b"_start_timestamp"] + ) -> typing_extensions.Literal["start_timestamp"] | None: ... global___GetDatasetSizeRequest = GetDatasetSizeRequest -@typing.final +@typing_extensions.final class GetDatasetSizeResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -325,11 +333,13 @@ class GetDatasetSizeResponse(google.protobuf.message.Message): success: builtins.bool = ..., num_keys: builtins.int = ..., ) -> None: ... - def ClearField(self, field_name: typing.Literal["num_keys", b"num_keys", "success", b"success"]) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["num_keys", b"num_keys", "success", b"success"] + ) -> None: ... global___GetDatasetSizeResponse = GetDatasetSizeResponse -@typing.final +@typing_extensions.final class DatasetAvailableRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -340,11 +350,11 @@ class DatasetAvailableRequest(google.protobuf.message.Message): *, dataset_id: builtins.str = ..., ) -> None: ... - def ClearField(self, field_name: typing.Literal["dataset_id", b"dataset_id"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["dataset_id", b"dataset_id"]) -> None: ... global___DatasetAvailableRequest = DatasetAvailableRequest -@typing.final +@typing_extensions.final class DatasetAvailableResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -355,11 +365,11 @@ class DatasetAvailableResponse(google.protobuf.message.Message): *, available: builtins.bool = ..., ) -> None: ... - def ClearField(self, field_name: typing.Literal["available", b"available"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["available", b"available"]) -> None: ... global___DatasetAvailableResponse = DatasetAvailableResponse -@typing.final +@typing_extensions.final class RegisterNewDatasetRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -396,7 +406,7 @@ class RegisterNewDatasetRequest(google.protobuf.message.Message): ) -> None: ... def ClearField( self, - field_name: typing.Literal[ + field_name: typing_extensions.Literal[ "base_path", b"base_path", "dataset_id", @@ -420,7 +430,7 @@ class RegisterNewDatasetRequest(google.protobuf.message.Message): global___RegisterNewDatasetRequest = RegisterNewDatasetRequest -@typing.final +@typing_extensions.final class RegisterNewDatasetResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -431,11 +441,11 @@ class RegisterNewDatasetResponse(google.protobuf.message.Message): *, success: builtins.bool = ..., ) -> None: ... - def ClearField(self, field_name: typing.Literal["success", b"success"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["success", b"success"]) -> None: ... global___RegisterNewDatasetResponse = RegisterNewDatasetResponse -@typing.final +@typing_extensions.final class GetCurrentTimestampResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -446,11 +456,11 @@ class GetCurrentTimestampResponse(google.protobuf.message.Message): *, timestamp: builtins.int = ..., ) -> None: ... - def ClearField(self, field_name: typing.Literal["timestamp", b"timestamp"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["timestamp", b"timestamp"]) -> None: ... global___GetCurrentTimestampResponse = GetCurrentTimestampResponse -@typing.final +@typing_extensions.final class DeleteDatasetResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -461,11 +471,11 @@ class DeleteDatasetResponse(google.protobuf.message.Message): *, success: builtins.bool = ..., ) -> None: ... - def ClearField(self, field_name: typing.Literal["success", b"success"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["success", b"success"]) -> None: ... global___DeleteDatasetResponse = DeleteDatasetResponse -@typing.final +@typing_extensions.final class DeleteDataRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -480,11 +490,13 @@ class DeleteDataRequest(google.protobuf.message.Message): dataset_id: builtins.str = ..., keys: collections.abc.Iterable[builtins.int] | None = ..., ) -> None: ... - def ClearField(self, field_name: typing.Literal["dataset_id", b"dataset_id", "keys", b"keys"]) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["dataset_id", b"dataset_id", "keys", b"keys"] + ) -> None: ... global___DeleteDataRequest = DeleteDataRequest -@typing.final +@typing_extensions.final class DeleteDataResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -495,6 +507,6 @@ class DeleteDataResponse(google.protobuf.message.Message): *, success: builtins.bool = ..., ) -> None: ... - def ClearField(self, field_name: typing.Literal["success", b"success"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["success", b"success"]) -> None: ... global___DeleteDataResponse = DeleteDataResponse diff --git a/modyn/storage/internal/grpc/generated/storage_pb2_grpc.py b/modyn/storage/internal/grpc/generated/storage_pb2_grpc.py index cae27014a..51114e565 100644 --- a/modyn/storage/internal/grpc/generated/storage_pb2_grpc.py +++ b/modyn/storage/internal/grpc/generated/storage_pb2_grpc.py @@ -1,15 +1,13 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" +import grpc import warnings -import grpc import modyn.storage.internal.grpc.generated.storage_pb2 as storage__pb2 -GRPC_GENERATED_VERSION = "1.63.0" +GRPC_GENERATED_VERSION = "1.67.1" GRPC_VERSION = grpc.__version__ -EXPECTED_ERROR_RELEASE = "1.65.0" -SCHEDULED_RELEASE_DATE = "June 25, 2024" _version_not_supported = False try: @@ -20,15 +18,12 @@ _version_not_supported = True if _version_not_supported: - warnings.warn( + raise RuntimeError( f"The grpc package installed is at version {GRPC_VERSION}," + f" but the generated code in storage_pb2_grpc.py depends on" + f" grpcio>={GRPC_GENERATED_VERSION}." + f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}" + f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}." - + f" This warning will become an error in {EXPECTED_ERROR_RELEASE}," - + f" scheduled for release on {SCHEDULED_RELEASE_DATE}.", - RuntimeWarning, ) @@ -47,6 +42,12 @@ def __init__(self, channel): response_deserializer=storage__pb2.GetResponse.FromString, _registered_method=True, ) + self.GetNL = channel.unary_stream( + "/modyn.storage.Storage/GetNL", + request_serializer=storage__pb2.GetRequest.SerializeToString, + response_deserializer=storage__pb2.GetResponse.FromString, + _registered_method=True, + ) self.GetNewDataSince = channel.unary_stream( "/modyn.storage.Storage/GetNewDataSince", request_serializer=storage__pb2.GetNewDataSinceRequest.SerializeToString, @@ -112,6 +113,12 @@ def Get(self, request, context): context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") + def GetNL(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + def GetNewDataSince(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -174,6 +181,11 @@ def add_StorageServicer_to_server(servicer, server): request_deserializer=storage__pb2.GetRequest.FromString, response_serializer=storage__pb2.GetResponse.SerializeToString, ), + "GetNL": grpc.unary_stream_rpc_method_handler( + servicer.GetNL, + request_deserializer=storage__pb2.GetRequest.FromString, + response_serializer=storage__pb2.GetResponse.SerializeToString, + ), "GetNewDataSince": grpc.unary_stream_rpc_method_handler( servicer.GetNewDataSince, request_deserializer=storage__pb2.GetNewDataSinceRequest.FromString, @@ -222,6 +234,7 @@ def add_StorageServicer_to_server(servicer, server): } generic_handler = grpc.method_handlers_generic_handler("modyn.storage.Storage", rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers("modyn.storage.Storage", rpc_method_handlers) # This class is part of an EXPERIMENTAL API. @@ -258,6 +271,36 @@ def Get( _registered_method=True, ) + @staticmethod + def GetNL( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_stream( + request, + target, + "/modyn.storage.Storage/GetNL", + storage__pb2.GetRequest.SerializeToString, + storage__pb2.GetResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True, + ) + @staticmethod def GetNewDataSince( request, diff --git a/modyn/storage/src/internal/file_watcher/file_watcher.cpp b/modyn/storage/src/internal/file_watcher/file_watcher.cpp index e76ba4824..1f4e414c0 100644 --- a/modyn/storage/src/internal/file_watcher/file_watcher.cpp +++ b/modyn/storage/src/internal/file_watcher/file_watcher.cpp @@ -277,6 +277,7 @@ void FileWatcher::handle_files_for_insertion(std::vector& files_for const std::shared_ptr& filesystem_wrapper) { const std::string file_path = files_for_insertion.front(); std::vector file_samples; + auto file_wrapper = get_file_wrapper(file_path, file_wrapper_type, file_wrapper_config, filesystem_wrapper); int64_t current_file_samples_to_be_inserted = 0; @@ -299,6 +300,7 @@ void FileWatcher::handle_files_for_insertion(std::vector& files_for file_samples.clear(); current_file_samples_to_be_inserted = 0; } + file_samples.push_back({file_id, index, label}); index++; current_file_samples_to_be_inserted++; diff --git a/modyn/storage/src/internal/file_wrapper/binary_file_wrapper.cpp b/modyn/storage/src/internal/file_wrapper/binary_file_wrapper.cpp index 20b04bb32..69a1de037 100644 --- a/modyn/storage/src/internal/file_wrapper/binary_file_wrapper.cpp +++ b/modyn/storage/src/internal/file_wrapper/binary_file_wrapper.cpp @@ -32,7 +32,9 @@ void BinaryFileWrapper::validate_file_extension() { */ int64_t BinaryFileWrapper::get_label(uint64_t index) { ASSERT(index < get_number_of_samples(), "Invalid index"); - + if (!has_labels_) { + return -1; // Return -1 to indicate no label + } const uint64_t label_start = index * record_size_; get_stream()->seekg(static_cast(label_start), std::ios::beg); @@ -76,7 +78,8 @@ std::vector BinaryFileWrapper::get_all_labels() { /* * Offset calculation to retrieve the data of a sample interval. */ -std::vector> BinaryFileWrapper::get_samples(uint64_t start, uint64_t end) { +std::vector> BinaryFileWrapper::get_samples(uint64_t start, uint64_t end, + const bool include_labels) { ASSERT(end >= start && end <= get_number_of_samples(), "Invalid indices"); const uint64_t num_samples = end - start + 1; @@ -85,7 +88,7 @@ std::vector> BinaryFileWrapper::get_samples(uint64_t uint64_t record_start; for (uint64_t index = 0; index < num_samples; ++index) { record_start = (start + index) * record_size_; - get_stream()->seekg(static_cast(record_start + label_size_), std::ios::beg); + get_stream()->seekg(static_cast(record_start + ((!include_labels) ? 0 : label_size_)), std::ios::beg); std::vector sample_vec(sample_size_); get_stream()->read(reinterpret_cast(sample_vec.data()), static_cast(sample_size_)); @@ -116,7 +119,7 @@ std::vector BinaryFileWrapper::get_sample(uint64_t index) { * Offset calculation to retrieve the data of a sample interval. */ std::vector> BinaryFileWrapper::get_samples_from_indices( - const std::vector& indices) { + const std::vector& indices, const bool include_labels) { ASSERT(std::all_of(indices.begin(), indices.end(), [&](uint64_t index) { return index < get_number_of_samples(); }), "Invalid indices"); @@ -127,7 +130,8 @@ std::vector> BinaryFileWrapper::get_samples_from_indi for (const uint64_t index : indices) { record_start = index * record_size_; - get_stream()->seekg(static_cast(record_start + label_size_), std::ios::beg); + // Adjust stream position based on the include_labels flag + get_stream()->seekg(static_cast(record_start + ((!include_labels) ? 0 : label_size_)), std::ios::beg); std::vector sample_vec(sample_size_); get_stream()->read(reinterpret_cast(sample_vec.data()), static_cast(sample_size_)); diff --git a/modyn/storage/src/internal/file_wrapper/csv_file_wrapper.cpp b/modyn/storage/src/internal/file_wrapper/csv_file_wrapper.cpp index 1958076ee..f84374131 100644 --- a/modyn/storage/src/internal/file_wrapper/csv_file_wrapper.cpp +++ b/modyn/storage/src/internal/file_wrapper/csv_file_wrapper.cpp @@ -3,8 +3,11 @@ #include #include +#include // Include for ASSERT or use your custom assert macro +#include // Include required for std::cout #include #include +#include using namespace modyn::storage; @@ -35,7 +38,8 @@ std::vector CsvFileWrapper::get_sample(uint64_t index) { return {row_string.begin(), row_string.end()}; } -std::vector> CsvFileWrapper::get_samples(uint64_t start, uint64_t end) { +std::vector> CsvFileWrapper::get_samples(uint64_t start, uint64_t end, + const bool include_labels) { ASSERT(end >= start && end <= get_number_of_samples(), "Invalid indices"); std::vector> samples; @@ -43,7 +47,9 @@ std::vector> CsvFileWrapper::get_samples(uint64_t sta const uint64_t end_t = end; for (uint64_t i = start_t; i < end_t; ++i) { std::vector row = doc_.GetRow(static_cast(i)); - row.erase(row.begin() + static_cast(label_index_)); + if (include_labels) { + row.erase(row.begin() + static_cast(label_index_)); + } std::string row_string; for (const auto& cell : row) { row_string += cell + separator_; @@ -55,19 +61,23 @@ std::vector> CsvFileWrapper::get_samples(uint64_t sta return samples; } -std::vector> CsvFileWrapper::get_samples_from_indices(const std::vector& indices) { - ASSERT(std::all_of(indices.begin(), indices.end(), [&](uint64_t index) { return index < get_number_of_samples(); }), - "Invalid indices"); - +std::vector> CsvFileWrapper::get_samples_from_indices(const std::vector& indices, + const bool include_labels) { std::vector> samples; for (const uint64_t index : indices) { std::vector row = doc_.GetRow(index); - row.erase(row.begin() + static_cast(label_index_)); + + // Erase label based on the include_labels flag + if (include_labels) { + row.erase(row.begin() + static_cast(label_index_)); + } + std::string row_string; for (const auto& cell : row) { row_string += cell + separator_; } row_string.pop_back(); + samples.emplace_back(row_string.begin(), row_string.end()); } return samples; @@ -75,15 +85,20 @@ std::vector> CsvFileWrapper::get_samples_from_indices int64_t CsvFileWrapper::get_label(uint64_t index) { ASSERT(index < get_number_of_samples(), "Invalid index"); + if (!has_labels_) { + return -1; // Return -1 to indicate no label + } return doc_.GetCell(static_cast(label_index_), static_cast(index)); } std::vector CsvFileWrapper::get_all_labels() { std::vector labels; const uint64_t num_samples = get_number_of_samples(); + for (uint64_t i = 0; i < num_samples; i++) { labels.push_back(get_label(i)); } + return labels; } diff --git a/modyn/storage/src/internal/file_wrapper/single_sample_file_wrapper.cpp b/modyn/storage/src/internal/file_wrapper/single_sample_file_wrapper.cpp index c6e42f5fd..bf4e13a70 100644 --- a/modyn/storage/src/internal/file_wrapper/single_sample_file_wrapper.cpp +++ b/modyn/storage/src/internal/file_wrapper/single_sample_file_wrapper.cpp @@ -45,7 +45,8 @@ std::vector SingleSampleFileWrapper::get_sample(uint64_t index) { return filesystem_wrapper_->get(file_path_); } -std::vector> SingleSampleFileWrapper::get_samples(uint64_t start, uint64_t end) { +std::vector> SingleSampleFileWrapper::get_samples(uint64_t start, uint64_t end, + const bool /*include_labels*/) { ASSERT( start == 0 && end == 1, fmt::format("Single sample file wrappers can only access the first sample. file_path = {}, start = {}, end = {}", @@ -54,7 +55,7 @@ std::vector> SingleSampleFileWrapper::get_samples(uin } std::vector> SingleSampleFileWrapper::get_samples_from_indices( - const std::vector& indices) { + const std::vector& indices, const bool /*include_labels*/) { ASSERT(indices.size() == 1 && indices[0] == 0, fmt::format("Single sample file wrappers can only access the first sample. file_path = {}, indices.size() = " "{}, indices = [{}]", diff --git a/modyn/storage/src/internal/filesystem_wrapper/local_filesystem_wrapper.cpp b/modyn/storage/src/internal/filesystem_wrapper/local_filesystem_wrapper.cpp index 50a518880..1ddb8c635 100644 --- a/modyn/storage/src/internal/filesystem_wrapper/local_filesystem_wrapper.cpp +++ b/modyn/storage/src/internal/filesystem_wrapper/local_filesystem_wrapper.cpp @@ -72,7 +72,9 @@ int64_t LocalFilesystemWrapper::get_modified_time(const std::string& path) { static_assert(sizeof(int64_t) >= sizeof(std::time_t), "Cannot cast time_t to int64_t"); // there is no portable way to get the modified time of a file in C++17 and earlier - struct stat file_attribute {}; + // clang-format off + struct stat file_attribute {};//This line keeps getting changed by clang format in my machine and not passing the format test when pushed. + // clang-format on stat(path.c_str(), &file_attribute); return static_cast(file_attribute.st_mtime); /* C++20 version, not supported by compilers yet */ diff --git a/modyn/storage/src/internal/grpc/storage_service_impl.cpp b/modyn/storage/src/internal/grpc/storage_service_impl.cpp index 4640f2593..95a624c9f 100644 --- a/modyn/storage/src/internal/grpc/storage_service_impl.cpp +++ b/modyn/storage/src/internal/grpc/storage_service_impl.cpp @@ -498,13 +498,24 @@ DatasetData StorageServiceImpl::get_dataset_data(soci::session& session, std::st auto filesystem_wrapper_type = static_cast(FilesystemWrapperType::INVALID_FSW); auto file_wrapper_type = static_cast(FileWrapperType::INVALID_FW); std::string file_wrapper_config; + int has_labels_int = 1; session << "SELECT dataset_id, base_path, filesystem_wrapper_type, file_wrapper_type, file_wrapper_config FROM " - "datasets WHERE " - "name = :name", + "datasets WHERE name = :name", soci::into(dataset_id), soci::into(base_path), soci::into(filesystem_wrapper_type), soci::into(file_wrapper_type), soci::into(file_wrapper_config), soci::use(dataset_name); - return {dataset_id, base_path, static_cast(filesystem_wrapper_type), - static_cast(file_wrapper_type), file_wrapper_config}; + YAML::Node config = YAML::Load(file_wrapper_config); + if (config["has_labels"]) { + has_labels_int = config["has_labels"].as() ? 1 : 0; + } + + // Convert has_labels_int to bool + const bool has_labels = (has_labels_int != 0); + return DatasetData{.dataset_id = dataset_id, + .base_path = base_path, + .filesystem_wrapper_type = static_cast(filesystem_wrapper_type), + .file_wrapper_type = static_cast(file_wrapper_type), + .file_wrapper_config = file_wrapper_config, + .has_labels = has_labels}; } diff --git a/modyn/supervisor/internal/grpc/supervisor_grpc_servicer.py b/modyn/supervisor/internal/grpc/supervisor_grpc_servicer.py index 3f38d0a7e..e3790451c 100644 --- a/modyn/supervisor/internal/grpc/supervisor_grpc_servicer.py +++ b/modyn/supervisor/internal/grpc/supervisor_grpc_servicer.py @@ -30,7 +30,6 @@ def __init__(self, supervisor: Supervisor, modyn_config: dict) -> None: def start_pipeline(self, request: StartPipelineRequest, context: grpc.ServicerContext) -> PipelineResponse: tid = threading.get_native_id() pid = os.getpid() - logger.info(f"[{pid}][{tid}]: Starting pipeline with request - {request}") start_replay_at: int | None = None @@ -44,7 +43,6 @@ def start_pipeline(self, request: StartPipelineRequest, context: grpc.ServicerCo maximum_triggers = request.maximum_triggers pipeline_config = json.loads(request.pipeline_config.value) - msg = self._supervisor.start_pipeline( pipeline_config, self.modyn_config["supervisor"]["eval_directory"], diff --git a/modyn/supervisor/internal/grpc_handler.py b/modyn/supervisor/internal/grpc_handler.py index b2d98c1af..7a192769e 100644 --- a/modyn/supervisor/internal/grpc_handler.py +++ b/modyn/supervisor/internal/grpc_handler.py @@ -247,7 +247,6 @@ def prepare_evaluation_request( dataset_id = dataset_config["dataset_id"] transform_list = dataset_config.get("transformations") or [] label_transformer = dataset_config.get("label_transformer_function") or "" - bytes_parser_function = dataset_config["bytes_parser_function"] batch_size = dataset_config["batch_size"] dataloader_workers = dataset_config["dataloader_workers"] @@ -262,6 +261,11 @@ def prepare_evaluation_request( tokenizer_arg = EvaluatorPythonString(value=tokenizer) else: tokenizer_arg = None + light_tuning = False + tuning_config = None + if dataset_config["light_tuning"]: + light_tuning = dataset_config["light_tuning"] + tuning_config = dataset_config["tuning_config"] return EvaluateModelRequest( model_id=model_id, @@ -277,6 +281,8 @@ def prepare_evaluation_request( bytes_parser=EvaluatorPythonString(value=bytes_parser_function), label_transformer=EvaluatorPythonString(value=label_transformer), tokenizer=tokenizer_arg, + light_tuning=light_tuning, + tuning_config=tuning_config, ) # pylint: disable=too-many-branches diff --git a/modyn/tests/evaluator/internal/pytorch_lighttuner.py b/modyn/tests/evaluator/internal/pytorch_lighttuner.py new file mode 100644 index 000000000..202651a55 --- /dev/null +++ b/modyn/tests/evaluator/internal/pytorch_lighttuner.py @@ -0,0 +1,368 @@ +# pylint: disable=no-name-in-module +# pylint: disable=unused-argument, no-name-in-module, no-value-for-parameter +# ruff: noqa: N802 # grpc functions are not snake case + +from __future__ import annotations + +import copy +import json +import logging +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock, patch + +import torch + +from modyn.config import ModynConfig +from modyn.evaluator.internal.grpc.generated.evaluator_pb2 import DatasetInfo, JsonString, PythonString +from modyn.evaluator.internal.pytorch_lighttuner import PytorchTuner +from modyn.evaluator.internal.utils.tuning_info import TuningInfo +from modyn.storage.internal.grpc.generated.storage_pb2 import ( + GetDataPerWorkerRequest, + GetDataPerWorkerResponse, + GetRequest, + GetResponse, +) +from modyn.trainer_server.custom_lr_schedulers.WarmupDecayLR.warmupdecay import WarmupDecayLR + + +class MockStorageStub: + def __init__(self, channel) -> None: + pass + + def Get(self, request: GetRequest): # pylint: disable=invalid-name + for key in request.keys: + yield GetResponse( + samples=[key.to_bytes(2, "big"), (key + 1).to_bytes(2, "big")], keys=[key, key + 1], labels=[5, 5] + ) + + def GetDataPerWorker(self, request: GetDataPerWorkerRequest): # pylint: disable=invalid-name + for i in range(0, 8, 4): + key = 8 * request.worker_id + i + yield GetDataPerWorkerResponse(keys=[key, key + 2]) + + +class NoneOrFalse: + def __eq__(self, other): + if other is None or not other: + return True + + return False + + +class MockModule: + def __init__(self, num_optimizers) -> None: + if num_optimizers == 1: + self.model = MockModelWrapper + else: + self.model = MockSuperModelWrapper + + def train(self) -> None: + pass + + +class MockModelWrapper: + def __init__(self, model_configuration=None, device="cpu", amp=False) -> None: + self.model = MockModel() + + +class MockSuperModelWrapper: + def __init__(self, model_configuration=None, device="cpu", amp=False) -> None: + self.model = MockSuperModel() + + +class MockSuperModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.moda = MockModel() + self.modb = MockModel() + self.modc = MockModel() + + def forward(self, data, sample_ids=None): + return self.moda(self.modb(data)) + + +class MockDataset(torch.utils.data.IterableDataset): + # pylint: disable=abstract-method, useless-parent-delegation + def __init__(self) -> None: + super().__init__() + + def __iter__(self): + return iter(range(100)) + + +class MockLRSchedulerModule: + def __init__(self) -> None: + self.custom_scheduler = CustomLRScheduler + + +# pylint: disable=dangerous-default-value +class CustomLRScheduler: + def __init__(self, optimizers, config={}) -> None: + pass + + def step(self): + pass + + +def get_mock_bytes_parser(): + return "def bytes_parser_function(x):\n\treturn x" + + +def get_mock_label_transformer(): + return "import torch\ndef label_transformer_function(x: torch.Tensor) -> " "torch.Tensor:\n\treturn x" + + +class MockModel(torch.nn.Module): + def __init__(self, num_classes=10, input_dim=10): + super().__init__() + self.fc = torch.nn.Linear(input_dim, num_classes) # Adjusted to match data input shape + self.device = "cpu" + + def forward(self, data, sample_ids=None): + data = data.to(torch.float32) # Ensure float dtype for model input + return self.fc(data) + + +class MockDataloader: + def __init__(self, tuning_info, num_batches=20): + self.batch_size = tuning_info.batch_size + self.num_batches = num_batches + self.dataset = MagicMock() + + def __iter__(self): + return iter( + [ + ( + ("1",) * self.batch_size, # sample_ids as strings (not used for training) + torch.ones(self.batch_size, 10, requires_grad=True), # Input data as (batch_size, 10) + torch.randint(0, 10, (self.batch_size,), dtype=torch.long), # Target as `long` for CE + ) + for _ in range(self.num_batches) + ] + ) + + def __len__(self): + return self.num_batches + + +def mock_get_dataloader(self, tuning_info): + """Creates a DataLoader similar to _prepare_dataloader.""" + mock_dataloader = MockDataloader(tuning_info) + return mock_dataloader + + +def noop_constructor_mock(self, channel): + pass + + +def get_tuning_info( + evaluation_id: int, + batch_size: int, + num_optimizers: int, + lr_scheduler: str, +): + if num_optimizers == 1: + torch_optimizers_configuration = { + "default": { + "algorithm": "SGD", + "source": "PyTorch", + "param_groups": [{"module": "model", "config": {"lr": 0.1}}], + } + } + else: + torch_optimizers_configuration = { + "opt1": { + "algorithm": "SGD", + "source": "PyTorch", + "param_groups": [{"module": "model.moda", "config": {"lr": 0.1}}], + }, + "opt2": { + "algorithm": "Adam", + "source": "PyTorch", + "param_groups": [ + {"module": "model.modb", "config": {"lr": 0.5}}, + {"module": "model.modc", "config": {"lr": 0.8}}, + ], + }, + } + + if lr_scheduler == "torch": + lr_scheduler_config = { + "name": "StepLR", + "source": "PyTorch", + "step_every": "batch", + "optimizers": ["default"] if num_optimizers == 1 else ["opt1"], + "config": {"step_size": 10}, + } + elif lr_scheduler == "custom": + lr_scheduler_config = { + "name": "WarmupDecayLR", + "source": "Custom", + "step_every": "batch", + "optimizers": ["default"] if num_optimizers == 1 else ["opt1", "opt2"], + "config": {}, + } + elif lr_scheduler == "torch_cosine": + lr_scheduler_config = { + "name": "CosineAnnealingLR", + "source": "PyTorch", + "step_every": "batch", + "optimizers": ["default"] if num_optimizers == 1 else ["opt1"], + "config": {"T_max": 100}, + } + else: + lr_scheduler_config = {} + + tuning_info = SimpleNamespace( + pipeline_id=1, + data_info=DatasetInfo(dataset_id="MNIST", num_dataloaders=2), + evaluation_id=42, + batch_size=32, + num_samples_to_pass=500, + transform_list=[], + bytes_parser=PythonString(value=get_mock_bytes_parser()), + torch_optimizers_configuration=JsonString(value=json.dumps(torch_optimizers_configuration)), + criterion_parameters=JsonString(value=json.dumps({})), + torch_criterion="CrossEntropyLoss", + lr_scheduler=JsonString(value=json.dumps(lr_scheduler_config)), + grad_scaler_configuration=JsonString(value=json.dumps({})), + epochs=10, + label_transformer=PythonString(value=get_mock_label_transformer()), + device="cpu", + amp=False, + shuffle=True, + enable_accurate_gpu_measurements=False, + generative=False, + steps=100, + drop_last_batch=False, + record_loss_every=10, + seed=42, + tokenizer=None, + ) + + return TuningInfo( + tuning_info, + 1, + None, + None, + ) + + +def get_mock_tuner( + modyn_config: ModynConfig, + num_optimizers: int, + lr_scheduler: str, + lr_scheduler_dynamic_module_patch: MagicMock, + model_dynamic_module_patch: MagicMock, + batch_size: int, + model: Any, +): + model_dynamic_module_patch.return_value = MockModule(num_optimizers) + lr_scheduler_dynamic_module_patch.return_value = MockLRSchedulerModule() + + tuning_info = get_tuning_info( + 0, + batch_size, + num_optimizers, + lr_scheduler, + ) + + # Fixing argument order: + tuner = PytorchTuner( + tuning_info, # ✅ Matches `tuning_info` + logging.getLogger(__name__), # ✅ Correctly placed `logger` + model, # ✅ Correctly assigned `model` + "localhost:1234", + ) + + return tuner + + +def test_tuner_init(dummy_system_config: ModynConfig): + tuner = get_mock_tuner(dummy_system_config, 1, "", MagicMock(), MagicMock(), 32, MockModelWrapper()) + + # Ensure model initialization is correct + assert isinstance(tuner._model, MockModelWrapper), "Expected tuner._model to be an instance of MockModule" + + # Validate optimizer setup + assert len(tuner._optimizers) == 1, "Expected one optimizer to be initialized" + assert isinstance(tuner._optimizers["default"], torch.optim.SGD), "Optimizer should be SGD" + + # Validate loss function + assert isinstance(tuner._criterion, torch.nn.CrossEntropyLoss), "Loss function should be CrossEntropyLoss" + + # Ensure learning rate scheduler is disabled + assert not tuner._lr_scheduler, "Expected no learning rate scheduler" + + # Verify device and batch size configurations + assert tuner._device == "cpu", "Expected tuner to run on CPU" + assert tuner._batch_size > 0, "Batch size should be greater than 0" + + +def test_tuner_init_multi_optimizers(dummy_system_config: ModynConfig): + tuner = get_mock_tuner(dummy_system_config, 2, "", MagicMock(), MagicMock(), 32, MockSuperModelWrapper()) + assert isinstance(tuner._model, MockSuperModelWrapper) + assert len(tuner._optimizers) == 2 + assert isinstance(tuner._optimizers["opt1"], torch.optim.SGD) + assert isinstance(tuner._optimizers["opt2"], torch.optim.Adam) + assert isinstance(tuner._criterion, torch.nn.CrossEntropyLoss) + assert not tuner._lr_scheduler + assert tuner._device == "cpu" + assert tuner._batch_size > 0 + assert tuner._dataset_log_path is not None + + +def test_tuner_init_torch_lr_scheduler(dummy_system_config: ModynConfig): + tuner = get_mock_tuner(dummy_system_config, 1, "torch", MagicMock(), MagicMock(), 32, MockModelWrapper()) + assert isinstance(tuner._model, MockModelWrapper) + assert len(tuner._optimizers) == 1 + assert isinstance(tuner._optimizers["default"], torch.optim.SGD) + assert isinstance(tuner._criterion, torch.nn.CrossEntropyLoss) + assert isinstance(tuner._lr_scheduler, torch.optim.lr_scheduler.StepLR) + assert tuner._device == "cpu" + assert tuner._batch_size > 0 + + +def test_tuner_init_custom_lr_scheduler(dummy_system_config: ModynConfig): + tuner = get_mock_tuner(dummy_system_config, 1, "custom", MagicMock(), MagicMock(), 32, MockModelWrapper()) + assert isinstance(tuner._model, MockModelWrapper) + assert len(tuner._optimizers) == 1 + assert isinstance(tuner._optimizers["default"], torch.optim.SGD) + assert isinstance(tuner._criterion, torch.nn.CrossEntropyLoss) + assert isinstance(tuner._lr_scheduler, WarmupDecayLR) + assert tuner._device == "cpu" + assert tuner._batch_size > 0 + + +@patch("modyn.evaluator.internal.pytorch_lighttuner.PytorchTuner._prepare_dataloader", mock_get_dataloader) +@patch("modyn.evaluator.internal.dataset.evaluation_dataset.StorageStub", MockStorageStub) +@patch("modyn.evaluator.internal.dataset.evaluation_dataset.grpc_connection_established", return_value=True) +def test_tuner_light_tuning( + dummy_system_config: ModynConfig, +): + model = MockModelWrapper() + tuner = get_mock_tuner(dummy_system_config, 1, "", MagicMock(), MagicMock(), 32, model) + + # Limit the tuning to 2 steps + tuner._light_tuning_steps = 2 + + # Capture old state before training + old_model_state = copy.deepcopy(model.model.state_dict()) + old_optim_states = {name: copy.deepcopy(opt.state_dict()) for name, opt in tuner._optimizers.items()} + + # Run the light tuning process + tuner.train() + + # Check that at least one model parameter changed + new_model_state = model.model.state_dict() + + change_detected = any(not torch.equal(old_model_state[p], new_model_state[p]) for p in old_model_state) + assert change_detected, "Expected at least one model parameter to change after light tuning." + + # Check that optimizer states changed + new_optim_states = {name: opt.state_dict() for name, opt in tuner._optimizers.items()} + for opt_name in old_optim_states: + assert ( + old_optim_states[opt_name] != new_optim_states[opt_name] + ), f"Expected optimizer {opt_name} state to change after light tuning." diff --git a/modyn/tests/models/test_gpt2.py b/modyn/tests/models/test_gpt2.py new file mode 100644 index 000000000..76a84677c --- /dev/null +++ b/modyn/tests/models/test_gpt2.py @@ -0,0 +1,65 @@ +import torch +from transformers import AutoTokenizer + +from modyn.models import Gpt2 + + +class HParams: + def __init__(self, model_name_or_path="gpt2-large", device="cpu", amp=False): + self.model_name_or_path = model_name_or_path + self.device = device + self.amp = amp + + +def test_gpt2modyn_initialization(): + hparams = HParams() + model = Gpt2(hparams, hparams.device, hparams.amp) # Pass device and amp explicitly + assert isinstance(model.model, torch.nn.Module) + + +def test_forward_pass(): + hparams = HParams() + model = Gpt2(hparams, hparams.device, hparams.amp) + tokenizer = AutoTokenizer.from_pretrained("gpt2-large") + tokenizer.pad_token = tokenizer.eos_token + text = ["Hello, how are you?"] + tokens = tokenizer(text, return_tensors="pt", padding=True) + + input_data = torch.stack([tokens.input_ids, tokens.attention_mask], dim=-1) + output = model.model(input_data) # Fix incorrect model call + + assert output.logits.shape[-1] == tokenizer.vocab_size # Logits over vocab size + + +def test_get_last_layer(): + hparams = HParams() + model = Gpt2(hparams, hparams.device, hparams.amp) + last_layer = model.model.get_last_layer() # Ensure method is correctly accessed + + assert isinstance(last_layer, torch.nn.Linear) + assert last_layer.out_features == 50257 + + +def test_freeze_unfreeze_params(): + hparams = HParams() + model = Gpt2(hparams, hparams.device, hparams.amp) + + model.model.freeze_params() + assert all(not param.requires_grad for param in model.model.parameters()) + model.model.unfreeze_params() + assert all(param.requires_grad for param in model.model.parameters()) + + +def test_text_generation(): + hparams = HParams() + model = Gpt2(hparams, hparams.device, hparams.amp) + tokenizer = AutoTokenizer.from_pretrained("gpt2-large") + + input_text = "Once upon a time" + input_ids = tokenizer.encode(input_text, return_tensors="pt") + + generated_texts = model.model.generate(input_ids, max_length=20, num_return_sequences=1) + + assert isinstance(generated_texts, list) + assert len(generated_texts) == 1 + assert isinstance(generated_texts[0], str) diff --git a/modyn/tests/models/test_modular_adapters.py b/modyn/tests/models/test_modular_adapters.py new file mode 100644 index 000000000..72d0a2210 --- /dev/null +++ b/modyn/tests/models/test_modular_adapters.py @@ -0,0 +1,156 @@ +import torch +from transformers import AutoTokenizer + +from modyn.models import Gpt2, apply_kadapter, apply_lora + +# ============================================================================= +# Updated Test Suite +# ============================================================================= + +# Note: The tests below assume that you're using a GPT-2 model from transformers. +# If you want to use your custom Gpt2 class from modyn.models, adjust the imports accordingly. + + +class HParams: + def __init__(self, model_name_or_path="gpt2-large", device="cpu", amp=False): + self.model_name_or_path = model_name_or_path + self.device = device + self.amp = amp + + +def test_apply_lora(): + hparams = HParams() + model = Gpt2(hparams, hparams.device, hparams.amp) + tokenizer = AutoTokenizer.from_pretrained("gpt2-large") + tokenizer.pad_token = tokenizer.eos_token + target_modules = ["c_attn", "c_proj"] + modified_model = apply_lora(model.model, target_modules=target_modules, adapter_dim=16, adapter_alpha=32) + # Verify that LoRA parameters are trainable and others are frozen. + for name, param in modified_model.model.named_parameters(): + if "lora" in name: + assert param.requires_grad, f"LoRA parameter {name} should be trainable." + else: + assert not param.requires_grad, f"Non-LoRA parameter {name} should be frozen." + + +def test_apply_kadapter(): + hparams = HParams() + model = Gpt2(hparams, hparams.device, hparams.amp) + tokenizer = AutoTokenizer.from_pretrained("gpt2-large") + tokenizer.pad_token = tokenizer.eos_token + modified_model = apply_kadapter(model.model) + # Check that our custom adapter is attached. + assert hasattr(modified_model, "kadapter"), "KAdapter not attached to the model." + for name, param in modified_model.kadapter.named_parameters(): + assert param.requires_grad, f"KAdapter parameter {name} should be trainable." + + +def test_model_with_adapters_inference(): + hparams = HParams() + model = Gpt2(hparams, hparams.device, hparams.amp) + tokenizer = AutoTokenizer.from_pretrained("gpt2-large") + tokenizer.pad_token = tokenizer.eos_token + model.model = apply_lora(model.model, target_modules=["c_attn", "c_proj"]) + model.model = apply_kadapter(model.model) + + input_text = "Hello, world!" + encoding = tokenizer(input_text, return_tensors="pt", padding=True) + input_ids = encoding["input_ids"] # shape: (1, seq_len) + attention_mask = encoding["attention_mask"] # shape: (1, seq_len) + data = torch.stack([input_ids, attention_mask], dim=-1) + model.model.eval() + with torch.no_grad(): + outputs = model.model(data) + logits = outputs.logits + assert logits.shape[-1] == tokenizer.vocab_size, "Output dimension mismatch." + + +def test_model_training_with_kadapters(): + hparams = HParams() + model = Gpt2(hparams, hparams.device, hparams.amp) + tokenizer = AutoTokenizer.from_pretrained("gpt2-large") + tokenizer.pad_token = tokenizer.eos_token + + model.model = apply_kadapter(model.model) + + input_text = "Once upon a time, a king had a dream." + + encoding = tokenizer(input_text, return_tensors="pt", padding=True) + input_ids = encoding["input_ids"] # shape: (1, seq_len) + attention_mask = encoding["attention_mask"] # shape: (1, seq_len) + data = torch.stack([input_ids, attention_mask], dim=-1) + labels = input_ids.clone() + optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.model.parameters()), lr=5e-5) + model.model.train() + with torch.no_grad(): + initial_outputs = model.model(data) + initial_logits = initial_outputs.logits.clone().detach() + for _ in range(2): + optimizer.zero_grad() + outputs = model.model(data, labels=labels) + loss = outputs.loss + loss.backward() + optimizer.step() + with torch.no_grad(): + final_outputs = model.model(data) + final_logits = final_outputs.logits.clone().detach() + assert loss.item() > 0, "Loss should be > 0 after training." + assert not torch.equal(initial_logits, final_logits), "Logits should change after training." + for name, param in model.model.named_parameters(): + if "kadapter" in name: + assert param.grad is not None, f"Expected gradient for {name} but found None." + else: + assert param.grad is None or torch.all(param.grad == 0), f"Unexpected gradient in {name}." + + +def test_model_training_with_lora(): + hparams = HParams() + model = Gpt2(hparams, hparams.device, hparams.amp) + tokenizer = AutoTokenizer.from_pretrained("gpt2-large") + tokenizer.pad_token = tokenizer.eos_token + target_modules = ["c_attn", "c_proj"] + model = apply_lora(model.model, target_modules=target_modules, adapter_dim=16, adapter_alpha=32) + + input_text = "Once upon a time, a king had a dream." + encoding = tokenizer(input_text, return_tensors="pt", padding=True) + input_ids = encoding["input_ids"] # shape: (1, seq_len) + attention_mask = encoding["attention_mask"] # shape: (1, seq_len) + data = torch.stack([input_ids, attention_mask], dim=-1) + labels = input_ids.clone() + + optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-5) + model.train() + + # Capture initial output + with torch.no_grad(): + initial_outputs = model(data) + initial_logits = initial_outputs.logits.clone().detach() + + # Small training loop + for _ in range(2): + optimizer.zero_grad() + outputs = model(data, labels=labels) + loss = outputs.loss + loss.backward() + optimizer.step() + + # Capture final output + with torch.no_grad(): + final_outputs = model(data) + final_logits = final_outputs.logits.clone().detach() + + # Check that loss is > 0 and that the model's output changes + assert loss.item() > 0, "Loss should be > 0 after training." + assert not torch.equal(initial_logits, final_logits), "Logits should change after training." + + # Verify that gradients exist only for LoRA parameters + for name, param in model.named_parameters(): + if "lora" in name: + assert param.grad is not None, f"Expected gradient for {name} but found None." + else: + assert param.grad is None or torch.all(param.grad == 0), f"Unexpected gradient in {name}." + + +if __name__ == "__main__": + test_model_training_with_lora() + print("All tests passed!") diff --git a/modyn/tests/storage/internal/file_wrapper/binary_file_wrapper_test.cpp b/modyn/tests/storage/internal/file_wrapper/binary_file_wrapper_test.cpp index 1e15446a5..6d08ca05d 100644 --- a/modyn/tests/storage/internal/file_wrapper/binary_file_wrapper_test.cpp +++ b/modyn/tests/storage/internal/file_wrapper/binary_file_wrapper_test.cpp @@ -255,31 +255,31 @@ TEST_F(BinaryFileWrapperTest, TestGetSamples) { EXPECT_CALL(*filesystem_wrapper_, get_stream(testing::_)).WillRepeatedly(testing::Return(stream_ptr)); BinaryFileWrapper file_wrapper(file_name_, config_, filesystem_wrapper_); - std::vector> samples = file_wrapper.get_samples(0, 3); + std::vector> samples = file_wrapper.get_samples(0, 3, /*include_labels=*/true); ASSERT_EQ(samples.size(), 4); ASSERT_EQ((samples)[0][0], 12); ASSERT_EQ((samples)[1][0], 13); ASSERT_EQ((samples)[2][0], 14); ASSERT_EQ((samples)[3][0], 15); - samples = file_wrapper.get_samples(1, 3); + samples = file_wrapper.get_samples(1, 3, /*include_labels=*/true); ASSERT_EQ(samples.size(), 3); ASSERT_EQ((samples)[0][0], 13); ASSERT_EQ((samples)[1][0], 14); ASSERT_EQ((samples)[2][0], 15); - samples = file_wrapper.get_samples(2, 3); + samples = file_wrapper.get_samples(2, 3, /*include_labels=*/true); ASSERT_EQ(samples.size(), 2); ASSERT_EQ((samples)[0][0], 14); ASSERT_EQ((samples)[1][0], 15); - samples = file_wrapper.get_samples(3, 3); + samples = file_wrapper.get_samples(3, 3, /*include_labels=*/true); ASSERT_EQ(samples.size(), 1); ASSERT_EQ((samples)[0][0], 15); - ASSERT_THROW(file_wrapper.get_samples(4, 3), modyn::utils::ModynException); + ASSERT_THROW(file_wrapper.get_samples(4, 3, /*include_labels=*/true), modyn::utils::ModynException); - samples = file_wrapper.get_samples(1, 2); + samples = file_wrapper.get_samples(1, 2, /*include_labels=*/true); ASSERT_EQ(samples.size(), 2); ASSERT_EQ((samples)[0][0], 13); ASSERT_EQ((samples)[1][0], 14); @@ -295,7 +295,8 @@ TEST_F(BinaryFileWrapperTest, TestGetSamplesFromIndices) { BinaryFileWrapper file_wrapper(file_name_, config_, filesystem_wrapper_); std::vector label_indices{0, 1, 2, 3}; - std::vector> samples = file_wrapper.get_samples_from_indices(label_indices); + std::vector> samples = + file_wrapper.get_samples_from_indices(label_indices, /*include_labels=*/true); ASSERT_EQ(samples.size(), 4); ASSERT_EQ((samples)[0][0], 12); ASSERT_EQ((samples)[1][0], 13); @@ -303,25 +304,25 @@ TEST_F(BinaryFileWrapperTest, TestGetSamplesFromIndices) { ASSERT_EQ((samples)[3][0], 15); label_indices = {1, 2, 3}; - samples = file_wrapper.get_samples_from_indices(label_indices); + samples = file_wrapper.get_samples_from_indices(label_indices, /*include_labels=*/true); ASSERT_EQ(samples.size(), 3); ASSERT_EQ((samples)[0][0], 13); ASSERT_EQ((samples)[1][0], 14); ASSERT_EQ((samples)[2][0], 15); label_indices = {2}; - samples = file_wrapper.get_samples_from_indices(label_indices); + samples = file_wrapper.get_samples_from_indices(label_indices, /*include_labels=*/true); ASSERT_EQ(samples.size(), 1); ASSERT_EQ((samples)[0][0], 14); label_indices = {1, 3}; - samples = file_wrapper.get_samples_from_indices(label_indices); + samples = file_wrapper.get_samples_from_indices(label_indices, /*include_labels=*/true); ASSERT_EQ(samples.size(), 2); ASSERT_EQ((samples)[0][0], 13); ASSERT_EQ((samples)[1][0], 15); label_indices = {3, 1, 3}; - samples = file_wrapper.get_samples_from_indices(label_indices); + samples = file_wrapper.get_samples_from_indices(label_indices, /*include_labels=*/true); ASSERT_EQ(samples.size(), 3); ASSERT_EQ((samples)[0][0], 15); ASSERT_EQ((samples)[1][0], 13); @@ -343,4 +344,20 @@ TEST_F(BinaryFileWrapperTest, TestDeleteSamples) { ASSERT_NO_THROW(file_wrapper.delete_samples(label_indices)); } +TEST_F(BinaryFileWrapperTest, TestGetSamplesFromIndicesWithoutLabels) { + EXPECT_CALL(*filesystem_wrapper_, get_file_size(testing::_)).WillOnce(testing::Return(16)); + const std::shared_ptr stream_ptr = std::make_shared(); + stream_ptr->open(file_name_, std::ios::binary); + ASSERT_TRUE(stream_ptr->is_open()); + + EXPECT_CALL(*filesystem_wrapper_, get_stream(testing::_)).WillOnce(testing::Return(stream_ptr)); + + BinaryFileWrapper file_wrapper(file_name_, config_, filesystem_wrapper_); + const std::vector indices = {0, 1}; + const std::vector> samples = + file_wrapper.get_samples_from_indices(indices, /*include_labels=*/false); + + ASSERT_EQ(samples.size(), 2); +} + } // namespace modyn::storage diff --git a/modyn/tests/storage/internal/file_wrapper/csv_file_wrapper_test.cpp b/modyn/tests/storage/internal/file_wrapper/csv_file_wrapper_test.cpp index e19bd998b..700459fe0 100644 --- a/modyn/tests/storage/internal/file_wrapper/csv_file_wrapper_test.cpp +++ b/modyn/tests/storage/internal/file_wrapper/csv_file_wrapper_test.cpp @@ -109,7 +109,8 @@ TEST_F(CsvFileWrapperTest, TestGetSamples) { {'J', 'a', 'n', 'e', ',', 'S', 'm', 'i', 't', 'h', ',', '3', '0'}, {'M', 'i', 'c', 'h', 'a', 'e', 'l', ',', 'J', 'o', 'h', 'n', 's', 'o', 'n', ',', '3', '5'}, }; - const std::vector> actual_samples = file_wrapper.get_samples(start, end); + const std::vector> actual_samples = + file_wrapper.get_samples(start, end, /*include_labels=*/true); ASSERT_EQ(actual_samples, expected_samples); } @@ -144,7 +145,8 @@ TEST_F(CsvFileWrapperTest, TestGetSamplesFromIndices) { {'J', 'o', 'h', 'n', ',', 'D', 'o', 'e', ',', '2', '5'}, {'M', 'i', 'c', 'h', 'a', 'e', 'l', ',', 'J', 'o', 'h', 'n', 's', 'o', 'n', ',', '3', '5'}, }; - const std::vector> actual_samples = file_wrapper.get_samples_from_indices(indices); + const std::vector> actual_samples = + file_wrapper.get_samples_from_indices(indices, /*include_labels=*/true); ASSERT_EQ(actual_samples, expected_samples); } @@ -175,3 +177,66 @@ TEST_F(CsvFileWrapperTest, TestDeleteSamples) { ASSERT_EQ(buffer, expected_samples[0]); } +TEST_F(CsvFileWrapperTest, TestGetSamplesFromIndicesWithoutLabels) { + // Create a test CSV file without labels + std::ofstream file_without_labels(file_name_); + file_without_labels << "first_name,last_name\n"; + file_without_labels << "John,Doe\n"; + file_without_labels << "Jane,Smith\n"; + file_without_labels << "Michael,Johnson\n"; + file_without_labels.close(); + ASSERT_TRUE(std::filesystem::exists(file_name_)); + + EXPECT_CALL(*filesystem_wrapper_, exists(testing::_)).WillOnce(testing::Return(true)); + const std::shared_ptr stream_ptr = std::make_shared(); + stream_ptr->open(file_name_, std::ios::binary); + ASSERT_TRUE(stream_ptr->is_open()); + + EXPECT_CALL(*filesystem_wrapper_, get_stream(testing::_)).WillOnce(testing::Return(stream_ptr)); + CsvFileWrapper file_wrapper{file_name_, config_, filesystem_wrapper_}; + + // Test get_samples_from_indices without labels + const std::vector indices = {0, 2}; + const std::vector> expected_samples = { + {'J', 'o', 'h', 'n', ',', 'D', 'o', 'e'}, + {'M', 'i', 'c', 'h', 'a', 'e', 'l', ',', 'J', 'o', 'h', 'n', 's', 'o', 'n'}, + }; + + const std::vector> actual_samples = + file_wrapper.get_samples_from_indices(indices, /*include_labels=*/false); + + ASSERT_EQ(actual_samples, expected_samples); +} + +TEST_F(CsvFileWrapperTest, TestGetSamplesWithoutLabels) { + // Create a test CSV file without labels + std::ofstream file_without_labels(file_name_); + file_without_labels << "first_name,last_name\n"; + file_without_labels << "John,Doe\n"; + file_without_labels << "Jane,Smith\n"; + file_without_labels << "Michael,Johnson\n"; + file_without_labels.close(); + ASSERT_TRUE(std::filesystem::exists(file_name_)); + + EXPECT_CALL(*filesystem_wrapper_, exists(testing::_)).WillOnce(testing::Return(true)); + const std::shared_ptr stream_ptr = std::make_shared(); + stream_ptr->open(file_name_, std::ios::binary); + ASSERT_TRUE(stream_ptr->is_open()); + + EXPECT_CALL(*filesystem_wrapper_, get_stream(testing::_)).WillOnce(testing::Return(stream_ptr)); + CsvFileWrapper file_wrapper{file_name_, config_, filesystem_wrapper_}; + + // Test get_samples without labels + const uint64_t start = 0; + const uint64_t end = 3; + const std::vector> expected_samples = { + {'J', 'o', 'h', 'n', ',', 'D', 'o', 'e'}, + {'J', 'a', 'n', 'e', ',', 'S', 'm', 'i', 't', 'h'}, + {'M', 'i', 'c', 'h', 'a', 'e', 'l', ',', 'J', 'o', 'h', 'n', 's', 'o', 'n'}, + }; + + const std::vector> actual_samples = + file_wrapper.get_samples(start, end, /*include_labels=*/false); + + ASSERT_EQ(actual_samples, expected_samples); +} diff --git a/modyn/tests/storage/internal/file_wrapper/mock_file_wrapper.hpp b/modyn/tests/storage/internal/file_wrapper/mock_file_wrapper.hpp index 0b58a07fb..16d418a1d 100644 --- a/modyn/tests/storage/internal/file_wrapper/mock_file_wrapper.hpp +++ b/modyn/tests/storage/internal/file_wrapper/mock_file_wrapper.hpp @@ -19,8 +19,8 @@ class MockFileWrapper : public FileWrapper { MOCK_METHOD(int64_t, get_label, (int64_t index), (override)); MOCK_METHOD(std::vector*, get_all_labels, (), (override)); MOCK_METHOD(std::vector*, get_sample, (int64_t index), (override)); - MOCK_METHOD(std::vector>*, get_samples_from_indices, (std::vector * indices), - (override)); + MOCK_METHOD(std::vector>*, get_samples_from_indices, + (std::vector * indices, true), (override)); MOCK_METHOD(FileWrapperType, get_type, (), (override)); MOCK_METHOD(void, validate_file_extension, (), (override)); MOCK_METHOD(void, delete_samples, (std::vector * indices), (override)); diff --git a/modyn/tests/storage/internal/file_wrapper/single_sample_file_wrapper_test.cpp b/modyn/tests/storage/internal/file_wrapper/single_sample_file_wrapper_test.cpp index ac3a1bd9a..ccdd81169 100644 --- a/modyn/tests/storage/internal/file_wrapper/single_sample_file_wrapper_test.cpp +++ b/modyn/tests/storage/internal/file_wrapper/single_sample_file_wrapper_test.cpp @@ -47,7 +47,7 @@ TEST(SingleSampleFileWrapperTest, TestGetSamples) { const std::shared_ptr filesystem_wrapper = std::make_shared(); EXPECT_CALL(*filesystem_wrapper, get(testing::_)).WillOnce(testing::Return(bytes)); ::SingleSampleFileWrapper file_wrapper = ::SingleSampleFileWrapper(file_name, config, filesystem_wrapper); - const std::vector> samples = file_wrapper.get_samples(0, 1); + const std::vector> samples = file_wrapper.get_samples(0, 1, /*include_labels=*/true); ASSERT_EQ(samples.size(), 1); ASSERT_EQ(samples[0].size(), 8); ASSERT_EQ((samples)[0][0], '1'); @@ -87,7 +87,8 @@ TEST(SingleSampleFileWrapperTest, TestGetSamplesFromIndices) { EXPECT_CALL(*filesystem_wrapper, get(testing::_)).WillOnce(testing::Return(bytes)); ::SingleSampleFileWrapper file_wrapper = ::SingleSampleFileWrapper(file_name, config, filesystem_wrapper); const std::vector indices = {0}; - const std::vector> samples = file_wrapper.get_samples_from_indices(indices); + const std::vector> samples = + file_wrapper.get_samples_from_indices(indices, /*include_labels*/ true); ASSERT_EQ(samples.size(), 1); ASSERT_EQ(samples[0].size(), 8); ASSERT_EQ((samples)[0][0], '1'); @@ -111,3 +112,26 @@ TEST(SingleSampleFileWrapperTest, TestDeleteSamples) { const std::vector indices = {0}; file_wrapper.delete_samples(indices); } +TEST(SingleSampleFileWrapperTest, TestGetSamplesFromIndicesWithoutLabels) { + const std::string file_name = "test.txt"; + const YAML::Node config = StorageTestUtils::get_dummy_file_wrapper_config(); + const std::vector bytes = {'1', '2', '3', '4', '5', '6', '7', '8'}; + const std::shared_ptr filesystem_wrapper = std::make_shared(); + EXPECT_CALL(*filesystem_wrapper, get(testing::_)).WillOnce(testing::Return(bytes)); + + ::SingleSampleFileWrapper file_wrapper = ::SingleSampleFileWrapper(file_name, config, filesystem_wrapper); + const std::vector indices = {0}; + const std::vector> samples = + file_wrapper.get_samples_from_indices(indices, /*include_labels=*/false); + + ASSERT_EQ(samples.size(), 1); + ASSERT_EQ(samples[0].size(), 8); + ASSERT_EQ((samples)[0][0], '1'); + ASSERT_EQ((samples)[0][1], '2'); + ASSERT_EQ((samples)[0][2], '3'); + ASSERT_EQ((samples)[0][3], '4'); + ASSERT_EQ((samples)[0][4], '5'); + ASSERT_EQ((samples)[0][5], '6'); + ASSERT_EQ((samples)[0][6], '7'); + ASSERT_EQ((samples)[0][7], '8'); +} diff --git a/modyn/tests/trainer_server/internal/trainer/test_pytorch_trainer.py b/modyn/tests/trainer_server/internal/trainer/test_pytorch_trainer.py index 7818a7417..f638c55c0 100644 --- a/modyn/tests/trainer_server/internal/trainer/test_pytorch_trainer.py +++ b/modyn/tests/trainer_server/internal/trainer/test_pytorch_trainer.py @@ -159,6 +159,7 @@ def mock_get_dataloaders( log_path, drop_last, num_batches: int = 100, + include_labels=True, ): mock_train_dataloader = MockDataloader(batch_size, num_batches) return mock_train_dataloader, None diff --git a/modyn/trainer_server/custom_lr_schedulers/WarmupDecayLR/Readme.txt b/modyn/trainer_server/custom_lr_schedulers/WarmupDecayLR/Readme.txt new file mode 100644 index 000000000..a7b91e433 --- /dev/null +++ b/modyn/trainer_server/custom_lr_schedulers/WarmupDecayLR/Readme.txt @@ -0,0 +1,219 @@ +# Info + +This folder contains the implementation of the learning rate scheduler used by the Deepspeed library. +We use the implementation provided by the Microsoft [repository](https://github.com/microsoft/DeepSpeed/blob/1640f6df4f6576636c472ba47ce76a507e3b4373/deepspeed/runtime/lr_schedules.py#L723) + +# Changes to Microsoft Version + +In according with the LICENSE of the NVIDIA code, we list our changes here: + +- We move the learning rate scheduler to a separate file. +- We rename the class `LearningRateScheduler` to `DLRMScheduler` +- We change the `__init__` function of the class to take as argument a configuration dictionary, containing all the arguments originally passed. +TODO find out what to put here exactly + +# Original License + +The code falls under the Apache 2.0 LICENSE + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/modyn/trainer_server/custom_lr_schedulers/WarmupDecayLR/_init_.py b/modyn/trainer_server/custom_lr_schedulers/WarmupDecayLR/_init_.py new file mode 100644 index 000000000..e3592b91e --- /dev/null +++ b/modyn/trainer_server/custom_lr_schedulers/WarmupDecayLR/_init_.py @@ -0,0 +1,7 @@ +"""DLRM LR Scheduler.""" + +import os + +files = os.listdir(os.path.dirname(__file__)) +files.remove("__init__.py") +__all__ = [f[:-3] for f in files if f.endswith(".py")] diff --git a/modyn/trainer_server/custom_lr_schedulers/WarmupDecayLR/warmupdecay.py b/modyn/trainer_server/custom_lr_schedulers/WarmupDecayLR/warmupdecay.py new file mode 100644 index 000000000..8d649e2b2 --- /dev/null +++ b/modyn/trainer_server/custom_lr_schedulers/WarmupDecayLR/warmupdecay.py @@ -0,0 +1,132 @@ +import math +from typing import Any + +from torch.optim import Optimizer + +WARMUP_LOG_RATE = "log" +WARMUP_LINEAR_RATE = "linear" + + +def get_torch_optimizer(optimizer: Any) -> Optimizer: + if isinstance(optimizer, Optimizer): + return optimizer + + if hasattr(optimizer, "optimizer") and isinstance(optimizer.optimizer, Optimizer): + return optimizer.optimizer + + raise TypeError(f"{type(optimizer).__name__} is not a subclass of torch.optim.Optimizer") + + +def update_lr(param_groups: list[dict[str, Any]], lrs: list[float]) -> list[float]: + for param_group, lr in zip(param_groups, lrs): + param_group["lr"] = lr + return [group["lr"] for group in param_groups] + + +class WarmupLR: + def __init__( + self, + optimizer: Optimizer, + warmup_min_lr: float | list[float] = 0.0, + warmup_max_lr: float | list[float] = 0.001, + warmup_num_steps: int = 1000, + warmup_type: str = WARMUP_LOG_RATE, + last_batch_iteration: int = -1, + ) -> None: + self.optimizer = get_torch_optimizer(optimizer) + + self.min_lrs = self._format_param(self.optimizer, warmup_min_lr, "min_lr") + self.max_lrs = self._format_param(self.optimizer, warmup_max_lr, "max_lr") + self.delta_lrs = [big - small for big, small in zip(self.max_lrs, self.min_lrs)] + self.warmup_num_steps = max(2, warmup_num_steps) + if warmup_type not in {WARMUP_LOG_RATE, WARMUP_LINEAR_RATE}: + print("Using unknown warmup_type. The increasing function " "is set to default (log)") + warmup_type = WARMUP_LOG_RATE + self.warmup_type = warmup_type + self.inverse_log_warm_up = 1.0 / math.log(self.warmup_num_steps) + self.last_batch_iteration = last_batch_iteration + if last_batch_iteration == -1: + self._last_lr = update_lr(self.optimizer.param_groups, self.get_lr()) + + def get_lr(self) -> list[float]: + if self.last_batch_iteration < 0: + print("Attempting to get learning rate from scheduler before it has started") + return self.min_lrs + gamma = self._get_gamma() + return [min_lr + (delta_lr * gamma) for min_lr, delta_lr in zip(self.min_lrs, self.delta_lrs)] + + def get_last_lr(self) -> list[float]: + assert getattr(self, "_last_lr", None) is not None, "need to call step() first" + return self._last_lr + + def step(self, last_batch_iteration: int | None = None) -> None: + if last_batch_iteration is None: + last_batch_iteration = self.last_batch_iteration + 1 + self.last_batch_iteration = last_batch_iteration + self._last_lr = update_lr(self.optimizer.param_groups, self.get_lr()) + + def state_dict(self) -> dict[str, int]: + return {"last_batch_iteration": self.last_batch_iteration} + + def load_state_dict(self, sd: dict[str, int]) -> None: + self.last_batch_iteration = sd["last_batch_iteration"] + + def _get_gamma(self) -> float: + if self.last_batch_iteration < self.warmup_num_steps: + if self.warmup_type == WARMUP_LOG_RATE: + return self.inverse_log_warm_up * math.log(self.last_batch_iteration + 1) + elif self.warmup_type == WARMUP_LINEAR_RATE: + return self.last_batch_iteration / self.warmup_num_steps + return 1.0 + + def _format_param(self, optimizer: Optimizer, param_value: float | list[float], param_name: str) -> list[float]: + if isinstance(param_value, list) or isinstance(param_value, tuple): + if len(param_value) != len(optimizer.param_groups): + raise ValueError( + f"expected {len(optimizer.param_groups)} value for {param_name}, got {FileNotFoundError(param_value)}" + ) + return list(param_value) + return [param_value] * len(optimizer.param_groups) + + +class WarmupDecayLR(WarmupLR): + def __init__(self, optimizers: list[Optimizer], scheduler_config: dict[str, Any]) -> None: + if len(optimizers) != 1: + raise ValueError("Only a single optimizer is supported.") + + self.optimizer = optimizers[0] + self.current_step = 0 + + self.total_num_steps = scheduler_config.get("total_num_steps", 10000) + self.warmup_min_lr = scheduler_config.get("warmup_min_lr", 0.0) + self.warmup_max_lr = scheduler_config.get("warmup_max_lr", 0.001) + self.warmup_num_steps = scheduler_config.get("warmup_num_steps", 1000) + self.warmup_type = scheduler_config.get("warmup_type", "log") + self.last_batch_iteration = scheduler_config.get("last_batch_iteration", -1) + + if self.total_num_steps < self.warmup_num_steps: + raise ValueError("total_num_steps must be greater than or equal to warmup_num_steps.") + + if self.warmup_min_lr >= self.warmup_max_lr: + raise ValueError("warmup_min_lr must be less than warmup_max_lr.") + + super().__init__( + self.optimizer, + self.warmup_min_lr, + self.warmup_max_lr, + self.warmup_num_steps, + self.warmup_type, + self.last_batch_iteration, + ) + + def _get_gamma(self) -> float: + if self.last_batch_iteration < self.warmup_num_steps: + if self.warmup_type == WARMUP_LOG_RATE: + return self.inverse_log_warm_up * math.log(self.last_batch_iteration + 1) + elif self.warmup_type == WARMUP_LINEAR_RATE: + return self.last_batch_iteration / self.warmup_num_steps + return max( + 0.0, + float(self.total_num_steps - self.last_batch_iteration) + / float(max(1.0, self.total_num_steps - self.warmup_num_steps)), + ) diff --git a/modyn/trainer_server/custom_lr_schedulers/__init__.py b/modyn/trainer_server/custom_lr_schedulers/__init__.py index a371b04da..92e2a276a 100644 --- a/modyn/trainer_server/custom_lr_schedulers/__init__.py +++ b/modyn/trainer_server/custom_lr_schedulers/__init__.py @@ -3,6 +3,7 @@ import os from .dlrm_lr_scheduler.dlrm_scheduler import DLRMScheduler # noqa: F401 +from .WarmupDecayLR.warmupdecay import WarmupDecayLR, WarmupLR # noqa: F401 files = os.listdir(os.path.dirname(__file__)) files.remove("__init__.py") diff --git a/modyn/evaluator/internal/grpc/generated/__init__.py b/modyn/trainer_server/custom_optimizers/__init__.py similarity index 55% rename from modyn/evaluator/internal/grpc/generated/__init__.py rename to modyn/trainer_server/custom_optimizers/__init__.py index ffacdd929..d53dc988c 100644 --- a/modyn/evaluator/internal/grpc/generated/__init__.py +++ b/modyn/trainer_server/custom_optimizers/__init__.py @@ -1,11 +1,9 @@ -"""Evaluator module. - -The evaluator module contains all classes and functions related the -evaluation of models. -""" +"""Learning rate schedulers.""" import os +from .recadam.recadam import RecAdam # noqa: F401 + files = os.listdir(os.path.dirname(__file__)) files.remove("__init__.py") __all__ = [f[:-3] for f in files if f.endswith(".py")] diff --git a/modyn/trainer_server/custom_optimizers/recadam/recadam.py b/modyn/trainer_server/custom_optimizers/recadam/recadam.py new file mode 100644 index 000000000..4348cc236 --- /dev/null +++ b/modyn/trainer_server/custom_optimizers/recadam/recadam.py @@ -0,0 +1,139 @@ +import logging +import math +from collections.abc import Callable +from typing import Any + +import numpy as np +import torch +from torch import Tensor +from torch.optim import Optimizer + +logger = logging.getLogger(__name__) + + +def anneal_function(function: str, step: int, k: float, t0: float, weight: float) -> float: + """Computes the annealing factor for RecAdam optimization.""" + if function == "sigmoid": + return float(1 / (1 + np.exp(-k * (step - t0)))) * weight + elif function == "linear": + return min(1, step / t0) * weight + elif function == "constant": + return weight + else: + raise ValueError(f"Invalid anneal function type: {function}") + + +class RecAdam(Optimizer): + """Implementation of RecAdam optimizer, a variant of the Adam optimizer.""" + + def __init__( + self, + params: list[Any] | list[dict[str, Any]], + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0.0, + correct_bias: bool = True, + anneal_fun: str = "sigmoid", + anneal_k: float = 0.0, + anneal_t0: float = 0.0, + anneal_w: float = 1.0, + pretrain_cof: float = 5000.0, + pretrain_params: list[Tensor] | None = None, + ) -> None: + """Initializes the RecAdam optimizer.""" + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0[") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0[") + if eps < 0.0: + raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") + + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + correct_bias=correct_bias, + anneal_fun=anneal_fun, + anneal_k=anneal_k, + anneal_t0=anneal_t0, + anneal_w=anneal_w, + pretrain_cof=pretrain_cof, + pretrain_params=pretrain_params, + ) + super().__init__(params, defaults) + + def step(self, closure: Callable[[], float] | None = None) -> float | None: + """Performs a single optimization step. + + Args: + closure (Callable, optional): A function that reevaluates the model + and returns the loss. Defaults to None. + + Returns: + Optional[float]: The loss value if a closure is provided, otherwise None. + """ + loss: float | None = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + pretrain_params: list[Tensor] | None = group["pretrain_params"] + if pretrain_params is None: + continue + + for p, pp in zip(group["params"], pretrain_params): + if p.grad is None: + continue + grad: Tensor = p.grad.data + if grad.is_sparse: + raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") + + state: dict[str, Any] = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data) + + exp_avg: Tensor = state["exp_avg"] + exp_avg_sq: Tensor = state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1.0 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) + denom: Tensor = exp_avg_sq.sqrt().add_(group["eps"]) + + step_size: float = group["lr"] + if group["correct_bias"]: + bias_correction1: float = 1.0 - beta1 ** state["step"] + bias_correction2: float = 1.0 - beta2 ** state["step"] + step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 + + # Apply RecAdam adjustments + if group["anneal_w"] > 0.0: + anneal_lambda: float = anneal_function( + group["anneal_fun"], state["step"], group["anneal_k"], group["anneal_t0"], group["anneal_w"] + ) + assert anneal_lambda <= group["anneal_w"] + p.data.addcdiv_(-step_size * anneal_lambda, exp_avg, denom) + p.data.add_( + -group["lr"] * (group["anneal_w"] - anneal_lambda) * group["pretrain_cof"], p.data - pp.data + ) + else: + p.data.addcdiv_(-step_size, exp_avg, denom) + + # Apply weight decay + if group["weight_decay"] > 0.0: + p.data.add_(-group["lr"] * group["weight_decay"], p.data) + + return loss diff --git a/modyn/trainer_server/internal/dataset/data_utils.py b/modyn/trainer_server/internal/dataset/data_utils.py index 73edf7eca..572584e04 100644 --- a/modyn/trainer_server/internal/dataset/data_utils.py +++ b/modyn/trainer_server/internal/dataset/data_utils.py @@ -28,6 +28,7 @@ def prepare_dataloaders( tokenizer: str | None, log_path: pathlib.Path | None, drop_last: bool = True, + include_labels: bool = True, ) -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader | None]: """Gets the proper dataset according to the dataset id, and creates the proper dataloaders. @@ -63,6 +64,7 @@ def prepare_dataloaders( shuffle, tokenizer, log_path, + include_labels, ) logger.debug("Creating DataLoader.") train_dataloader = torch.utils.data.DataLoader( @@ -100,5 +102,6 @@ def prepare_per_class_dataloader_from_online_dataset( online_dataset._parallel_prefetch_requests, online_dataset._shuffle, online_dataset._tokenizer_name, + online_dataset._include_labels, ) return torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, drop_last=drop_last) diff --git a/modyn/trainer_server/internal/dataset/online_dataset.py b/modyn/trainer_server/internal/dataset/online_dataset.py index 48856aaa8..c6947967d 100644 --- a/modyn/trainer_server/internal/dataset/online_dataset.py +++ b/modyn/trainer_server/internal/dataset/online_dataset.py @@ -12,7 +12,14 @@ from typing import Any, cast import grpc -from tenacity import Retrying, after_log, before_log, retry, stop_after_attempt, wait_random_exponential +from tenacity import ( + Retrying, + after_log, + before_log, + retry, + stop_after_attempt, + wait_random_exponential, +) from torch.utils.data import IterableDataset, get_worker_info from torchvision import transforms @@ -22,7 +29,10 @@ GetResponse, ) from modyn.storage.internal.grpc.generated.storage_pb2_grpc import StorageStub -from modyn.trainer_server.internal.dataset.key_sources import AbstractKeySource, SelectorKeySource +from modyn.trainer_server.internal.dataset.key_sources import ( + AbstractKeySource, + SelectorKeySource, +) from modyn.utils import ( BYTES_PARSER_FUNC_NAME, deserialize_function, @@ -53,7 +63,9 @@ def __init__( shuffle: bool, tokenizer: str | None, log_path: pathlib.Path | None, + include_labels: bool = True, ): + self._include_labels = include_labels self._pipeline_id = pipeline_id self._trigger_id = trigger_id self._training_id = training_id @@ -147,7 +159,8 @@ def _init_grpc(self, worker_id: int | None = None) -> None: # pragma: no cover def _silence_pil(self) -> None: # pragma: no cover pil_logger = logging.getLogger("PIL") - pil_logger.setLevel(logging.INFO) # by default, PIL on DEBUG spams the console + # by default, PIL on DEBUG spams the console + pil_logger.setLevel(logging.INFO) def _info(self, msg: str, worker_id: int | None) -> None: # pragma: no cover logger.info(f"[Training {self._training_id}][PL {self._pipeline_id}][Worker {worker_id}] {msg}") @@ -155,45 +168,64 @@ def _info(self, msg: str, worker_id: int | None) -> None: # pragma: no cover def _debug(self, msg: str, worker_id: int | None) -> None: # pragma: no cover logger.debug(f"[Training {self._training_id}][PL {self._pipeline_id}][Worker {worker_id}] {msg}") + # pylint: disable=too-many-locals def _get_data_from_storage( self, selector_keys: list[int], worker_id: int | None = None - ) -> Iterator[tuple[list[int], list[bytes], list[int], int]]: + ) -> Iterator[tuple[list[int], list[bytes], list[int] | None, int]]: processed_keys: set[int] | list[int] = [] has_failed = False for attempt in Retrying( - stop=stop_after_attempt(5), wait=wait_random_exponential(multiplier=1, min=2, max=60), reraise=True + stop=stop_after_attempt(5), + wait=wait_random_exponential(multiplier=1, min=2, max=60), + reraise=True, ): with attempt: try: - req = GetRequest(dataset_id=self._dataset_id, keys=selector_keys) - stopw = Stopwatch() + new_keys: list[int] + new_samples: list[bytes] + response_time: int + # Select appropriate response type response: GetResponse + + # response: GetResponse # type: ignore[no-redef] + rpc_call = self._storagestub.Get + + # Request setup + req = GetRequest(dataset_id=self._dataset_id, keys=selector_keys) + stopw = Stopwatch() stopw.start("ResponseTime", overwrite=True) - for response in self._storagestub.Get(req): + + for response in rpc_call(req): response_time = stopw.stop("ResponseTime") keys = list(response.keys) + if not has_failed: assert isinstance(processed_keys, list) processed_keys.extend(keys) - yield keys, list(response.samples), list(response.labels), response_time - else: # If we have failed, we need to filter out yielded samples - # Note that the returned order by storage is non-deterministic + if not self._include_labels: + yield keys, list(response.samples), None, response_time + else: + yield keys, list(response.samples), list(response.labels), response_time + else: # Handle failure and deduplication assert isinstance(processed_keys, set) - new_keys: list[int] = [key for key in keys if key not in processed_keys] - new_samples: list[bytes] = [ + new_keys = [key for key in keys if key not in processed_keys] + new_samples = [ sample for key, sample in zip(keys, response.samples) if key not in processed_keys ] - new_labels: list[int] = [ - label for key, label in zip(keys, response.labels) if key not in processed_keys - ] processed_keys.update(keys) - yield new_keys, new_samples, new_labels, response_time - stopw.start("ResponseTime", overwrite=True) + if self._include_labels: + yield new_keys, new_samples, None, response_time + else: + new_labels = [ + label for key, label in zip(keys, response.labels) if key not in processed_keys + ] + yield new_keys, new_samples, new_labels, response_time - except grpc.RpcError as e: # We catch and reraise to reconnect to rpc and do logging + stopw.start("ResponseTime", overwrite=True) + except grpc.RpcError as e: has_failed = True # Convert processed keys to set on first failure processed_keys = set(processed_keys) if isinstance(processed_keys, list) else processed_keys @@ -202,7 +234,10 @@ def _get_data_from_storage( worker_id, ) self._info(f"Stringified exception: {str(e)}", worker_id) - self._info(f"Error occurred while asking {self._dataset_id} for keys:\n{selector_keys}", worker_id) + self._info( + f"Error occurred while asking {self._dataset_id} for keys:\n{selector_keys}", + worker_id, + ) self._init_grpc(worker_id=worker_id) raise e @@ -223,7 +258,8 @@ def _get_data( get_data_log = {} self._sw.start(f"GetKeysAndWeightsPart{partition_id}", overwrite=True) keys, weights = self._key_source.get_keys_and_weights( - worker_id, shuffled_partition_id if shuffled_partition_id is not None else partition_id + worker_id, + (shuffled_partition_id if shuffled_partition_id is not None else partition_id), ) get_data_log["get_keys_and_weights"] = self._sw.stop(f"GetKeysAndWeightsPart{partition_id}") get_data_log["num_items"] = len(keys) @@ -231,29 +267,48 @@ def _get_data( self._info("Getting data from storage", worker_id) self._sw.start(f"GetDataPart{partition_id}", overwrite=True) all_response_times = [] - key_weight_map = {key: weights[idx] for idx, key in enumerate(keys)} if weights is not None else None + if not self._include_labels: + for data_tuple_gen in self._get_data_from_storage(keys, worker_id=worker_id): + stor_keys, data, _, response_time = data_tuple_gen + + all_response_times.append(response_time) + num_items = len(stor_keys) + with partition_locks[partition_id] if partition_locks is not None else contextlib.suppress(): + data_container["data"].extend(data) + data_container["keys"].extend(stor_keys) + data_container["weights"].extend( + [cast(float | None, key_weight_map[key]) for key in stor_keys] + if key_weight_map is not None + else [None for _ in range(len(stor_keys))] + ) + if partition_valid_until is not None: + partition_valid_until[partition_id] += num_items - for data_tuple in self._get_data_from_storage(keys, worker_id=worker_id): - stor_keys, data, labels, response_time = data_tuple - - all_response_times.append(response_time) - num_items = len(stor_keys) - with partition_locks[partition_id] if partition_locks is not None else contextlib.suppress(): - data_container["data"].extend(data) - data_container["keys"].extend(stor_keys) - data_container["labels"].extend(labels) - data_container["weights"].extend( - [cast(float | None, key_weight_map[key]) for key in stor_keys] - if key_weight_map is not None - else [None for _ in range(len(stor_keys))] - ) - if partition_valid_until is not None: - partition_valid_until[partition_id] += num_items + if partition_signals is not None: + with partition_signals[partition_id]: + partition_signals[partition_id].notify_all() + else: + for data_tuple in self._get_data_from_storage(keys, worker_id=worker_id): + stor_keys, data, labels, response_time = data_tuple + + all_response_times.append(response_time) + num_items = len(stor_keys) + with partition_locks[partition_id] if partition_locks is not None else contextlib.suppress(): + data_container["data"].extend(data) + data_container["keys"].extend(stor_keys) + data_container["labels"].extend(labels) + data_container["weights"].extend( + [cast(float | None, key_weight_map[key]) for key in stor_keys] + if key_weight_map is not None + else [None for _ in range(len(stor_keys))] + ) + if partition_valid_until is not None: + partition_valid_until[partition_id] += num_items - if partition_signals is not None: - with partition_signals[partition_id]: - partition_signals[partition_id].notify_all() + if partition_signals is not None: + with partition_signals[partition_id]: + partition_signals[partition_id].notify_all() get_data_log["get_data"] = self._sw.stop(f"GetDataPart{partition_id}") get_data_log["response_times"] = all_response_times @@ -270,13 +325,19 @@ def _get_data( callback() def _get_transformed_data_tuple( - self, key: int, sample: memoryview, label: int, weight: float | None + self, key: int, sample: memoryview, label: int | None = None, weight: float | None = None ) -> tuple | None: assert self._uses_weights is not None self._sw.start("transform", resume=True) - # mypy complains here because _transform has unknown type, which is ok transformed_sample = self._transform(sample) # type: ignore self._sw.stop("transform") + + if not self._include_labels: + if self._uses_weights: + return key, transformed_sample, weight + return key, transformed_sample + + # Non-include_labels case with labels if self._uses_weights: return key, transformed_sample, label, weight return key, transformed_sample, label @@ -324,18 +385,23 @@ def _prefetch_partition(self, worker_id: int, maybe_continue: bool = False) -> N if maybe_continue: # Do this as early as possible to avoid running into the "problem" above frequently self._launched_prefetches += 1 - assert self._next_partition_to_fetch >= 0 assert ( self._next_partition_to_fetch not in self._data_threads ), f"Prefetching for partition {self._next_partition_to_fetch} has already been started" - - self._thread_data_container[self._next_partition_to_fetch] = { - "data": [], - "keys": [], - "labels": [], - "weights": [], - } + if self._include_labels: + self._thread_data_container[self._next_partition_to_fetch] = { + "data": [], + "keys": [], + "labels": [], + "weights": [], + } + else: + self._thread_data_container[self._next_partition_to_fetch] = { + "data": [], + "keys": [], + "weights": [], + } self._partition_valid[self._next_partition_to_fetch] = False self._partition_valid_until[self._next_partition_to_fetch] = -1 self._partition_locks[self._next_partition_to_fetch] = threading.Lock() @@ -391,23 +457,51 @@ def callback_func() -> None: def _fetch_partition_noprefetch( self, worker_id: int, partition_id: int - ) -> Iterator[tuple[int, memoryview, int, float | None]]: + ) -> Iterator[tuple[int, memoryview, int, float | None]] | Iterator[tuple[int, memoryview, float | None]]: assert self._num_prefetched_partitions < 1 - container: dict[str, Any] = {"data": [], "keys": [], "labels": [], "weights": []} + container: dict[str, Any] = { + "data": [], + "keys": [], + "weights": [], + } + if self._include_labels: + container["labels"] = [] + shuffle_partition_id = self._shuffled_partition_indices[partition_id] if self._shuffle else None - self._get_data(container, worker_id, partition_id, None, None, None, None, None, shuffle_partition_id) - assert "data" in container and "labels" in container and "keys" in container and "weights" in container + + self._get_data( + container, + worker_id, + partition_id, + None, + None, + None, + None, + None, + shuffle_partition_id, + ) + + assert "data" in container and "keys" in container and "weights" in container + if self._include_labels: + assert "labels" in container if self._shuffle: self._shuffle_partition(partition_id, worker_id, container=container) for idx in range(len(container["keys"])): - yield ( - container["keys"][idx], - memoryview(container["data"][idx]), - container["labels"][idx], - container["weights"][idx], - ) + if not self._include_labels: + yield ( + container["keys"][idx], + memoryview(container["data"][idx]), + container["weights"][idx], + ) + else: + yield ( + container["keys"][idx], + memoryview(container["data"][idx]), + container["labels"][idx], + container["weights"][idx], + ) def _is_partition_fetched(self, partition_id: int) -> bool: if partition_id not in self._partition_locks or partition_id not in self._partition_valid: @@ -422,14 +516,22 @@ def _partition_max_index(self, partition_id: int) -> int: def _get_partition_data( self, last_idx: int, max_idx: int, partition_id: int - ) -> Iterator[tuple[int, memoryview, int, float | None]]: - for idx in range(last_idx + 1, max_idx + 1): - yield ( - self._thread_data_container[partition_id]["keys"][idx], - memoryview(self._thread_data_container[partition_id]["data"][idx]), - self._thread_data_container[partition_id]["labels"][idx], - self._thread_data_container[partition_id]["weights"][idx], - ) + ) -> Iterator[tuple[int, memoryview, int, float | None]] | Iterator[tuple[int, memoryview, float | None]]: + if not self._include_labels: + for idx in range(last_idx + 1, max_idx + 1): + yield ( + self._thread_data_container[partition_id]["keys"][idx], + memoryview(self._thread_data_container[partition_id]["data"][idx]), + self._thread_data_container[partition_id]["weights"][idx], + ) + else: + for idx in range(last_idx + 1, max_idx + 1): + yield ( + self._thread_data_container[partition_id]["keys"][idx], + memoryview(self._thread_data_container[partition_id]["data"][idx]), + self._thread_data_container[partition_id]["labels"][idx], + self._thread_data_container[partition_id]["weights"][idx], + ) def _wait_for_new_partition_data(self, partition_id: int) -> None: with self._partition_signals[partition_id]: @@ -447,21 +549,18 @@ def _shuffle_partition(self, partition_id: int, worker_id: int, container: dict indices = list(range(data_length)) random.shuffle(indices) - new_data = [container["data"][i] for i in indices] - new_keys = [container["keys"][i] for i in indices] - new_labels = [container["labels"][i] for i in indices] - new_weights = [container["weights"][i] for i in indices] + container["data"] = [container["data"][i] for i in indices] + container["keys"] = [container["keys"][i] for i in indices] + container["weights"] = [container["weights"][i] for i in indices] - container["data"] = new_data - container["keys"] = new_keys - container["labels"] = new_labels - container["weights"] = new_weights + if self._include_labels: + container["labels"] = [container["labels"][i] for i in indices] self._info(f"Shuffled partition {partition_id}", worker_id) def prefetched_partition_generator( self, worker_id: int, partition_id: int - ) -> Iterator[tuple[int, memoryview, int, float | None]]: + ) -> Iterator[tuple[int, memoryview, int, float | None]] | Iterator[tuple[int, memoryview, float | None]]: last_idx = -1 if not self._shuffle: # If we do not shuffle, we can emit data as soon as it streamed over @@ -470,13 +569,11 @@ def prefetched_partition_generator( max_idx = self._partition_max_index(partition_id) if max_idx <= last_idx: # No new data self._wait_for_new_partition_data(partition_id) - yield from self._get_partition_data(last_idx, max_idx, partition_id) last_idx = max_idx else: while not self._is_partition_fetched(partition_id): self._wait_for_new_partition_data(partition_id) - self._shuffle_partition(partition_id, worker_id) # Yield potential remaining data (when not shuffling) or all data (when shuffling) @@ -504,18 +601,20 @@ def start_prefetching(self, worker_id: int) -> None: for _ in range(self._parallel_prefetch_requests): self._prefetch_partition(worker_id, True) - def all_partition_generator(self, worker_id: int) -> Iterator[tuple[int, memoryview, int, float | None]]: + def all_partition_generator( + self, worker_id: int + ) -> Iterator[tuple[int, memoryview, int, float | None]] | Iterator[tuple[int, memoryview, float | None]]: self.start_prefetching(worker_id) for partition_id in range(self._num_partitions): self._persist_log(worker_id) - if self._num_prefetched_partitions > 0: if partition_id < self._num_partitions - 1: # As we consume one partition, prefetch exactly one more partition self._prefetch_partition(worker_id, False) yield from self.prefetched_partition_generator(worker_id, partition_id) + else: yield from self._fetch_partition_noprefetch(worker_id, partition_id) @@ -562,7 +661,10 @@ def __iter__(self) -> Generator: if self._shuffle: self._shuffled_partition_indices = list(range(0, self._num_partitions)) random.shuffle(self._shuffled_partition_indices) - self._info(f"Shuffled partitions into random order: {self._shuffled_partition_indices}", worker_id) + self._info( + f"Shuffled partitions into random order: {self._shuffled_partition_indices}", + worker_id, + ) self._info( f"Total number of partitions will be {self._num_partitions}.\n" diff --git a/modyn/trainer_server/internal/dataset/per_class_online_dataset.py b/modyn/trainer_server/internal/dataset/per_class_online_dataset.py index 6bcd51c36..b833126f8 100644 --- a/modyn/trainer_server/internal/dataset/per_class_online_dataset.py +++ b/modyn/trainer_server/internal/dataset/per_class_online_dataset.py @@ -23,6 +23,7 @@ def __init__( parallel_prefetch_requests: int, shuffle: bool, tokenizer: str | None, + include_labels: bool = True, ): super().__init__( pipeline_id, @@ -37,12 +38,14 @@ def __init__( parallel_prefetch_requests, shuffle, tokenizer, - None, + include_labels, # type: ignore[arg-type] ) assert initial_filtered_label is not None self.filtered_label = initial_filtered_label - def _get_transformed_data_tuple(self, key: int, sample: bytes, label: int, weight: float | None) -> tuple | None: + def _get_transformed_data_tuple( + self, key: int, sample: memoryview, label: int | None = None, weight: float | None = None + ) -> tuple | None: assert self.filtered_label is not None if self.filtered_label != label: diff --git a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.py b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.py index c3a21053b..ae7828750 100644 --- a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.py +++ b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.py @@ -1,21 +1,24 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE # source: trainer_server.proto -# Protobuf Python Version: 5.26.1 +# Protobuf Python Version: 5.27.2 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion(_runtime_version.Domain.PUBLIC, 5, 27, 2, "", "trainer_server.proto") # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x14trainer_server.proto\x12\x07trainer"\x1b\n\nJsonString\x12\r\n\x05value\x18\x01 \x01(\t"\x1d\n\x0cPythonString\x12\r\n\x05value\x18\x01 \x01(\t"3\n\x04\x44\x61ta\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x17\n\x0fnum_dataloaders\x18\x02 \x01(\x05"\x19\n\x17TrainerAvailableRequest"-\n\x18TrainerAvailableResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08"F\n\x0e\x43heckpointInfo\x12\x1b\n\x13\x63heckpoint_interval\x18\x01 \x01(\x05\x12\x17\n\x0f\x63heckpoint_path\x18\x02 \x01(\t"\xbb\x07\n\x14StartTrainingRequest\x12\x13\n\x0bpipeline_id\x18\x01 \x01(\x05\x12\x12\n\ntrigger_id\x18\x02 \x01(\x05\x12\x0e\n\x06\x64\x65vice\x18\x03 \x01(\t\x12\x1c\n\x14use_pretrained_model\x18\x04 \x01(\x08\x12\x1c\n\x14load_optimizer_state\x18\x05 \x01(\x08\x12\x1b\n\x13pretrained_model_id\x18\x06 \x01(\x05\x12\x12\n\nbatch_size\x18\x07 \x01(\x05\x12;\n\x1etorch_optimizers_configuration\x18\x08 \x01(\x0b\x32\x13.trainer.JsonString\x12\x17\n\x0ftorch_criterion\x18\t \x01(\t\x12\x31\n\x14\x63riterion_parameters\x18\n \x01(\x0b\x32\x13.trainer.JsonString\x12 \n\tdata_info\x18\x0b \x01(\x0b\x32\r.trainer.Data\x12\x30\n\x0f\x63heckpoint_info\x18\x0c \x01(\x0b\x32\x17.trainer.CheckpointInfo\x12+\n\x0c\x62ytes_parser\x18\r \x01(\x0b\x32\x15.trainer.PythonString\x12\x16\n\x0etransform_list\x18\x0e \x03(\t\x12)\n\x0clr_scheduler\x18\x0f \x01(\x0b\x32\x13.trainer.JsonString\x12\x30\n\x11label_transformer\x18\x10 \x01(\x0b\x32\x15.trainer.PythonString\x12\x36\n\x19grad_scaler_configuration\x18\x11 \x01(\x0b\x32\x13.trainer.JsonString\x12\x1a\n\x12\x65pochs_per_trigger\x18\x12 \x01(\x05\x12!\n\x19num_prefetched_partitions\x18\x13 \x01(\x05\x12"\n\x1aparallel_prefetch_requests\x18\x14 \x01(\x05\x12\x11\n\x04seed\x18\x15 \x01(\x05H\x00\x88\x01\x01\x12-\n\ttokenizer\x18\x16 \x01(\x0b\x32\x15.trainer.PythonStringH\x01\x88\x01\x01\x12\x1b\n\x13num_samples_to_pass\x18\x17 \x01(\x03\x12\x0f\n\x07shuffle\x18\x18 \x01(\x08\x12(\n enable_accurate_gpu_measurements\x18\x19 \x01(\x08\x12\x19\n\x11record_loss_every\x18\x1a \x01(\x03\x12\x17\n\x0f\x64rop_last_batch\x18\x1b \x01(\x08\x42\x07\n\x05_seedB\x0c\n\n_tokenizer"F\n\x15StartTrainingResponse\x12\x18\n\x10training_started\x18\x01 \x01(\x08\x12\x13\n\x0btraining_id\x18\x02 \x01(\x05",\n\x15TrainingStatusRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05"\xa6\x03\n\x16TrainingStatusResponse\x12\r\n\x05valid\x18\x01 \x01(\x08\x12\x12\n\nis_running\x18\x02 \x01(\x08\x12\x13\n\x0bis_training\x18\x03 \x01(\x08\x12\x17\n\x0fstate_available\x18\x04 \x01(\x08\x12\x0f\n\x07\x62locked\x18\x05 \x01(\x08\x12 \n\x03log\x18\x06 \x01(\x0b\x32\x13.trainer.JsonString\x12\x16\n\texception\x18\x07 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0c\x62\x61tches_seen\x18\x08 \x01(\x03H\x01\x88\x01\x01\x12\x19\n\x0csamples_seen\x18\t \x01(\x03H\x02\x88\x01\x01\x12&\n\x19\x64ownsampling_batches_seen\x18\n \x01(\x03H\x03\x88\x01\x01\x12&\n\x19\x64ownsampling_samples_seen\x18\x0b \x01(\x03H\x04\x88\x01\x01\x42\x0c\n\n_exceptionB\x0f\n\r_batches_seenB\x0f\n\r_samples_seenB\x1c\n\x1a_downsampling_batches_seenB\x1c\n\x1a_downsampling_samples_seen"-\n\x16StoreFinalModelRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05"@\n\x17StoreFinalModelResponse\x12\x13\n\x0bvalid_state\x18\x01 \x01(\x08\x12\x10\n\x08model_id\x18\x02 \x01(\x05",\n\x15GetLatestModelRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05"A\n\x16GetLatestModelResponse\x12\x13\n\x0bvalid_state\x18\x01 \x01(\x08\x12\x12\n\nmodel_path\x18\x02 \x01(\t2\xc9\x03\n\rTrainerServer\x12Z\n\x11trainer_available\x12 .trainer.TrainerAvailableRequest\x1a!.trainer.TrainerAvailableResponse"\x00\x12Q\n\x0estart_training\x12\x1d.trainer.StartTrainingRequest\x1a\x1e.trainer.StartTrainingResponse"\x00\x12X\n\x13get_training_status\x12\x1e.trainer.TrainingStatusRequest\x1a\x1f.trainer.TrainingStatusResponse"\x00\x12X\n\x11store_final_model\x12\x1f.trainer.StoreFinalModelRequest\x1a .trainer.StoreFinalModelResponse"\x00\x12U\n\x10get_latest_model\x12\x1e.trainer.GetLatestModelRequest\x1a\x1f.trainer.GetLatestModelResponse"\x00\x62\x06proto3' + b'\n\x14trainer_server.proto\x12\x07trainer"\x1b\n\nJsonString\x12\r\n\x05value\x18\x01 \x01(\t"\x1d\n\x0cPythonString\x12\r\n\x05value\x18\x01 \x01(\t"3\n\x04\x44\x61ta\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x17\n\x0fnum_dataloaders\x18\x02 \x01(\x05"\x19\n\x17TrainerAvailableRequest"-\n\x18TrainerAvailableResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08"F\n\x0e\x43heckpointInfo\x12\x1b\n\x13\x63heckpoint_interval\x18\x01 \x01(\x05\x12\x17\n\x0f\x63heckpoint_path\x18\x02 \x01(\t"\x95\x08\n\x14StartTrainingRequest\x12\x13\n\x0bpipeline_id\x18\x01 \x01(\x05\x12\x12\n\ntrigger_id\x18\x02 \x01(\x05\x12\x0e\n\x06\x64\x65vice\x18\x03 \x01(\t\x12\x1c\n\x14use_pretrained_model\x18\x04 \x01(\x08\x12\x1c\n\x14load_optimizer_state\x18\x05 \x01(\x08\x12\x1b\n\x13pretrained_model_id\x18\x06 \x01(\x05\x12\x12\n\nbatch_size\x18\x07 \x01(\x05\x12;\n\x1etorch_optimizers_configuration\x18\x08 \x01(\x0b\x32\x13.trainer.JsonString\x12\x17\n\x0ftorch_criterion\x18\t \x01(\t\x12\x31\n\x14\x63riterion_parameters\x18\n \x01(\x0b\x32\x13.trainer.JsonString\x12 \n\tdata_info\x18\x0b \x01(\x0b\x32\r.trainer.Data\x12\x30\n\x0f\x63heckpoint_info\x18\x0c \x01(\x0b\x32\x17.trainer.CheckpointInfo\x12+\n\x0c\x62ytes_parser\x18\r \x01(\x0b\x32\x15.trainer.PythonString\x12\x16\n\x0etransform_list\x18\x0e \x03(\t\x12)\n\x0clr_scheduler\x18\x0f \x01(\x0b\x32\x13.trainer.JsonString\x12\x30\n\x11label_transformer\x18\x10 \x01(\x0b\x32\x15.trainer.PythonString\x12\x36\n\x19grad_scaler_configuration\x18\x11 \x01(\x0b\x32\x13.trainer.JsonString\x12\x1a\n\x12\x65pochs_per_trigger\x18\x12 \x01(\x05\x12!\n\x19num_prefetched_partitions\x18\x13 \x01(\x05\x12"\n\x1aparallel_prefetch_requests\x18\x14 \x01(\x05\x12\x11\n\x04seed\x18\x15 \x01(\x05H\x00\x88\x01\x01\x12-\n\ttokenizer\x18\x16 \x01(\x0b\x32\x15.trainer.PythonStringH\x01\x88\x01\x01\x12\x1b\n\x13num_samples_to_pass\x18\x17 \x01(\x03\x12\x0f\n\x07shuffle\x18\x18 \x01(\x08\x12(\n enable_accurate_gpu_measurements\x18\x19 \x01(\x08\x12\x19\n\x11record_loss_every\x18\x1a \x01(\x03\x12\x17\n\x0f\x64rop_last_batch\x18\x1b \x01(\x08\x12\x12\n\ngenerative\x18\x1c \x01(\x08\x12\x16\n\tgrad_norm\x18\x1d \x01(\x02H\x02\x88\x01\x01\x12\x0c\n\x04lora\x18\x1e \x01(\x08\x12\x10\n\x08kadapter\x18\x1f \x01(\x08\x42\x07\n\x05_seedB\x0c\n\n_tokenizerB\x0c\n\n_grad_norm"F\n\x15StartTrainingResponse\x12\x18\n\x10training_started\x18\x01 \x01(\x08\x12\x13\n\x0btraining_id\x18\x02 \x01(\x05",\n\x15TrainingStatusRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05"\xa6\x03\n\x16TrainingStatusResponse\x12\r\n\x05valid\x18\x01 \x01(\x08\x12\x12\n\nis_running\x18\x02 \x01(\x08\x12\x13\n\x0bis_training\x18\x03 \x01(\x08\x12\x17\n\x0fstate_available\x18\x04 \x01(\x08\x12\x0f\n\x07\x62locked\x18\x05 \x01(\x08\x12 \n\x03log\x18\x06 \x01(\x0b\x32\x13.trainer.JsonString\x12\x16\n\texception\x18\x07 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0c\x62\x61tches_seen\x18\x08 \x01(\x03H\x01\x88\x01\x01\x12\x19\n\x0csamples_seen\x18\t \x01(\x03H\x02\x88\x01\x01\x12&\n\x19\x64ownsampling_batches_seen\x18\n \x01(\x03H\x03\x88\x01\x01\x12&\n\x19\x64ownsampling_samples_seen\x18\x0b \x01(\x03H\x04\x88\x01\x01\x42\x0c\n\n_exceptionB\x0f\n\r_batches_seenB\x0f\n\r_samples_seenB\x1c\n\x1a_downsampling_batches_seenB\x1c\n\x1a_downsampling_samples_seen"-\n\x16StoreFinalModelRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05"@\n\x17StoreFinalModelResponse\x12\x13\n\x0bvalid_state\x18\x01 \x01(\x08\x12\x10\n\x08model_id\x18\x02 \x01(\x05",\n\x15GetLatestModelRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05"A\n\x16GetLatestModelResponse\x12\x13\n\x0bvalid_state\x18\x01 \x01(\x08\x12\x12\n\nmodel_path\x18\x02 \x01(\t2\xc9\x03\n\rTrainerServer\x12Z\n\x11trainer_available\x12 .trainer.TrainerAvailableRequest\x1a!.trainer.TrainerAvailableResponse"\x00\x12Q\n\x0estart_training\x12\x1d.trainer.StartTrainingRequest\x1a\x1e.trainer.StartTrainingResponse"\x00\x12X\n\x13get_training_status\x12\x1e.trainer.TrainingStatusRequest\x1a\x1f.trainer.TrainingStatusResponse"\x00\x12X\n\x11store_final_model\x12\x1f.trainer.StoreFinalModelRequest\x1a .trainer.StoreFinalModelResponse"\x00\x12U\n\x10get_latest_model\x12\x1e.trainer.GetLatestModelRequest\x1a\x1f.trainer.GetLatestModelResponse"\x00\x62\x06proto3' ) _globals = globals() @@ -36,21 +39,21 @@ _globals["_CHECKPOINTINFO"]._serialized_start = 220 _globals["_CHECKPOINTINFO"]._serialized_end = 290 _globals["_STARTTRAININGREQUEST"]._serialized_start = 293 - _globals["_STARTTRAININGREQUEST"]._serialized_end = 1248 - _globals["_STARTTRAININGRESPONSE"]._serialized_start = 1250 - _globals["_STARTTRAININGRESPONSE"]._serialized_end = 1320 - _globals["_TRAININGSTATUSREQUEST"]._serialized_start = 1322 - _globals["_TRAININGSTATUSREQUEST"]._serialized_end = 1366 - _globals["_TRAININGSTATUSRESPONSE"]._serialized_start = 1369 - _globals["_TRAININGSTATUSRESPONSE"]._serialized_end = 1791 - _globals["_STOREFINALMODELREQUEST"]._serialized_start = 1793 - _globals["_STOREFINALMODELREQUEST"]._serialized_end = 1838 - _globals["_STOREFINALMODELRESPONSE"]._serialized_start = 1840 - _globals["_STOREFINALMODELRESPONSE"]._serialized_end = 1904 - _globals["_GETLATESTMODELREQUEST"]._serialized_start = 1906 - _globals["_GETLATESTMODELREQUEST"]._serialized_end = 1950 - _globals["_GETLATESTMODELRESPONSE"]._serialized_start = 1952 - _globals["_GETLATESTMODELRESPONSE"]._serialized_end = 2017 - _globals["_TRAINERSERVER"]._serialized_start = 2020 - _globals["_TRAINERSERVER"]._serialized_end = 2477 + _globals["_STARTTRAININGREQUEST"]._serialized_end = 1338 + _globals["_STARTTRAININGRESPONSE"]._serialized_start = 1340 + _globals["_STARTTRAININGRESPONSE"]._serialized_end = 1410 + _globals["_TRAININGSTATUSREQUEST"]._serialized_start = 1412 + _globals["_TRAININGSTATUSREQUEST"]._serialized_end = 1456 + _globals["_TRAININGSTATUSRESPONSE"]._serialized_start = 1459 + _globals["_TRAININGSTATUSRESPONSE"]._serialized_end = 1881 + _globals["_STOREFINALMODELREQUEST"]._serialized_start = 1883 + _globals["_STOREFINALMODELREQUEST"]._serialized_end = 1928 + _globals["_STOREFINALMODELRESPONSE"]._serialized_start = 1930 + _globals["_STOREFINALMODELRESPONSE"]._serialized_end = 1994 + _globals["_GETLATESTMODELREQUEST"]._serialized_start = 1996 + _globals["_GETLATESTMODELREQUEST"]._serialized_end = 2040 + _globals["_GETLATESTMODELRESPONSE"]._serialized_start = 2042 + _globals["_GETLATESTMODELRESPONSE"]._serialized_end = 2107 + _globals["_TRAINERSERVER"]._serialized_start = 2110 + _globals["_TRAINERSERVER"]._serialized_end = 2567 # @@protoc_insertion_point(module_scope) diff --git a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.pyi b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.pyi index 96b20dde5..b18d0f26b 100644 --- a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.pyi +++ b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.pyi @@ -141,6 +141,10 @@ class StartTrainingRequest(google.protobuf.message.Message): ENABLE_ACCURATE_GPU_MEASUREMENTS_FIELD_NUMBER: builtins.int RECORD_LOSS_EVERY_FIELD_NUMBER: builtins.int DROP_LAST_BATCH_FIELD_NUMBER: builtins.int + GENERATIVE_FIELD_NUMBER: builtins.int + GRAD_NORM_FIELD_NUMBER: builtins.int + LORA_FIELD_NUMBER: builtins.int + KADAPTER_FIELD_NUMBER: builtins.int pipeline_id: builtins.int trigger_id: builtins.int device: builtins.str @@ -158,6 +162,10 @@ class StartTrainingRequest(google.protobuf.message.Message): enable_accurate_gpu_measurements: builtins.bool record_loss_every: builtins.int drop_last_batch: builtins.bool + generative: builtins.bool + grad_norm: builtins.float + lora: builtins.bool + kadapter: builtins.bool @property def torch_optimizers_configuration(self) -> global___JsonString: ... @property @@ -208,10 +216,16 @@ class StartTrainingRequest(google.protobuf.message.Message): enable_accurate_gpu_measurements: builtins.bool = ..., record_loss_every: builtins.int = ..., drop_last_batch: builtins.bool = ..., + generative: builtins.bool = ..., + grad_norm: builtins.float | None = ..., + lora: builtins.bool = ..., + kadapter: builtins.bool = ..., ) -> None: ... def HasField( self, field_name: typing.Literal[ + "_grad_norm", + b"_grad_norm", "_seed", b"_seed", "_tokenizer", @@ -224,6 +238,8 @@ class StartTrainingRequest(google.protobuf.message.Message): b"criterion_parameters", "data_info", b"data_info", + "grad_norm", + b"grad_norm", "grad_scaler_configuration", b"grad_scaler_configuration", "label_transformer", @@ -241,6 +257,8 @@ class StartTrainingRequest(google.protobuf.message.Message): def ClearField( self, field_name: typing.Literal[ + "_grad_norm", + b"_grad_norm", "_seed", b"_seed", "_tokenizer", @@ -263,12 +281,20 @@ class StartTrainingRequest(google.protobuf.message.Message): b"enable_accurate_gpu_measurements", "epochs_per_trigger", b"epochs_per_trigger", + "generative", + b"generative", + "grad_norm", + b"grad_norm", "grad_scaler_configuration", b"grad_scaler_configuration", + "kadapter", + b"kadapter", "label_transformer", b"label_transformer", "load_optimizer_state", b"load_optimizer_state", + "lora", + b"lora", "lr_scheduler", b"lr_scheduler", "num_prefetched_partitions", @@ -302,6 +328,10 @@ class StartTrainingRequest(google.protobuf.message.Message): ], ) -> None: ... @typing.overload + def WhichOneof( + self, oneof_group: typing.Literal["_grad_norm", b"_grad_norm"] + ) -> typing.Literal["grad_norm"] | None: ... + @typing.overload def WhichOneof(self, oneof_group: typing.Literal["_seed", b"_seed"]) -> typing.Literal["seed"] | None: ... @typing.overload def WhichOneof( diff --git a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2_grpc.py b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2_grpc.py index a1978b37a..ca5b1c31e 100644 --- a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2_grpc.py +++ b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2_grpc.py @@ -1,15 +1,13 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" +import grpc import warnings -import grpc import modyn.trainer_server.internal.grpc.generated.trainer_server_pb2 as trainer__server__pb2 -GRPC_GENERATED_VERSION = "1.63.0" +GRPC_GENERATED_VERSION = "1.67.1" GRPC_VERSION = grpc.__version__ -EXPECTED_ERROR_RELEASE = "1.65.0" -SCHEDULED_RELEASE_DATE = "June 25, 2024" _version_not_supported = False try: @@ -20,15 +18,12 @@ _version_not_supported = True if _version_not_supported: - warnings.warn( + raise RuntimeError( f"The grpc package installed is at version {GRPC_VERSION}," + f" but the generated code in trainer_server_pb2_grpc.py depends on" + f" grpcio>={GRPC_GENERATED_VERSION}." + f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}" + f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}." - + f" This warning will become an error in {EXPECTED_ERROR_RELEASE}," - + f" scheduled for release on {SCHEDULED_RELEASE_DATE}.", - RuntimeWarning, ) @@ -137,6 +132,7 @@ def add_TrainerServerServicer_to_server(servicer, server): } generic_handler = grpc.method_handlers_generic_handler("trainer.TrainerServer", rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers("trainer.TrainerServer", rpc_method_handlers) # This class is part of an EXPERIMENTAL API. diff --git a/modyn/trainer_server/internal/trainer/pytorch_trainer.py b/modyn/trainer_server/internal/trainer/pytorch_trainer.py index 47acce439..dcd684d4c 100644 --- a/modyn/trainer_server/internal/trainer/pytorch_trainer.py +++ b/modyn/trainer_server/internal/trainer/pytorch_trainer.py @@ -2,6 +2,7 @@ from __future__ import annotations import contextlib +import copy import glob import io import itertools @@ -21,10 +22,12 @@ import grpc import numpy as np import torch +import transformers from modyn.common.benchmark.stopwatch import Stopwatch from modyn.models.coreset_methods_support import CoresetSupportingModule from modyn.models.dlrm.dlrm import DLRM +from modyn.models.modular_adapters.modular_adapters import apply_kadapter, apply_lora from modyn.selector.internal.grpc.generated.selector_pb2 import ( AvailableLabelsResponse, GetAvailableLabelsRequest, @@ -85,7 +88,11 @@ def __init__( self.pipeline_id = training_info.pipeline_id self.training_id = training_info.training_id self.trigger_id = training_info.trigger_id - + self._info("Initializing Pytorch Trainer") + self.generative = training_info.generative + self._grad_norm = training_info.grad_norm # 0.5 # remember add this to training infotraining_info.grad_norm + self._lora = training_info.lora + self._kadapter = training_info.kadapter self.selector_stub = self.connect_to_selector(training_info.selector_address) if training_info.seed is not None: @@ -101,7 +108,10 @@ def __init__( self._scaler = torch.cuda.amp.GradScaler(enabled=training_info.amp, **training_info.grad_scaler_configuration) self._info("Grad scaler created.") - + if self._lora: + apply_lora(self._model.model) + if self._kadapter: + apply_kadapter(self._model.model) if training_info.use_pretrained_model: self._info("Loading model state from pretrained model.") self.load_state_if_given(training_info.pretrained_model_path, training_info.load_optimizer_state) @@ -190,6 +200,7 @@ def __init__( training_info.tokenizer, self._dataset_log_path, drop_last=self._drop_last_batch, + include_labels=not self.generative, ) # Create callbacks @@ -208,6 +219,8 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches stopw = Stopwatch() total_stopw.start("TotalTrain") + # Initialize wandb + self._model.model.train() stopw.start("OnBeginCallbacks") @@ -217,7 +230,6 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches self._info("Handled OnBegin Callbacks.") self._log["epochs"] = [] - training_loss: list[float] = [] if self.num_samples_to_pass == 0: epoch_num_generator: Iterable[int] = range(self.epochs_per_trigger) @@ -244,6 +256,7 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches trained_batches = 0 passed_batches = 0 + for epoch in epoch_num_generator: stopw = Stopwatch() # Reset timings per epoch self._log["epochs"].append({}) @@ -258,16 +271,15 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches stopw.start("IndivFetchBatch", overwrite=True) stopw.start("FetchBatch", resume=True) + for batch in self._train_dataloader: stopw.stop("FetchBatch") batch_timings.append(stopw.stop("IndivFetchBatch")) retrieve_weights_from_dataloader, weighted_optimization = self.weights_handling(len(batch)) - stopw.start("OnBatchBeginCallbacks", resume=True) for _, callback in self._callbacks.items(): callback.on_batch_begin(self._model.model, self._optimizers, batch, passed_batches) stopw.stop() - self.update_queue("TRAINING", trained_batches, trained_batches * self._batch_size, training_active=True) passed_batches += 1 with GPUMeasurement(self._measure_gpu_ops, "PreprocessBatch", self._device, stopw, resume=True): @@ -277,7 +289,6 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches # model output is a torch.FloatTensor but weights is a torch.DoubleTensor. # We need to cast to do the dot product weights = batch[3].float().to(self._device) - for _, optimizer in self._optimizers.items(): optimizer.zero_grad() @@ -297,15 +308,59 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches self._assert_data_size(self._batch_size, data, sample_ids, target) with GPUMeasurement(self._measure_gpu_ops, "Forward", self._device, stopw, resume=True): - output = self._model.model(data, sample_ids) + # Measure memory usage before forward pass + # initial_memory = torch.cuda.memory_allocated() + # print(f"Before forward pass: {initial_memory / 1e9:.2f} GB") + # torch.cuda.reset_peak_memory_stats() # Reset peak memory tracking + + if self.generative: + output = self._model.model(data) + + else: + # Non-generative task: Pass data, and optionally sample_ids if required + output = self._model.model(data, sample_ids=sample_ids) + + # Measure memory usage after forward pass + # final_memory = torch.cuda.memory_allocated() + # peak_memory = torch.cuda.max_memory_allocated() + # print(f"After forward pass: {final_memory / 1e9:.2f} GB") + # print(f"Peak memory during forward pass: {peak_memory / 1e9:.2f} GB") with GPUMeasurement(self._measure_gpu_ops, "Loss", self._device, stopw, resume=True): - if weighted_optimization: - # weighted gradient descent - assert weights is not None - loss = torch.dot(self._criterion_nored(output, target), weights / weights.sum()) + # Measure memory usage before loss computation + # initial_memory = torch.cuda.memory_allocated() + # print(f"Before loss computation: {initial_memory / 1e9:.2f} GB") + # torch.cuda.reset_peak_memory_stats() # Reset peak memory tracking + + if self.generative: + # Shift logits and labels for next-token prediction + output = output[..., :-1, :] # Output for all tokens except the last one + target = data[..., 1:, 0] # Target for all tokens except the first one + + # Use reshape instead of view to handle non-contiguous tensors safely + output = output.reshape(-1, output.size(-1)) + target = target.reshape(-1) + target[target == 50256] = -100 + # Calculate loss + if weighted_optimization: + # Weighted gradient descent + assert weights is not None + loss = torch.dot(self._criterion_nored(output, target), weights / weights.sum()) + else: + loss = self._criterion(output, target) else: - loss = self._criterion(output, target) + if weighted_optimization: + # Weighted gradient descent + assert weights is not None + loss = torch.dot(self._criterion_nored(output, target), weights / weights.sum()) + else: + loss = self._criterion(output, target) + + # Measure memory usage after loss computation + # final_memory = torch.cuda.memory_allocated() + # peak_memory = torch.cuda.max_memory_allocated() + # print(f"After loss computation: {final_memory / 1e9:.2f} GB") + # print(f"Peak memory during loss computation: {peak_memory / 1e9:.2f} GB") stopw.start("OnBatchBeforeUpdate", resume=True) for _, callback in self._callbacks.items(): @@ -316,6 +371,8 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches with GPUMeasurement(self._measure_gpu_ops, "Backward", self._device, stopw, resume=True): self._scaler.scale(loss).backward() + if self._grad_norm is not None: + torch.nn.utils.clip_grad_norm_(self._model.model.parameters(), max_norm=self._grad_norm) with GPUMeasurement(self._measure_gpu_ops, "OptimizerStep", self._device, stopw, resume=True): for _, optimizer in self._optimizers.items(): @@ -334,6 +391,14 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches if self._record_loss_every > 0 and trained_batches % self._record_loss_every == 0: training_loss.append(loss.item()) + print(loss.item()) + # Log loss and batch number + log_file = self._checkpoint_path / "training_log.txt" + with ( + open(log_file, "a") as f # pylint: disable=unspecified-encoding + ): # 'a' mode appends if the file exists, else creates it + f.write(f"{trained_batches},{loss.item()}\n") + # Example: Logging training losses in a loop self._num_samples += len(sample_ids) @@ -410,6 +475,8 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches self._metadata_collector.cleanup() # save final model + print("Final checkpioint path") + print(self._final_checkpoint_path) final_checkpoint_file_name = self._final_checkpoint_path / "model_final.modyn" self.save_state(final_checkpoint_file_name) @@ -514,19 +581,20 @@ def preprocess_batch( sample_ids = sample_ids.tolist() elif isinstance(sample_ids, tuple): sample_ids = list(sample_ids) - assert isinstance(sample_ids, list), "Cannot parse result from DataLoader" stopw.stop("PreprocSampleIDs") - - stopw.start("LabelTransform", resume=True) - if self._label_transformer_function is not None: - target = self._label_transformer_function(batch[2]) + if self.generative: + target = None else: - target = batch[2] - stopw.stop("LabelTransform") + stopw.start("LabelTransform", resume=True) + if self._label_transformer_function is not None: + target = self._label_transformer_function(batch[2]) + else: + target = batch[2] + stopw.stop("LabelTransform") - with GPUMeasurement(self._measure_gpu_ops, "MoveLabelToGPU", self._device, stopw, resume=True): - target = target.to(self._device) + with GPUMeasurement(self._measure_gpu_ops, "MoveLabelToGPU", self._device, stopw, resume=True): + target = target.to(self._device) with GPUMeasurement(self._measure_gpu_ops, "MoveDataToGPU", self._device, stopw, resume=True): data: torch.Tensor | dict @@ -541,7 +609,6 @@ def preprocess_batch( "The format of the data provided is not supported in modyn. " "Please use either torch tensors or dict[str, torch.Tensor]" ) - return sample_ids, target, data def downsample_batch( @@ -573,9 +640,7 @@ def downsample_batch( big_batch_output = self._model.model(data) if self._downsampler.forward_required else torch.Tensor() embeddings = self.get_embeddings_if_recorded() self._downsampler.inform_samples(sample_ids, data, big_batch_output, target, embeddings) - self.end_embedding_recorder_if_needed() - # TODO(#218) Persist information on the sample IDs/weights when downsampling is performed selected_indexes, weights = self._downsampler.select_points() selected_data, selected_target = get_tensors_subset(selected_indexes, data, target, sample_ids) @@ -705,7 +770,7 @@ def save_state(self, destination: pathlib.Path | io.BytesIO, iteration: int | No if iteration is not None: dict_to_save["iteration"] = iteration - + print(destination) torch.save(dict_to_save, destination) def load_state_if_given(self, path: pathlib.Path | None, load_optimizer_state: bool = False) -> None: @@ -793,6 +858,11 @@ def _setup_optimizers(self, training_info: TrainingInfo) -> None: optimizer_func = getattr(apex.optimizers, optimizer_config["algorithm"]) else: raise ValueError("Apex Optimizer defined, but apex is not available in the system") + elif optimizer_config["source"] == "HuggingFace": + optimizer_func = getattr(transformers, optimizer_config["algorithm"]) + elif optimizer_config["source"] == "Custom": + optimizer_module = dynamic_module_import("modyn.trainer_server.custom_optimizers") + optimizer_func = getattr(optimizer_module, optimizer_config["algorithm"]) else: raise ValueError( f"Unsupported optimizer from {optimizer_config['source']}. PyTorch and APEX are supported" @@ -800,10 +870,38 @@ def _setup_optimizers(self, training_info: TrainingInfo) -> None: optimizer_config_list = [] for param_group in optimizer_config["param_groups"]: module = param_group["module"] - param_group["config"]["params"] = eval( # pylint: disable=eval-used - f"self._model.{module}.parameters()" - ) - optimizer_config_list.append(param_group["config"]) + + if optimizer_config["algorithm"] == "Adafactor": # Check if optimizer is Adafactor + # Debug: Print the type of self._model + no_decay = ["bias", "LayerNorm.weight"] + + # Create separate parameter group dictionaries + param_group_no_decay = copy.deepcopy(param_group["config"]) + param_group_decay = copy.deepcopy(param_group["config"]) + + param_group_decay["params"] = [ + p + for n, p in eval(f"self._model.{module}.named_parameters()") # pylint: disable=eval-used + if not any(m in n for m in no_decay) + ] + param_group_decay["weight_decay"] = 0.01 + optimizer_config_list.append(param_group_decay) + + param_group_no_decay["params"] = [ + p + for n, p in eval(f"self._model.{module}.named_parameters()") # pylint: disable=eval-used + if any(m in n for m in no_decay) + ] + + param_group_no_decay["weight_decay"] = 0.0 + optimizer_config_list.append(param_group_no_decay) + + else: + param_group["config"]["params"] = eval( # pylint: disable=eval-used + f"self._model.{module}.parameters()" + ) + + optimizer_config_list.append(param_group["config"]) self._optimizers[name] = optimizer_func(optimizer_config_list) def _update_lr_config_dict(self, lr_scheduler_config: dict[str, Any]) -> dict[str, Any]: @@ -898,30 +996,31 @@ def _iterate_dataloader_and_compute_scores( Args: dataloader: torch.dataloader to get the data previous_batch_number: The batch number returned from the last call to this method. Useful when this - function is called several times to keep track of previous invocations (ex label by label dataloader). We - need to have a total to correctly update the queue and show the progress in the supervisor counter. - previous_number_of_samples: number of samples processed before calling this function. See above for the use. + function is called several times to keep track of previous invocations (e.g., label-by-label dataloader). + previous_number_of_samples: Number of samples processed before calling this function. See above for usage. Returns: Updated number of batches and samples """ number_of_samples = previous_number_of_samples batch_number = previous_batch_number + for batch in dataloader: self.update_queue("DOWNSAMPLING", batch_number, number_of_samples, training_active=False) batch_number += 1 sample_ids, target, data = self.preprocess_batch(batch) + # Handle cases where target is None for generative tasks + if self.generative and target is None: + target = torch.Tensor() number_of_samples += len(sample_ids) - no_grad_mgr = torch.no_grad() if isinstance(self._model, DLRM) else torch.inference_mode() context_manager = contextlib.nullcontext() if self._downsampler.requires_grad else no_grad_mgr with context_manager: with torch.autocast(self._device_type, enabled=self._amp): - # compute the scores and accumulate them model_output = self._model.model(data) if self._downsampler.forward_required else torch.Tensor() embeddings = self.get_embeddings_if_recorded() + # Inform the downsampler self._downsampler.inform_samples(sample_ids, data, model_output, target, embeddings) - return batch_number, number_of_samples # ---------------------------------------------------- Logging --------------------------------------------------- # @@ -959,16 +1058,20 @@ def _load_dataset_log(self) -> None: def _assert_data_size( expected_size: int, data: torch.Tensor | dict[Any, torch.Tensor], sample_ids: list, target: torch.Tensor ) -> None: + def _get_tensor_size_in_gb(tensor: torch.Tensor) -> float: + """Calculate the size of a tensor in GB.""" + return tensor.element_size() * tensor.nelement() / (1024**3) + assert ( all(tensor.shape[0] == expected_size for tensor in data.values()) if isinstance(data, dict) else data.shape[0] == expected_size ), ( - f"expected size: {expected_size} actual size: " + f"expected size: {expected_size}, actual size: " + f"{data.shape[0] if isinstance(data, torch.Tensor) else 'n/a'}" ) - assert len(sample_ids) == expected_size, f"expected size: {expected_size} actual size: {len(sample_ids)}" - assert target.shape[0] == expected_size, f"expected size: {expected_size} actual size: {target.shape[0]}" + assert len(sample_ids) == expected_size, f"expected size: {expected_size}, actual size: {len(sample_ids)}" + assert target.shape[0] == expected_size, f"expected size: {expected_size}, actual size: {target.shape[0]}" def _assert_training_size(self, epoch: int, trained_batches: int) -> None: if self._lr_scheduler is not None: diff --git a/modyn/trainer_server/internal/utils/training_info.py b/modyn/trainer_server/internal/utils/training_info.py index 07a246b35..60ed06caf 100644 --- a/modyn/trainer_server/internal/utils/training_info.py +++ b/modyn/trainer_server/internal/utils/training_info.py @@ -31,7 +31,8 @@ def __init__( self.training_id = training_id self.num_prefetched_partitions = request.num_prefetched_partitions self.parallel_prefetch_requests = request.parallel_prefetch_requests - + self.selector_address = selector_address + self.storage_address = storage_address self.dataset_id = request.data_info.dataset_id self.num_dataloaders = request.data_info.num_dataloaders self.epochs_per_trigger = request.epochs_per_trigger @@ -54,10 +55,9 @@ def __init__( self.load_optimizer_state = request.load_optimizer_state self.pretrained_model_path = pretrained_model_path self.log_file_path = log_file_path - self.shuffle = request.shuffle self.enable_accurate_gpu_measurements = request.enable_accurate_gpu_measurements - + self.generative = request.generative assert ( self.pretrained_model_path or not self.use_pretrained_model ), "Inconsistent pretrained model configuration" @@ -80,5 +80,7 @@ def __init__( self.seed: int | None = request.seed if request.HasField("seed") else None self.tokenizer: str | None = request.tokenizer.value if request.HasField("tokenizer") else None - + self.grad_norm: float | None = request.grad_norm if request.HasField("grad_norm") else None + self.lora = request.lora + self.kadapter = request.kadapter self.offline_dataset_path = offline_dataset_path