Skip to content

Commit

Permalink
fix imports correct type with learning rate
Browse files Browse the repository at this point in the history
  • Loading branch information
Ioanna Nika committed Oct 20, 2024
1 parent b2e3c18 commit 2280054
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions src/goViral/train_nucleotide_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import torch.optim as optim
from pytorch_lightning.loggers import WandbLogger
import pytorch_lightning as pl
from vqa.data.datasets.SimulatedAmpliconSiameseReads import SiameseReads
from vqa.lightning.TransformerBinaryNetTrainer import TransformerBinaryNetTrainer
from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from transformers import TrainingArguments, AutoModelForSequenceClassification, AutoModelForMaskedLM, AutoModel
Expand Down Expand Up @@ -45,41 +43,44 @@ def main():
param.requires_grad = True

if args.pb:
print("Pac-Bio reads")
# pacbio-hifi training set
data = AmpliconReads(input_path='/tudelft.net/staff-umbrella/ViralQuasispecies/inika/Read_simulators/data/tuples_pacbio_sars_cov_2_rev_compl_more/dataset/samples.tsv')
data = AmpliconReads(input_path='/tudelft.net/staff-umbrella/ViralQuasispecies/inika/Read_simulators/data/tuples_pacbio_sars_cov_2_rev_compl/dataset/samples.tsv')
else:
print("ONT reads")
# ONT training set
data = AmpliconReads(input_path ='/tudelft.net/staff-umbrella/ViralQuasispecies/inika/Read_simulators/data/tuples_ONT_sars_cov_2_rev_compl/dataset/samples.tsv')

train_count = int(len(data)*0.8)
val_count = int(len(data)*0.1)
test_count = len(data) - train_count - val_count
train_count = int(len(data)*0.9)
val_count = len(data) - train_count
test_count = 0 #len(data) - train_count - val_count

train_data, val_data, test_data = random_split(data, [train_count, val_count, test_count])


train_datal = DataLoader(train_data, batch_size=batch, shuffle=True, pin_memory=True, num_workers=4, prefetch_factor=1)
val_datal = DataLoader(val_data, batch_size=batch , shuffle=False, pin_memory=True, num_workers=4, prefetch_factor=1)
test_datal = DataLoader(test_data, batch_size=batch, shuffle=False, pin_memory=True, num_workers=4, prefetch_factor=1)

optimizer = optim.Adam(peft_model.parameters(), lr=7e-2)
optimizer = optim.Adam(peft_model.parameters(), lr=7e-3)

os.environ["WANDB_DIR"] = "/tmp"
os.environ["WANDB_START_METHOD"]="thread"
wandb.init(project="NT_binary")
wandb.init(project="New_nt_binary")
wandb_logger = WandbLogger()
# log loss per epoch
wandb_logger.watch(peft_model)

early_stop_callback = EarlyStopping(monitor="val_acc", patience=3, verbose=False, mode="max")

binary_transformer = TransformerBinaryNetTrainer(peft_model, train_datal, val_datal, test_datal,optimizer, batch, device = device, checkpoint_dir = check_point_dir, max_length=max_length, outdir = None)
binary_transformer = TransformerBinaryNetTrainer(peft_model, train_datal, val_datal, test_datal, optimizer, batch, device = device, checkpoint_dir = check_point_dir, max_length=max_length, outdir = None)

trainer = pl.Trainer(max_epochs=15, logger=wandb_logger, accumulate_grad_batches=10, strategy='ddp_find_unused_parameters_true', callbacks=[early_stop_callback], devices=n_devices, accelerator="gpu", enable_progress_bar=False)
trainer = pl.Trainer(max_epochs=50, logger=wandb_logger, accumulate_grad_batches=10, strategy='ddp_find_unused_parameters_true', callbacks=[early_stop_callback], devices=n_devices, accelerator="gpu", enable_progress_bar=False)
trainer.fit(binary_transformer)

wandb_logger.experiment.unwatch(peft_model)

trainer.test()
# trainer.test()

if __name__ == "__main__":
main()

0 comments on commit 2280054

Please sign in to comment.