Skip to content

Commit

Permalink
Adding proper test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
shanmugamr1992 authored and jaredcasper committed Oct 13, 2022
1 parent f861467 commit 8b94a16
Show file tree
Hide file tree
Showing 13 changed files with 330 additions and 108 deletions.
Binary file removed .coverage
Binary file not shown.
5 changes: 4 additions & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
[html]
directory = coverage
directory = coverage

[run]
data_file = .coverage_$LOCAL_RANK
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
__pycache__
*.so
build
.coverage_*
*.egg-info
3 changes: 1 addition & 2 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ test:
tags:
- docker_gpu_enabled
script:
- nvidia-smi
- torchrun --nproc_per_node=2 -m pytest --cov-report=term --cov-report=html --cov=megatron/core tests/
- torchrun --nproc_per_node=8 -m pytest --cov-report=term --cov-report=html --cov=megatron/core tests/
coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/'
artifacts:
paths:
Expand Down
2 changes: 2 additions & 0 deletions megatron/core/tensor_parallel/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
gather_split_1d_tensor,
)

from megatron.core.utils import safely_set_viewless_tensor_data

# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'

Expand Down
Empty file added tests/__init__.py
Empty file.
14 changes: 14 additions & 0 deletions tests/tensor_parallel/test_cross_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from megatron.core.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy
import torch
from tests.test_utilities import Utils
import numpy as np

def test_vocab_parallel_cross_entropy():
Utils.initialize_model_parallel(4,2)
vocab_parallel_logits = torch.range(0,7).repeat(16,4).cuda()
target = torch.arange(0,32,2).cuda()
output = vocab_parallel_cross_entropy(vocab_parallel_logits, target)
expected_output = torch.tensor([10.2309, 8.2309, 6.2309, 4.2309, 10.2309, 8.2309, 6.2309, 4.2309,
10.2309, 8.2309, 6.2309, 4.2309, 10.2309, 8.2309, 6.2309, 4.2309]).cuda()
assert(torch.equal(torch.round(expected_output), torch.round(output)))
Utils.destroy_model_parallel()
21 changes: 21 additions & 0 deletions tests/tensor_parallel/test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from megatron.core.tensor_parallel.data import broadcast_data
import torch
from tests.test_utilities import Utils

def test_broadcast_data():
Utils.initialize_model_parallel(2,4)
input_data = {
0 : torch.ones((8,8)).cuda() * 0.0,
1 : torch.ones((8,8)).cuda() * 1.0,
2 : torch.ones((8,8)).cuda() * 2.0,
3 : torch.ones((8,8)).cuda() * 3.0,
4 : torch.ones((8,8)).cuda() * 4.0,
5 : torch.ones((8,8)).cuda() * 5.0,
6 : torch.ones((8,8)).cuda() * 6.0,
7 : torch.ones((8,8)).cuda() * 7.0
}
dtype = torch.float32
actual_output = broadcast_data([0,1],input_data, dtype)
assert(torch.equal(actual_output[0], input_data[0]))
assert(torch.equal(actual_output[1], input_data[1]))
Utils.destroy_model_parallel()
135 changes: 135 additions & 0 deletions tests/tensor_parallel/test_mappings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from megatron.core.tensor_parallel import mappings
from tests.test_utilities import Utils
import torch

def test_CopyToModelParallelRegion():
Utils.initialize_model_parallel(4,2)
input_data = torch.ones((1)).cuda()*Utils.rank
output_data = mappings._CopyToModelParallelRegion.backward(None, input_data)
result = torch.ones(1).cuda()
result = result * 22 if Utils.rank >= 4 else result * 6
assert(torch.equal(output_data, result))
assert(torch.equal(input_data, mappings.copy_to_tensor_model_parallel_region(input_data)))
assert(torch.equal(input_data, mappings._CopyToModelParallelRegion.symbolic(None, input_data)))
Utils.destroy_model_parallel()

def test_ReduceFromModelParallelRegion():
Utils.initialize_model_parallel(4,2)
input_data = torch.ones((1)).cuda()*Utils.rank
output_data = mappings._ReduceFromModelParallelRegion.symbolic(None, input_data)
result = torch.ones(1).cuda()
result = result * 22 if Utils.rank >= 4 else result * 6
assert(torch.equal(output_data, result))
input_data = torch.ones((1)).cuda()*Utils.rank
assert(torch.equal(mappings.reduce_from_tensor_model_parallel_region(input_data), result))
assert(torch.equal(input_data, mappings._ReduceFromModelParallelRegion.backward(None, input_data)))
Utils.destroy_model_parallel()

