diff --git a/.coverage b/.coverage deleted file mode 100644 index 26a3c59252..0000000000 Binary files a/.coverage and /dev/null differ diff --git a/.coveragerc b/.coveragerc index 13612a43ee..29de6ff8a3 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,2 +1,5 @@ [html] -directory = coverage \ No newline at end of file +directory = coverage + +[run] +data_file = .coverage_$LOCAL_RANK diff --git a/.gitignore b/.gitignore index 0cca053883..e99e246e1a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__ *.so build +.coverage_* *.egg-info diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 91d9330d60..115a6e59a2 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -4,8 +4,7 @@ test: tags: - docker_gpu_enabled script: - - nvidia-smi - - torchrun --nproc_per_node=2 -m pytest --cov-report=term --cov-report=html --cov=megatron/core tests/ + - torchrun --nproc_per_node=8 -m pytest --cov-report=term --cov-report=html --cov=megatron/core tests/ coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/' artifacts: paths: diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py index dc76695aab..23059fc1f5 100644 --- a/megatron/core/tensor_parallel/random.py +++ b/megatron/core/tensor_parallel/random.py @@ -22,6 +22,8 @@ gather_split_1d_tensor, ) +from megatron.core.utils import safely_set_viewless_tensor_data + # Default name for the model parallel rng tracker. _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/tensor_parallel/test_cross_entropy.py b/tests/tensor_parallel/test_cross_entropy.py new file mode 100644 index 0000000000..2a725a2715 --- /dev/null +++ b/tests/tensor_parallel/test_cross_entropy.py @@ -0,0 +1,14 @@ +from megatron.core.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy +import torch +from tests.test_utilities import Utils +import numpy as np + +def test_vocab_parallel_cross_entropy(): + Utils.initialize_model_parallel(4,2) + vocab_parallel_logits = torch.range(0,7).repeat(16,4).cuda() + target = torch.arange(0,32,2).cuda() + output = vocab_parallel_cross_entropy(vocab_parallel_logits, target) + expected_output = torch.tensor([10.2309, 8.2309, 6.2309, 4.2309, 10.2309, 8.2309, 6.2309, 4.2309, + 10.2309, 8.2309, 6.2309, 4.2309, 10.2309, 8.2309, 6.2309, 4.2309]).cuda() + assert(torch.equal(torch.round(expected_output), torch.round(output))) + Utils.destroy_model_parallel() \ No newline at end of file diff --git a/tests/tensor_parallel/test_data.py b/tests/tensor_parallel/test_data.py new file mode 100644 index 0000000000..d7948474a7 --- /dev/null +++ b/tests/tensor_parallel/test_data.py @@ -0,0 +1,21 @@ +from megatron.core.tensor_parallel.data import broadcast_data +import torch +from tests.test_utilities import Utils + +def test_broadcast_data(): + Utils.initialize_model_parallel(2,4) + input_data = { + 0 : torch.ones((8,8)).cuda() * 0.0, + 1 : torch.ones((8,8)).cuda() * 1.0, + 2 : torch.ones((8,8)).cuda() * 2.0, + 3 : torch.ones((8,8)).cuda() * 3.0, + 4 : torch.ones((8,8)).cuda() * 4.0, + 5 : torch.ones((8,8)).cuda() * 5.0, + 6 : torch.ones((8,8)).cuda() * 6.0, + 7 : torch.ones((8,8)).cuda() * 7.0 + } + dtype = torch.float32 + actual_output = broadcast_data([0,1],input_data, dtype) + assert(torch.equal(actual_output[0], input_data[0])) + assert(torch.equal(actual_output[1], input_data[1])) + Utils.destroy_model_parallel() \ No newline at end of file diff --git a/tests/tensor_parallel/test_mappings.py b/tests/tensor_parallel/test_mappings.py new file mode 100644 index 0000000000..52040a2edf --- /dev/null +++ b/tests/tensor_parallel/test_mappings.py @@ -0,0 +1,135 @@ +from megatron.core.tensor_parallel import mappings +from tests.test_utilities import Utils +import torch + +def test_CopyToModelParallelRegion(): + Utils.initialize_model_parallel(4,2) + input_data = torch.ones((1)).cuda()*Utils.rank + output_data = mappings._CopyToModelParallelRegion.backward(None, input_data) + result = torch.ones(1).cuda() + result = result * 22 if Utils.rank >= 4 else result * 6 + assert(torch.equal(output_data, result)) + assert(torch.equal(input_data, mappings.copy_to_tensor_model_parallel_region(input_data))) + assert(torch.equal(input_data, mappings._CopyToModelParallelRegion.symbolic(None, input_data))) + Utils.destroy_model_parallel() + +def test_ReduceFromModelParallelRegion(): + Utils.initialize_model_parallel(4,2) + input_data = torch.ones((1)).cuda()*Utils.rank + output_data = mappings._ReduceFromModelParallelRegion.symbolic(None, input_data) + result = torch.ones(1).cuda() + result = result * 22 if Utils.rank >= 4 else result * 6 + assert(torch.equal(output_data, result)) + input_data = torch.ones((1)).cuda()*Utils.rank + assert(torch.equal(mappings.reduce_from_tensor_model_parallel_region(input_data), result)) + assert(torch.equal(input_data, mappings._ReduceFromModelParallelRegion.backward(None, input_data))) + Utils.destroy_model_parallel() + +def test_ScatterToModelParallelRegion(): + Utils.initialize_model_parallel(4,2) + input_data = torch.rand((8,4)).cuda() + output_data = mappings.scatter_to_tensor_model_parallel_region(input_data) + req_dim = int(Utils.rank%(Utils.world_size/2)) + assert(torch.equal(output_data, input_data[:,req_dim].reshape((8,1)))) + output_data = mappings._ScatterToModelParallelRegion.symbolic(None, input_data) + assert(torch.equal(output_data, input_data[:, req_dim].reshape((8,1)))) + + input_data = torch.ones(8).cuda() * Utils.rank + actual_output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data) + expected_output = torch.cat(( + torch.ones(8)*0, + torch.ones(8)*1, + torch.ones(8)*2, + torch.ones(8)*3)).cuda() + if (Utils.rank >= 4): + expected_output = expected_output + 4 + assert(torch.equal(actual_output_data, expected_output)) + Utils.destroy_model_parallel() + +def test_GatherFromModelParallelRegion(): + Utils.initialize_model_parallel(4,2) + input_data = torch.rand((8,4)).cuda() + req_dim = int(Utils.rank%(Utils.world_size/2)) + output_data = mappings._GatherFromModelParallelRegion.backward(None, input_data) + assert(torch.equal(output_data, input_data[:, req_dim].reshape((8,1)))) + input_data = torch.ones(8).cuda() * Utils.rank + actual_output_data = mappings.gather_from_tensor_model_parallel_region(input_data) + expected_output = torch.cat(( + torch.ones(8)*0, + torch.ones(8)*1, + torch.ones(8)*2, + torch.ones(8)*3)).cuda() + if (Utils.rank >= 4): + expected_output = expected_output + 4 + assert(torch.equal(actual_output_data, expected_output)) + assert(torch.equal(mappings._GatherFromModelParallelRegion.symbolic(None, input_data), expected_output)) + Utils.destroy_model_parallel() + +def test_ScatterToSequenceParallelRegion(): + Utils.initialize_model_parallel(4,2) + input_data = torch.rand((8,4)).cuda() + req_dim = int(Utils.rank%(Utils.world_size/2))*2 + output_data = mappings._ScatterToSequenceParallelRegion.symbolic(None, input_data) + assert(torch.equal(output_data, input_data[req_dim:req_dim+2, :])) + output_data = mappings.scatter_to_sequence_parallel_region(input_data) + assert(torch.equal(output_data, input_data[req_dim:req_dim+2, :])) + input_data = torch.ones(4).cuda() * Utils.rank + output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data) + expected_output = torch.concat(( + torch.ones(4)*0, + torch.ones(4)*1, + torch.ones(4)*2, + torch.ones(4)*3)).cuda() + if (Utils.rank >= 4): + expected_output = expected_output + 4 + assert(torch.equal(output_data, expected_output)) + Utils.destroy_model_parallel() + +def test_GatherFromSequenceParallelRegion(): + Utils.initialize_model_parallel(4,2) + input_data = torch.ones(4).cuda() * Utils.rank + output_data = mappings.gather_from_sequence_parallel_region(input_data) + expected_output = torch.concat(( + torch.ones(4)*0, + torch.ones(4)*1, + torch.ones(4)*2, + torch.ones(4)*3)).cuda() + if (Utils.rank >= 4): + expected_output = expected_output + 4 + assert(torch.equal(output_data, expected_output)) + assert(torch.equal(mappings._GatherFromSequenceParallelRegion.symbolic(None, input_data), expected_output)) + input_data = torch.vstack(( + torch.ones(4)*0, + torch.ones(4)*1, + torch.ones(4)*2, + torch.ones(4)*3)).cuda() + class Ctx: + tensor_parallel_output_grad = True + output_data = mappings._GatherFromSequenceParallelRegion.backward(Ctx(), input_data) + expected_output = torch.ones((1,4)).cuda() * 4 * int(Utils.rank % 4) + assert(torch.equal(output_data[0], expected_output)) + Utils.destroy_model_parallel() + +def test_ReduceScatterToSequenceParallelRegion(): + Utils.initialize_model_parallel(4,2) + input_data = torch.vstack(( + torch.ones(4)*0, + torch.ones(4)*1, + torch.ones(4)*2, + torch.ones(4)*3)).cuda() + output_data = mappings.reduce_scatter_to_sequence_parallel_region(input_data) + expected_output = torch.ones(4).cuda() * 4 * int(Utils.rank % 4) + assert(torch.equal(output_data[0], expected_output)) + assert(torch.equal(mappings._ReduceScatterToSequenceParallelRegion.symbolic(None, input_data) , expected_output.reshape((1,4)))) + input_data = torch.ones(4).cuda() * Utils.rank + output_data = mappings._ReduceScatterToSequenceParallelRegion.backward(None,input_data) + expected_output = torch.concat(( + torch.ones(4)*0, + torch.ones(4)*1, + torch.ones(4)*2, + torch.ones(4)*3)).cuda() + if (Utils.rank >= 4): + expected_output = expected_output + 4 + assert(torch.equal(output_data, expected_output)) + Utils.destroy_model_parallel() + diff --git a/tests/tensor_parallel/test_random.py b/tests/tensor_parallel/test_random.py new file mode 100644 index 0000000000..8aaf4b855c --- /dev/null +++ b/tests/tensor_parallel/test_random.py @@ -0,0 +1,44 @@ +from megatron.core.tensor_parallel.random import CudaRNGStatesTracker +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.tensor_parallel.random import _CUDA_RNG_STATE_TRACKER +from megatron.core.tensor_parallel.random import checkpoint +from tests.test_utilities import Utils +import pytest +import torch + +def test_cuda_rng_states_tracker(): + rng_tracker = CudaRNGStatesTracker() + rng_tracker.set_states({"state1":1234}) + assert(rng_tracker.get_states()["state1"] == 1234) + rng_tracker.reset() + assert(rng_tracker.get_states() == {}) + seed = 1111 + rng_tracker.add("state2",seed) + with pytest.raises(Exception): + assert(rng_tracker.add("state3",seed)) + with pytest.raises(Exception): + assert(rng_tracker.add("state2",111)) + assert(rng_tracker.get_states()['state2'] is not None) + with pytest.raises(Exception): + assert() + + rng_tracker.fork("state2") + torch.cuda.manual_seed(seed) + rng_state = torch.cuda.get_rng_state() + assert torch.equal(rng_tracker.get_states()['state2'], rng_state) + +def test_model_parallel_cuda_manual_seed(): + Utils.initialize_model_parallel(4,2) + model_parallel_cuda_manual_seed(0) + assert(_CUDA_RNG_STATE_TRACKER.get_states()['model-parallel-rng'] is not None) + Utils.destroy_model_parallel() + +def test_checkpoint(): + def test_forward(*input): + return input[0]+input[1] + assert(torch.equal(torch.ones(16)*3,checkpoint(test_forward, None, torch.ones(16), torch.ones(16)*2))) + Utils.initialize_model_parallel() + input1 = torch.ones((4,4)) + checkpoint(test_forward, True, input1, torch.ones((4,4))*2) + assert(torch.equal(torch.ones(input1.numel()).cuda(), input1)) + Utils.destroy_model_parallel() \ No newline at end of file diff --git a/tests/tensor_parallel/test_tensor_parallel_utils.py b/tests/tensor_parallel/test_tensor_parallel_utils.py index 872be90c17..5aae470f4f 100644 --- a/tests/tensor_parallel/test_tensor_parallel_utils.py +++ b/tests/tensor_parallel/test_tensor_parallel_utils.py @@ -1,7 +1,43 @@ import torch import megatron.core.tensor_parallel.utils as util +import megatron.core.parallel_state as ps +from tests.test_utilities import Utils + +rank = Utils.rank def test_split_tensor_along_last_dim(): input_tensor = torch.rand((3,4)) torch.equal(input_tensor[0:2,0:2], util.split_tensor_along_last_dim(input_tensor,2)[0]) torch.equal(input_tensor[2:,2:], util.split_tensor_along_last_dim(input_tensor,2)[1]) + +def test_split_tensor_into_1d_equal_chunks(): + Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4) + input_tensor = torch.rand((3,4)) + output_tensor = util.split_tensor_into_1d_equal_chunks(input_tensor) + if rank % 2 == 0 : + start = 0 + end = int(input_tensor.numel()/2) + else : + start = int(input_tensor.numel()/2) + end = input_tensor.numel() + + assert torch.equal(output_tensor, input_tensor.flatten()[start:end]) + Utils.destroy_model_parallel() + +def test_gather_split_1d_tensor(): + Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4) + input_tensor = torch.ones((2,4)).cuda() * rank + actual_output_tensor = util.gather_split_1d_tensor(input_tensor) + if rank %2 == 0: + expected_output_tensor = torch.concat((input_tensor.flatten(), input_tensor.flatten() + 1)) + else : + expected_output_tensor = torch.concat((input_tensor.flatten() - 1, input_tensor.flatten())) + assert(torch.equal(actual_output_tensor, expected_output_tensor)) + Utils.destroy_model_parallel() + +def test_vocab(): + global_vocab_size = 1600 + per_partition_vocab_size = 1600 / Utils.world_size + assert((rank * per_partition_vocab_size, (rank + 1)* per_partition_vocab_size) == (util.VocabUtility.vocab_range_from_per_partition_vocab_size(global_vocab_size // Utils.world_size, rank, Utils.world_size))) + assert((rank * per_partition_vocab_size, (rank + 1)* per_partition_vocab_size) == (util.VocabUtility.vocab_range_from_global_vocab_size(global_vocab_size, rank, Utils.world_size))) + \ No newline at end of file diff --git a/tests/test_parallel_state.py b/tests/test_parallel_state.py index 5fdd09fee4..de9c550e60 100644 --- a/tests/test_parallel_state.py +++ b/tests/test_parallel_state.py @@ -1,41 +1,16 @@ -import os import torch import megatron.core.parallel_state as ps -from datetime import timedelta import pytest +from tests.test_utilities import Utils +import os +rank = Utils.rank +world_size = Utils.world_size -world_size = torch.cuda.device_count() -rank = int(os.environ['LOCAL_RANK']) -print('Ranks is : ' + str(rank)) - -def initialize_distributed(): - print(f'Initializing torch.distributed with rank: {rank}, world_size: {world_size}') - torch.cuda.set_device(rank % torch.cuda.device_count()) - init_method = 'tcp://' - master_ip = os.getenv('MASTER_ADDR', 'localhost') - master_port = os.getenv('MASTER_PORT', '6000') - init_method += master_ip + ':' + master_port - torch.distributed.init_process_group(backend='nccl', world_size=world_size, rank=rank, init_method=init_method, timeout=timedelta(seconds=10)) - -def initialize_model_parallel( - tensor_model_parallel_size: int = 1, - pipeline_model_parallel_size: int = 1, - virtual_pipeline_model_parallel_size = None, - pipeline_model_parallel_split_rank = None, -): - # This might not be the right way to do this. - try: - ps.initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank) - except: - ps.destroy_model_parallel() - ps.initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank) - pass - -def test_initialize_model_parallel(): +def test_initialize__and_destroy_model_parallel(): with pytest.raises(AssertionError): assert(ps.initialize_model_parallel()) - initialize_distributed() + Utils.initialize_distributed() with pytest.raises(RuntimeError): assert(ps.initialize_model_parallel(tensor_model_parallel_size=2*world_size)) with pytest.raises(RuntimeError): @@ -44,124 +19,86 @@ def test_initialize_model_parallel(): assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size)) with pytest.raises(RuntimeError): assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2)) - initialize_model_parallel() + Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4) assert(ps.model_parallel_is_initialized()) assert(ps.get_model_parallel_group() is not None) assert(ps.get_tensor_model_parallel_group() is not None) assert(ps.get_pipeline_model_parallel_group() is not None) assert(ps.get_data_parallel_group() is not None) - assert(ps.get_embedding_group() is not None) - assert(ps.get_position_embedding_group() is not None) - ps.destroy_model_parallel() + Utils.destroy_model_parallel() + assert(ps._MODEL_PARALLEL_GROUP is None) def test_pipeline_parallel_initializations(): - initialize_model_parallel(pipeline_model_parallel_size=2) - assert(ps.get_pipeline_model_parallel_first_rank() == 0) + Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4) + assert(ps.get_pipeline_model_parallel_first_rank() == rank % 2 ) assert(ps.get_data_parallel_src_rank() == rank) - assert(ps.get_pipeline_model_parallel_next_rank() == 0 if rank == world_size - 1 else rank + 1) - assert(ps.get_pipeline_model_parallel_prev_rank() == rank - 1 if rank > 0 else world_size - 1) - ps.destroy_model_parallel() - + assert(ps.get_pipeline_model_parallel_next_rank() == ((rank + 2) % world_size)) + assert(ps.get_pipeline_model_parallel_prev_rank() == ((rank - 2) % world_size)) + Utils.destroy_model_parallel() + def test_data_parallel_initializations(): - initialize_model_parallel(pipeline_model_parallel_size=world_size) + Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size) assert(ps.get_data_parallel_src_rank() == rank) - assert(ps.get_data_parallel_world_size() == world_size-1) + assert(ps.get_data_parallel_world_size() == 1) assert(ps.get_data_parallel_rank() == 0) - ps.destroy_model_parallel() + Utils.destroy_model_parallel() + def test_tensor_model_parellel_world_size(): - initialize_model_parallel(tensor_model_parallel_size=world_size) + Utils.initialize_model_parallel(tensor_model_parallel_size=world_size) assert(ps.get_tensor_model_parallel_world_size() == world_size) ps.set_tensor_model_parallel_world_size(None) assert(ps.get_tensor_model_parallel_world_size() == world_size) - ps.destroy_model_parallel() - + Utils.destroy_model_parallel() + def test_pipeline_model_parallel_world_size(): - initialize_model_parallel(pipeline_model_parallel_size=world_size) + Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size) assert(ps.get_pipeline_model_parallel_world_size() == world_size) ps.set_pipeline_model_parallel_world_size(None) assert(ps.get_pipeline_model_parallel_world_size() == world_size) - ps.destroy_model_parallel() - + Utils.destroy_model_parallel() + def test_tensor_model_parallel_rank(): - initialize_model_parallel(tensor_model_parallel_size=world_size) + Utils.initialize_model_parallel(tensor_model_parallel_size=world_size) assert(ps.get_tensor_model_parallel_rank() == rank) ps.set_tensor_model_parallel_rank(None) assert(ps.get_tensor_model_parallel_rank() == rank) - ps.destroy_model_parallel() + Utils.destroy_model_parallel() + def test_pipeline_model_parallel_rank(): - initialize_model_parallel(pipeline_model_parallel_size=world_size) + Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size) assert(ps.get_pipeline_model_parallel_rank() == rank) ps.set_pipeline_model_parallel_rank(None) assert(ps.get_pipeline_model_parallel_rank() == rank) - ps.destroy_model_parallel() + Utils.destroy_model_parallel() + def test_is_pipeline_first_stage(): - initialize_model_parallel(pipeline_model_parallel_size=world_size) + Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size) assert(ps.is_pipeline_first_stage(ignore_virtual=True) == (rank == 0)) assert(ps.is_pipeline_first_stage() == (rank == 0)) - ps.destroy_model_parallel() + Utils.destroy_model_parallel() + def test_is_pipeline_last_stage(): - initialize_model_parallel(pipeline_model_parallel_size=world_size) + Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size) assert(ps.is_pipeline_last_stage(ignore_virtual=True) == (rank == world_size-1)) assert(ps.is_pipeline_last_stage() == (rank == world_size-1)) - ps.destroy_model_parallel() - + Utils.destroy_model_parallel() + def test_virtual_pipeline_model_parallel_rank(): - initialize_model_parallel(pipeline_model_parallel_size=world_size) + Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size) ps.set_virtual_pipeline_model_parallel_rank(rank) assert(ps.get_virtual_pipeline_model_parallel_rank() == rank) - ps.destroy_model_parallel() + Utils.destroy_model_parallel() + def test_get_tensor_model_parallel_src_rank(): - initialize_model_parallel(tensor_model_parallel_size=world_size) + Utils.initialize_model_parallel(tensor_model_parallel_size=world_size) assert(ps.get_tensor_model_parallel_src_rank() == ((rank // world_size) * world_size)) - ps.destroy_model_parallel() - -""" -def test_get_virtual_pipeline_model_parallel_world_size(): - initialize_model_parallel(pipeline_model_parallel_size=world_size) - ps.set_virtual_pipeline_model_parallel_rank(world_size) - assert(ps.get_virtual_pipeline_model_parallel_world_size() == world_size) - ps.destroy_model_parallel() - -def test_is_rank_in_embedding_group(): - assert(ps.is_rank_in_embedding_group(ignore_virtual=True) == (rank in ps._EMBEDDING_GLOBAL_RANKS)) - if rank in ps._EMBEDDING_GLOBAL_RANKS: - assert(ps.is_rank_in_embedding_group() == ps.is_pipeline_first_stage()) - elif rank == _EMBEDDING_GLOBAL_RANKS[-1]: - assert(ps.is_rank_in_embedding_group() == ps.is_pipeline_last_stage()) - else: - assert(ps.is_rank_in_embedding_group()) - -def test_is_rank_in_position_embedding_group(): - assert(ps.is_rank_in_position_embedding_group() == (rank in ps._POSITION_EMBEDDING_GLOBAL_RANKS)) - -def test_is_pipeline_stage_before_split(): - if world_size == 1: - assert(ps.is_pipeline_stage_before_split()) - # TODO: Changes here for more than one world size - assert(ps.is_pipeline_stage_before_split()) - -def test_is_pipeline_stage_after_split(): - if world_size == 1: - assert(ps.is_pipeline_stage_after_split()) - # TODO: Changes here for more than one world size - assert(ps.is_pipeline_stage_before_split()) - -def test_is_pipeline_stage_at_split(): - assert( - ps.is_pipeline_stage_at_split() == - (ps.is_pipeline_stage_before_split(rank) and ps.is_pipeline_stage_after_split(rank+1)) - ) - -def test_destroy_model_parallel(): - ps.destroy_model_parallel() - assert(ps._MODEL_PARALLEL_GROUP is None) -""" \ No newline at end of file + Utils.destroy_model_parallel() \ No newline at end of file diff --git a/tests/test_utilities.py b/tests/test_utilities.py new file mode 100644 index 0000000000..b35c77b58d --- /dev/null +++ b/tests/test_utilities.py @@ -0,0 +1,30 @@ +import os +import torch +import megatron.core.parallel_state as ps + +class Utils: + + world_size = torch.cuda.device_count() + rank = int(os.environ['LOCAL_RANK']) + + @staticmethod + def initialize_distributed(): + print(f'Initializing torch.distributed with rank: {Utils.rank}, world_size: {Utils.world_size}') + torch.cuda.set_device(Utils.rank % torch.cuda.device_count()) + init_method = 'tcp://' + master_ip = os.getenv('MASTER_ADDR', 'localhost') + master_port = os.getenv('MASTER_PORT', '6000') + init_method += master_ip + ':' + master_port + torch.distributed.init_process_group(backend='nccl', world_size=Utils.world_size, rank=Utils.rank, init_method=init_method) + + @staticmethod + def destroy_model_parallel(): + ps.destroy_model_parallel() + torch.distributed.barrier() + + @staticmethod + def initialize_model_parallel(tensor_model_parallel_size = 1, pipeline_model_parallel_size = 1, virtual_pipeline_model_parallel_size = None, pipeline_model_parallel_split_rank = None): + ps.destroy_model_parallel() + if not torch.distributed.is_initialized(): + Utils.initialize_distributed() + ps.initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank) \ No newline at end of file