|
32 | 32 | from pydantic import BaseModel
|
33 | 33 |
|
34 | 34 | from bionemo.core.utils.dtypes import get_autocast_dtype
|
35 |
| -from bionemo.llm.lightning import BionemoLightningModule, PerplexityLoggingCallback |
| 35 | +from bionemo.llm.lightning import BionemoLightningModule |
36 | 36 | from bionemo.llm.model.biobert.lightning import biobert_lightning_module
|
37 | 37 | from bionemo.llm.model.lr_scheduler import WarmupAnnealDecayHoldScheduler
|
38 | 38 | from bionemo.llm.run.config_models import (
|
@@ -132,9 +132,6 @@ def setup_trainer(
|
132 | 132 | LearningRateMonitor(),
|
133 | 133 | ]
|
134 | 134 |
|
135 |
| - if training_config.include_perplexity: |
136 |
| - callbacks.append(PerplexityLoggingCallback()) |
137 |
| - |
138 | 135 | if training_config.gc_interval > 0:
|
139 | 136 | callbacks.append(
|
140 | 137 | nl_callbacks.GarbageCollectionCallback(
|
@@ -252,7 +249,11 @@ def train(
|
252 | 249 | )
|
253 | 250 |
|
254 | 251 | model: BionemoLightningModule = biobert_lightning_module(
|
255 |
| - config=bionemo_model_config, tokenizer=data.tokenizer, optimizer=optimizer |
| 252 | + config=bionemo_model_config, |
| 253 | + tokenizer=data.tokenizer, |
| 254 | + optimizer=optimizer, |
| 255 | + log_train_ppl=training_config.log_train_ppl, |
| 256 | + log_val_ppl=training_config.log_val_ppl, |
256 | 257 | )
|
257 | 258 | trainer: nl.Trainer = setup_trainer(parallel_config, training_config, nsys_config=nsys_config)
|
258 | 259 | nemo_logger: nl.NeMoLogger = nemo_logger_factory(experiment_config, wandb_config=wandb_config)
|
|
0 commit comments