Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Text generation #647

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
4df0bba
Copied all my changes to this repository
Jan 9, 2025
91a7054
fixed a bug
Jan 10, 2025
f10f24d
Format changes
Jan 13, 2025
eff5a2f
Format changes
Jan 13, 2025
174bca2
Format changes
Jan 13, 2025
6147697
Format changes
Jan 13, 2025
6dd260b
Format changes
Jan 13, 2025
f37bcfd
Format changes
Jan 13, 2025
b57eb67
Format changes
Jan 13, 2025
c64710a
Format changes
Jan 13, 2025
4cb4d34
Format changes
Jan 13, 2025
dd66329
Format changes
Jan 13, 2025
51222e6
Format changes
Jan 14, 2025
915e67f
Format changes
Jan 14, 2025
f91ae61
Format changes
Jan 14, 2025
b5c1a6d
Format changes
Jan 14, 2025
e8ca218
Format changes
Jan 14, 2025
082abc2
Format changes
Jan 14, 2025
80e0a6d
changed the code acccording to feedback
Jan 21, 2025
7c65830
changed the code acccording to feedback
Jan 21, 2025
61ef6b8
Format changes
sjohn4 Jan 21, 2025
a3eb57a
Format changes
sjohn4 Jan 21, 2025
1b1d9e6
Format changes
sjohn4 Jan 21, 2025
225f21c
Format changes
sjohn4 Jan 22, 2025
e9b3093
lets keep the docformatter happy
sjohn4 Jan 22, 2025
923f627
lets keep the docformatter happy
sjohn4 Jan 22, 2025
20c33bf
lets keep the docformatter happy
sjohn4 Jan 22, 2025
8a64e64
Fixed some errors and added some tests
sjohn4 Jan 22, 2025
266b3a8
Fixed some errors and added some tests
sjohn4 Jan 22, 2025
91e92f9
Fixed some errors and added some tests
sjohn4 Jan 22, 2025
900504a
Fixed some errors and added some tests
sjohn4 Jan 22, 2025
789ac93
Fixed some errors and added some tests
sjohn4 Jan 22, 2025
be606b4
Fixed some errors and added some tests
sjohn4 Jan 22, 2025
ed95f67
Fixed some errors and added some tests
sjohn4 Jan 22, 2025
63e9cb5
Fixed some errors and added some tests
sjohn4 Jan 22, 2025
e296816
Fixed some errors and added some tests
sjohn4 Jan 22, 2025
18f4896
Fixed format
sjohn4 Jan 22, 2025
401cf82
Fixed format
sjohn4 Jan 22, 2025
2f5bcaf
I think everything should trullywork now
sjohn4 Jan 22, 2025
9e3a9e8
I think everything should trullywork now
sjohn4 Jan 22, 2025
ed4dcb4
This is not gonna pass any test but I want my changes today in the ot…
sjohn4 Jan 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions modyn/common/grpc/grpc_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def prepare_start_training_request(
enable_accurate_gpu_measurements=training_config.enable_accurate_gpu_measurements,
record_loss_every=training_config.record_loss_every,
drop_last_batch=training_config.drop_last_batch,
generative=training_config.generative,
)

