Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
148 commits
Select commit Hold shift + click to select a range
642e360
feat: add callback group definition & callback ABC
PytLab May 5, 2025
1badf29
Apply isort and black reformatting
PytLab May 5, 2025
3bf3367
feat: insert callback functions of CallbackGroup
PytLab May 6, 2025
2b51e12
Apply isort and black reformatting
PytLab May 6, 2025
249dad3
chore: PR test for jiashang
liquor233 May 7, 2025
db2b15d
feat: use __init_subclass__ to cover all ModelPT subclasses
PytLab May 12, 2025
d921d64
Apply isort and black reformatting
PytLab May 12, 2025
3e32f1a
feat: Adding metadata config manager poc
May 12, 2025
e1074f6
Apply isort and black reformatting
sajup-oss May 12, 2025
d79f4f1
feat: revert test changes.
liquor233 May 13, 2025
263f7e9
fix: Updating metadata attributes
sajup-oss May 21, 2025
81cd1d9
fix: Merging changes
sajup-oss May 21, 2025
4852936
Apply isort and black reformatting
sajup-oss May 21, 2025
48d6d87
fix: Adding OneloggerCallback
sajup-oss May 22, 2025
2ba6cc5
fix: Reverting changes in examples/multimodal/speech_llm/modular_audi…
sajup-oss May 22, 2025
c908b53
fix: Merge branch 'zshao/add_callback_group' of github.com:NVIDIA/NeM…
sajup-oss May 23, 2025
bd39d8f
Apply isort and black reformatting
sajup-oss May 23, 2025
ba4e4a6
fix: update modular models and megatron GPT models
liquor233 May 26, 2025
515136c
Apply isort and black reformatting
liquor233 May 26, 2025
bc030f7
feat: add on_app_start and on_app_end
liquor233 May 26, 2025
2ed58f4
Apply isort and black reformatting
liquor233 May 26, 2025
35d2f2c
fix: Adding small test example for testing
sajup-oss May 26, 2025
ddc99fb
Apply isort and black reformatting
sajup-oss May 26, 2025
ca6ff4d
fix: Fixing review comments as discussed with Jiashang
May 26, 2025
9f11d01
Apply isort and black reformatting
sajup-oss May 26, 2025
64e0e03
fix: updating nemo code to v2
sajup-oss Jun 13, 2025
181bb3e
fix: updating code to v2
sajup-oss Jun 13, 2025
61d631c
Apply isort and black reformatting
sajup-oss Jun 13, 2025
8eb4fc6
fix: updating wandb to get info from env
sajup-oss Jun 13, 2025
2900246
fix: updating wandb to get info from env
sajup-oss Jun 13, 2025
4acbc2c
Apply isort and black reformatting
sajup-oss Jun 13, 2025
dffccfa
fix: fix som impl issue
liquor233 Jul 4, 2025
60eb727
Apply isort and black reformatting
liquor233 Jul 4, 2025
b97fbda
fix: fix issue for exp manager.
liquor233 Jul 7, 2025
5c144ed
feat: Merge branch 'zshao/add_callback_group' of https://github.com/N…
liquor233 Jul 7, 2025
041a32b
Apply isort and black reformatting
liquor233 Jul 7, 2025
b70f85b
feat: remove callback_group
liquor233 Jul 10, 2025
f473d1b
feat: fix timingtracker issue
liquor233 Jul 10, 2025
1705b19
Apply isort and black reformatting
liquor233 Jul 10, 2025
e6b4e64
feat: fix for startup callbcaks
liquor233 Jul 14, 2025
5b7bd1c
Apply isort and black reformatting
liquor233 Jul 14, 2025
c687003
feat: change to adapter
liquor233 Jul 14, 2025
42181c5
Apply isort and black reformatting
liquor233 Jul 14, 2025
f522e9c
feat: use new nv-one-logger
liquor233 Jul 16, 2025
07aaa05
feat: add on_app_end
liquor233 Jul 16, 2025
5f0f184
Apply isort and black reformatting
liquor233 Jul 16, 2025
c75373a
feat: make OneLogger configurable
liquor233 Jul 17, 2025
f5640f9
Apply isort and black reformatting
liquor233 Jul 17, 2025
06520f0
feat: remove NeMocallback import
liquor233 Jul 17, 2025
51615ac
feat: fix the enable_onelogger setting.
liquor233 Jul 17, 2025
56feca2
Apply isort and black reformatting
liquor233 Jul 17, 2025
acf0c5a
feat: clean the code.
liquor233 Jul 17, 2025
57a5b0e
feat: enable onelogger
liquor233 Jul 17, 2025
f3e7f83
Apply isort and black reformatting
liquor233 Jul 17, 2025
d2d49c3
test: Adding few unit tests
Jul 17, 2025
6350923
Apply isort and black reformatting
sajup-oss Jul 17, 2025
dafb75d
feat: tmp fix for functional testing.
liquor233 Jul 18, 2025
1d4be52
Apply isort and black reformatting
liquor233 Jul 18, 2025
bc2a9d6
fix: add on_app_end for NeMov2
liquor233 Jul 18, 2025
ef9c503
fix: typo.
liquor233 Jul 18, 2025
0c027d5
Apply isort and black reformatting
liquor233 Jul 18, 2025
7b9ea68
fix: fix the get attributes
liquor233 Jul 18, 2025
1a1e1b7
fix: moving test test_meta_info_manager.py to tests/collections/common/
Jul 18, 2025
5d03d87
fix:Merge branch 'zshao/add_callback_group' of github.com:NVIDIA/NeMo…
Jul 18, 2025
84b076e
fix: fix format issue.
liquor233 Jul 18, 2025
c1f853b
Apply isort and black reformatting
liquor233 Jul 18, 2025
304a7bd
feat: Merge remote-tracking branch 'origin/main' into zshao/add_callb…
liquor233 Jul 21, 2025
8e47ecd
fix: fix lint errors
liquor233 Jul 21, 2025
de6994d
Apply isort and black reformatting
liquor233 Jul 21, 2025
32a3371
Revert "Apply isort and black reformatting"
liquor233 Jul 21, 2025
729e020
Revert "fix: fix lint errors"
liquor233 Jul 21, 2025
a679703
fix: fix linting issues.
liquor233 Jul 21, 2025
1c0b9cf
Apply isort and black reformatting
liquor233 Jul 21, 2025
0869066
fix: fix linting issue
liquor233 Jul 21, 2025
0dca014
Apply isort and black reformatting
liquor233 Jul 21, 2025
1060ca7
fix: add copyright info
liquor233 Jul 21, 2025
4f3b901
Apply isort and black reformatting
liquor233 Jul 21, 2025
a143550
fix: small fix.
liquor233 Jul 21, 2025
0b034b8
fix: fix small issues for t5
liquor233 Jul 22, 2025
e1ffef0
fix: fix dataloader issue.
liquor233 Jul 22, 2025
87de1ee
fix: remove dataloader setting.
liquor233 Jul 22, 2025
1a0a2a6
feat: update OneLogger.
liquor233 Jul 22, 2025
6e827a7
fix: fix hydra runner.
liquor233 Jul 22, 2025
2239787
Apply isort and black reformatting
liquor233 Jul 22, 2025
8c74641
fix: start using partial config.
liquor233 Jul 23, 2025
fe0618b
Apply isort and black reformatting
liquor233 Jul 23, 2025
461885f
fix: fix the unused variables
liquor233 Jul 24, 2025
383eb6a
fix: change get_one_logger name
liquor233 Jul 24, 2025
a445401
fix: code clean up.
liquor233 Jul 24, 2025
eda1072
Apply isort and black reformatting
liquor233 Jul 24, 2025
9adcb60
fix: import more specific to avoid circular dependency. (#14306)
PeiyuanQi Jul 24, 2025
558bbde
fix: use ptl callback from ls
liquor233 Jul 25, 2025
0025f87
Apply isort and black reformatting
liquor233 Jul 25, 2025
2f485c7
feat: fix meta info manager.
liquor233 Aug 4, 2025
de3c9d8
fix: fix meta data issue.
liquor233 Aug 5, 2025
3d357a9
Apply isort and black reformatting
liquor233 Aug 5, 2025
1f739a9
fix: fix the lint issue
liquor233 Aug 6, 2025
ee04438
fix: fix the unit tests.
liquor233 Aug 6, 2025
ae63eb2
fix: fix minor metadata issue.
liquor233 Aug 6, 2025
4d509ec
Apply isort and black reformatting
liquor233 Aug 6, 2025
05b78a2
Merge branch 'main' into zshao/add_callback_group
liquor233 Aug 6, 2025
2e6dd6f
fix: fix some test issues
liquor233 Aug 6, 2025
0c736ad
fix: fix pytest issue for meta info manager
liquor233 Aug 6, 2025
d0a25ad
fix: fix lint issues for optimizers.
liquor233 Aug 6, 2025
182e68f
chore: Merge branch 'main' into zshao/add_callback_group
liquor233 Aug 6, 2025
313e49d
fix: fix pytest issues.
liquor233 Aug 7, 2025
acea1bf
Apply isort and black reformatting
liquor233 Aug 7, 2025
1c4071f
chore: Merge branch 'main' into zshao/add_callback_group
liquor233 Aug 7, 2025
ece8b51
fix: fix CICD issues.
liquor233 Aug 11, 2025
b89b6bd
fix: fix all pytests
liquor233 Aug 11, 2025
69f1080
Apply isort and black reformatting
liquor233 Aug 11, 2025
f970db3
Merge branch 'main' into zshao/add_callback_group
liquor233 Aug 11, 2025
d783893
chore: fix lint
liquor233 Aug 11, 2025
55b2539
chore: fix unused import issues.
liquor233 Aug 11, 2025
516c9a2
chore: fix CICD issues.
liquor233 Aug 11, 2025
3e3ab98
Apply isort and black reformatting
liquor233 Aug 11, 2025
b5dd037
fix: fix the CICD issues.
liquor233 Aug 13, 2025
8f07246
Apply isort and black reformatting
liquor233 Aug 13, 2025
2036123
Merge branch 'main' into zshao/add_callback_group
liquor233 Aug 14, 2025
ee6cba1
fix: fix the linting issue
liquor233 Aug 14, 2025
89ec5f2
fix: fix CICD issues.
liquor233 Aug 15, 2025
6594abe
Merge branch 'main' into zshao/add_callback_group
liquor233 Aug 15, 2025
645791e
Merge branch 'main' into zshao/add_callback_group
liquor233 Aug 18, 2025
b248254
fix: fix the circular import issue.
liquor233 Aug 18, 2025
e296668
Apply isort and black reformatting
liquor233 Aug 18, 2025
c2551b1
fix: fix some pytests.
liquor233 Aug 18, 2025
6718ece
fix: revert some change.
liquor233 Aug 18, 2025
85daa2d
fix: error handling for init onelogger
liquor233 Aug 18, 2025
6cbc033
Apply isort and black reformatting
liquor233 Aug 18, 2025
0a7d53e
chore: fix one_logger code.
liquor233 Aug 26, 2025
de408ee
Apply isort and black reformatting
liquor233 Aug 26, 2025
ad7d40a
Merge branch 'main' into zshao/add_callback_group
liquor233 Aug 27, 2025
2de6230
Merge branch 'main' into zshao/add_callback_group
liquor233 Aug 27, 2025
6624733
chore: remove unused vars.
liquor233 Aug 27, 2025
0ec6257
fix: fix CICD for nemo
liquor233 Aug 28, 2025
9a50ae8
Merge branch 'main' into zshao/add_callback_group
liquor233 Aug 28, 2025
2f09433
chore: fix NeMo CICD.
liquor233 Aug 28, 2025
88fb787
chore: renaming onelogger
liquor233 Sep 1, 2025
d8156dd
chore: fix some exception.
liquor233 Sep 1, 2025
a449cc6
Merge branch 'main' into zshao/add_callback_group
liquor233 Sep 1, 2025
951e143
chore: renaming.
liquor233 Sep 1, 2025
3eea3b2
chore: resolve some comments.
liquor233 Sep 2, 2025
129615e
chore: remove duplicate init.
liquor233 Sep 2, 2025
d7085fd
chore: resolve some github comments.
liquor233 Sep 2, 2025
09d8347
Apply isort and black reformatting
liquor233 Sep 2, 2025
a9fc88b
chore: fix the linting issue.
liquor233 Sep 2, 2025
4dc1c91
Merge branch 'main' into zshao/add_callback_group
liquor233 Sep 3, 2025
68f5caf
chore(callbacks): restore generic CallbackGroup and route telemetry v…
liquor233 Sep 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from nemo.lightning.base import NEMO_MODELS_CACHE
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
from nemo.lightning.pytorch.callbacks import PEFT, JitTransform, ModelTransform
from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup
from nemo.utils import logging
from nemo.utils.get_rank import is_global_rank_zero

Expand Down Expand Up @@ -135,6 +136,9 @@ def train(

trainer.fit(model, data)

# Track app end for NeMo v2 recipe-based applications
CallbackGroup.get_instance().on_app_end()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the matching on_app_start() called? Why not consider a function decorator to call both?


return app_state.exp_dir


Expand Down Expand Up @@ -1255,11 +1259,19 @@ def _setup(
resume_if_exists=getattr(resume, "resume_if_exists", False),
task_config=getattr(train, "__io__", None),
)

# Configure telemetry via CallbackGroup
CallbackGroup.get_instance().update_config(nemo_version='v2', trainer=trainer, data=data)

if resume is not None:
CallbackGroup.get_instance().on_load_checkpoint_start()
resume.setup(trainer, model)
CallbackGroup.get_instance().on_load_checkpoint_end()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider introducing a context manager for cases similar to this, so that the end function is always called, and we can log any exception as an error event. For Heimdall error events are required.


if optim:
CallbackGroup.get_instance().on_optimizer_init_start()
optim.connect(model)
CallbackGroup.get_instance().on_optimizer_init_end()
if tokenizer: # TODO: Improve this
_use_tokenizer(model, data, tokenizer)

Expand Down
10 changes: 10 additions & 0 deletions nemo/collections/llm/fn/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import lightning.pytorch as pl
from torch import nn
from typing_extensions import Self

Expand Down Expand Up @@ -50,6 +51,15 @@ class FNMixin:
True
"""

def __init_subclass__(cls, **kwargs):
# Add OneLogger timing hooks for LightningModule subclasses to enable telemetry tracking
if issubclass(cls, pl.LightningModule):
from nemo.lightning.pytorch.callbacks.callback_group import hook_class_init_with_callbacks

hook_class_init_with_callbacks(cls, "on_model_init_start", "on_model_init_end")

super().__init_subclass__(**kwargs)

def forall(self, func: fn.ModulePredicate, recurse: bool = False) -> bool:
"""
Evaluates a predicate for all modules in the container, optionally recursively.
Expand Down
6 changes: 6 additions & 0 deletions nemo/collections/llm/gpt/data/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def __init__(
vocab_file: Optional[str] = None,
merges_file: Optional[str] = None,
):
from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup

CallbackGroup.get_instance().on_dataloader_init_start()

super().__init__()
self.seq_length = seq_length
self.micro_batch_size = micro_batch_size
Expand Down Expand Up @@ -96,6 +100,8 @@ def __init__(
rampup_batch_size=rampup_batch_size,
)

CallbackGroup.get_instance().on_dataloader_init_end()

def setup(self, stage: str = "") -> None:
"""
Setup the data module.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

try:
ALGORITHMS = {
"eagle3": mtsp.EAGLE3_DEFAULT_CFG,
"eagle3": mtsp.EAGLE3_DEFAULT_CFG if hasattr(mtsp, "EAGLE3_DEFAULT_CFG") else None,
# more TBD
}
except UnavailableError:
Expand Down
6 changes: 6 additions & 0 deletions nemo/collections/llm/t5/data/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def __init__(
persistent_workers: bool = False,
create_attention_mask: bool = False,
):
from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup

CallbackGroup.get_instance().on_dataloader_init_start()

super().__init__()
self.seq_length = seq_length
self.seq_length_dec = seq_length_dec
Expand All @@ -72,6 +76,8 @@ def __init__(
rampup_batch_size=rampup_batch_size,
)

CallbackGroup.get_instance().on_dataloader_init_end()

def setup(self, stage: str = "") -> None:
"""Setup the datasets"""
self._train_ds = _MockT5Dataset(
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/speechlm2/parts/optim_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def freeze_and_subset(

>>> model = MyModel()
... # freeze all LLM parameters in "model.llm"
... params = freeze_and_subset(model.named_parameters(), ['^llm\..+$'])
... params = freeze_and_subset(model.named_parameters(), [r'^llm\\.\\..+$'])
... optimizer = torch.optim.AdamW(params, lr=1e-3)

"""
Expand Down
6 changes: 6 additions & 0 deletions nemo/collections/vlm/neva/data/preloaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,10 @@ def __init__(
num_image_embeddings_per_tile: int = 576,
seed: int = 1234,
) -> None:
from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup

CallbackGroup.get_instance().on_dataloader_init_start()

super().__init__()
if not isinstance(paths, (list, tuple)):
paths = [paths]
Expand Down Expand Up @@ -576,6 +580,8 @@ def custom_on_megatron_step_start(self, step):
dataloader_type="cyclic",
)

CallbackGroup.get_instance().on_dataloader_init_end()

def setup(self, stage: str = "") -> None:
assert len(self.paths) == 1, "not yet support blend dataset in Neva 2.0!"
self._train_ds = NevaDataset(
Expand Down
37 changes: 36 additions & 1 deletion nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from nemo.core.classes.common import Model
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
from nemo.core.optim import McoreDistributedOptimizer, prepare_lr_scheduler
from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup
from nemo.utils import logging, model_utils
from nemo.utils.app_state import AppState
from nemo.utils.debug_hook import register_debug_hooks
Expand Down Expand Up @@ -86,6 +87,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
f"trainer constructor argument must be either None or lightning.pytorch.Trainer. "
f"But got {type(trainer)} instead."
)

# Track model init start
CallbackGroup.get_instance().on_model_init_start()

super().__init__()

"""
Expand Down Expand Up @@ -152,6 +157,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
if torch.cuda.is_available() and torch.cuda.current_device() is not None:
app_state.device_id = torch.cuda.current_device()

CallbackGroup.get_instance().on_model_init_end()
CallbackGroup.get_instance().on_dataloader_init_start()
if self._cfg is not None and not self._is_model_being_restored():
# Setup data loaders now (default) or defer setup to `self.setup()`
# if `defer_setup` is set in the config of the corresponding dataloader.
Expand Down Expand Up @@ -198,6 +205,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
f"Test config : \n{OmegaConf.to_yaml(self._cfg.test_ds)}"
)

CallbackGroup.get_instance().on_dataloader_init_end()

# Create list of lists for val and test outputs to support multiple dataloaders
# Initialize an empty list as sometimes self._validation_dl can be None at this stage
self._validation_step_outputs = None
Expand Down Expand Up @@ -469,6 +478,8 @@ def restore_from(
Returns:
An instance of type cls or its underlying config (if return_config is set).
"""
# Notify OneLogger of checkpoint loading start for telemetry tracking
CallbackGroup.get_instance().on_load_checkpoint_start()

if save_restore_connector is None:
save_restore_connector = SaveRestoreConnector()
Expand Down Expand Up @@ -502,6 +513,10 @@ def restore_from(
)
if isinstance(instance, ModelPT):
instance._save_restore_connector = save_restore_connector

# Notify OneLogger of checkpoint loading completion for telemetry tracking
CallbackGroup.get_instance().on_load_checkpoint_end()

return instance

@classmethod
Expand All @@ -518,6 +533,9 @@ def load_from_checkpoint(
Loads ModelPT from checkpoint, with some maintenance of restoration.
For documentation, please refer to LightningModule.load_from_checkpoint() documentation.
"""
# Notify OneLogger of checkpoint loading start for telemetry tracking
CallbackGroup.get_instance().on_load_checkpoint_start()

checkpoint = None
try:
cls._set_model_restore_state(is_being_restored=True)
Expand All @@ -533,6 +551,10 @@ def load_from_checkpoint(

finally:
cls._set_model_restore_state(is_being_restored=False)

# Notify OneLogger of checkpoint loading completion for telemetry tracking
CallbackGroup.get_instance().on_load_checkpoint_end()

return checkpoint

@abstractmethod
Expand Down Expand Up @@ -729,7 +751,8 @@ def setup_optimization(

if optimizer_cls is None:
# Try to get optimizer name for dynamic resolution, defaulting to Adam
optimizer_name = optim_config.get('name', 'adam')
# Use or instead of default as None will also results in default value not used.
optimizer_name = optim_config.get('name') or 'adam'
else:
if inspect.isclass(optimizer_cls):
optimizer_name = optimizer_cls.__name__.lower()
Expand Down Expand Up @@ -890,8 +913,12 @@ def configure_optimizers(self):
"""
Configure the optimizer and scheduler.
"""
# Track optimizer init start
CallbackGroup.get_instance().on_optimizer_init_start()
self.setup_optimization()

CallbackGroup.get_instance().on_optimizer_init_end()

if self._scheduler is None:
return self._optimizer
else:
Expand Down Expand Up @@ -955,6 +982,9 @@ def setup(self, stage: Optional[str] = None):
if no_test_dataloader and test_deferred_setup:
self.setup_multiple_test_data(test_data_config=self._cfg.test_ds)

if stage == 'fit':
CallbackGroup.get_instance().update_config(nemo_version='v1', trainer=self._trainer)

def train_dataloader(self):
"""
Get the training dataloader.
Expand Down Expand Up @@ -1344,6 +1374,8 @@ def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: st
f"Found : {[args[idx] for idx, arg_present in enumerate(arg_matches) if arg_present]}"
)

CallbackGroup.get_instance().on_load_checkpoint_start()

if 'init_from_nemo_model' in cfg and cfg.init_from_nemo_model is not None:
with open_dict(cfg):
if isinstance(cfg.init_from_nemo_model, str):
Expand Down Expand Up @@ -1460,6 +1492,9 @@ def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: st
else:
raise TypeError("Invalid type: init_from_ptl_ckpt is not a string or a dict!")

# Track load checkpoint end
CallbackGroup.get_instance().on_load_checkpoint_end()

def teardown(self, stage: str):
"""
Called at the end of fit and test.
Expand Down
7 changes: 5 additions & 2 deletions nemo/core/config/hydra_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import functools
import os
import sys
Expand Down Expand Up @@ -103,7 +102,7 @@ def wrapper(cfg_passthrough: Optional[DictConfig] = None) -> Any:
# Make sure the path is not set - as this will disable validation scheme.
if path != '':
sys.stderr.write(
f"ERROR Cannot set config file path using `--config-name` when "
"ERROR Cannot set config file path using `--config-name` when "
"using schema. Please set path using `--config-path` and file name using "
"`--config-name` separately.\n"
)
Expand Down Expand Up @@ -133,6 +132,10 @@ def parse_args(self, args=None, namespace=None):
config_path=config_path,
config_name=config_name,
)
# Import here to avoid circular import
from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup

CallbackGroup.get_instance().on_app_end()

return wrapper

Expand Down
2 changes: 1 addition & 1 deletion nemo/core/optim/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from omegaconf import DictConfig, OmegaConf
from torch.optim.lr_scheduler import _LRScheduler

from nemo.core.config import SchedulerParams, get_scheduler_config, register_scheduler_params
from nemo.core.config.schedulers import SchedulerParams, get_scheduler_config, register_scheduler_params
from nemo.utils import logging
from nemo.utils.model_utils import maybe_update_config_version

Expand Down
15 changes: 12 additions & 3 deletions nemo/core/optim/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
from torch.optim import adadelta, adagrad, adamax, rmsprop, rprop
from torch.optim.optimizer import Optimizer

from nemo.core.config import OptimizerParams, get_optimizer_config, register_optimizer_params
from nemo.core.config.optimizers import OptimizerParams, get_optimizer_config, register_optimizer_params
from nemo.core.optim.adafactor import Adafactor
from nemo.core.optim.adan import Adan
from nemo.core.optim.novograd import Novograd
from nemo.utils import logging

from nemo.utils.model_utils import maybe_update_config_version

AVAILABLE_OPTIMIZERS = {
Expand Down Expand Up @@ -195,14 +195,23 @@ def get_optimizer(name: str, **kwargs: Optional[Dict[str, Any]]) -> Optimizer:
)
if name == 'fused_adam':
if not torch.cuda.is_available():
raise ValueError(f'CUDA must be available to use fused_adam.')
raise ValueError('CUDA must be available to use fused_adam.')

optimizer = AVAILABLE_OPTIMIZERS[name]
optimizer = partial(optimizer, **kwargs)
return optimizer


def init_optimizer_states(optimizer: Optimizer):
"""
Initialize optimizer states for Adam-based optimizers.

This function initializes the exponential moving averages (exp_avg and exp_avg_sq)
for Adam, AdamW, and FusedAdam optimizers if they haven't been initialized yet.

Args:
optimizer: The optimizer instance to initialize states for
"""
adam_nondist_optims = (optim.Adam, optim.AdamW)
if HAVE_APEX:
adam_nondist_optims += (FusedAdam,)
Expand Down
8 changes: 8 additions & 0 deletions nemo/lightning/io/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import fiddle as fdl
import fiddle._src.experimental.dataclasses as fdl_dc
import lightning.pytorch as pl
from cloudpickle import dump
from cloudpickle import load as pickle_load
from fiddle._src import config as config_lib
Expand Down Expand Up @@ -189,6 +190,13 @@ def __new__(cls, *args, **kwargs):
def __init_subclass__(cls):
_io_register_serialization(cls)

# Add OneLogger timing hooks for data modules to enable telemetry tracking
if issubclass(cls, pl.LightningDataModule):
from nemo.lightning.pytorch.callbacks.callback_group import hook_class_init_with_callbacks

hook_class_init_with_callbacks(cls, "on_dataloader_init_start", "on_dataloader_init_end")
super().__init_subclass__()

def io_transform_args(self, init_fn, *args, **kwargs) -> Dict[str, Any]:
"""
Transforms and captures the arguments passed to the `__init__` method, filtering out
Expand Down
Loading
Loading