Skip to content

Commit

Permalink
ADLR/megatron-lm!1822 - ci: Fix process groups and flaky tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ko3n1g authored and terrykong committed Jul 30, 2024
1 parent c8e9aa2 commit 314450e
Show file tree
Hide file tree
Showing 15 changed files with 167 additions and 91 deletions.
25 changes: 11 additions & 14 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ workflow:
- if: $CI_PIPELINE_SOURCE == "schedule"
variables:
FUNCTIONAL_TEST: "yes"
UNIT_TEST_TIMEOUT: 180
UNIT_TEST_REPEAT: 10
- if: $CI_PIPELINE_SOURCE == "web"
- if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH
variables:
Expand Down Expand Up @@ -65,6 +67,8 @@ variables:
CI_MCORE_IMAGE: ${GITLAB_ENDPOINT}:5005/adlr/megatron-lm/mcore_ci
CI_NEMO_IMAGE: ${GITLAB_ENDPOINT}:5005/adlr/megatron-lm/nemo_ci
LINTING_IMAGE: ${GITLAB_ENDPOINT}:5005/adlr/megatron-lm/mcore_linting
UNIT_TEST_TIMEOUT: 15
UNIT_TEST_REPEAT: 1

metadata:
image: python:3.10
Expand Down Expand Up @@ -242,27 +246,20 @@ unit_tests:
image: ${CI_MCORE_IMAGE}:${CI_PIPELINE_ID}
stage: unit_tests
needs: [build_image]
timeout: 180m
tags:
- 8xL40S
rules:
- if: '$FUNCTIONAL_TEST == "no" && $CI_PIPELINE_SOURCE == "merge_request_event" && ($CI_MERGE_REQUEST_TARGET_BRANCH_NAME != $CI_DEFAULT_BRANCH && $CI_MERGE_REQUEST_TARGET_BRANCH_NAME !~ /^core_r/)'
allow_failure: true
- when: always
parallel:
matrix:
- DIR:
- data
- dist_checkpointing
- distributed
- fusions
- inference
- models
- pipeline_parallel
- tensor_parallel
- transformer
- '*.py'
script:
- torchrun --nproc_per_node=8 -m pytest -x -v -s --cov-report=term --cov-report=html --cov=megatron/core --no-cov-on-fail tests/unit_tests/$DIR
- |
for i in $(seq $UNIT_TEST_REPEAT); do
SEED=$((RANDOM % 9000 + 1000));
timeout ${UNIT_TEST_TIMEOUT}m torchrun --nproc_per_node=8 -m pytest --random-order --random-order-seed ${SEED} -xvs --cov-report=term --cov-report=html --cov=megatron/core --no-cov-on-fail tests/unit_tests
done
artifacts:
paths:
- coverage
Expand Down
31 changes: 17 additions & 14 deletions Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,6 @@ RUN apt-get update && \
RUN wget https://github.com/mikefarah/yq/releases/download/v4.44.1/yq_linux_amd64 -O /usr/local/bin/yq && \
chmod a+x /usr/local/bin/yq

RUN pip3 install --no-cache-dir \
einops \
flask-restful \
nltk \
pytest \
pytest-cov \
pytest_mock \
sentencepiece \
wrapt \
git+https://github.com/fanshiqing/[email protected] \
zarr \
tensorstore==0.1.45 \
wandb

##### For Mamba begin #####
RUN pip uninstall -y triton && \
pip install triton==2.1.0
Expand Down Expand Up @@ -69,6 +55,23 @@ RUN apt-get install -y python3-venv && \
python -m venv /opt/jet
##### For JET-API end #####

RUN pip3 install --no-cache-dir \
einops \
flask-restful \
nltk \
pytest \
pytest-cov \
pytest_mock \
pytest-random-order \
sentencepiece \
wrapt \
git+https://github.com/fanshiqing/[email protected] \
zarr \
tensorstore==0.1.45 \
wandb

COPY . /workspace/megatron-lm

COPY . /workspace/megatron-lm
RUN cp -r /workspace/megatron-lm /opt && \
pip install /opt/megatron-lm
Expand Down
33 changes: 33 additions & 0 deletions megatron/core/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,61 +1373,94 @@ def destroy_model_parallel():
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None

global _MODEL_AND_EXPERT_PARALLEL_GROUP
_MODEL_AND_EXPERT_PARALLEL_GROUP = None

global _TENSOR_MODEL_PARALLEL_GROUP
_TENSOR_MODEL_PARALLEL_GROUP = None

global _PIPELINE_MODEL_PARALLEL_GROUP
_PIPELINE_MODEL_PARALLEL_GROUP = None

global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None

global _DATA_PARALLEL_GROUP_WITH_CP
_DATA_PARALLEL_GROUP_WITH_CP = None

global _CONTEXT_PARALLEL_GROUP
_CONTEXT_PARALLEL_GROUP = None

