Skip to content

Commit d7085fd

Browse files
committed
chore: resolve some github comments.
1 parent 129615e commit d7085fd

File tree

5 files changed

+32
-56
lines changed

5 files changed

+32
-56
lines changed

nemo/collections/llm/fn/mixin.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from torch import nn
1616
from typing_extensions import Self
17+
import lightning.pytorch as pl
1718

1819
from nemo.collections.llm.fn import base as fn
1920
from nemo.utils import logging
@@ -52,18 +53,12 @@ class FNMixin:
5253

5354
def __init_subclass__(cls, **kwargs):
5455
# Add OneLogger timing hooks for LightningModule subclasses to enable telemetry tracking
55-
try:
56-
import lightning.pytorch as pl
57-
58-
if issubclass(cls, pl.LightningModule):
59-
from nemo.lightning.one_logger_callback import hook_class_init_with_callbacks
60-
61-
hook_class_init_with_callbacks(cls, "on_model_init_start", "on_model_init_end")
62-
except Exception:
63-
# Continue gracefully if OneLogger hooks cannot be applied
64-
pass
65-
finally:
66-
super().__init_subclass__(**kwargs)
56+
if issubclass(cls, pl.LightningModule):
57+
from nemo.lightning.one_logger_callback import hook_class_init_with_callbacks
58+
59+
hook_class_init_with_callbacks(cls, "on_model_init_start", "on_model_init_end")
60+
61+
super().__init_subclass__(**kwargs)
6762

6863
def forall(self, func: fn.ModulePredicate, recurse: bool = False) -> bool:
6964
"""

nemo/collections/vlm/neva/data/preloaded.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -516,10 +516,9 @@ def __init__(
516516
num_image_embeddings_per_tile: int = 576,
517517
seed: int = 1234,
518518
) -> None:
519-
if not hasattr(self, "_one_logger_init_started"):
520-
from nemo.lightning.one_logger_callback import call_one_logger_callback
519+
from nemo.lightning.one_logger_callback import call_one_logger_callback
521520

522-
call_one_logger_callback("on_dataloader_init_start")
521+
call_one_logger_callback("on_dataloader_init_start")
523522

524523
super().__init__()
525524
if not isinstance(paths, (list, tuple)):
@@ -581,8 +580,7 @@ def custom_on_megatron_step_start(self, step):
581580
dataloader_type="cyclic",
582581
)
583582

584-
if not hasattr(self, "_one_logger_init_started"):
585-
call_one_logger_callback("on_dataloader_init_end")
583+
call_one_logger_callback("on_dataloader_init_end")
586584

587585
def setup(self, stage: str = "") -> None:
588586
assert len(self.paths) == 1, "not yet support blend dataset in Neva 2.0!"

nemo/lightning/io/mixin.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from fiddle._src import partial
3434
from fiddle._src.experimental import serialization
3535
from typing_extensions import Self
36+
import lightning.pytorch as pl
3637

3738
from nemo.lightning.io.artifact.base import Artifact
3839
from nemo.lightning.io.capture import IOProtocol
@@ -190,18 +191,11 @@ def __init_subclass__(cls):
190191
_io_register_serialization(cls)
191192

192193
# Add OneLogger timing hooks for data modules to enable telemetry tracking
193-
try:
194-
import lightning.pytorch as pl
195-
196-
if issubclass(cls, pl.LightningDataModule):
197-
from nemo.lightning.one_logger_callback import hook_class_init_with_callbacks
194+
if issubclass(cls, pl.LightningDataModule):
195+
from nemo.lightning.one_logger_callback import hook_class_init_with_callbacks
198196

199-
hook_class_init_with_callbacks(cls, "on_dataloader_init_start", "on_dataloader_init_end")
200-
except Exception:
201-
# Continue gracefully if OneLogger hooks cannot be applied
202-
pass
203-
finally:
204-
super().__init_subclass__()
197+
hook_class_init_with_callbacks(cls, "on_dataloader_init_start", "on_dataloader_init_end")
198+
super().__init_subclass__()
205199

206200
def io_transform_args(self, init_fn, *args, **kwargs) -> Dict[str, Any]:
207201
"""

nemo/lightning/one_logger_callback.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -153,22 +153,15 @@ def init_one_logger() -> None:
153153
if not enable_one_logger:
154154
return
155155