def start_training(
Expand Down
49 changes: 45 additions & 4 deletions modyn/config/examples/modyn_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ storage:
filesystem_wrapper_type: "LocalFilesystemWrapper",
file_wrapper_type: "SingleSampleFileWrapper",
file_wrapper_config:
{ file_extension: ".png", label_file_extension: ".label" },
{ file_extension: ".png", label_file_extension: ".label",has_labels: true },
ignore_last_timestamp: false,
file_watcher_interval: 5,
selector_batch_size: 128,

},
# ----------------------------------- CRITEO ----------------------------------- #
{
Expand All @@ -41,10 +42,12 @@ storage:
record_size: 160,
label_size: 4,
file_extension: ".bin",
has_labels: true,
},
ignore_last_timestamp: false,
file_watcher_interval: 5,
selector_batch_size: 2000000,

},
# ---------------------------------- YEARBOOK ---------------------------------- #
{
Expand All @@ -60,10 +63,12 @@ storage:
record_size: 12292,
label_size: 4,
file_extension: ".bin",
has_labels: true,
},
ignore_last_timestamp: false,
file_watcher_interval: 5,
selector_batch_size: 256,

},
{
name: "yearbook_train",
Expand All @@ -78,6 +83,7 @@ storage:
record_size: 12292,
label_size: 4,
file_extension: ".bin",
has_labels: true,
},
ignore_last_timestamp: false,
file_watcher_interval: 5,
Expand All @@ -96,6 +102,8 @@ storage:
record_size: 12292,
label_size: 4,
file_extension: ".bin",
has_labels: true,

},
ignore_last_timestamp: false,
file_watcher_interval: 5,
Expand All @@ -110,7 +118,7 @@ storage:
filesystem_wrapper_type: "LocalFilesystemWrapper",
file_wrapper_type: "SingleSampleFileWrapper",
file_wrapper_config:
{ file_extension: ".png", label_file_extension: ".label" },
{ file_extension: ".png", label_file_extension: ".label",has_labels: true, },
ignore_last_timestamp: false,
file_watcher_interval: 5,
selector_batch_size: 1024,
Expand All @@ -127,6 +135,7 @@ storage:
file_extension: ".csv",
separator: "\t", #tsv best option here since headlines contain commas and semicolons
label_index: 1,
has_labels: true,
},
ignore_last_timestamp: false,
file_watcher_interval: 5,
Expand All @@ -144,6 +153,7 @@ storage:
file_extension: ".csv",
separator: "\t", #tsv best option here since headlines contain commas and semicolons
label_index: 1,
has_labels: true,
},
ignore_last_timestamp: false,
file_watcher_interval: 5,
Expand All @@ -160,10 +170,12 @@ storage:
file_extension: ".csv",
separator: "\t", #tsv best option here since headlines contain commas and semicolons
label_index: 1,
has_labels: true,
},
ignore_last_timestamp: false,
file_watcher_interval: 5,
selector_batch_size: 4096,

},
# ------------------------------------ ARXIV ----------------------------------- #
{
Expand All @@ -177,10 +189,12 @@ storage:
file_extension: ".csv",
separator: "\t", #tsv best option here since sentences contain commas and semicolons
label_index: 1,
has_labels: true
},
ignore_last_timestamp: false,
file_watcher_interval: 5,
selector_batch_size: 4096,

},
{
name: "arxiv_test",
Expand All @@ -193,10 +207,12 @@ storage:
file_extension: ".csv",
separator: "\t", #tsv best option here since sentences contain commas and semicolons
label_index: 1,
has_labels: true,
},
ignore_last_timestamp: false,
file_watcher_interval: 5,
selector_batch_size: 4096,

},
# -------------------------------- ARXIV KAGGLE -------------------------------- #
{
Expand All @@ -210,10 +226,12 @@ storage:
file_extension: ".csv",
separator: "\t", #tsv best option here since sentences contain commas and semicolons
label_index: 1,
has_labels: true,
},
ignore_last_timestamp: false,
file_watcher_interval: 5,
selector_batch_size: 4096,

},
{
name: "arxiv_kaggle_test",
Expand All @@ -226,10 +244,12 @@ storage:
file_extension: ".csv",
separator: "\t", #tsv best option here since sentences contain commas and semicolons
label_index: 1,
has_labels: true,
},
ignore_last_timestamp: false,
file_watcher_interval: 5,
selector_batch_size: 4096,

},
# ------------------------------------ CLOC ------------------------------------ #
{
Expand All @@ -240,10 +260,31 @@ storage:
filesystem_wrapper_type: "LocalFilesystemWrapper",
file_wrapper_type: "SingleSampleFileWrapper",
file_wrapper_config:
{ file_extension: ".jpg", label_file_extension: ".label" },
{ file_extension: ".jpg", label_file_extension: ".label",
has_labels: true,
},
ignore_last_timestamp: false,
file_watcher_interval: 999999999,
selector_batch_size: 100000,

},
# ------------------------------------ Wikipedia ------------------------------------ #
{
name: "Wikipedia",
description: "Wikipedia text dump from 2021",
version: "0.0.1",
base_path: "/datasets/readablewiki",
filesystem_wrapper_type: "LocalFilesystemWrapper",
file_wrapper_type: "CsvFileWrapper",
file_wrapper_config: {
file_extension: ".csv",
separator: "\t", #tsv best option here since sentences contain commas and semicolons
has_labels: false,
},
ignore_last_timestamp: false,
file_watcher_interval: 5,
selector_batch_size: 4096,

},
]
database:
Expand Down Expand Up @@ -278,7 +319,7 @@ selector:
local_storage_directory: "/tmp/local_storage"
local_storage_max_samples_in_file: 1000000
cleanup_storage_directories_after_shutdown: true
ignore_existing_trigger_samples: false
ignore_existing_trigger_samples: true
sjohn4 marked this conversation as resolved.
Show resolved Hide resolved

