forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f861467
commit 8b94a16
Showing
13 changed files
with
330 additions
and
108 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,5 @@ | ||
[html] | ||
directory = coverage | ||
directory = coverage | ||
|
||
[run] | ||
data_file = .coverage_$LOCAL_RANK |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
__pycache__ | ||
*.so | ||
build | ||
.coverage_* | ||
*.egg-info |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from megatron.core.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy | ||
import torch | ||
from tests.test_utilities import Utils | ||
import numpy as np | ||
|
||
def test_vocab_parallel_cross_entropy(): | ||
Utils.initialize_model_parallel(4,2) | ||
vocab_parallel_logits = torch.range(0,7).repeat(16,4).cuda() | ||
target = torch.arange(0,32,2).cuda() | ||
output = vocab_parallel_cross_entropy(vocab_parallel_logits, target) | ||
expected_output = torch.tensor([10.2309, 8.2309, 6.2309, 4.2309, 10.2309, 8.2309, 6.2309, 4.2309, | ||
10.2309, 8.2309, 6.2309, 4.2309, 10.2309, 8.2309, 6.2309, 4.2309]).cuda() | ||
assert(torch.equal(torch.round(expected_output), torch.round(output))) | ||
Utils.destroy_model_parallel() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from megatron.core.tensor_parallel.data import broadcast_data | ||
import torch | ||
from tests.test_utilities import Utils | ||
|
||
def test_broadcast_data(): | ||
Utils.initialize_model_parallel(2,4) | ||
input_data = { | ||
0 : torch.ones((8,8)).cuda() * 0.0, | ||
1 : torch.ones((8,8)).cuda() * 1.0, | ||
2 : torch.ones((8,8)).cuda() * 2.0, | ||
3 : torch.ones((8,8)).cuda() * 3.0, | ||
4 : torch.ones((8,8)).cuda() * 4.0, | ||
5 : torch.ones((8,8)).cuda() * 5.0, | ||
6 : torch.ones((8,8)).cuda() * 6.0, | ||
7 : torch.ones((8,8)).cuda() * 7.0 | ||
} | ||
dtype = torch.float32 | ||
actual_output = broadcast_data([0,1],input_data, dtype) | ||
assert(torch.equal(actual_output[0], input_data[0])) | ||
assert(torch.equal(actual_output[1], input_data[1])) | ||
Utils.destroy_model_parallel() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
from megatron.core.tensor_parallel import mappings | ||
from tests.test_utilities import Utils | ||
import torch | ||
|
||
def test_CopyToModelParallelRegion(): | ||
Utils.initialize_model_parallel(4,2) | ||
input_data = torch.ones((1)).cuda()*Utils.rank | ||
output_data = mappings._CopyToModelParallelRegion.backward(None, input_data) | ||
result = torch.ones(1).cuda() | ||
result = result * 22 if Utils.rank >= 4 else result * 6 | ||
assert(torch.equal(output_data, result)) | ||
assert(torch.equal(input_data, mappings.copy_to_tensor_model_parallel_region(input_data))) | ||
assert(torch.equal(input_data, mappings._CopyToModelParallelRegion.symbolic(None, input_data))) | ||
Utils.destroy_model_parallel() | ||
|
||
def test_ReduceFromModelParallelRegion(): | ||
Utils.initialize_model_parallel(4,2) | ||
input_data = torch.ones((1)).cuda()*Utils.rank | ||
output_data = mappings._ReduceFromModelParallelRegion.symbolic(None, input_data) | ||
result = torch.ones(1).cuda() | ||
result = result * 22 if Utils.rank >= 4 else result * 6 | ||
assert(torch.equal(output_data, result)) | ||
input_data = torch.ones((1)).cuda()*Utils.rank | ||
assert(torch.equal(mappings.reduce_from_tensor_model_parallel_region(input_data), result)) | ||
assert(torch.equal(input_data, mappings._ReduceFromModelParallelRegion.backward(None, input_data))) | ||
Utils.destroy_model_parallel() | ||
|
||
def test_ScatterToModelParallelRegion(): | ||
Utils.initialize_model_parallel(4,2) | ||
input_data = torch.rand((8,4)).cuda() | ||
output_data = mappings.scatter_to_tensor_model_parallel_region(input_data) | ||
req_dim = int(Utils.rank%(Utils.world_size/2)) | ||
assert(torch.equal(output_data, input_data[:,req_dim].reshape((8,1)))) | ||
output_data = mappings._ScatterToModelParallelRegion.symbolic(None, input_data) | ||
assert(torch.equal(output_data, input_data[:, req_dim].reshape((8,1)))) | ||
|
||
input_data = torch.ones(8).cuda() * Utils.rank | ||
actual_output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data) | ||
expected_output = torch.cat(( | ||
torch.ones(8)*0, | ||
torch.ones(8)*1, | ||
torch.ones(8)*2, | ||
torch.ones(8)*3)).cuda() | ||
if (Utils.rank >= 4): | ||
expected_output = expected_output + 4 | ||
assert(torch.equal(actual_output_data, expected_output)) | ||
Utils.destroy_model_parallel() | ||
|
||
def test_GatherFromModelParallelRegion(): | ||
Utils.initialize_model_parallel(4,2) | ||
input_data = torch.rand((8,4)).cuda() | ||
req_dim = int(Utils.rank%(Utils.world_size/2)) | ||
output_data = mappings._GatherFromModelParallelRegion.backward(None, input_data) | ||
assert(torch.equal(output_data, input_data[:, req_dim].reshape((8,1)))) | ||
input_data = torch.ones(8).cuda() * Utils.rank | ||
actual_output_data = mappings.gather_from_tensor_model_parallel_region(input_data) | ||
expected_output = torch.cat(( | ||
torch.ones(8)*0, | ||
torch.ones(8)*1, | ||
torch.ones(8)*2, | ||
torch.ones(8)*3)).cuda() | ||
if (Utils.rank >= 4): | ||
expected_output = expected_output + 4 | ||
assert(torch.equal(actual_output_data, expected_output)) | ||
assert(torch.equal(mappings._GatherFromModelParallelRegion.symbolic(None, input_data), expected_output)) | ||
Utils.destroy_model_parallel() | ||
|
||
def test_ScatterToSequenceParallelRegion(): | ||
Utils.initialize_model_parallel(4,2) | ||
input_data = torch.rand((8,4)).cuda() | ||
req_dim = int(Utils.rank%(Utils.world_size/2))*2 | ||
output_data = mappings._ScatterToSequenceParallelRegion.symbolic(None, input_data) | ||
assert(torch.equal(output_data, input_data[req_dim:req_dim+2, :])) | ||
output_data = mappings.scatter_to_sequence_parallel_region(input_data) | ||
assert(torch.equal(output_data, input_data[req_dim:req_dim+2, :])) | ||
input_data = torch.ones(4).cuda() * Utils.rank | ||
output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data) | ||
expected_output = torch.concat(( | ||
torch.ones(4)*0, | ||
torch.ones(4)*1, | ||
torch.ones(4)*2, | ||
torch.ones(4)*3)).cuda() | ||
if (Utils.rank >= 4): | ||
expected_output = expected_output + 4 | ||
assert(torch.equal(output_data, expected_output)) | ||
Utils.destroy_model_parallel() | ||
|
||
def test_GatherFromSequenceParallelRegion(): | ||
Utils.initialize_model_parallel(4,2) | ||
input_data = torch.ones(4).cuda() * Utils.rank | ||
output_data = mappings.gather_from_sequence_parallel_region(input_data) | ||
expected_output = torch.concat(( | ||
torch.ones(4)*0, | ||
torch.ones(4)*1, | ||
torch.ones(4)*2, | ||
torch.ones(4)*3)).cuda() | ||
if (Utils.rank >= 4): | ||
expected_output = expected_output + 4 | ||
assert(torch.equal(output_data, expected_output)) | ||
assert(torch.equal(mappings._GatherFromSequenceParallelRegion.symbolic(None, input_data), expected_output)) | ||
input_data = torch.vstack(( | ||
torch.ones(4)*0, | ||
torch.ones(4)*1, | ||
torch.ones(4)*2, | ||
torch.ones(4)*3)).cuda() | ||
class Ctx: | ||
tensor_parallel_output_grad = True | ||
output_data = mappings._GatherFromSequenceParallelRegion.backward(Ctx(), input_data) | ||
expected_output = torch.ones((1,4)).cuda() * 4 * int(Utils.rank % 4) | ||
assert(torch.equal(output_data[0], expected_output)) | ||
Utils.destroy_model_parallel() | ||
|
||
def test_ReduceScatterToSequenceParallelRegion(): | ||
Utils.initialize_model_parallel(4,2) | ||
input_data = torch.vstack(( | ||
torch.ones(4)*0, | ||
torch.ones(4)*1, | ||
torch.ones(4)*2, | ||
torch.ones(4)*3)).cuda() | ||
output_data = mappings.reduce_scatter_to_sequence_parallel_region(input_data) | ||
expected_output = torch.ones(4).cuda() * 4 * int(Utils.rank % 4) | ||
assert(torch.equal(output_data[0], expected_output)) | ||
assert(torch.equal(mappings._ReduceScatterToSequenceParallelRegion.symbolic(None, input_data) , expected_output.reshape((1,4)))) | ||
input_data = torch.ones(4).cuda() * Utils.rank | ||
output_data = mappings._ReduceScatterToSequenceParallelRegion.backward(None,input_data) | ||
expected_output = torch.concat(( | ||
torch.ones(4)*0, | ||
torch.ones(4)*1, | ||
torch.ones(4)*2, | ||
torch.ones(4)*3)).cuda() | ||
if (Utils.rank >= 4): | ||
expected_output = expected_output + 4 | ||
assert(torch.equal(output_data, expected_output)) | ||
Utils.destroy_model_parallel() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from megatron.core.tensor_parallel.random import CudaRNGStatesTracker | ||
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed | ||
from megatron.core.tensor_parallel.random import _CUDA_RNG_STATE_TRACKER | ||
from megatron.core.tensor_parallel.random import checkpoint | ||
from tests.test_utilities import Utils | ||
import pytest | ||
import torch | ||
|
||
def test_cuda_rng_states_tracker(): | ||
rng_tracker = CudaRNGStatesTracker() | ||
rng_tracker.set_states({"state1":1234}) | ||
assert(rng_tracker.get_states()["state1"] == 1234) | ||
rng_tracker.reset() | ||
assert(rng_tracker.get_states() == {}) | ||
seed = 1111 | ||
rng_tracker.add("state2",seed) | ||
with pytest.raises(Exception): | ||
assert(rng_tracker.add("state3",seed)) | ||
with pytest.raises(Exception): | ||
assert(rng_tracker.add("state2",111)) | ||
assert(rng_tracker.get_states()['state2'] is not None) | ||
with pytest.raises(Exception): | ||
assert() | ||
|
||
rng_tracker.fork("state2") | ||
torch.cuda.manual_seed(seed) | ||
rng_state = torch.cuda.get_rng_state() | ||
assert torch.equal(rng_tracker.get_states()['state2'], rng_state) | ||
|
||
def test_model_parallel_cuda_manual_seed(): | ||
Utils.initialize_model_parallel(4,2) | ||
model_parallel_cuda_manual_seed(0) | ||
assert(_CUDA_RNG_STATE_TRACKER.get_states()['model-parallel-rng'] is not None) | ||
Utils.destroy_model_parallel() | ||
|
||
def test_checkpoint(): | ||
def test_forward(*input): | ||
return input[0]+input[1] | ||
assert(torch.equal(torch.ones(16)*3,checkpoint(test_forward, None, torch.ones(16), torch.ones(16)*2))) | ||
Utils.initialize_model_parallel() | ||
input1 = torch.ones((4,4)) | ||
checkpoint(test_forward, True, input1, torch.ones((4,4))*2) | ||
assert(torch.equal(torch.ones(input1.numel()).cuda(), input1)) | ||
Utils.destroy_model_parallel() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,43 @@ | ||
import torch | ||
import megatron.core.tensor_parallel.utils as util | ||
import megatron.core.parallel_state as ps | ||
from tests.test_utilities import Utils | ||
|
||
rank = Utils.rank | ||
|
||
def test_split_tensor_along_last_dim(): | ||
input_tensor = torch.rand((3,4)) | ||
torch.equal(input_tensor[0:2,0:2], util.split_tensor_along_last_dim(input_tensor,2)[0]) | ||
torch.equal(input_tensor[2:,2:], util.split_tensor_along_last_dim(input_tensor,2)[1]) | ||
|
||
def test_split_tensor_into_1d_equal_chunks(): | ||
Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4) | ||
input_tensor = torch.rand((3,4)) | ||
output_tensor = util.split_tensor_into_1d_equal_chunks(input_tensor) | ||
if rank % 2 == 0 : | ||
start = 0 | ||
end = int(input_tensor.numel()/2) | ||
else : | ||
start = int(input_tensor.numel()/2) | ||
end = input_tensor.numel() | ||
|
||
assert torch.equal(output_tensor, input_tensor.flatten()[start:end]) | ||
Utils.destroy_model_parallel() | ||
|
||
def test_gather_split_1d_tensor(): | ||
Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4) | ||
input_tensor = torch.ones((2,4)).cuda() * rank | ||
actual_output_tensor = util.gather_split_1d_tensor(input_tensor) | ||
if rank %2 == 0: | ||
expected_output_tensor = torch.concat((input_tensor.flatten(), input_tensor.flatten() + 1)) | ||
else : | ||
expected_output_tensor = torch.concat((input_tensor.flatten() - 1, input_tensor.flatten())) | ||
assert(torch.equal(actual_output_tensor, expected_output_tensor)) | ||
Utils.destroy_model_parallel() | ||
|
||
def test_vocab(): | ||
global_vocab_size = 1600 | ||
per_partition_vocab_size = 1600 / Utils.world_size | ||
assert((rank * per_partition_vocab_size, (rank + 1)* per_partition_vocab_size) == (util.VocabUtility.vocab_range_from_per_partition_vocab_size(global_vocab_size // Utils.world_size, rank, Utils.world_size))) | ||
assert((rank * per_partition_vocab_size, (rank + 1)* per_partition_vocab_size) == (util.VocabUtility.vocab_range_from_global_vocab_size(global_vocab_size, rank, Utils.world_size))) | ||
|
Oops, something went wrong.