Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move EarlyStopper to recipes #951

Merged
merged 1 commit into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 0 additions & 24 deletions src/fairseq2/early_stopper.py

This file was deleted.

32 changes: 32 additions & 0 deletions src/fairseq2/recipes/early_stopper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import final

from typing_extensions import override


class EarlyStopper(ABC):
"""Stops training when an implementation-specific condition is not met."""

@abstractmethod
def should_stop(self, step_nr: int, score: float) -> bool:
"""
:param step_nr: The number of the current training step.
:para score: The validation score of the current training step.

:returns: ``True`` if the training should be stopped; otherwise, ``False``.
"""


@final
class NoopEarlyStopper(EarlyStopper):
@override
def should_stop(self, step_nr: int, score: float) -> bool:
return False
8 changes: 5 additions & 3 deletions src/fairseq2/recipes/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

from fairseq2.checkpoint import CheckpointManager, CheckpointNotFoundError
from fairseq2.datasets import DataReader
from fairseq2.early_stopper import EarlyStopper
from fairseq2.error import ContractError, InternalError, InvalidOperationError
from fairseq2.gang import FakeGang, Gang, broadcast_flag
from fairseq2.logging import log
Expand All @@ -51,6 +50,7 @@
from fairseq2.optim import DynamicLossScaler
from fairseq2.optim.lr_scheduler import LRScheduler, NoopLR, get_effective_lr
from fairseq2.recipes.common_metrics import extend_batch_metrics
from fairseq2.recipes.early_stopper import EarlyStopper, NoopEarlyStopper
from fairseq2.recipes.evaluator import EvalUnit
from fairseq2.recipes.utils.rich import create_rich_progress
from fairseq2.typing import CPU, DataType
Expand Down Expand Up @@ -403,7 +403,7 @@ def __init__(
)

if root_gang.rank != 0:
early_stopper = lambda step_nr, score: False
early_stopper = NoopEarlyStopper()

self._early_stopper = early_stopper
else:
Expand Down Expand Up @@ -1075,7 +1075,9 @@ def _maybe_request_early_stop(self) -> None:
if self._valid_score is None:
raise InternalError("Early stopping, but `_valid_score` is `None`.")

should_stop = self._early_stopper(self._step_nr, self._valid_score)
should_stop = self._early_stopper.should_stop(
self._step_nr, self._valid_score
)
else:
should_stop = False

Expand Down
Loading