trainer_server:
hostname: "trainer_server"
Expand Down
9 changes: 8 additions & 1 deletion modyn/config/schema/pipeline/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from modyn.config.schema.base_model import ModynBaseModel

OptimizerSource = Literal["PyTorch", "APEX"]
OptimizerSource = Literal["PyTorch", "APEX", "HuggingFace"]


class OptimizerParamGroup(ModynBaseModel):
Expand Down Expand Up @@ -119,6 +119,13 @@ class TrainingConfig(ModynBaseModel):
"we start with random weights. If initial_model is 'pretrained', cannot be False."
)
)
generative: bool = Field(
False,
description=(
"If True then, then the training pipeline goes into the generative branch, data is sampled without expecting labels."
),
)

seed: int | None = Field(
None,
description=(
Expand Down
6 changes: 5 additions & 1 deletion modyn/config/schema/system/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@ class DatasetCsvFileWrapperConfig(_DatasetBaseFileWrapperConfig):
quoted_linebreaks: bool = Field(True, description="Whether linebreaks are quoted in CSV files.")

label_index: int = Field(
-1,
description=(
"Column index of the label. For columns 'width, 'height, 'age', 'label' you should set label_index to 3."
)
),
)
ignore_first_line: bool = Field(
False, description="If the first line is the table header, you can skip it setting this parameter to True."
Expand All @@ -73,6 +74,7 @@ class DatasetCsvFileWrapperConfig(_DatasetBaseFileWrapperConfig):
"rows are the same size and that the 'label' column exists."
),
)
has_labels: bool = Field(True, description=("Describes wether the dataset contains a label field or not"))


class DatasetBinaryFileWrapperConfig(_DatasetBaseFileWrapperConfig):
Expand All @@ -83,12 +85,14 @@ class DatasetBinaryFileWrapperConfig(_DatasetBaseFileWrapperConfig):
)
record_size: int = Field(description="The size of each full record in bytes (label + features).")
label_size: int = Field(description="The size of the label field in bytes for a binary file wrapper.")
has_labels: bool = Field(True, description=("Describes wether the dataset contains a label field or not"))


class DatasetPngFileWrapperConfig(_DatasetBaseFileWrapperConfig):
"""Represents a png dataset file used by modyn."""

label_file_extension: str = Field(description="The label file extension of the dataset", pattern=r"^\..*$")
has_labels: bool = Field(True, description=("Describes wether the dataset contains a label field or not"))


