From 9b42cdb16c32d18c2116d589e2936e6398f247dd Mon Sep 17 00:00:00 2001 From: Lev Kurilenko <113481193+lekurile@users.noreply.github.com> Date: Fri, 28 Jul 2023 15:24:59 -0700 Subject: [PATCH] Fix Megatron Text Generation Issues (#188) * MPU import, local_rank, and config fixes * Change _transpose_first_dim function to source attn module using self_attention * Add p2p communication file * Update generate_text.sh to export CUDA_DEVICE_MAX_CONNECTIONS, disable DS inference for now * Update logits indexing in generation code to account for MoE * Properly provide config arg to GPTModel in text gen server code * Update all GPTModel initialization to use config * Explicit config argument assignment * Import mpu from megatron.core instead of megatron --- examples/detxoify_lm/finetune_gpt.py | 4 + examples/detxoify_lm/generate_samples_gpt.py | 5 +- examples_deepspeed/generate_text.sh | 12 +- megatron/checkpointing.py | 4 +- megatron/p2p_communication.py | 264 ++++++++++++ megatron/text_generation/generation.py | 35 +- megatron/text_generation_utils.py | 2 +- pretrain_gpt.py | 4 +- tasks/eval_harness/evaluate.py | 4 +- tasks/msdp/prompt.py | 24 +- tasks/zeroshot_gpt/evaluate.py | 423 ++++++++++--------- tests/run_megatron.py | 7 +- tools/generate_samples_gpt.py | 12 +- tools/run_text_generation_server.py | 5 +- 14 files changed, 551 insertions(+), 254 deletions(-) create mode 100644 megatron/p2p_communication.py diff --git a/examples/detxoify_lm/finetune_gpt.py b/examples/detxoify_lm/finetune_gpt.py index 70b781e0..70a91482 100644 --- a/examples/detxoify_lm/finetune_gpt.py +++ b/examples/detxoify_lm/finetune_gpt.py @@ -18,6 +18,7 @@ from megatron.data.blendable_dataset import BlendableDataset from megatron.data.gpt_dataset import build_train_valid_test_datasets from megatron.model import GPTModel +from megatron.arguments import core_transformer_config_from_args from megatron.core.enums import ModelType from megatron.training import pretrain from megatron.utils import get_ltor_masks_and_position_ids @@ -26,8 +27,11 @@ def model_provider(pre_process=True, post_process=True): """Build the model.""" + config = core_transformer_config_from_args(args) + print_rank_0('building GPT model ...') model = GPTModel( + config=config, num_tokentypes=0, parallel_output=True, pre_process=pre_process, diff --git a/examples/detxoify_lm/generate_samples_gpt.py b/examples/detxoify_lm/generate_samples_gpt.py index 47e1590e..c3c093a4 100644 --- a/examples/detxoify_lm/generate_samples_gpt.py +++ b/examples/detxoify_lm/generate_samples_gpt.py @@ -17,14 +17,17 @@ from megatron.initialize import initialize_megatron from megatron.model import GPTModel from megatron.training import get_model +from megatron.arguments import core_transformer_config_from_args from megatron.text_generation import generate_and_post_process def model_provider(pre_process=True, post_process=True): """Build the model.""" + config = core_transformer_config_from_args(args) + print_rank_0('building GPT model ...') - model = GPTModel(num_tokentypes=0, parallel_output=False, + model = GPTModel(config=config, num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process) return model diff --git a/examples_deepspeed/generate_text.sh b/examples_deepspeed/generate_text.sh index 7e330ab8..e29d521e 100755 --- a/examples_deepspeed/generate_text.sh +++ b/examples_deepspeed/generate_text.sh @@ -1,8 +1,8 @@ #!/bin/bash export TORCH_CUDA_ARCH_LIST=8.6+PTX -CHECKPOINT_PATH=checkpoints/gpt2_345m -VOCAB_FILE=gpt2-vocab.json -MERGE_FILE=gpt2-merges.txt +CHECKPOINT_PATH=dataset/checkpoints/gpt2_345m +VOCAB_FILE=dataset/gpt2-vocab.json +MERGE_FILE=dataset/gpt2-merges.txt b=8 mp=1 experts=1 @@ -14,8 +14,10 @@ use_tutel="" #use_tutel="--use-tutel" -#ds_inference="" -ds_inference="--ds-inference" +ds_inference="" +#ds_inference="--ds-inference" + +export CUDA_DEVICE_MAX_CONNECTIONS=1 launch_cmd="deepspeed --num_nodes $nodes --num_gpus $gpus" L=24 diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 49bc84a8..5070bdd5 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -320,8 +320,8 @@ def _transpose_first_dim(t, num_splits, num_splits_first, model): # specific to self attention so should work for cross attention as well while hasattr(model, 'module'): model = model.module - #attention_module = model.language_model.encoder.layers[0].self_attention - attention_module = model.language_model.encoder.layers[0].attention + attention_module = model.language_model.encoder.layers[0].self_attention + #attention_module = model.language_model.encoder.layers[0].attention hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition if num_splits_first: diff --git a/megatron/p2p_communication.py b/megatron/p2p_communication.py new file mode 100644 index 00000000..770060a8 --- /dev/null +++ b/megatron/p2p_communication.py @@ -0,0 +1,264 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import reduce +import operator +import torch +from deepspeed.accelerator import get_accelerator +from megatron import get_args +from megatron.core import mpu + + +def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, + use_ring_exchange=False): + """Communicate tensors between stages. Used as helper method in other + communication methods that are used in megatron/schedules.py. + + Takes the following arguments: + tensor_send_next: tensor to send to next rank (no tensor sent if + set to None). + tensor_send_prev: tensor to send to prev rank (no tensor sent if + set to None). + recv_prev: boolean for whether tensor should be received from + previous rank. + recv_next: boolean for whether tensor should be received from + next rank. + use_ring_exchange: boolean for whether torch.distributed.ring_exchange() + API should be used. + + Returns: + (tensor_recv_prev, tensor_recv_next) + """ + args = get_args() + + # Create placeholder tensors for receive in forward and backward directions + # if needed. + tensor_recv_prev = None + tensor_recv_next = None + tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) + if args.scatter_gather_tensors_in_pipeline: + tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \ + mpu.get_tensor_model_parallel_world_size() + else: + tensor_chunk_shape = tensor_shape + dtype = args.params_dtype + if args.fp32_residual_connection: + dtype = torch.float + if recv_prev: + tensor_recv_prev = torch.empty(tensor_chunk_shape, + requires_grad=True, + device=get_accelerator().current_device_name(), + dtype=dtype) + if recv_next: + tensor_recv_next = torch.empty(tensor_chunk_shape, + requires_grad=True, + device=get_accelerator().current_device_name(), + dtype=dtype) + + # Split tensor into smaller chunks if using scatter-gather optimization. + if args.scatter_gather_tensors_in_pipeline: + if tensor_send_next is not None: + tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next) + + if tensor_send_prev is not None: + tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev) + + # Send tensors in both the forward and backward directions as appropriate. + if use_ring_exchange: + torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev, + tensor_recv_prev=tensor_recv_prev, + tensor_send_next=tensor_send_next, + tensor_recv_next=tensor_recv_next, + group=mpu.get_pipeline_model_parallel_group()) + else: + ops = [] + if tensor_send_prev is not None: + send_prev_op = torch.distributed.P2POp( + torch.distributed.isend, tensor_send_prev, + mpu.get_pipeline_model_parallel_prev_rank()) + ops.append(send_prev_op) + if tensor_recv_prev is not None: + recv_prev_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor_recv_prev, + mpu.get_pipeline_model_parallel_prev_rank()) + ops.append(recv_prev_op) + if tensor_send_next is not None: + send_next_op = torch.distributed.P2POp( + torch.distributed.isend, tensor_send_next, + mpu.get_pipeline_model_parallel_next_rank()) + ops.append(send_next_op) + if tensor_recv_next is not None: + recv_next_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor_recv_next, + mpu.get_pipeline_model_parallel_next_rank()) + ops.append(recv_next_op) + if len(ops) > 0: + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + # To protect against race condition when using batch_isend_irecv(). + get_accelerator().synchronize() + + # If using scatter-gather optimization, gather smaller chunks. + if args.scatter_gather_tensors_in_pipeline: + if recv_prev: + tensor_recv_prev = mpu.gather_split_1d_tensor( + tensor_recv_prev).view(tensor_shape).requires_grad_() + + if recv_next: + tensor_recv_next = mpu.gather_split_1d_tensor( + tensor_recv_next).view(tensor_shape).requires_grad_() + + return tensor_recv_prev, tensor_recv_next + + +def recv_forward(timers=None): + """Receive tensor from previous rank in pipeline (forward receive).""" + if mpu.is_pipeline_first_stage(): + input_tensor = None + else: + if timers is not None: + timers('forward-recv').start() + input_tensor, _ = _communicate( + tensor_send_next=None, + tensor_send_prev=None, + recv_prev=True, + recv_next=False) + if timers is not None: + timers('forward-recv').stop() + return input_tensor + + +def recv_backward(timers=None): + """Receive tensor from next rank in pipeline (backward receive).""" + if mpu.is_pipeline_last_stage(): + output_tensor_grad = None + else: + if timers is not None: + timers('backward-recv').start() + _, output_tensor_grad = _communicate( + tensor_send_next=None, + tensor_send_prev=None, + recv_prev=False, + recv_next=True) + if timers is not None: + timers('backward-recv').stop() + return output_tensor_grad + + +def send_forward(output_tensor, timers=None): + """Send tensor to next rank in pipeline (forward send).""" + if not mpu.is_pipeline_last_stage(): + if timers is not None: + timers('forward-send').start() + _communicate( + tensor_send_next=output_tensor, + tensor_send_prev=None, + recv_prev=False, + recv_next=False) + if timers is not None: + timers('forward-send').stop() + + +def send_backward(input_tensor_grad, timers=None): + """Send tensor to previous rank in pipeline (backward send).""" + if not mpu.is_pipeline_first_stage(): + if timers is not None: + timers('backward-send').start() + _communicate( + tensor_send_next=None, + tensor_send_prev=input_tensor_grad, + recv_prev=False, + recv_next=False) + if timers is not None: + timers('backward-send').stop() + + +def send_forward_recv_backward(output_tensor, timers=None): + """Batched send and recv with next rank in pipeline.""" + if mpu.is_pipeline_last_stage(): + output_tensor_grad = None + else: + if timers is not None: + timers('forward-send-backward-recv').start() + _, output_tensor_grad = _communicate( + tensor_send_next=output_tensor, + tensor_send_prev=None, + recv_prev=False, + recv_next=True) + if timers is not None: + timers('forward-send-backward-recv').stop() + return output_tensor_grad + + +def send_backward_recv_forward(input_tensor_grad, timers=None): + """Batched send and recv with previous rank in pipeline.""" + if mpu.is_pipeline_first_stage(): + input_tensor = None + else: + if timers is not None: + timers('backward-send-forward-recv').start() + input_tensor, _ = _communicate( + tensor_send_next=None, + tensor_send_prev=input_tensor_grad, + recv_prev=True, + recv_next=False) + if timers is not None: + timers('backward-send-forward-recv').stop() + return input_tensor + + +def send_forward_recv_forward(output_tensor, recv_prev, timers=None): + """Batched recv from previous rank and send to next rank in pipeline.""" + if timers is not None: + timers('forward-send-forward-recv').start() + input_tensor, _ = _communicate( + tensor_send_next=output_tensor, + tensor_send_prev=None, + recv_prev=recv_prev, + recv_next=False) + if timers is not None: + timers('forward-send-forward-recv').stop() + return input_tensor + + +def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None): + """Batched recv from next rank and send to previous rank in pipeline.""" + if timers is not None: + timers('backward-send-backward-recv').start() + _, output_tensor_grad = _communicate( + tensor_send_next=None, + tensor_send_prev=input_tensor_grad, + recv_prev=False, + recv_next=recv_next) + if timers is not None: + timers('backward-send-backward-recv').stop() + return output_tensor_grad + + +def send_forward_backward_recv_forward_backward( + output_tensor, input_tensor_grad, recv_prev, + recv_next, timers=None): + """Batched send and recv with previous and next ranks in pipeline.""" + if timers is not None: + timers('forward-backward-send-forward-backward-recv').start() + input_tensor, output_tensor_grad = _communicate( + tensor_send_next=output_tensor, + tensor_send_prev=input_tensor_grad, + recv_prev=recv_prev, + recv_next=recv_next) + if timers is not None: + timers('forward-backward-send-forward-backward-recv').stop() + return input_tensor, output_tensor_grad diff --git a/megatron/text_generation/generation.py b/megatron/text_generation/generation.py index 098706ee..79c8b4d1 100644 --- a/megatron/text_generation/generation.py +++ b/megatron/text_generation/generation.py @@ -24,7 +24,7 @@ def score_and_return_on_first_stage(model, tokens, lengths): lengths: original prompt length, size: [b] Note: Outside of model, other parameters only need to be available on rank 0. - Outputs: + Outputs: output_log_probs: log probability of the selected tokens. size: [b, s] """ @@ -33,10 +33,10 @@ def score_and_return_on_first_stage(model, tokens, lengths): batch_size = tokens.size(0) max_prompt_length = lengths.max().item() assert max_prompt_length == tokens.size(1) - + if max_prompt_length > args.max_position_embeddings: raise ValueError("Length of prompt + tokens_to_generate longer than allowed") - + if max_prompt_length * batch_size > args.max_tokens_to_oom: raise ValueError("Too many tokens. " + str(max_prompt_length*batch_size)+ " is greater than "+str(args.max_tokens_to_oom)) @@ -50,18 +50,18 @@ def score_and_return_on_first_stage(model, tokens, lengths): # Log probability of the sequence (prompt + generated tokens). output_log_probs = None output_log_probs_size = (batch_size, max_prompt_length - 1) - + if mpu.is_pipeline_last_stage(): output_log_probs = torch.empty(output_log_probs_size, dtype=torch.float32, device=torch.cuda.current_device()) - + # ============= # Run infernece # ============= with torch.no_grad(): attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens) - + # logits will be meanigful only in the last pipeline stage. logits = forward_step(tokens, position_ids, attention_mask) @@ -69,20 +69,20 @@ def score_and_return_on_first_stage(model, tokens, lengths): # Always the last stage should have an output. assert logits is not None log_probs = F.log_softmax(logits, dim=2) - + # Pick the tokens that we need to get the log # probabilities for. Note that next input token is # the token which we selected in the current logits, # so shift by 1. indices = torch.unsqueeze(tokens[:, 1:], 2) output_log_probs = torch.gather(log_probs, 2, indices).squeeze(2) - + # ====================================== # Broadcast to the first pipeline stage. # ====================================== output_log_probs = broadcast_from_last_to_first_pipeline_stage( output_log_probs_size, torch.float32, output_log_probs) - + return tokens, lengths, output_log_probs def generate_tokens_probs_and_return_on_first_stage( @@ -131,7 +131,7 @@ def generate_tokens_probs_and_return_on_first_stage( if max_sequence_length > args.max_position_embeddings: raise ValueError("Length of prompt + tokens_to_generate longer than allowed") - + if max_sequence_length * batch_size > args.max_tokens_to_oom: raise ValueError("Too many tokens. " + str(max_sequence_length*batch_size)+ " is greater than "+str(args.max_tokens_to_oom)) @@ -162,7 +162,7 @@ def generate_tokens_probs_and_return_on_first_stage( generated_sequence_lengths = torch.ones( batch_size, dtype=torch.int64, device=torch.cuda.current_device()) * max_sequence_length - + # Whether we have reached a termination id. is_generation_done = torch.zeros(batch_size, dtype=torch.uint8, device=torch.cuda.current_device()) @@ -185,6 +185,7 @@ def generate_tokens_probs_and_return_on_first_stage( # logits will be meanigful only in the last pipeline stage. logits = forward_step(tokens2use, positions2use, attention_mask2use) + logits = logits[0] if mpu.is_pipeline_last_stage(): if prevent_newline_after_colon: @@ -248,10 +249,10 @@ def generate_tokens_probs_and_return_on_first_stage( hit_double_eol = (new_sample == 628).byte() & started.byte() hit_eol = (new_sample == 198).byte() & started.byte() done_token = hit_double_eol | hit_eol - else: + else: done_token = (new_sample == termination_id).byte() & \ started.byte() - + just_finished = (done_token & ~is_generation_done).bool() generated_sequence_lengths[just_finished.view(-1)] = \ context_length + 1 @@ -261,7 +262,7 @@ def generate_tokens_probs_and_return_on_first_stage( tensor=done) if use_eod_token_for_early_termination and done: break - + # =================================================== # Update the length of based on max generated length. # =================================================== @@ -293,7 +294,7 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto prompt_length = lengths.item() final_sequence_length = tokens.size(1) final_sequence_length = min(final_sequence_length, args.max_position_embeddings) - + # If the context is too big, this happens if prompt_length >= final_sequence_length: raise ValueError("context length + tokens_to_generate too large") @@ -365,12 +366,12 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length): done = torch.ones(1, dtype=torch.uint8, device=torch.cuda.current_device()) - + best_batches = tokens.new([item[2] for item in next_beams]) tokens = tokens[best_batches,:] tokens[:, context_length] = tokens.new([item[0] for item in next_beams]) scores = scores.new([item[1] for item in next_beams]).unsqueeze(1) - + # torch.distributed.barrier() done = broadcast_from_last_pipeline_stage(1, torch.uint8, done) if done: diff --git a/megatron/text_generation_utils.py b/megatron/text_generation_utils.py index 95c013e4..f5ee09dd 100644 --- a/megatron/text_generation_utils.py +++ b/megatron/text_generation_utils.py @@ -24,7 +24,7 @@ import torch.nn.functional as F from megatron import get_args from megatron import get_tokenizer -from megatron import mpu +from megatron.core import mpu from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model from megatron.p2p_communication import recv_forward, send_forward diff --git a/pretrain_gpt.py b/pretrain_gpt.py index d6389953..e119e7e8 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -43,7 +43,7 @@ def model_provider(pre_process=True, post_process=True): mpu=mpu): if args.deepspeed and not args.no_pipeline_parallel: model = GPTModelPipe( - config, + config=config, num_tokentypes=0, parallel_output=True ) @@ -89,7 +89,7 @@ def model_provider(pre_process=True, post_process=True): else: model = GPTModel( - config, + config=config, num_tokentypes=0, parallel_output=True, pre_process=pre_process, diff --git a/tasks/eval_harness/evaluate.py b/tasks/eval_harness/evaluate.py index 7b692b16..0392bd4f 100644 --- a/tasks/eval_harness/evaluate.py +++ b/tasks/eval_harness/evaluate.py @@ -23,7 +23,7 @@ from megatron import get_args from megatron import print_rank_0 from megatron import get_tokenizer -from megatron import mpu +from megatron.core import mpu from megatron.training import setup_model_and_optimizer, get_model from megatron.mpu.mappings import gather_from_tensor_model_parallel_region @@ -61,7 +61,7 @@ def __init__(self, model, tokenizer): self.is_pipe_parallel = mpu.get_pipeline_model_parallel_world_size() > 1 self.is_data_parallel = mpu.get_data_parallel_world_size() > 1 self.adaptive_seq_len = args.adaptive_seq_len - if self.is_data_parallel and args.moe_expert_parallel_size == 1: # For MoE model, allow a "fake data parallel" in order to partition model into multiple gpus + if self.is_data_parallel and args.moe_expert_parallel_size == 1: # For MoE model, allow a "fake data parallel" in order to partition model into multiple gpus raise NotImplementedError("Data parallelism is currently not supported for evaluation") self.is_last_stage = True if not self.is_pipe_parallel else mpu.is_pipeline_last_stage() # only the last stage of the pipeline model will receive the logits diff --git a/tasks/msdp/prompt.py b/tasks/msdp/prompt.py index a4e777e0..cbfeb4e4 100644 --- a/tasks/msdp/prompt.py +++ b/tasks/msdp/prompt.py @@ -12,6 +12,7 @@ from megatron.core import mpu from megatron.model import GPTModel from megatron.training import get_model +from megatron.arguments import core_transformer_config_from_args from megatron.checkpointing import load_checkpoint from megatron.initialize import initialize_megatron from megatron.text_generation import generate_and_post_process @@ -19,7 +20,7 @@ def call_model_api(inputs, tokens_to_generate): """Calling the model api to get the output generations""" - + args = get_args() # The following is an example of using the Megatron API @@ -32,7 +33,7 @@ def call_model_api(inputs, tokens_to_generate): input_len = len(inputs) outputs = outputs[input_len:] outputs = outputs.split("\n")[0].strip() - + return outputs @@ -48,7 +49,7 @@ def read_prompts(prompt_path, prompt_type, n_example): line = line.strip() line_dict = json.loads(line) key = list(line_dict.keys())[0] - + if key not in prompt_examples_dict: prompt_examples = line_dict[key] prompt = "" @@ -83,7 +84,7 @@ def generate_samples_by_calling_api(): # read knowledge generation prompts knwl_gen_prompt_dict = read_prompts( args.prompt_file, args.prompt_type, args.num_prompt_examples) - + else: resp_gen_prompt = read_prompts( args.prompt_file, args.prompt_type, args.num_prompt_examples) @@ -130,7 +131,7 @@ def generate_samples_by_calling_api(): inputs += "We know that: " + knowledge + " " inputs += "System replies:" - # get the output generations from the api, + # get the output generations from the api, # and write to the output file generations = call_model_api(inputs, args.out_seq_length) fname_out.write(generations) @@ -143,8 +144,11 @@ def generate_samples_by_calling_api(): def model_provider(pre_process=True, post_process=True): """Build the model.""" + config = core_transformer_config_from_args(get_args()) + print_rank_0('building GPT model ...') model = GPTModel( + config=config, num_tokentypes=0, parallel_output=True, pre_process=pre_process, @@ -155,7 +159,7 @@ def model_provider(pre_process=True, post_process=True): def generate_samples_by_prompting_input_from_file(model): """Prompt a pretrained language model to generate knowledge/response""" - + # get tokenizer args = get_args() tokenizer = get_tokenizer() @@ -234,7 +238,7 @@ def generate_samples_by_prompting_input_from_file(model): # construct inputs for knowledge generation # then add the constructed inputs into the raw_text raw_text += "( " + last_turn + " ) " + topic + " =>" - + else: # first add the prompt into the raw_text raw_text = prompt @@ -255,7 +259,7 @@ def generate_samples_by_prompting_input_from_file(model): input_pos += 1 raw_text_len = len(raw_text) - + else: raw_text = "EMPTY TEXT" @@ -263,8 +267,8 @@ def generate_samples_by_prompting_input_from_file(model): print_rank_0("input_pos: %d" % input_pos) outputs = generate_and_post_process( - model=model, - prompts=[raw_text], + model=model, + prompts=[raw_text], tokens_to_generate=args.out_seq_length, top_k_sampling=1) prompts_plus_generations = outputs[0] diff --git a/tasks/zeroshot_gpt/evaluate.py b/tasks/zeroshot_gpt/evaluate.py index 05e8363e..3f136a26 100644 --- a/tasks/zeroshot_gpt/evaluate.py +++ b/tasks/zeroshot_gpt/evaluate.py @@ -1,210 +1,213 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""GPT zero-shot evaluation.""" - -import math - -import torch - -from megatron import get_args -from megatron import print_rank_0, is_last_rank -from megatron import get_tokenizer -from megatron.core import parallel_state, tensor_parallel -from megatron.checkpointing import load_checkpoint -from megatron.model import GPTModel -from megatron.training import get_model -from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model -from megatron.p2p_communication import recv_forward, send_forward -from tasks.finetune_utils import build_data_loader -from deepspeed.accelerator import get_accelerator -from .datasets import build_dataset - -# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible? -from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP -from megatron.model import DistributedDataParallel as LocalDDP -from megatron.model import Float16Module - -def get_model_provider(eval_metric): - """Based on evaluation metric set the parallel-output flag and - return the model provider.""" - - def model_provider(pre_process=True, post_process=True): - """Build the model.""" - - if eval_metric == 'loss': - parallel_output = True - elif eval_metric == 'accuracy': - parallel_output = False - else: - raise NotImplementedError('output type for {} evaluation metric ' - 'is not supported.'.format(eval_metric)) - - print_rank_0('building GPT model ...') - model = GPTModel(num_tokentypes=0, parallel_output=parallel_output, - pre_process=pre_process, post_process=post_process) - - return model - - return model_provider - - -def process_batch(batch): - """Process batch and produce inputs for the model.""" - args = get_args() - tokenizer = get_tokenizer() - - loss_mask = batch['pad_mask'].long().to(get_accelerator().device_name()).contiguous().byte() - tokens_ = batch['text'].long().to(get_accelerator().device_name()).contiguous() - labels = tokens_[:, 1:].contiguous() - tokens = tokens_[:, :-1].contiguous() - - # Get the masks and postition ids. - attention_mask, _, position_ids = get_ltor_masks_and_position_ids( - tokens, - tokenizer.eod, - args.reset_position_ids, - args.reset_attention_mask, - args.eod_mask_loss) - - return tokens, labels, attention_mask, position_ids, loss_mask - - -def forward_step(batch, model, eval_metric): - """Forward step.""" - - # Get the batch. - tokens, labels, attention_mask, position_ids, loss_mask = process_batch( - batch) - - # Tell the model what our actual batch size will be - args = get_args() - args.micro_batch_size = len(labels) - - input_tensor = recv_forward() - - # Forward pass through the model. - unwrapped_model = unwrap_model( - model, (torchDDP, LocalDDP, Float16Module)) - unwrapped_model.set_input_tensor(input_tensor) - output = model(tokens, position_ids, attention_mask) - - send_forward(output) - - if parallel_state.is_pipeline_last_stage(): - # For loss, return the unreduced loss. - if eval_metric == 'loss': - losses = tensor_parallel.vocab_parallel_cross_entropy( - output.contiguous().float(), labels.contiguous()) - loss = torch.sum( - losses.view(-1) * loss_mask.contiguous().view(-1).float()) - return loss - - # For accuracy, return the number of correctly predicted samples. - if eval_metric == 'accuracy': - outputs = torch.argmax(output, -1) - correct = (outputs == labels).float() - correct[(1 - loss_mask).bool()] = 1 - correct = correct.prod(-1) - return correct.sum() - - raise NotImplementedError('forward method for evaluation metric {} ' - 'is not implemented.'.format(eval_metric)) - return None - - -def evaluate(data_loader, model, eval_metric): - """Evaluation.""" - args = get_args() - - # Turn on evaluation mode which disables dropout. - model.eval() - - total_output = 0.0 - with torch.no_grad(): - # For all the batches in the dataset. - for iteration, batch in enumerate(data_loader): - if iteration % args.log_interval == 0: - print_rank_0('> working on iteration: {}'.format(iteration)) - # Forward evaluation. - output = forward_step(batch, model, eval_metric) - - # Reduce across processes. - if parallel_state.is_pipeline_last_stage(): - torch.distributed.all_reduce(output, - group=parallel_state.get_data_parallel_group()) - - total_output += output - - return total_output - - -def evaluate_and_print_results(task, data_loader, model, eval_metric): - """Evaluate and print results on screen.""" - - # Evaluate and get results. - output = evaluate(data_loader, model, eval_metric) - - string = ' validation results on {} | '.format(task) - if is_last_rank(): - if eval_metric == 'loss': - num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens - num_original_tokens = data_loader.dataset.num_original_tokens - val_loss = output / (num_tokenized_tokens - 1) - ppl = math.exp(min(20, val_loss)) - token_ratio = (num_tokenized_tokens - 1) / (num_original_tokens - 1) - adjusted_ppl = math.exp(min(20, val_loss * token_ratio)) - string += 'avg loss: {:.4E} | '.format(val_loss) - string += 'ppl: {:.4E} | '.format(ppl) - string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl) - string += 'token ratio: {} |'.format(token_ratio) - - elif eval_metric == 'accuracy': - num_examples = len(data_loader.dataset) - acc = output / num_examples - string += 'number correct: {:.4E} | '.format(output) - string += 'total examples: {:.4E} | '.format(num_examples) - string += 'avg accuracy: {:.4E}'.format(acc) - - else: - raise NotImplementedError('evaluation method for {} metric is not ' - 'implemented yet.'.format(eval_metric)) - - length = len(string) + 1 - print('-' * length) - print(string) - print('-' * length) - - -def main(): - """Main program.""" - args = get_args() - - if args.num_layers_per_virtual_pipeline_stage is not None: - print("Interleaved pipeline schedule is not yet supported for text generation.") - exit() - - if args.task == 'LAMBADA': - eval_metric = 'accuracy' - elif args.task == 'WIKITEXT103': - eval_metric = 'loss' - else: - raise NotImplementedError('{} task is not implemented.'.format( - args.task)) - - # Set up model and load checkpoint. - model = get_model(get_model_provider(eval_metric), wrap_with_ddp=False) - if args.load is not None: - _ = load_checkpoint(model, None, None) - - assert len(model) == 1, "Above condition should have caught this" - model = model[0] - - # Data stuff. - dataset = build_dataset(args.task) - dataloader = build_data_loader(dataset, args.micro_batch_size, - args.num_workers, drop_last=False) - - # Run evaluation. - evaluate_and_print_results(args.task, dataloader, model, eval_metric) - - print_rank_0('done :-)') +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""GPT zero-shot evaluation.""" + +import math + +import torch + +from megatron import get_args +from megatron import print_rank_0, is_last_rank +from megatron import get_tokenizer +from megatron.core import parallel_state, tensor_parallel +from megatron.checkpointing import load_checkpoint +from megatron.model import GPTModel +from megatron.training import get_model +from megatron.arguments import core_transformer_config_from_args +from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model +from megatron.p2p_communication import recv_forward, send_forward +from tasks.finetune_utils import build_data_loader +from deepspeed.accelerator import get_accelerator +from .datasets import build_dataset + +# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible? +from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP +from megatron.model import DistributedDataParallel as LocalDDP +from megatron.model import Float16Module + +def get_model_provider(eval_metric): + """Based on evaluation metric set the parallel-output flag and + return the model provider.""" + + def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + config = core_transformer_config_from_args(get_args()) + + if eval_metric == 'loss': + parallel_output = True + elif eval_metric == 'accuracy': + parallel_output = False + else: + raise NotImplementedError('output type for {} evaluation metric ' + 'is not supported.'.format(eval_metric)) + + print_rank_0('building GPT model ...') + model = GPTModel(config=config, num_tokentypes=0, parallel_output=parallel_output, + pre_process=pre_process, post_process=post_process) + + return model + + return model_provider + + +def process_batch(batch): + """Process batch and produce inputs for the model.""" + args = get_args() + tokenizer = get_tokenizer() + + loss_mask = batch['pad_mask'].long().to(get_accelerator().device_name()).contiguous().byte() + tokens_ = batch['text'].long().to(get_accelerator().device_name()).contiguous() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Get the masks and postition ids. + attention_mask, _, position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss) + + return tokens, labels, attention_mask, position_ids, loss_mask + + +def forward_step(batch, model, eval_metric): + """Forward step.""" + + # Get the batch. + tokens, labels, attention_mask, position_ids, loss_mask = process_batch( + batch) + + # Tell the model what our actual batch size will be + args = get_args() + args.micro_batch_size = len(labels) + + input_tensor = recv_forward() + + # Forward pass through the model. + unwrapped_model = unwrap_model( + model, (torchDDP, LocalDDP, Float16Module)) + unwrapped_model.set_input_tensor(input_tensor) + output = model(tokens, position_ids, attention_mask) + + send_forward(output) + + if parallel_state.is_pipeline_last_stage(): + # For loss, return the unreduced loss. + if eval_metric == 'loss': + losses = tensor_parallel.vocab_parallel_cross_entropy( + output.contiguous().float(), labels.contiguous()) + loss = torch.sum( + losses.view(-1) * loss_mask.contiguous().view(-1).float()) + return loss + + # For accuracy, return the number of correctly predicted samples. + if eval_metric == 'accuracy': + outputs = torch.argmax(output, -1) + correct = (outputs == labels).float() + correct[(1 - loss_mask).bool()] = 1 + correct = correct.prod(-1) + return correct.sum() + + raise NotImplementedError('forward method for evaluation metric {} ' + 'is not implemented.'.format(eval_metric)) + return None + + +def evaluate(data_loader, model, eval_metric): + """Evaluation.""" + args = get_args() + + # Turn on evaluation mode which disables dropout. + model.eval() + + total_output = 0.0 + with torch.no_grad(): + # For all the batches in the dataset. + for iteration, batch in enumerate(data_loader): + if iteration % args.log_interval == 0: + print_rank_0('> working on iteration: {}'.format(iteration)) + # Forward evaluation. + output = forward_step(batch, model, eval_metric) + + # Reduce across processes. + if parallel_state.is_pipeline_last_stage(): + torch.distributed.all_reduce(output, + group=parallel_state.get_data_parallel_group()) + + total_output += output + + return total_output + + +def evaluate_and_print_results(task, data_loader, model, eval_metric): + """Evaluate and print results on screen.""" + + # Evaluate and get results. + output = evaluate(data_loader, model, eval_metric) + + string = ' validation results on {} | '.format(task) + if is_last_rank(): + if eval_metric == 'loss': + num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens + num_original_tokens = data_loader.dataset.num_original_tokens + val_loss = output / (num_tokenized_tokens - 1) + ppl = math.exp(min(20, val_loss)) + token_ratio = (num_tokenized_tokens - 1) / (num_original_tokens - 1) + adjusted_ppl = math.exp(min(20, val_loss * token_ratio)) + string += 'avg loss: {:.4E} | '.format(val_loss) + string += 'ppl: {:.4E} | '.format(ppl) + string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl) + string += 'token ratio: {} |'.format(token_ratio) + + elif eval_metric == 'accuracy': + num_examples = len(data_loader.dataset) + acc = output / num_examples + string += 'number correct: {:.4E} | '.format(output) + string += 'total examples: {:.4E} | '.format(num_examples) + string += 'avg accuracy: {:.4E}'.format(acc) + + else: + raise NotImplementedError('evaluation method for {} metric is not ' + 'implemented yet.'.format(eval_metric)) + + length = len(string) + 1 + print('-' * length) + print(string) + print('-' * length) + + +def main(): + """Main program.""" + args = get_args() + + if args.num_layers_per_virtual_pipeline_stage is not None: + print("Interleaved pipeline schedule is not yet supported for text generation.") + exit() + + if args.task == 'LAMBADA': + eval_metric = 'accuracy' + elif args.task == 'WIKITEXT103': + eval_metric = 'loss' + else: + raise NotImplementedError('{} task is not implemented.'.format( + args.task)) + + # Set up model and load checkpoint. + model = get_model(get_model_provider(eval_metric), wrap_with_ddp=False) + if args.load is not None: + _ = load_checkpoint(model, None, None) + + assert len(model) == 1, "Above condition should have caught this" + model = model[0] + + # Data stuff. + dataset = build_dataset(args.task) + dataloader = build_data_loader(dataset, args.micro_batch_size, + args.num_workers, drop_last=False) + + # Run evaluation. + evaluate_and_print_results(args.task, dataloader, model, eval_metric) + + print_rank_0('done :-)') diff --git a/tests/run_megatron.py b/tests/run_megatron.py index ad96165b..4f8fdf79 100644 --- a/tests/run_megatron.py +++ b/tests/run_megatron.py @@ -2,16 +2,21 @@ import deepspeed import megatron from megatron import get_args -from megatron import mpu +from megatron.core import mpu from megatron.checkpointing import load_checkpoint from megatron.initialize import initialize_megatron from megatron.model import GPTModel from megatron.training import get_model +from megatron.arguments import core_transformer_config_from_args from megatron.text_generation_utils import generate_samples_eval def model_provider(pre_process=True, post_process=True): + + config = core_transformer_config_from_args(get_args()) + model = GPTModel( + config=config, num_tokentypes=0, parallel_output=False, pre_process=pre_process, diff --git a/tools/generate_samples_gpt.py b/tools/generate_samples_gpt.py index bbd1164c..26f1e1da 100644 --- a/tools/generate_samples_gpt.py +++ b/tools/generate_samples_gpt.py @@ -25,7 +25,7 @@ from megatron import get_args from megatron import print_rank_0 from megatron import get_tokenizer -from megatron import mpu +from megatron.core import mpu from megatron.checkpointing import load_checkpoint from megatron.initialize import initialize_megatron from megatron.model import GPTModel @@ -36,11 +36,17 @@ import deepspeed import torch +from megatron.arguments import core_transformer_config_from_args +from megatron import get_args + def model_provider(pre_process=True, post_process=True): """Build the model.""" + args = get_args() + config = core_transformer_config_from_args(args) + print_rank_0('building GPT model ...') - model = GPTModel(num_tokentypes=0, parallel_output=False, + model = GPTModel(config=config, num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process, return_moe_loss=False) # we need to set "return_moe_loss" for the inference_mode return model @@ -73,6 +79,8 @@ def add_text_generate_args(parser): group.add_argument("--recompute", action='store_true', help='During generation recompute all attention ' 'instead of using previously computed keys/values.') + group.add_argument("--local_rank", type=int, default=0, + help='local_rank') return parser diff --git a/tools/run_text_generation_server.py b/tools/run_text_generation_server.py index 3fdd27be..aefad91e 100644 --- a/tools/run_text_generation_server.py +++ b/tools/run_text_generation_server.py @@ -13,6 +13,7 @@ from megatron.initialize import initialize_megatron from megatron.model import GPTModel from megatron.training import get_model +from megatron.arguments import core_transformer_config_from_args from megatron.text_generation_server import MegatronServer from megatron.text_generation import generate_and_post_process from megatron.text_generation import beam_search_and_post_process @@ -21,8 +22,10 @@ def model_provider(pre_process=True, post_process=True): """Build the model.""" + config = core_transformer_config_from_args(get_args()) + print_rank_0('building GPT model ...') - model = GPTModel(num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process) + model = GPTModel(config=config, num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process) return model