global _CONTEXT_PARALLEL_GLOBAL_RANKS
_CONTEXT_PARALLEL_GLOBAL_RANKS = None

global _EMBEDDING_GROUP
_EMBEDDING_GROUP = None

global _POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP = None

global _TENSOR_AND_DATA_PARALLEL_GROUP
_TENSOR_AND_DATA_PARALLEL_GROUP = None

global _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = None

global _TENSOR_AND_CONTEXT_PARALLEL_GROUP
_TENSOR_AND_CONTEXT_PARALLEL_GROUP = None

global _EXPERT_MODEL_PARALLEL_GROUP
_EXPERT_MODEL_PARALLEL_GROUP = None

global _TENSOR_AND_EXPERT_PARALLEL_GROUP
_TENSOR_AND_EXPERT_PARALLEL_GROUP = None

global _DATA_MODULO_EXPERT_PARALLEL_GROUP
_DATA_MODULO_EXPERT_PARALLEL_GROUP = None

global _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP
_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP = None

global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None

global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None

global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None

global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None

global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK = None

global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None

global _GLOBAL_MEMORY_BUFFER
_GLOBAL_MEMORY_BUFFER = None

global _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = None

global _MPU_EXPERT_MODEL_PARALLEL_RANK
_MPU_EXPERT_MODEL_PARALLEL_RANK = None

global _DATA_PARALLEL_GROUP_GLOO
if _DATA_PARALLEL_GROUP_GLOO is not None:
torch.distributed.destroy_process_group(_DATA_PARALLEL_GROUP_GLOO)
_DATA_PARALLEL_GROUP_GLOO = None

global _DATA_PARALLEL_GROUP_WITH_CP_GLOO
_DATA_PARALLEL_GROUP_WITH_CP_GLOO = None

global _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
if _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO is not None:
torch.distributed.destroy_process_group(_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO)
_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO = None

global _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO
_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO = None
5 changes: 3 additions & 2 deletions tests/unit_tests/data/test_preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tempfile

import nltk
import pytest
import requests

from megatron.core.datasets.indexed_dataset import IndexedDataset
Expand Down Expand Up @@ -183,7 +184,7 @@ def gpt2_merge(odir):
writer.write(requests.get(PRETRAINED_MERGES_ARCHIVE_MAP['gpt2']).content)
return path


@pytest.mark.skip(reason="Tests are flaky and need to be debugged")
def test_preprocess_data_gpt():
with tempfile.TemporaryDirectory() as temp_dir:

Expand Down Expand Up @@ -213,7 +214,7 @@ def bert_vocab(odir):
writer.write(requests.get(__HUGGINGFACE_BERT_BASE_UNCASED_VOCAB).content)
return path


@pytest.mark.skip(reason="Tests are flaky and need to be debugged")
def test_preprocess_data_bert():
with tempfile.TemporaryDirectory() as temp_dir:

Expand Down
6 changes: 6 additions & 0 deletions tests/unit_tests/dist_checkpointing/test_async_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ def write_data_os_err_mock_fn(local_proc_idx, write_bucket, results_queue, count


class TestAsyncSave:
def setup_method(self, method):
pass

def teardown_method(self, method):
Utils.destroy_model_parallel()

def test_async_is_equivalent_to_sync(self, tmp_path_dist_ckpt):
Utils.initialize_model_parallel(2, 4)

Expand Down
6 changes: 6 additions & 0 deletions tests/unit_tests/dist_checkpointing/test_cached_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@


class TestCachedMetadata:
def setup_method(self, method):
pass

def teardown_method(self, method):
Utils.destroy_model_parallel()

def test_cached_metadata(self, tmp_path_dist_ckpt):
Utils.initialize_model_parallel(2, 4)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@


class TestFlattenedResharding:
def setup_method(self, method):
pass

def teardown_method(self, method):
Utils.destroy_model_parallel()

@pytest.mark.parametrize(
('src_tp_pp', 'dest_tp_pp',),
[
Expand Down
46 changes: 32 additions & 14 deletions tests/unit_tests/dist_checkpointing/test_fully_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,27 @@

import numpy as np
import pytest

import torch

from megatron.core import parallel_state
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.dict_utils import nested_values, \
map_reduce, dict_list_map_outplace
from megatron.core.dist_checkpointing.dict_utils import (
dict_list_map_outplace,
map_reduce,
nested_values,
)
from megatron.core.dist_checkpointing.mapping import is_main_replica
from megatron.core.dist_checkpointing.strategies.base import \
SaveShardedStrategy, LoadShardedStrategy
from megatron.core.dist_checkpointing.strategies.fully_parallel import \
FullyParallelSaveStrategyWrapper, _sharded_tensor_shard_id, \
FullyParallelLoadStrategyWrapper, _ShardId
from megatron.core.dist_checkpointing.strategies.base import (
LoadShardedStrategy,
SaveShardedStrategy,
)
from megatron.core.dist_checkpointing.strategies.fully_parallel import (
FullyParallelLoadStrategyWrapper,
FullyParallelSaveStrategyWrapper,
_sharded_tensor_shard_id,
_ShardId,
)
from tests.unit_tests.dist_checkpointing import TempNamedDir
from tests.unit_tests.test_utilities import Utils


Expand Down Expand Up @@ -59,6 +67,12 @@ def check_version_compatibility(self, loaded_version):


class TestFullyParallelSaveAndLoad:
def setup_method(self, method):
pass

def teardown_method(self, method):
Utils.destroy_model_parallel()

@staticmethod
def get_sharded_state_dict():
return {
Expand All @@ -75,7 +89,7 @@ def get_sharded_state_dict():
}

@pytest.mark.parametrize("parallelization_along_dp", [False, True])
def test_save_distribution(self, parallelization_along_dp):
def test_save_distribution(self, parallelization_along_dp, tmp_path_dist_ckpt):
Utils.initialize_model_parallel(2, 1)
state_dict = self.get_sharded_state_dict()

Expand Down Expand Up @@ -122,7 +136,8 @@ def test_save_distribution(self, parallelization_along_dp):
save_strategy = FullyParallelSaveStrategyWrapper(mock_strategy,
parallelization_group,
do_cache_distribution=True)
save_strategy.save(state_dict, Path('mock_dir'))
with TempNamedDir(tmp_path_dist_ckpt / 'mock_dir') as ckpt_dir_A:
save_strategy.save(state_dict, ckpt_dir_A)
key_to_saving_rank = dict(map_reduce(save_strategy.cached_distribution.main_rank_for_shard.items(), lambda shard_rank: shard_rank[0][0], lambda shard_rank: shard_rank[1]))
assert expected_key_to_saving_ranks == key_to_saving_rank

Expand All @@ -134,7 +149,7 @@ def test_save_distribution(self, parallelization_along_dp):
assert mock_strategy.save_keys == expected_keys_saved_by_current_rank, (Utils.rank, mock_strategy.save_keys, expected_keys_saved_by_current_rank)

@pytest.mark.parametrize("parallelization_along_dp", [False, True])
def test_load_distribution(self, parallelization_along_dp):
def test_load_distribution(self, parallelization_along_dp, tmp_path_dist_ckpt):
Utils.initialize_model_parallel(2, 1)

state_dict = self.get_sharded_state_dict()
Expand Down Expand Up @@ -174,16 +189,18 @@ def test_load_distribution(self, parallelization_along_dp):
load_strategy = FullyParallelLoadStrategyWrapper(mock_strategy,
parallelization_group,
do_cache_distribution=True)
loaded_state_dict = load_strategy.load(state_dict, Path('mock_dir'))
with TempNamedDir(tmp_path_dist_ckpt / 'mock_dir') as ckpt_dir_A:
loaded_state_dict = load_strategy.load(state_dict, ckpt_dir_A)
key_to_saving_rank = dict(map_reduce(load_strategy.cached_distribution.main_rank_for_shard.items(), lambda shard_rank: shard_rank[0][0], lambda shard_rank: shard_rank[1]))
assert expected_key_to_saving_ranks == key_to_saving_rank

assert mock_strategy.load_keys == expected_keys_saved_by_current_rank, (Utils.rank, mock_strategy.load_keys, expected_keys_saved_by_current_rank)

assert loaded_state_dict.keys() == state_dict.keys()

@pytest.mark.skip(reason="Tests are flaky and need to be debugged")
@pytest.mark.parametrize('state_dict_device', ['cpu', 'cuda'])
def test_memory_usage(self, state_dict_device):
def test_memory_usage(self, state_dict_device, tmp_path_dist_ckpt):
Utils.initialize_model_parallel(2, 1)

megabytes = 1024 * 1024
Expand All @@ -210,7 +227,8 @@ def _get_empty_tensor_for_exchange(self, *args, **kwargs) -> torch.Tensor:

mem_alloc_start = torch.cuda.memory_allocated()

loaded_state_dict = load_strategy.load(sharded_state_dict, Path('mock_dir'))
with TempNamedDir(tmp_path_dist_ckpt / 'mock_dir') as ckpt_dir_A:
loaded_state_dict = load_strategy.load(sharded_state_dict, ckpt_dir_A)

# Each rank is expected to do 7 * 10 empty allocations
assert len(mem_alloc) == 7 * 10
Expand Down
6 changes: 6 additions & 0 deletions tests/unit_tests/dist_checkpointing/test_nonpersistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
from tests.unit_tests.test_utilities import Utils

class TestNonPersistentSaveAndLoad:
def setup_method(self, method):
pass

def teardown_method(self, method):
Utils.destroy_model_parallel()

@pytest.mark.parametrize(
('tp,pp'),
[
Expand Down
Loading

0 comments on commit 314450e

Please sign in to comment.