DatasetFileWrapperConfig = Union[ # noqa: UP007
Expand Down
1 change: 1 addition & 0 deletions modyn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .dlrm.dlrm import DLRM # noqa: F401
from .dummy.dummy import Dummy # noqa: F401
from .fmownet.fmownet import FmowNet # noqa: F401
from .gpt2.gpt2 import Gpt2 # noqa: F401
from .resnet18.resnet18 import ResNet18 # noqa: F401
from .resnet50.resnet50 import ResNet50 # noqa: F401
from .resnet152.resnet152 import ResNet152 # noqa: F401
Expand Down
5 changes: 5 additions & 0 deletions modyn/models/gpt2/_init_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import os

files = os.listdir(os.path.dirname(__file__))
files.remove("__init__.py")
__all__ = [f[:-3] for f in files if f.endswith(".py")]
56 changes: 56 additions & 0 deletions modyn/models/gpt2/gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Any

import torch
from torch import nn
from transformers import GPT2LMHeadModel

from modyn.models.coreset_methods_support import CoresetSupportingModule


class Gpt2:
# pylint: disable-next=unused-argument
def __init__(self, hparams: Any, device: str, amp: bool) -> None:
self.model = Gpt2Modyn(hparams)
self.model.to(device)


"""
Adapted from an example implementation of a GPT-2 model.
This implementation uses the GPT-2 tokenizer from Hugging Face's Transformers library:
https://huggingface.co/docs/transformers/model_doc/gpt2
"""


class Gpt2Modyn(CoresetSupportingModule):
def __init__(self, hparams: Any) -> None:
super().__init__()

self.model = GPT2LMHeadModel.from_pretrained("gpt2-large") # hparams.model_name_or_path

def forward(self, data: torch.Tensor, labels: torch.Tensor = None) -> torch.Tensor:
"""Forward method for text generation or language modeling tasks.

Args:
- data (torch.Tensor): Tensor of shape (batch_size, seq_len, 2), where
the last dimension contains token IDs and attention masks.
- labels (torch.Tensor, optional): Tensor of labels for language modeling tasks.

Returns:
- output: The output logits or loss from the GPT-2 model.
"""
# Split input into token IDs and attention masks
input_ids = data[:, :, 0]
attention_mask = data[:, :, 1]
# Forward pass through GPT-2

output = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

return output.logits

def get_last_layer(self) -> nn.Module:
"""Retrieve the last layer (lm_head) of the model.

Returns:
The final linear layer of the GPT-2 model.
"""
return self.model.lm_head
4 changes: 3 additions & 1 deletion modyn/models/tokenizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Bert Tokenizer for NLP tasks."""
"""Tokenizer for NLP tasks."""

import os

from .distill_bert_tokenizer import DistilBertTokenizerTransform # noqa: F401
from .gpt2_tokenizer import GPT2TokenizerTransform # noqa: F401
from .hf_tokenizer import HFTokenizerTransform # noqa: F401

files = os.listdir(os.path.dirname(__file__))
files.remove("__init__.py")
Expand Down
30 changes: 10 additions & 20 deletions modyn/models/tokenizers/distill_bert_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,14 @@
import torch
from transformers import DistilBertTokenizer

from .hf_tokenizer import HFTokenizerTransform

class DistilBertTokenizerTransform:
"""
Adapted from WildTime's initialize_distilbert_transform
Here you can find the original implementation:
https://github.com/huaxiuyao/Wild-Time/blob/main/wildtime/data/utils.py
"""

def __init__(self, max_token_length: int = 300) -> None:
self.max_token_length = max_token_length
self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

def __call__(self, sample: str) -> torch.Tensor:
# make the class Callable to use it as Torch Transform
tokens = self.tokenizer(
sample, padding="max_length", truncation=True, max_length=self.max_token_length, return_tensors="pt"
)
# create a tensor whose first dimension is the input_ids and the second is the attention_mask
data = torch.stack((tokens["input_ids"], tokens["attention_mask"]), dim=2)
data = torch.squeeze(data, dim=0) # First shape dim is always 1, since the input is just one string
return data
class DistilBertTokenizerTransform(HFTokenizerTransform):
def __init__(self, max_token_length: int = 300):
"""
Adapted from WildTime's initialize_distilbert_transform
Here you can find the original implementation:
https://github.com/huaxiuyao/Wild-Time/blob/main/wildtime/data/utils.py
"""
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
super().__init__(tokenizer, max_token_length)
17 changes: 17 additions & 0 deletions modyn/models/tokenizers/gpt2_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from transformers import GPT2Tokenizer

from .hf_tokenizer import HFTokenizerTransform


class GPT2TokenizerTransform(HFTokenizerTransform):
def __init__(self, max_token_length: int = 512):
"""Adapted from an example implementation of a GPT-2 tokenizer.

This implementation uses the GPT-2 tokenizer from Hugging Face's
Transformers library:
https://huggingface.co/docs/transformers/model_doc/gpt2
"""
tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large")
tokenizer.pad_token = tokenizer.eos_token # Set pad token to eos token to avoid padding errors
tokenizer.padding_side = "right"
super().__init__(tokenizer, max_token_length)
Loading
Loading