156-
try:
157-
# Check if OneLogger is already configured
158-
if TrainingTelemetryProvider.instance().one_logger_ready:
159-
return
160-
161-
# Get initialization configuration
162-
init_config = get_one_logger_init_config()
163-
one_logger_config = OneLoggerConfig(**init_config)
164-
165-
# Configure the provider with entry-point exporters (automatically calls on_app_start)
166-
TrainingTelemetryProvider.instance().with_base_config(
167-
one_logger_config
168-
).with_export_config().configure_provider()
169-
_ONE_LOGGER_CALLBACK = OneLoggerNeMoCallback(TrainingTelemetryProvider.instance())
170-
except Exception:
171-
_HAVE_ONE_LOGGER = False
156+
# Get initialization configuration
157+
init_config = get_one_logger_init_config()
158+
one_logger_config = OneLoggerConfig(**init_config)
159+
160+
# Configure the provider with entry-point exporters (automatically calls on_app_start)
161+
TrainingTelemetryProvider.instance().with_base_config(
162+
one_logger_config
163+
).with_export_config().configure_provider()
164+
_ONE_LOGGER_CALLBACK = OneLoggerNeMoCallback(TrainingTelemetryProvider.instance())
172165

173166

174167
def update_one_logger_config(

nemo/utils/meta_info_manager.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import os
1919
from typing import Any, Dict
2020

21+
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
22+
2123
enable_one_logger = True
2224

2325

@@ -41,7 +43,7 @@ def get_one_logger_init_config() -> Dict[str, Any]:
4143
# Minimal configuration - required fields only
4244
init_config = {
4345
# Required fields (from OneLoggerConfig) - no defaults
44-
"application_name": "nemo-application",
46+
"application_name": "nemo",
4547
"session_tag_or_fn": session_tag,
4648
# Important fields with defaults - provide if available from config
4749
"enable_for_current_rank": _should_enable_for_current_rank(),
@@ -82,7 +84,6 @@ def _get_base_callback_config(
8284

8385
world_size = int(os.environ.get('WORLD_SIZE', 1))
8486
max_steps = getattr(trainer, 'max_steps', 1)
85-
# Use hardcoded value for log_every_n_steps instead of getting from trainer
8687
log_every_n_steps = getattr(trainer, 'log_every_n_steps', 10)
8788
micro_batch_size = global_batch_size // world_size
8889
# Get PERF_VERSION_TAG from environment
@@ -99,7 +100,7 @@ def _get_base_callback_config(
99100
is_validation_iterations_enabled = False
100101
save_checkpoint_strategy = "sync"
101102

102-
checkpoint_callbacks = [cb for cb in trainer.callbacks if 'Checkpoint' in type(cb).__name__]
103+
checkpoint_callbacks = [cb for cb in trainer.callbacks if isinstance(cb, ModelCheckpoint)]
103104
is_save_checkpoint_enabled = len(checkpoint_callbacks) > 0
104105

105106
val_check_interval = getattr(trainer, 'val_check_interval', -1)
@@ -166,6 +167,7 @@ def get_nemo_v1_callback_config(trainer: Any) -> Dict[str, Any]:
166167
global_batch_size = micro_batch_size * int(os.environ.get('WORLD_SIZE', 1))
167168
elif hasattr(model_cfg, 'train_ds') and hasattr(model_cfg.train_ds, 'bucket_batch_size'):
168169
# For ASR with bucketing, use the average batch size
170+
# This is a temporary fix to support the bucketing
169171
bucket_batch_sizes = model_cfg.train_ds.bucket_batch_size
170172
# Handle both ListConfig and regular list types
171173
if hasattr(bucket_batch_sizes, '__len__') and len(bucket_batch_sizes) > 0:
@@ -232,12 +234,6 @@ def _should_enable_for_current_rank() -> bool:
232234
Returns:
233235
True if OneLogger should be enabled for the current rank, False otherwise
234236
"""
235-
try:
236-
rank = int(os.environ.get('RANK', 0))
237-
world_size = int(os.environ.get('WORLD_SIZE', 1))
238-
239-
# Enable for rank 0 or the last rank (common pattern)
240-
return rank == 0 or rank == world_size - 1
241-
except (ValueError, TypeError):
242-
# Default to True on invalid values (as expected by tests)
243-
return True
237+
rank = int(os.environ.get('RANK', 0))
238+
# Enable for rank 0 or the last rank (common pattern)
239+
return rank == 0

0 commit comments

Comments
 (0)