def test_ScatterToModelParallelRegion():
Utils.initialize_model_parallel(4,2)
input_data = torch.rand((8,4)).cuda()
output_data = mappings.scatter_to_tensor_model_parallel_region(input_data)
req_dim = int(Utils.rank%(Utils.world_size/2))
assert(torch.equal(output_data, input_data[:,req_dim].reshape((8,1))))
output_data = mappings._ScatterToModelParallelRegion.symbolic(None, input_data)
assert(torch.equal(output_data, input_data[:, req_dim].reshape((8,1))))

input_data = torch.ones(8).cuda() * Utils.rank
actual_output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data)
expected_output = torch.cat((
torch.ones(8)*0,
torch.ones(8)*1,
torch.ones(8)*2,
torch.ones(8)*3)).cuda()
if (Utils.rank >= 4):
expected_output = expected_output + 4
assert(torch.equal(actual_output_data, expected_output))
Utils.destroy_model_parallel()

def test_GatherFromModelParallelRegion():
Utils.initialize_model_parallel(4,2)
input_data = torch.rand((8,4)).cuda()
req_dim = int(Utils.rank%(Utils.world_size/2))
output_data = mappings._GatherFromModelParallelRegion.backward(None, input_data)
assert(torch.equal(output_data, input_data[:, req_dim].reshape((8,1))))
input_data = torch.ones(8).cuda() * Utils.rank
actual_output_data = mappings.gather_from_tensor_model_parallel_region(input_data)
expected_output = torch.cat((
torch.ones(8)*0,
torch.ones(8)*1,
torch.ones(8)*2,
torch.ones(8)*3)).cuda()
if (Utils.rank >= 4):
expected_output = expected_output + 4
assert(torch.equal(actual_output_data, expected_output))
assert(torch.equal(mappings._GatherFromModelParallelRegion.symbolic(None, input_data), expected_output))
Utils.destroy_model_parallel()

def test_ScatterToSequenceParallelRegion():
Utils.initialize_model_parallel(4,2)
input_data = torch.rand((8,4)).cuda()
req_dim = int(Utils.rank%(Utils.world_size/2))*2
output_data = mappings._ScatterToSequenceParallelRegion.symbolic(None, input_data)
assert(torch.equal(output_data, input_data[req_dim:req_dim+2, :]))
output_data = mappings.scatter_to_sequence_parallel_region(input_data)
assert(torch.equal(output_data, input_data[req_dim:req_dim+2, :]))
input_data = torch.ones(4).cuda() * Utils.rank
output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data)
expected_output = torch.concat((
torch.ones(4)*0,
torch.ones(4)*1,
torch.ones(4)*2,
torch.ones(4)*3)).cuda()
if (Utils.rank >= 4):
expected_output = expected_output + 4
assert(torch.equal(output_data, expected_output))
Utils.destroy_model_parallel()

def test_GatherFromSequenceParallelRegion():
Utils.initialize_model_parallel(4,2)
input_data = torch.ones(4).cuda() * Utils.rank
output_data = mappings.gather_from_sequence_parallel_region(input_data)
expected_output = torch.concat((
torch.ones(4)*0,
torch.ones(4)*1,
torch.ones(4)*2,
torch.ones(4)*3)).cuda()
if (Utils.rank >= 4):
expected_output = expected_output + 4
assert(torch.equal(output_data, expected_output))
assert(torch.equal(mappings._GatherFromSequenceParallelRegion.symbolic(None, input_data), expected_output))
input_data = torch.vstack((
torch.ones(4)*0,
torch.ones(4)*1,
torch.ones(4)*2,
torch.ones(4)*3)).cuda()
class Ctx:
tensor_parallel_output_grad = True
output_data = mappings._GatherFromSequenceParallelRegion.backward(Ctx(), input_data)
expected_output = torch.ones((1,4)).cuda() * 4 * int(Utils.rank % 4)
assert(torch.equal(output_data[0], expected_output))
Utils.destroy_model_parallel()

def test_ReduceScatterToSequenceParallelRegion():
Utils.initialize_model_parallel(4,2)
input_data = torch.vstack((
torch.ones(4)*0,
torch.ones(4)*1,
torch.ones(4)*2,
torch.ones(4)*3)).cuda()
output_data = mappings.reduce_scatter_to_sequence_parallel_region(input_data)
expected_output = torch.ones(4).cuda() * 4 * int(Utils.rank % 4)
assert(torch.equal(output_data[0], expected_output))
assert(torch.equal(mappings._ReduceScatterToSequenceParallelRegion.symbolic(None, input_data) , expected_output.reshape((1,4))))
input_data = torch.ones(4).cuda() * Utils.rank
output_data = mappings._ReduceScatterToSequenceParallelRegion.backward(None,input_data)
expected_output = torch.concat((
torch.ones(4)*0,
torch.ones(4)*1,
torch.ones(4)*2,
torch.ones(4)*3)).cuda()
if (Utils.rank >= 4):
expected_output = expected_output + 4
assert(torch.equal(output_data, expected_output))
Utils.destroy_model_parallel()

44 changes: 44 additions & 0 deletions tests/tensor_parallel/test_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from megatron.core.tensor_parallel.random import CudaRNGStatesTracker
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.tensor_parallel.random import _CUDA_RNG_STATE_TRACKER
from megatron.core.tensor_parallel.random import checkpoint
from tests.test_utilities import Utils
import pytest
import torch

def test_cuda_rng_states_tracker():
rng_tracker = CudaRNGStatesTracker()
rng_tracker.set_states({"state1":1234})
assert(rng_tracker.get_states()["state1"] == 1234)
rng_tracker.reset()
assert(rng_tracker.get_states() == {})
seed = 1111
rng_tracker.add("state2",seed)
with pytest.raises(Exception):
assert(rng_tracker.add("state3",seed))
with pytest.raises(Exception):
assert(rng_tracker.add("state2",111))
assert(rng_tracker.get_states()['state2'] is not None)
with pytest.raises(Exception):
assert()

rng_tracker.fork("state2")
torch.cuda.manual_seed(seed)
rng_state = torch.cuda.get_rng_state()
assert torch.equal(rng_tracker.get_states()['state2'], rng_state)

def test_model_parallel_cuda_manual_seed():
Utils.initialize_model_parallel(4,2)
model_parallel_cuda_manual_seed(0)
assert(_CUDA_RNG_STATE_TRACKER.get_states()['model-parallel-rng'] is not None)
Utils.destroy_model_parallel()

def test_checkpoint():
def test_forward(*input):
return input[0]+input[1]
assert(torch.equal(torch.ones(16)*3,checkpoint(test_forward, None, torch.ones(16), torch.ones(16)*2)))
Utils.initialize_model_parallel()
input1 = torch.ones((4,4))
checkpoint(test_forward, True, input1, torch.ones((4,4))*2)
assert(torch.equal(torch.ones(input1.numel()).cuda(), input1))
Utils.destroy_model_parallel()
36 changes: 36 additions & 0 deletions tests/tensor_parallel/test_tensor_parallel_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,43 @@
import torch
import megatron.core.tensor_parallel.utils as util
import megatron.core.parallel_state as ps
from tests.test_utilities import Utils

rank = Utils.rank

def test_split_tensor_along_last_dim():
input_tensor = torch.rand((3,4))
torch.equal(input_tensor[0:2,0:2], util.split_tensor_along_last_dim(input_tensor,2)[0])
torch.equal(input_tensor[2:,2:], util.split_tensor_along_last_dim(input_tensor,2)[1])

def test_split_tensor_into_1d_equal_chunks():
Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)
input_tensor = torch.rand((3,4))
output_tensor = util.split_tensor_into_1d_equal_chunks(input_tensor)
if rank % 2 == 0 :
start = 0
end = int(input_tensor.numel()/2)
else :
start = int(input_tensor.numel()/2)
end = input_tensor.numel()

assert torch.equal(output_tensor, input_tensor.flatten()[start:end])
Utils.destroy_model_parallel()

def test_gather_split_1d_tensor():
Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)
input_tensor = torch.ones((2,4)).cuda() * rank
actual_output_tensor = util.gather_split_1d_tensor(input_tensor)
if rank %2 == 0:
expected_output_tensor = torch.concat((input_tensor.flatten(), input_tensor.flatten() + 1))
else :
expected_output_tensor = torch.concat((input_tensor.flatten() - 1, input_tensor.flatten()))
assert(torch.equal(actual_output_tensor, expected_output_tensor))
Utils.destroy_model_parallel()

def test_vocab():
global_vocab_size = 1600
per_partition_vocab_size = 1600 / Utils.world_size
assert((rank * per_partition_vocab_size, (rank + 1)* per_partition_vocab_size) == (util.VocabUtility.vocab_range_from_per_partition_vocab_size(global_vocab_size // Utils.world_size, rank, Utils.world_size)))
assert((rank * per_partition_vocab_size, (rank + 1)* per_partition_vocab_size) == (util.VocabUtility.vocab_range_from_global_vocab_size(global_vocab_size, rank, Utils.world_size)))

Loading

0 comments on commit 8b94a16

Please sign in to comment.