diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000000..261f59bc24 --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] +max-line-length = 100 +extend-ignore = E203 +per-file-ignores = __init__.py:F401 \ No newline at end of file diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000000..5e550f1703 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,7 @@ +[MASTER] +ignore=tests + +[MESSAGES CONTROL] +disable=all + +enable=C0115,C0116 \ No newline at end of file diff --git a/Dockerfile.linting b/Dockerfile.linting index 910df314f8..b0670af9d1 100644 --- a/Dockerfile.linting +++ b/Dockerfile.linting @@ -10,7 +10,9 @@ RUN sed -i -e 's/^APT/# APT/' -e 's/^DPkg/# DPkg/' \ RUN pip3 install --no-cache-dir \ black==24.4.2 \ - isort + isort==5.13.2 \ + flake8==7.1.0 \ + pylint==3.2.6 COPY . /opt/megatron-lm diff --git a/megatron/core/datasets/bert_dataset.py b/megatron/core/datasets/bert_dataset.py index 657cc6a78a..78ae2edf62 100644 --- a/megatron/core/datasets/bert_dataset.py +++ b/megatron/core/datasets/bert_dataset.py @@ -21,8 +21,7 @@ class BERTMaskedWordPieceDatasetConfig(MaskedWordPieceDatasetConfig): """Option to perform the next sequence prediction during sampling""" def __post_init__(self) -> None: - """Do asserts and set fields post init - """ + """Do asserts and set fields post init""" super().__post_init__() assert self.classification_head is not None @@ -73,22 +72,20 @@ def _key_config_attributes() -> List[str]: """ return super( BERTMaskedWordPieceDataset, BERTMaskedWordPieceDataset - )._key_config_attributes() + ["classification_head",] + )._key_config_attributes() + ["classification_head"] def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]: """Abstract method implementation - + Args: idx (int): The index into the dataset Returns: - Dict[str, Union[int, numpy.ndarray]]: The + Dict[str, Union[int, numpy.ndarray]]: The """ idx_beg, idx_end, target_sequence_length = self.sample_index[idx] sample = [self.dataset[i] for i in range(idx_beg, idx_end)] - numpy_random_state = numpy.random.RandomState( - seed=(self.config.random_seed + idx) % 2 ** 32 - ) + numpy_random_state = numpy.random.RandomState(seed=(self.config.random_seed + idx) % 2**32) assert target_sequence_length <= self.config.sequence_length @@ -127,11 +124,7 @@ def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]: truncated = True # Merge the subsegments and create the token assignment labels - tokens = [ - self.config.tokenizer.cls, - *split_A, - self.config.tokenizer.sep, - ] + tokens = [self.config.tokenizer.cls, *split_A, self.config.tokenizer.sep] assignments = [0 for _ in range(1 + len(split_A) + 1)] if split_B: tokens += [*split_B, self.config.tokenizer.sep] diff --git a/megatron/core/datasets/blended_dataset.py b/megatron/core/datasets/blended_dataset.py index f7883d9b14..be0b7a4a08 100644 --- a/megatron/core/datasets/blended_dataset.py +++ b/megatron/core/datasets/blended_dataset.py @@ -93,10 +93,7 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]: dataset_id = self.dataset_index[idx] dataset_sample_id = self.dataset_sample_index[idx] - return { - "dataset_id": dataset_id, - **self.datasets[dataset_id][dataset_sample_id], - } + return {"dataset_id": dataset_id, **self.datasets[dataset_id][dataset_sample_id]} def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]: """Build and optionally cache the dataset index and the dataset sample index @@ -129,9 +126,7 @@ def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]: if not path_to_cache or (not cache_hit and torch.distributed.get_rank() == 0): log_single_rank( - logger, - logging.INFO, - f"Build and save the {type(self).__name__} indices", + logger, logging.INFO, f"Build and save the {type(self).__name__} indices" ) self.built_anew_on_cache_miss = True diff --git a/megatron/core/datasets/blended_megatron_dataset_builder.py b/megatron/core/datasets/blended_megatron_dataset_builder.py index 0230faf5e0..c9cf4abf63 100644 --- a/megatron/core/datasets/blended_megatron_dataset_builder.py +++ b/megatron/core/datasets/blended_megatron_dataset_builder.py @@ -156,9 +156,7 @@ def build(self) -> List[Optional[TopLevelDataset]]: return datasets - def _build_blended_dataset_splits( - self, - ) -> List[Optional[TopLevelDataset]]: + def _build_blended_dataset_splits(self) -> List[Optional[TopLevelDataset]]: """Build all dataset splits according to the provided blend(s) See the BlendedMegatronDatasetBuilder.build alias for more information. @@ -306,10 +304,7 @@ def _build_blended_dataset_splits( return blended_datasets def _build_megatron_datasets_parallel( - self, - prefixes: List[str], - split: List[float], - sizes_per_dataset: List[List[int]], + self, prefixes: List[str], split: List[float], sizes_per_dataset: List[List[int]] ) -> List[List[Optional[MegatronDataset]]]: """Build the megatron datasets for a list of prefixes in parallel @@ -369,11 +364,7 @@ def _threading_helper( # i.e. meant for serial build, do not scale up. num_workers *= min(2, max(1, torch.cuda.device_count())) _threading_helper( - megatron_datasets, - num_workers, - prefixes, - split, - sizes_per_dataset, + megatron_datasets, num_workers, prefixes, split, sizes_per_dataset ) torch.distributed.barrier() @@ -389,11 +380,7 @@ def _threading_helper( ) else: _threading_helper( - megatron_datasets, - num_dataset_builder_threads, - prefixes, - split, - sizes_per_dataset, + megatron_datasets, num_dataset_builder_threads, prefixes, split, sizes_per_dataset ) return megatron_datasets diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index c5b2bbe7b4..115727de92 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -108,11 +108,9 @@ def __init__( except Exception: self._pad_token_id = _PAD_TOKEN_ID - ( - self.document_index, - self.sample_index, - self.shuffle_index, - ) = self._build_document_sample_shuffle_indices() + (self.document_index, self.sample_index, self.shuffle_index) = ( + self._build_document_sample_shuffle_indices() + ) @staticmethod def numel_low_level_dataset(low_level_dataset: IndexedDataset) -> int: diff --git a/megatron/core/datasets/indexed_dataset.py b/megatron/core/datasets/indexed_dataset.py index ae05bcbc6a..29975336f1 100644 --- a/megatron/core/datasets/indexed_dataset.py +++ b/megatron/core/datasets/indexed_dataset.py @@ -385,12 +385,7 @@ def read(self, dtype: Type[numpy.number], count: int, offset: int) -> numpy.ndar Returns: numpy.ndarray: An array with `count` items and data-type `dtype` constructed from reading bytes from the data file starting at `offset`. """ - return numpy.frombuffer( - self._bin_buffer, - dtype=dtype, - count=count, - offset=offset, - ) + return numpy.frombuffer(self._bin_buffer, dtype=dtype, count=count, offset=offset) def __del__(self) -> None: """Clean up the object.""" @@ -633,9 +628,7 @@ def __getitem__( if isinstance(idx, (int, numpy.integer)): sequence_pointer, sequence_length, sequence_mode = self.index[idx] sequence = self.bin_reader.read( - dtype=self.index.dtype, - count=sequence_length, - offset=sequence_pointer, + dtype=self.index.dtype, count=sequence_length, offset=sequence_pointer ) return (sequence, sequence_mode) if sequence_mode is not None else sequence elif isinstance(idx, slice): diff --git a/megatron/core/datasets/masked_dataset.py b/megatron/core/datasets/masked_dataset.py index 081d58525b..9db6c67eb1 100644 --- a/megatron/core/datasets/masked_dataset.py +++ b/megatron/core/datasets/masked_dataset.py @@ -154,15 +154,7 @@ def _build_sample_index( ) path_to_description = get_path_to("description.txt") path_to_sample_index = get_path_to("sample_index.npy") - cache_hit = all( - map( - os.path.isfile, - [ - path_to_description, - path_to_sample_index, - ], - ) - ) + cache_hit = all(map(os.path.isfile, [path_to_description, path_to_sample_index])) if self.num_samples is not None: num_epochs = numpy.iinfo(numpy.int32).max - 1 diff --git a/megatron/core/datasets/retro/db/build.py b/megatron/core/datasets/retro/db/build.py index 780cc9e503..44b9038230 100644 --- a/megatron/core/datasets/retro/db/build.py +++ b/megatron/core/datasets/retro/db/build.py @@ -95,23 +95,13 @@ def build_partial_db( if proc_id in progress_proc_ids: log_retro_rank_0( " > building partial chunk db, proc %d / %d, docs %d:%d / %d." - % ( - proc_id, - n_procs, - doc_start_id, - doc_end_id, - n_docs, - ) + % (proc_id, n_procs, doc_start_id, doc_end_id, n_docs) ) # Progress bars (snapshot of overall progress). doc_id_iter = range(doc_start_id, doc_end_id) pbar = ( - tqdm( - doc_id_iter, - "parse doc chunks", - miniters=len(doc_id_iter) // 20, - ) + tqdm(doc_id_iter, "parse doc chunks", miniters=len(doc_id_iter) // 20) if proc_id in progress_proc_ids else doc_id_iter ) @@ -156,9 +146,7 @@ def build_partial_db( # Re-tokenize. chunk_end_idx = chunk_end_idxs[i] gpt_token_ids = indexed_dataset.get( - idx=doc_id, - offset=chunk_start_idx, - length=chunk_end_idx - chunk_start_idx, + idx=doc_id, offset=chunk_start_idx, length=chunk_end_idx - chunk_start_idx ) text = config.gpt_detokenize(gpt_token_ids.tolist()) bert_token_ids = config.bert_tokenize(text) @@ -169,14 +157,7 @@ def build_partial_db( else: _chunk_db = chunk_db_valid doc_size_map[doc_id] += 1 - _chunk_db.append( - ( - doc_id, - chunk_start_idx, - chunk_end_idx, - len(bert_token_ids), - ) - ) + _chunk_db.append((doc_id, chunk_start_idx, chunk_end_idx, len(bert_token_ids))) return proc_id, chunk_db_valid, chunk_db_invalid, doc_size_map @@ -269,10 +250,7 @@ def build_block_db( def save_block_db( - block: dict, - chunk_db_valid: np.ndarray, - chunk_db_invalid: np.ndarray, - doc_offsets: np.ndarray, + block: dict, chunk_db_valid: np.ndarray, chunk_db_invalid: np.ndarray, doc_offsets: np.ndarray ) -> None: """Save block of chunked tokens to disk. These blocks are later used for training and adding to the vector index. @@ -291,10 +269,7 @@ def save_block_db( def build_individual_db( - config: RetroPreprocessingConfig, - dataset_idx: int, - n_datasets: int, - dataset_info: dict, + config: RetroPreprocessingConfig, dataset_idx: int, n_datasets: int, dataset_info: dict ) -> None: """Process a single indexed dataset & extract chunks. @@ -395,8 +370,7 @@ def build_individual_db( def build_individual_dbs( - config: RetroPreprocessingConfig, - indexed_dataset_infos: List[Dict], + config: RetroPreprocessingConfig, indexed_dataset_infos: List[Dict] ) -> None: """Iterate each indexed dataset & process its chunks. @@ -412,11 +386,7 @@ def build_individual_dbs( # Progress. log_retro_rank_0( " > building individual db, dataset %d / %d ... '%s'." - % ( - ds_idx, - len(indexed_dataset_infos), - ds_info["prefix"], - ) + % (ds_idx, len(indexed_dataset_infos), ds_info["prefix"]) ) # Process single dataset. @@ -562,7 +532,7 @@ def merge_dbs(project_dir: str, indexed_dataset_infos: List[Dict], db_type: str) for ds_idx, ds_info in enumerate(indexed_dataset_infos): log_retro_rank_0( " > merging dbs; '%s', dataset %d / %d ... '%s'." - % (db_type, ds_idx, len(indexed_dataset_infos), ds_info["prefix"]), + % (db_type, ds_idx, len(indexed_dataset_infos), ds_info["prefix"]) ) individual_chunk_db: np.ndarray = get_individual_chunk_db(project_dir, ds_idx, ds_info) individual_doc_offsets: np.ndarray = ( diff --git a/megatron/core/datasets/retro/db/dataset.py b/megatron/core/datasets/retro/db/dataset.py index 1de6e02b10..f9053622ab 100644 --- a/megatron/core/datasets/retro/db/dataset.py +++ b/megatron/core/datasets/retro/db/dataset.py @@ -17,7 +17,7 @@ class DBDataset(torch.utils.data.Dataset): """Dataset for iterating chunks. - + Args: db_path (str): Path of HDF5-format chunk database. indexed_datasets (List[IndexedDataset]): Indexed datasets used to build database. @@ -85,10 +85,7 @@ def __getitem__(self, chunk_id: int) -> dict: token_ids = token_ids.tolist() token_ids += [self.eod_token_id] * (self.max_chunk_length - chunk_length) - return { - "doc_id": doc_id, - "text": np.array(token_ids, dtype=np.int64), - } + return {"doc_id": doc_id, "text": np.array(token_ids, dtype=np.int64)} def load_doc_tuples(self) -> None: """Load the dataset & document ids. diff --git a/megatron/core/datasets/retro/db/utils.py b/megatron/core/datasets/retro/db/utils.py index df13089840..e8578a09d5 100644 --- a/megatron/core/datasets/retro/db/utils.py +++ b/megatron/core/datasets/retro/db/utils.py @@ -22,7 +22,7 @@ def get_db_dir(project_dir: str) -> str: Args: project_dir (str): Path to Retro project dir. - + Returns: Path of the DB sub-directory within the project. """ @@ -55,9 +55,7 @@ def init_indexed_dataset_infos(config: RetroPreprocessingConfig) -> List[Dict]: prefix = data_blend[i + 1] path = os.path.join(data_dir, prefix + ".bin") assert os.path.exists(path), "couldn't find '%s'." % path - infos.append( - {"ratio": ratio, "prefix": prefix,} - ) + infos.append({"ratio": ratio, "prefix": prefix}) # Load indexed datasets. load_indexed_datasets(config.retro_project_dir, infos) diff --git a/megatron/core/datasets/retro/external_libs.py b/megatron/core/datasets/retro/external_libs.py index 98b28728d4..c057eba25c 100644 --- a/megatron/core/datasets/retro/external_libs.py +++ b/megatron/core/datasets/retro/external_libs.py @@ -4,11 +4,7 @@ import importlib -required_libs = [ - "faiss", - "h5py", - "transformers", # for huggingface bert -] +required_libs = ["faiss", "h5py", "transformers"] # for huggingface bert for lib in required_libs: try: diff --git a/megatron/core/datasets/retro/index/build.py b/megatron/core/datasets/retro/index/build.py index a5659e92db..1f310d89c3 100644 --- a/megatron/core/datasets/retro/index/build.py +++ b/megatron/core/datasets/retro/index/build.py @@ -41,7 +41,7 @@ def get_empty_index_path(config: RetroPreprocessingConfig) -> str: Args: config (RetroPreprocessingConfig): Retro preprocessing config. - + Returns: Path to the empty (trained, but without added samples) vector index. """ diff --git a/megatron/core/datasets/retro/index/factory.py b/megatron/core/datasets/retro/index/factory.py index 293d58c678..f88084ddb1 100644 --- a/megatron/core/datasets/retro/index/factory.py +++ b/megatron/core/datasets/retro/index/factory.py @@ -23,7 +23,7 @@ def get_index_class(cls, index_type: str) -> type: Returns: An `Index` sub-type corresponding to the `index_type`. """ - return {"faiss-base": FaissBaseIndex, "faiss-par-add": FaissParallelAddIndex,}[index_type] + return {"faiss-base": FaissBaseIndex, "faiss-par-add": FaissParallelAddIndex}[index_type] @classmethod def get_index(cls, index_type: str) -> Index: diff --git a/megatron/core/datasets/retro/index/index.py b/megatron/core/datasets/retro/index/index.py index a8c086fb94..c6bd13fbee 100644 --- a/megatron/core/datasets/retro/index/index.py +++ b/megatron/core/datasets/retro/index/index.py @@ -27,7 +27,6 @@ class Index(abc.ABC): - """Abstract base class for indexes. *Note* : While currently only Faiss-based classes are implemented, in the @@ -60,7 +59,7 @@ def get_empty_index_path(self, config: RetroPreprocessingConfig) -> str: File path to empty index (i.e., this index has had index.train() called, but not yet index.add()). """ return os.path.join( - get_index_dir(config), "empty_%.3f.faissindex" % config.retro_index_train_load_fraction, + get_index_dir(config), "empty_%.3f.faissindex" % config.retro_index_train_load_fraction ) def get_empty_index(self, config: RetroPreprocessingConfig) -> faiss.Index: @@ -86,7 +85,7 @@ def get_added_index_path(self, config: RetroPreprocessingConfig) -> str: return os.path.join( get_index_dir(config), "added_%.3f_%.3f.faissindex" - % (config.retro_index_train_load_fraction, config.retro_index_add_load_fraction,), + % (config.retro_index_train_load_fraction, config.retro_index_add_load_fraction), ) def get_added_index(self, config: RetroPreprocessingConfig) -> faiss.Index: diff --git a/megatron/core/datasets/retro/index/indexes/faiss_base.py b/megatron/core/datasets/retro/index/indexes/faiss_base.py index 1ffc72528c..c1daf3f533 100644 --- a/megatron/core/datasets/retro/index/indexes/faiss_base.py +++ b/megatron/core/datasets/retro/index/indexes/faiss_base.py @@ -52,7 +52,7 @@ def _train(self, config: RetroPreprocessingConfig) -> None: # Load data. merged_path = get_training_data_merged_path(config) - inp = np.memmap(merged_path, dtype="f4", mode="r",).reshape((-1, config.hidden_size)) + inp = np.memmap(merged_path, dtype="f4", mode="r").reshape((-1, config.hidden_size)) # Init index. index = faiss.index_factory(config.hidden_size, config.retro_index_str) diff --git a/megatron/core/datasets/retro/index/indexes/faiss_par_add.py b/megatron/core/datasets/retro/index/indexes/faiss_par_add.py index 6d9d68f821..e014217262 100644 --- a/megatron/core/datasets/retro/index/indexes/faiss_par_add.py +++ b/megatron/core/datasets/retro/index/indexes/faiss_par_add.py @@ -58,7 +58,7 @@ def encode_block( """ # Embed block. - embeddings = self.embed_text_dataset_block(embedder, text_dataset, block["range"],) + embeddings = self.embed_text_dataset_block(embedder, text_dataset, block["range"]) # Encode block. log_retro_rank_0("encode.") @@ -108,7 +108,7 @@ def validate(f: h5py.File) -> None: assert len(f["data"].shape) == 2 blocks = get_blocks_by_rank( - codes_dir, len(text_dataset), config.retro_block_size, validate=validate, + codes_dir, len(text_dataset), config.retro_block_size, validate=validate ) # Encode each block. @@ -119,7 +119,7 @@ def validate(f: h5py.File) -> None: # Progress. log_retro_rank_0( "encode block %d / %d ... %s." - % (block_index, len(blocks.missing), block["path"],) + % (block_index, len(blocks.missing), block["path"]) ) # Encode and save. @@ -156,7 +156,7 @@ def add_codes(self, config: RetroPreprocessingConfig) -> None: for code_path in pbar: pbar.set_description( "add codes, mem %.3f gb, %.1f%%" - % (psutil.virtual_memory()[3] / 1024 ** 3, psutil.virtual_memory()[2],) + % (psutil.virtual_memory()[3] / 1024**3, psutil.virtual_memory()[2]) ) with h5py.File(code_path) as f: diff --git a/megatron/core/datasets/retro/index/utils.py b/megatron/core/datasets/retro/index/utils.py index 321cd659d8..58229439ae 100644 --- a/megatron/core/datasets/retro/index/utils.py +++ b/megatron/core/datasets/retro/index/utils.py @@ -22,7 +22,7 @@ def get_index_dir(config: RetroPreprocessingConfig) -> str: # Directory path. index_dir_path = os.path.join( - config.retro_project_dir, "index", config.retro_index_type, config.retro_index_str, + config.retro_project_dir, "index", config.retro_index_type, config.retro_index_str ) # Make directory. diff --git a/megatron/core/datasets/retro/index/validate.py b/megatron/core/datasets/retro/index/validate.py index 6783df6492..57306707c4 100644 --- a/megatron/core/datasets/retro/index/validate.py +++ b/megatron/core/datasets/retro/index/validate.py @@ -74,7 +74,7 @@ def validate_training_embeddings(config: RetroPreprocessingConfig) -> None: # Progress. (*note*: move world progress to here.) log_retro_rank_0( "embed training block %d / %d ... %s." - % (block_idx, len(blocks.existing), block["path"],) + % (block_idx, len(blocks.existing), block["path"]) ) # Load existing block embeddings. @@ -147,7 +147,7 @@ def validate(f: h5py.File) -> None: # Progress. log_retro_rank_0( - "encode block %d / %d ... %s." % (block_idx, len(blocks.existing), block["path"],) + "encode block %d / %d ... %s." % (block_idx, len(blocks.existing), block["path"]) ) # Load existing codes. diff --git a/megatron/core/datasets/retro/query/gpt_chunk_dataset.py b/megatron/core/datasets/retro/query/gpt_chunk_dataset.py index 34a2ee6c87..6191a30a31 100644 --- a/megatron/core/datasets/retro/query/gpt_chunk_dataset.py +++ b/megatron/core/datasets/retro/query/gpt_chunk_dataset.py @@ -73,14 +73,11 @@ def __getitem__(self, idx: int) -> dict: chunk_token_ids = sample_token_ids[token_start_idx:token_end_idx] # Sample. - return { - "doc_ids": sample_doc_ids, - "text": chunk_token_ids, - } + return {"doc_ids": sample_doc_ids, "text": chunk_token_ids} def build_gpt_chunk_datasets_from_gpt_datasets( - project_dir: str, gpt_datasets: dict, sample_length: int, chunk_length: int, + project_dir: str, gpt_datasets: dict, sample_length: int, chunk_length: int ) -> dict: """Get train, valid, test GPT chunk datasets. @@ -96,14 +93,16 @@ def build_gpt_chunk_datasets_from_gpt_datasets( # GPT chunk datasets. chunk_datasets = { - key: { - "dataset": GPTChunkDataset(sample_ds, sample_length, chunk_length), - "neighbor_dir": get_neighbor_dir(project_dir, key, sample_ds), - "num_active_chunks": num_active_samples - * get_num_chunks_per_sample(sample_length, chunk_length), - } - if sample_ds - else None + key: ( + { + "dataset": GPTChunkDataset(sample_ds, sample_length, chunk_length), + "neighbor_dir": get_neighbor_dir(project_dir, key, sample_ds), + "num_active_chunks": num_active_samples + * get_num_chunks_per_sample(sample_length, chunk_length), + } + if sample_ds + else None + ) for key, (sample_ds, num_active_samples) in gpt_datasets.items() } diff --git a/megatron/core/datasets/retro/query/query.py b/megatron/core/datasets/retro/query/query.py index 165792f9a0..9da3381712 100644 --- a/megatron/core/datasets/retro/query/query.py +++ b/megatron/core/datasets/retro/query/query.py @@ -39,7 +39,7 @@ from .gpt_chunk_dataset import build_gpt_chunk_datasets_from_gpt_datasets -def get_index(config: RetroPreprocessingConfig, ondisk: bool = False,) -> faiss.Index: +def get_index(config: RetroPreprocessingConfig, ondisk: bool = False) -> faiss.Index: """Read index from disk. Args: @@ -67,7 +67,7 @@ def get_index(config: RetroPreprocessingConfig, ondisk: bool = False,) -> faiss. def embed_block( - config: RetroPreprocessingConfig, gpt_dataset: GPTChunkDataset, block: dict, + config: RetroPreprocessingConfig, gpt_dataset: GPTChunkDataset, block: dict ) -> np.ndarray: """Embed block of chunks. @@ -80,7 +80,7 @@ def embed_block( Embeddings array, with shape (len(block["range"]), dimension(embedder)). """ text_block_dataset = torch.utils.data.Subset( - GPTToTextDataset(gpt_dataset, config.retro_tokenizers.gpt), range(*block["range"]), + GPTToTextDataset(gpt_dataset, config.retro_tokenizers.gpt), range(*block["range"]) ) return config.retro_bert_embedders.mem.embed_text_dataset(text_block_dataset) @@ -248,17 +248,14 @@ def query_block_neighbors( sample_map = {} for i in sample_ids: sample = query_dataset.sample_dataset[i] - sample_map[i] = { - "dataset_idx": sample["dataset_id"], - "doc_ids": sample["document_ids"], - } + sample_map[i] = {"dataset_idx": sample["dataset_id"], "doc_ids": sample["document_ids"]} # Embed block. embeddings = embed_block(config, query_dataset, block) # Query embeddings. _, filtered_neighbor_ids = query_embedding_block( - config, db_dataset, index, embeddings, block["range"], sample_map, n_chunks_per_sample, + config, db_dataset, index, embeddings, block["range"], sample_map, n_chunks_per_sample ) if config.retro_task_validate is None: @@ -303,15 +300,17 @@ def validate(f: h5py.File) -> None: Args: f (h5py.File): File containing save neighbor IDs. """ - assert f["neighbors"].shape[1] == config.retro_query_num_neighbors_save, ( - "neighbors.shape == %s; num_neighbors_target == %d." - % (str(f["neighbors"].shape), config.retro_num_neighbors_target,) + assert ( + f["neighbors"].shape[1] == config.retro_query_num_neighbors_save + ), "neighbors.shape == %s; num_neighbors_target == %d." % ( + str(f["neighbors"].shape), + config.retro_num_neighbors_target, ) if config.retro_task_validate is None: retro_makedir(config, neighbor_dir) blocks = get_blocks_by_rank( - neighbor_dir, num_active_chunks, config.retro_block_size, validate=validate, + neighbor_dir, num_active_chunks, config.retro_block_size, validate=validate ) active_blocks = blocks.missing else: @@ -339,7 +338,7 @@ def validate(f: h5py.File) -> None: block_index, len(active_blocks), os.path.basename(block["path"]), - psutil.virtual_memory()[3] / 1024 ** 3, + psutil.virtual_memory()[3] / 1024**3, psutil.virtual_memory()[2], ) ) diff --git a/megatron/core/datasets/retro/query/retro_dataset.py b/megatron/core/datasets/retro/query/retro_dataset.py index 07af161693..6c3b9ae60c 100644 --- a/megatron/core/datasets/retro/query/retro_dataset.py +++ b/megatron/core/datasets/retro/query/retro_dataset.py @@ -94,7 +94,7 @@ def __getitem__(self, sample_idx: int) -> dict: # Sample idx to chunk idxs. chunk_idxs = list( - range(sample_idx * n_chunks_per_sample, (sample_idx + 1) * n_chunks_per_sample,) + range(sample_idx * n_chunks_per_sample, (sample_idx + 1) * n_chunks_per_sample) ) # Collect retrieved tokens. @@ -144,7 +144,7 @@ def __getitem__(self, sample_idx: int) -> dict: def get_retro_datasets( - config: RetroConfig, gpt_datasets: dict, sample_length: int, eod_token_id: int, + config: RetroConfig, gpt_datasets: dict, sample_length: int, eod_token_id: int ) -> Tuple[Optional[RetroDataset], Optional[RetroDataset], Optional[RetroDataset]]: """Get train, valid, test retro datasets. @@ -190,7 +190,7 @@ def get_retro_datasets( # preprocessing and pretraining. chunk_dataset = chunk_ds_info["dataset"] chunk_ds_info["neighbor_dir"] = os.path.join( - query_dir, config.retro_neighbor_dirs[data_key], + query_dir, config.retro_neighbor_dirs[data_key] ) neighbor_dir = chunk_ds_info["neighbor_dir"] neighbor_path_map = BlockPathMap.from_dir( @@ -235,8 +235,4 @@ def get_retro_datasets( neighbor_path_map=neighbor_path_map, ) - return ( - retro_dataset_map["train"], - retro_dataset_map["valid"], - retro_dataset_map["test"], - ) + return (retro_dataset_map["train"], retro_dataset_map["valid"], retro_dataset_map["test"]) diff --git a/megatron/core/datasets/retro/query/utils.py b/megatron/core/datasets/retro/query/utils.py index f07920d48c..b4e0c67009 100644 --- a/megatron/core/datasets/retro/query/utils.py +++ b/megatron/core/datasets/retro/query/utils.py @@ -31,5 +31,5 @@ def get_neighbor_dir(project_dir: str, key: str, dataset: MegatronDataset) -> st Path to directory containing this dataset's neighbors within Retro project. """ return os.path.join( - get_query_dir(project_dir), os.path.basename(f"{key}_{dataset.unique_description_hash}"), + get_query_dir(project_dir), os.path.basename(f"{key}_{dataset.unique_description_hash}") ) diff --git a/megatron/core/datasets/retro/utils.py b/megatron/core/datasets/retro/utils.py index dbef86a38d..31c0be14c8 100644 --- a/megatron/core/datasets/retro/utils.py +++ b/megatron/core/datasets/retro/utils.py @@ -110,10 +110,7 @@ def __getitem__(self, idx: int) -> dict: def get_blocks( - dirname: str, - n_samples: int, - block_size: int, - validate: Callable = None, + dirname: str, n_samples: int, block_size: int, validate: Callable = None ) -> SimpleNamespace: """Divide range [0, num_samples) to sequence of block ranges. @@ -147,8 +144,7 @@ def get_blocks( { "range": r, "path": os.path.join( - dirname, - "%s-%s.hdf5" % tuple([str(i).zfill(n_digits) for i in r]), + dirname, "%s-%s.hdf5" % tuple([str(i).zfill(n_digits) for i in r]) ), } for r in block_ranges diff --git a/megatron/core/datasets/t5_dataset.py b/megatron/core/datasets/t5_dataset.py index 33792c8636..b54e4f5315 100644 --- a/megatron/core/datasets/t5_dataset.py +++ b/megatron/core/datasets/t5_dataset.py @@ -30,8 +30,7 @@ class T5MaskedWordPieceDatasetConfig(MaskedWordPieceDatasetConfig): """The sequence length for the decoder""" def __post_init__(self) -> None: - """Do asserts and set fields post init - """ + """Do asserts and set fields post init""" super().__post_init__() self.sequence_length_encoder = self.sequence_length @@ -85,23 +84,21 @@ def _key_config_attributes() -> List[str]: """ return super( T5MaskedWordPieceDataset, T5MaskedWordPieceDataset - )._key_config_attributes() + ["sequence_length_decoder",] + )._key_config_attributes() + ["sequence_length_decoder"] def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]: """Abstract method implementation - + Args: idx (int): The index into the dataset Returns: - Dict[str, Union[int, numpy.ndarray]]: The + Dict[str, Union[int, numpy.ndarray]]: The """ idx_beg, idx_end, target_sequence_length = self.sample_index[idx] sample = [self.dataset[i] for i in range(idx_beg, idx_end)] - numpy_random_state = numpy.random.RandomState( - seed=(self.config.random_seed + idx) % 2 ** 32 - ) + numpy_random_state = numpy.random.RandomState(seed=(self.config.random_seed + idx) % 2**32) assert target_sequence_length <= self.config.sequence_length @@ -113,7 +110,7 @@ def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]: tokens = tokens[:target_sequence_length] # Masking - (tokens, _, _, _, masked_spans,) = self._create_masked_lm_predictions( + (tokens, _, _, _, masked_spans) = self._create_masked_lm_predictions( tokens, target_sequence_length, numpy_random_state ) diff --git a/megatron/core/datasets/utils.py b/megatron/core/datasets/utils.py index 45203c256a..8d887d4a4a 100644 --- a/megatron/core/datasets/utils.py +++ b/megatron/core/datasets/utils.py @@ -19,8 +19,7 @@ class Split(Enum): def compile_helpers(): - """Compile C++ helper functions at runtime. Make sure this is invoked on a single process. - """ + """Compile C++ helper functions at runtime. Make sure this is invoked on a single process.""" import os import subprocess @@ -51,7 +50,7 @@ def get_blend_from_list( blend: Optional[List[str]], ) -> Optional[Tuple[List[str], Optional[List[float]]]]: """Get the megatron.core.datasets.blended_megatron_dataset_config.BlendedMegatronDatasetConfig blend from the blend list - + Args: blend (Optional[List[str]]): The blend list, which can be either (1) a list of prefixes, e.g. ["path/to/dataset_1_prefix", "path/to/dataset_2_prefix"], or (2) a flattened, zipped list of weights and prefixes, e.g. ["30", "path/to/dataset_1_prefix", "70", "path/to/dataset_2_prefix"] diff --git a/megatron/core/dist_checkpointing/core.py b/megatron/core/dist_checkpointing/core.py index 50384e661b..af6ebff6ec 100644 --- a/megatron/core/dist_checkpointing/core.py +++ b/megatron/core/dist_checkpointing/core.py @@ -11,14 +11,14 @@ class CheckpointingException(Exception): - """ Base checkpointing related exception """ + """Base checkpointing related exception""" pass @dataclass class CheckpointingConfig: - """ Documents backends used in the checkpoint. + """Documents backends used in the checkpoint. Checkpoint config keeps track of formats used for storing the sharded tensors (sharded_backend) and other objects (common_backend). @@ -34,7 +34,7 @@ class CheckpointingConfig: def check_is_distributed_checkpoint(checkpoint_dir): - """ Checks if `metadata.json` exists in the checkpoint and is a valid config. + """Checks if `metadata.json` exists in the checkpoint and is a valid config. Args: checkpoint_dir: checkpoint directory @@ -46,7 +46,7 @@ def check_is_distributed_checkpoint(checkpoint_dir): def maybe_load_config(checkpoint_dir: str) -> Optional[CheckpointingConfig]: - """ Returns checkpoint config if `checkpoint_dir` is a distributed checkpoint and None otherwise + """Returns checkpoint config if `checkpoint_dir` is a distributed checkpoint and None otherwise Args: checkpoint_dir: checkpoint directory @@ -63,7 +63,7 @@ def maybe_load_config(checkpoint_dir: str) -> Optional[CheckpointingConfig]: def save_config(config: CheckpointingConfig, checkpoint_dir: str): - """ Save given config to checkpoint directory. + """Save given config to checkpoint directory. Args: config: checkpoint config diff --git a/megatron/core/dist_checkpointing/serialization.py b/megatron/core/dist_checkpointing/serialization.py index f37aadc913..43ad3bc49e 100644 --- a/megatron/core/dist_checkpointing/serialization.py +++ b/megatron/core/dist_checkpointing/serialization.py @@ -182,8 +182,7 @@ def load_common_state_dict(checkpoint_dir: Path) -> StateDict: def load_tensors_metadata( - checkpoint_dir: str, - sharded_strategy: Union[LoadShardedStrategy, None] = None, + checkpoint_dir: str, sharded_strategy: Union[LoadShardedStrategy, None] = None ) -> CkptShardedMetadata: """Load tensors metadata from the checkpoint. diff --git a/megatron/core/dist_checkpointing/strategies/async_utils.py b/megatron/core/dist_checkpointing/strategies/async_utils.py index 24ee43d7e0..7cdda8ac32 100644 --- a/megatron/core/dist_checkpointing/strategies/async_utils.py +++ b/megatron/core/dist_checkpointing/strategies/async_utils.py @@ -76,11 +76,7 @@ def __init__(self): self.process: Optional[mp.Process] = None self.start_time: Optional[float] = None - def schedule_async_call( - self, - async_fn: Optional[Callable], - save_args: Tuple, - ) -> None: + def schedule_async_call(self, async_fn: Optional[Callable], save_args: Tuple) -> None: """Spawn a process with `async_fn` as the target. This method must be called on all ranks. @@ -101,10 +97,7 @@ def schedule_async_call( ctx = mp.get_context('fork') self.start_time = time() - self.process = ctx.Process( - target=async_fn, - args=save_args, - ) + self.process = ctx.Process(target=async_fn, args=save_args) self.process.start() init_time = time() logger.debug( diff --git a/megatron/core/dist_checkpointing/strategies/filesystem_async.py b/megatron/core/dist_checkpointing/strategies/filesystem_async.py index bfa609128a..9d0be4d6e7 100644 --- a/megatron/core/dist_checkpointing/strategies/filesystem_async.py +++ b/megatron/core/dist_checkpointing/strategies/filesystem_async.py @@ -284,11 +284,7 @@ def write_preloaded_data( f"{local_proc_idx} consumed: {mem_after - mem_before}, before: {mem_before}, after: {mem_after}" ) - def write_data( - self, - plan: SavePlan, - planner: SavePlanner, - ) -> Future[List[WriteResult]]: + def write_data(self, plan: SavePlan, planner: SavePlanner) -> Future[List[WriteResult]]: raise NotImplementedError('write_data not implemented for FileSystemWriterAsync') def retrieve_write_results(self) -> List[WriteResult]: diff --git a/megatron/core/dist_checkpointing/strategies/fully_parallel.py b/megatron/core/dist_checkpointing/strategies/fully_parallel.py index 0b004e2bce..238c381378 100644 --- a/megatron/core/dist_checkpointing/strategies/fully_parallel.py +++ b/megatron/core/dist_checkpointing/strategies/fully_parallel.py @@ -97,11 +97,7 @@ def __init__( self.cached_distribution: Optional[SaveLoadDistribution] = None - def async_save( - self, - sharded_state_dict: ShardedStateDict, - checkpoint_dir: Path, - ): + def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): if not isinstance(self.base_strategy, AsyncSaveShardedStrategy): raise CheckpointingException( f'Cannot apply async_save to non-async base strategy {self.base_strategy}' @@ -109,11 +105,7 @@ def async_save( self.apply_saving_parallelization(sharded_state_dict) return self.base_strategy.async_save(sharded_state_dict, checkpoint_dir) - def save( - self, - sharded_state_dict: ShardedStateDict, - checkpoint_dir: Path, - ): + def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): self.apply_saving_parallelization(sharded_state_dict) return self.base_strategy.save(sharded_state_dict, checkpoint_dir) @@ -248,12 +240,9 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St # Step 3: load part of the checkpoint. # Load only sharded objects first. ShardedTensors will be loaded separately # so that we can keep track of sharded tensors loaded by this rank - ( - sharded_tensors, - sharded_state_dict, - to_load_shards, - unloaded_shards, - ) = self._defer_loading_sharded_tensors(sharded_state_dict) + (sharded_tensors, sharded_state_dict, to_load_shards, unloaded_shards) = ( + self._defer_loading_sharded_tensors(sharded_state_dict) + ) loaded_state_dict = self.base_strategy.load(sharded_state_dict, checkpoint_dir) end = time() @@ -279,10 +268,7 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St raise NotImplementedError(f'Unrecognized gather algorithm: {self.exchange_algo}') all_loaded_tensors = exchange_fn( - loaded_tensors, - unloaded_shards, - precomputed_distribution, - self.parallelization_group, + loaded_tensors, unloaded_shards, precomputed_distribution, self.parallelization_group ) if not set(unloaded_shards.keys()).issubset(all_loaded_tensors.keys()): missing_shards = set(unloaded_shards.keys()) - all_loaded_tensors.keys() @@ -300,7 +286,9 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St merge(loaded_state_dict, sharded_tensors) return loaded_state_dict - def _defer_loading_sharded_tensors(self, sharded_state_dict: ShardedStateDict) -> Tuple[ + def _defer_loading_sharded_tensors( + self, sharded_state_dict: ShardedStateDict + ) -> Tuple[ ShardedStateDict, ShardedStateDict, Dict[_ShardId, ShardedTensor], diff --git a/megatron/core/dist_checkpointing/strategies/state_dict_saver.py b/megatron/core/dist_checkpointing/strategies/state_dict_saver.py index 092e91d2f8..8e1d2c5523 100644 --- a/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +++ b/megatron/core/dist_checkpointing/strategies/state_dict_saver.py @@ -124,9 +124,7 @@ def global_step(all_local_plans): def save_state_dict_async_finalize( - storage_writer: 'FileSystemWriterAsync', - global_metadata: Metadata, - dist_wrapper: _DistWrapper, + storage_writer: 'FileSystemWriterAsync', global_metadata: Metadata, dist_wrapper: _DistWrapper ) -> None: """ Finalization of save_state_dict_async_plan. diff --git a/megatron/core/dist_checkpointing/strategies/tensorstore.py b/megatron/core/dist_checkpointing/strategies/tensorstore.py index 61972ec95b..9b4eeb3185 100644 --- a/megatron/core/dist_checkpointing/strategies/tensorstore.py +++ b/megatron/core/dist_checkpointing/strategies/tensorstore.py @@ -115,10 +115,7 @@ def open_ts_array(arr_path: Path): arr_path (Path): path to a Zarr (Tensorstore) array """ spec = {'driver': 'zarr', 'metadata_key': '.zarray', 'kvstore': {}} - spec['kvstore'] = { - 'driver': 'file', - 'path': str(arr_path), - } + spec['kvstore'] = {'driver': 'file', 'path': str(arr_path)} try: arr = ts.open(ts.Spec(spec), open=True).result() except Exception as e: diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py index d42d3ccda0..2fccba1f8d 100644 --- a/megatron/core/dist_checkpointing/strategies/torch.py +++ b/megatron/core/dist_checkpointing/strategies/torch.py @@ -524,8 +524,7 @@ def resolve_tensor(self, read_item: ReadItem): ): self._intermediate_read_item_and_target = (read_item, target_tensor) target_tensor = Float8Tensor.make_like( - target_tensor, - data=target_tensor._data.contiguous(), + target_tensor, data=target_tensor._data.contiguous() ) return target_tensor @@ -588,9 +587,7 @@ def __init__( self.use_cached_ckpt_structure: bool = cached_metadata def async_save( - self, - sharded_state_dict: ShardedStateDict, - checkpoint_dir: Path, + self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path ) -> AsyncRequest: """Translates MCore ShardedTensors to PyT ShardedTensors and saves in PyT Distributed format. @@ -601,12 +598,10 @@ def async_save( Returns: None """ # Translate the state dict - ( - sharded_state_dict, - flat_mapping, - rename_mapping, - ) = _replace_state_dict_keys_with_sharded_keys( - sharded_state_dict, self.keep_only_main_replica + (sharded_state_dict, flat_mapping, rename_mapping) = ( + _replace_state_dict_keys_with_sharded_keys( + sharded_state_dict, self.keep_only_main_replica + ) ) pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, False) # Use PyT saving mechanism @@ -716,11 +711,9 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St orig_sharded_state_dict = sharded_state_dict # MCore state dict to PyT Distributed compatible - ( - sharded_state_dict, - flat_mapping, - rename_mapping, - ) = _replace_state_dict_keys_with_sharded_keys(sharded_state_dict) + (sharded_state_dict, flat_mapping, rename_mapping) = ( + _replace_state_dict_keys_with_sharded_keys(sharded_state_dict) + ) pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, True) # Load PyT Distributed format checkpoint.load_state_dict( @@ -764,8 +757,7 @@ def load_tensors_metadata(self, checkpoint_dir: Path, metadata: Metadata = None) if nd_orig_global_shape is None: # Regular tensor sharded_metadata[k] = ShardedTensor.from_rank_offsets( - k, - torch.empty(tp.size, **tp.properties.__dict__, device='meta'), + k, torch.empty(tp.size, **tp.properties.__dict__, device='meta') ).without_data() else: # N-D flattened tensor diff --git a/megatron/core/dist_checkpointing/strategies/two_stage.py b/megatron/core/dist_checkpointing/strategies/two_stage.py index 8d20c32bbb..72e60bc79b 100644 --- a/megatron/core/dist_checkpointing/strategies/two_stage.py +++ b/megatron/core/dist_checkpointing/strategies/two_stage.py @@ -59,10 +59,7 @@ class _ShardedTensorMetadata: def sharded_tensor_chunk_id(sharded_tensor: ShardedTensor): - return ( - sharded_tensor.key, - sharded_tensor.global_offset, - ) + return (sharded_tensor.key, sharded_tensor.global_offset) class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy): @@ -177,7 +174,7 @@ def _build_load_plan( @timed() def deduplicate_chunks(self, ten_metas: List[_ShardedTensorMetadata]): - """ Group tensors by chunk and then pick the tensor with the lowest rank. + """Group tensors by chunk and then pick the tensor with the lowest rank. NOTE: with proper loading overlap, loading from randomized ranks (instead of the smallest one) could be beneficial here. diff --git a/megatron/core/dist_checkpointing/utils.py b/megatron/core/dist_checkpointing/utils.py index 98ce01dd37..ff12b32662 100644 --- a/megatron/core/dist_checkpointing/utils.py +++ b/megatron/core/dist_checkpointing/utils.py @@ -73,18 +73,14 @@ def extract_sharded_tensors_or_nonpersistent( def extract_sharded_base( sharded_state_dict: ShardedStateDict, ) -> Tuple[ShardedStateDict, StateDict]: - return extract_matching_values( - sharded_state_dict, - lambda v: isinstance(v, ShardedBase), - ) + return extract_matching_values(sharded_state_dict, lambda v: isinstance(v, ShardedBase)) def extract_nonpersistent( sharded_state_dict: ShardedStateDict, ) -> Tuple[ShardedStateDict, StateDict]: return extract_matching_values( - sharded_state_dict, - lambda v: isinstance(v, LocalNonpersistentObject), + sharded_state_dict, lambda v: isinstance(v, LocalNonpersistentObject) ) diff --git a/megatron/core/dist_checkpointing/validation.py b/megatron/core/dist_checkpointing/validation.py index c45245b2e5..cd11b82ed6 100644 --- a/megatron/core/dist_checkpointing/validation.py +++ b/megatron/core/dist_checkpointing/validation.py @@ -100,10 +100,7 @@ def requires_global_app_metadata(val: 'StrictHandling') -> bool: @staticmethod def requires_returning_mismatch_keys(val: 'StrictHandling') -> bool: """Whether a given strict option results in extra return value from the `load` function.""" - return val in ( - StrictHandling.RETURN_UNEXPECTED, - StrictHandling.RETURN_ALL, - ) + return val in (StrictHandling.RETURN_UNEXPECTED, StrictHandling.RETURN_ALL) def parse_strict_flag(strict: Union[str, StrictHandling]) -> StrictHandling: @@ -253,8 +250,7 @@ def verify_checkpoint_and_load_strategy( def adjust_non_strict_load( - sharded_state_dict: ShardedStateDict, - sharded_keys_to_remove: Set[str], + sharded_state_dict: ShardedStateDict, sharded_keys_to_remove: Set[str] ) -> ShardedStateDict: """Adjusts sharded state dict removing keys not existing in the checkpoint. diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py index 2c02e5f7d1..0451a6e4fb 100644 --- a/megatron/core/distributed/distributed_data_parallel.py +++ b/megatron/core/distributed/distributed_data_parallel.py @@ -97,9 +97,7 @@ def __init__( expert_parallel_params.append(param) def allocate_buffers_for_parameters( - input_params, - data_parallel_group, - gradient_scaling_factor, + input_params, data_parallel_group, gradient_scaling_factor ): param_and_grad_dtype_to_params = {} diff --git a/megatron/core/distributed/finalize_model_grads.py b/megatron/core/distributed/finalize_model_grads.py index f1a1c2b88c..ff5046afa5 100644 --- a/megatron/core/distributed/finalize_model_grads.py +++ b/megatron/core/distributed/finalize_model_grads.py @@ -150,11 +150,7 @@ def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torc # need to do a broadcast for every pp group, even though num_tokens should be the same. num_tokens_list = [] for lr, group in zip(last_rank, pp_group): - torch.distributed.broadcast( - num_tokens, - src=lr, - group=group, - ) + torch.distributed.broadcast(num_tokens, src=lr, group=group) num_tokens_list.append(torch.clone(num_tokens)) assert all(x.item() == num_tokens_list[0] for x in num_tokens_list) diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index efed47c5ba..65c8eeb1be 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -324,11 +324,7 @@ def _does_param_require_new_bucket(param): assert data_start_index % self.data_parallel_world_size == 0 _create_new_bucket(data_start_index) - self.param_index_map[param] = ( - data_start_index, - data_end_index, - bucket_id, - ) + self.param_index_map[param] = (data_start_index, data_end_index, bucket_id) bucket_params.add(param) # If we have enough elements already or the current param is part of the shared embedding diff --git a/megatron/core/fusions/fused_bias_dropout.py b/megatron/core/fusions/fused_bias_dropout.py index 08af02b099..c7fa8419a0 100644 --- a/megatron/core/fusions/fused_bias_dropout.py +++ b/megatron/core/fusions/fused_bias_dropout.py @@ -47,14 +47,14 @@ def _bias_dropout_add(x_with_bias, residual, prob): @jit_fuser def bias_dropout_add_fused_train( - x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float, + x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float ) -> torch.Tensor: return _bias_dropout_add_func(x_with_bias, residual, prob, True) @jit_fuser def bias_dropout_add_fused_inference( - x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float, + x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float ) -> torch.Tensor: return _bias_dropout_add_func(x_with_bias, residual, prob, False) diff --git a/megatron/core/fusions/fused_cross_entropy.py b/megatron/core/fusions/fused_cross_entropy.py index e10c04c23b..909cc403cf 100644 --- a/megatron/core/fusions/fused_cross_entropy.py +++ b/megatron/core/fusions/fused_cross_entropy.py @@ -33,14 +33,10 @@ def calculate_predicted_logits( vocab_end_index: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - ( - target_mask, - masked_target_1d, - predicted_logits, - sum_exp_logits, - exp_logits, - ) = VocabParallelCrossEntropy.calculate_predicted_logits( - vocab_parallel_logits, target, logits_max, vocab_start_index, vocab_end_index + (target_mask, masked_target_1d, predicted_logits, sum_exp_logits, exp_logits) = ( + VocabParallelCrossEntropy.calculate_predicted_logits( + vocab_parallel_logits, target, logits_max, vocab_start_index, vocab_end_index + ) ) predicted_logits_sum_exp_logits = torch.cat((predicted_logits, sum_exp_logits)) @@ -71,12 +67,9 @@ def calculate_gradients( masked_target_1d: torch.Tensor, ) -> torch.Tensor: - ( - grad_2d, - arange_1d, - softmax_update, - grad_input, - ) = VocabParallelCrossEntropy.prepare_gradient_calculation_operands(softmax, target_mask) + (grad_2d, arange_1d, softmax_update, grad_input) = ( + VocabParallelCrossEntropy.prepare_gradient_calculation_operands(softmax, target_mask) + ) grad_input = VocabParallelCrossEntropy.calculate_gradients( grad_2d, arange_1d, masked_target_1d, softmax_update, grad_input, grad_output @@ -103,13 +96,10 @@ def forward(ctx, vocab_parallel_logits, target): world_size = get_tensor_model_parallel_world_size() vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size) - ( - target_mask, - masked_target_1d, - predicted_logits_sum_exp_logits, - exp_logits, - ) = calculate_predicted_logits( - vocab_parallel_logits, target, logits_max, vocab_start_index, vocab_end_index + (target_mask, masked_target_1d, predicted_logits_sum_exp_logits, exp_logits) = ( + calculate_predicted_logits( + vocab_parallel_logits, target, logits_max, vocab_start_index, vocab_end_index + ) ) # All reduce is needed to get the chunks from other GPUs. diff --git a/megatron/core/inference/modelopt_support/gpt/model_specs.py b/megatron/core/inference/modelopt_support/gpt/model_specs.py index e3d8e08d30..50415ac006 100644 --- a/megatron/core/inference/modelopt_support/gpt/model_specs.py +++ b/megatron/core/inference/modelopt_support/gpt/model_specs.py @@ -47,8 +47,7 @@ def get_gpt_layer_modelopt_spec( mlp=ModuleSpec( module=MLP, submodules=MLPSubmodules( - linear_fc1=ColumnParallelLinear, - linear_fc2=RowParallelLinear, + linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear ), ), mlp_bda=get_bias_dropout_add, diff --git a/megatron/core/inference/modelopt_support/gpt/state_dict_hooks.py b/megatron/core/inference/modelopt_support/gpt/state_dict_hooks.py index f81c4f5e03..15c3527c94 100644 --- a/megatron/core/inference/modelopt_support/gpt/state_dict_hooks.py +++ b/megatron/core/inference/modelopt_support/gpt/state_dict_hooks.py @@ -8,13 +8,7 @@ def mcore_gpt_load_legacy_state_dict_pre_hook( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): """Register a pre-hook to fix the state_dict key difference. @@ -87,13 +81,7 @@ def mcore_gpt_load_legacy_state_dict_pre_hook( def mcore_gpt_load_te_state_dict_pre_hook( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): """Register a pre-hook to fix the state_dict key difference of. diff --git a/megatron/core/inference/scheduler.py b/megatron/core/inference/scheduler.py index 35efb935f0..abcb325185 100644 --- a/megatron/core/inference/scheduler.py +++ b/megatron/core/inference/scheduler.py @@ -85,10 +85,9 @@ def add_earliest_waiting_request_to_active_pool(self): len(self.active_request_pool) < self.max_batch_size ), "Active request pool is already full. Cant add any more requests" if len(self.waiting_request_pool) > 0: - ( - earliest_waiting_request_request_id, - earliest_waiting_request, - ) = self.waiting_request_pool.popitem(last=False) + (earliest_waiting_request_request_id, earliest_waiting_request) = ( + self.waiting_request_pool.popitem(last=False) + ) earliest_waiting_request.status = Status.ACTIVE_BUT_NOT_GENERATING_TOKENS self.active_request_pool[earliest_waiting_request_request_id] = earliest_waiting_request diff --git a/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py b/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py index b5eed123bc..e4db83f6b3 100644 --- a/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py @@ -189,8 +189,7 @@ def pad_input_prompt_tokens( return torch.tensor(batch_prompt_tokens_list).cuda() def generate_output_tokens_dynamic_batch( - self, - active_requests: OrderedDict[int, InferenceRequest], + self, active_requests: OrderedDict[int, InferenceRequest] ) -> OrderedDict[int, InferenceRequest]: """Utility to generate the output tokens and probabilities for the prompts @@ -205,8 +204,7 @@ def generate_output_tokens_dynamic_batch( raise Exception("Not implemented yet") def generate_all_output_tokens_static_batch( - self, - active_requests: OrderedDict[int, InferenceRequest], + self, active_requests: OrderedDict[int, InferenceRequest] ) -> OrderedDict[int, InferenceRequest]: """Utility to generate the all the output tokens and probabilities for the prompts . @@ -305,15 +303,14 @@ def generate_all_output_tokens_static_batch( context_start_position = context_end_position # Check end of generation status for each tensor and update generated sequence lengths - ( - is_generation_done_tensor, - generated_sequence_lengths, - ) = self.update_generation_status( - updated_prompts_tokens=batch_prompt_tokens, - generation_started=generation_started, - current_context_end_position=context_end_position, - is_generation_done_tensor=is_generation_done_tensor, - generated_sequence_lengths=generated_sequence_lengths, + (is_generation_done_tensor, generated_sequence_lengths) = ( + self.update_generation_status( + updated_prompts_tokens=batch_prompt_tokens, + generation_started=generation_started, + current_context_end_position=context_end_position, + is_generation_done_tensor=is_generation_done_tensor, + generated_sequence_lengths=generated_sequence_lengths, + ) ) # Boolean flag indicating if all prompts are finished all_prompts_done = torch.all(is_generation_done_tensor) diff --git a/megatron/core/models/T5/t5_model.py b/megatron/core/models/T5/t5_model.py index 37a395ea47..8266757433 100644 --- a/megatron/core/models/T5/t5_model.py +++ b/megatron/core/models/T5/t5_model.py @@ -247,12 +247,10 @@ def forward( Tensor: loss tensor """ - ( - encoder_attn_mask, - decoder_attn_mask, - encoder_decoder_attn_mask, - ) = t5_extended_attention_mask( - [encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask] + (encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask) = ( + t5_extended_attention_mask( + [encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask] + ) ) ## Encoder forward diff --git a/megatron/core/models/T5/t5_spec.py b/megatron/core/models/T5/t5_spec.py index f195dcac35..520c3c5c8a 100644 --- a/megatron/core/models/T5/t5_spec.py +++ b/megatron/core/models/T5/t5_spec.py @@ -69,8 +69,7 @@ def encoder_model_with_transformer_engine_default_spec() -> ModuleSpec: mlp=ModuleSpec( module=MLP, submodules=MLPSubmodules( - linear_fc1=TELayerNormColumnParallelLinear, - linear_fc2=TERowParallelLinear, + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear ), ), mlp_bda=get_bias_dropout_add, @@ -110,8 +109,7 @@ def decoder_model_with_transformer_engine_default_spec() -> ModuleSpec: mlp=ModuleSpec( module=MLP, submodules=MLPSubmodules( - linear_fc1=TELayerNormColumnParallelLinear, - linear_fc2=TERowParallelLinear, + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear ), ), mlp_bda=get_bias_dropout_add, @@ -142,8 +140,7 @@ def encoder_model_with_local_spec() -> ModuleSpec: mlp=ModuleSpec( module=MLP, submodules=MLPSubmodules( - linear_fc1=ColumnParallelLinear, - linear_fc2=RowParallelLinear, + linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear ), ), mlp_bda=get_bias_dropout_add, @@ -189,8 +186,7 @@ def decoder_model_with_local_spec() -> ModuleSpec: mlp=ModuleSpec( module=MLP, submodules=MLPSubmodules( - linear_fc1=ColumnParallelLinear, - linear_fc2=RowParallelLinear, + linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear ), ), mlp_bda=get_bias_dropout_add, diff --git a/megatron/core/models/bert/bert_layer_specs.py b/megatron/core/models/bert/bert_layer_specs.py index 1eb965c299..b5b117b498 100644 --- a/megatron/core/models/bert/bert_layer_specs.py +++ b/megatron/core/models/bert/bert_layer_specs.py @@ -54,8 +54,7 @@ mlp=ModuleSpec( module=MLP, submodules=MLPSubmodules( - linear_fc1=TELayerNormColumnParallelLinear, - linear_fc2=TERowParallelLinear, + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear ), ), mlp_bda=get_bias_dropout_add, @@ -82,10 +81,7 @@ pre_mlp_layernorm=LNImpl, mlp=ModuleSpec( module=MLP, - submodules=MLPSubmodules( - linear_fc1=ColumnParallelLinear, - linear_fc2=RowParallelLinear, - ), + submodules=MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear), ), mlp_bda=get_bias_dropout_add, sharded_state_dict_keys_map={ diff --git a/megatron/core/models/bert/bert_lm_head.py b/megatron/core/models/bert/bert_lm_head.py index ff0411dc59..fd26ebd16f 100644 --- a/megatron/core/models/bert/bert_lm_head.py +++ b/megatron/core/models/bert/bert_lm_head.py @@ -30,11 +30,7 @@ class BertLMHead(MegatronModule): config (TransformerConfig): TransformerConfig object """ - def __init__( - self, - hidden_size: int, - config: TransformerConfig, - ): + def __init__(self, hidden_size: int, config: TransformerConfig): super().__init__(config=config) # TODO: Should switch this to TE ? @@ -46,9 +42,7 @@ def __init__( setattr(self.dense.bias, 'sequence_parallel', config.sequence_parallel) self.layer_norm = LNImpl( - config=config, - hidden_size=hidden_size, - eps=config.layernorm_epsilon, + config=config, hidden_size=hidden_size, eps=config.layernorm_epsilon ) self.gelu = torch.nn.functional.gelu diff --git a/megatron/core/models/bert/bert_model.py b/megatron/core/models/bert/bert_model.py index eb94ebbb9f..0b571ca68d 100644 --- a/megatron/core/models/bert/bert_model.py +++ b/megatron/core/models/bert/bert_model.py @@ -122,10 +122,7 @@ def __init__( # Output if post_process: # TODO: Make sure you are passing in the mpu_vocab_size properly - self.lm_head = BertLMHead( - config.hidden_size, - config, - ) + self.lm_head = BertLMHead(config.hidden_size, config) self.output_layer = tensor_parallel.ColumnParallelLinear( config.hidden_size, diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py index 207706d0be..0a4e5bf6de 100644 --- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py @@ -223,10 +223,7 @@ def apply_rotary_pos_emb_thd( def apply_rotary_pos_emb( - t: Tensor, - freqs: Tensor, - config: TransformerConfig, - cu_seqlens: Optional[Tensor] = None, + t: Tensor, freqs: Tensor, config: TransformerConfig, cu_seqlens: Optional[Tensor] = None ): """ Reroute to the appropriate apply_rotary_pos_emb function depending on diff --git a/megatron/core/models/mamba/mamba_layer_specs.py b/megatron/core/models/mamba/mamba_layer_specs.py index 91224bf6b3..8fcfc424e6 100755 --- a/megatron/core/models/mamba/mamba_layer_specs.py +++ b/megatron/core/models/mamba/mamba_layer_specs.py @@ -24,8 +24,7 @@ mixer=ModuleSpec( module=MambaMixer, submodules=MambaMixerSubmodules( - in_proj=TELayerNormColumnParallelLinear, - out_proj=TERowParallelLinear, + in_proj=TELayerNormColumnParallelLinear, out_proj=TERowParallelLinear ), ), mamba_bda=get_bias_dropout_add, @@ -58,8 +57,7 @@ mlp=ModuleSpec( module=MLP, submodules=MLPSubmodules( - linear_fc1=TELayerNormColumnParallelLinear, - linear_fc2=TERowParallelLinear, + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear ), ), mlp_bda=get_bias_dropout_add, diff --git a/megatron/core/models/retro/base_attention.py b/megatron/core/models/retro/base_attention.py index 741f712b72..ee8656d96a 100644 --- a/megatron/core/models/retro/base_attention.py +++ b/megatron/core/models/retro/base_attention.py @@ -9,7 +9,6 @@ class BaseRetroCrossAttention(MegatronModule): - """Base class for Retro cross attention, for both encoder & decoder layers. This class collects the retro arguments below (i.e., num neighbors, chunk diff --git a/megatron/core/models/retro/config.py b/megatron/core/models/retro/config.py index b9a5eb9648..3e3d0b538a 100644 --- a/megatron/core/models/retro/config.py +++ b/megatron/core/models/retro/config.py @@ -14,7 +14,7 @@ @dataclass class RetroConfig(TransformerConfig): - """Configuration object for Retro models. """ + """Configuration object for Retro models.""" # Retro. retro_project_dir: str = None diff --git a/megatron/core/models/retro/decoder_attention.py b/megatron/core/models/retro/decoder_attention.py index f459163ccc..6b7a04d884 100644 --- a/megatron/core/models/retro/decoder_attention.py +++ b/megatron/core/models/retro/decoder_attention.py @@ -22,7 +22,6 @@ class RetroDecoderCrossAttention(BaseRetroCrossAttention): - """Retro decoder's chunked cross attention operator. See this paper for more details: https://arxiv.org/abs/2112.04426. @@ -69,7 +68,7 @@ def __init__( if encoder_block_spec: self.encoder = TransformerBlock( - config=config, spec=encoder_block_spec, pre_process=True, post_process=False, + config=config, spec=encoder_block_spec, pre_process=True, post_process=False ) # self._encoder_key = 'encoder' # ... necessary? else: @@ -124,7 +123,7 @@ def forward( # Pad partial chunk with zeros. first_chunk = torch.nn.functional.pad( - first_chunk, (0, 0, 0, 0, 0, self.retro_chunk_length - first_ns), 'constant', 0, + first_chunk, (0, 0, 0, 0, 0, self.retro_chunk_length - first_ns), 'constant', 0 ) # Concatenate padded chunk with remaining chunks. @@ -169,7 +168,7 @@ def forward( # Pad attending tokens to sequence length. padded_chunks = torch.nn.functional.pad( - attending_chunks, (0, 0, 0, 0, 0, self.retro_chunk_length - 1), 'constant', 0, + attending_chunks, (0, 0, 0, 0, 0, self.retro_chunk_length - 1), 'constant', 0 ) # Permute attending chunks. @@ -210,7 +209,6 @@ def forward( class RetroDecoderBiasDropoutAdd(MegatronModule): - """Retro decoder's bias-dropout-add operator. This operator takes care of reshaping and permuting the output from the @@ -220,9 +218,7 @@ class RetroDecoderBiasDropoutAdd(MegatronModule): config (RetroConfig): Retro config. """ - def __init__( - self, config: RetroConfig, - ): + def __init__(self, config: RetroConfig): super().__init__(config=config) self.retro_chunk_length = config.retro_chunk_length @@ -282,7 +278,7 @@ def _forward( ) # Prepend zeros for non-attending tokens. - x = torch.nn.functional.pad(x, (0, 0, 0, 0, pad, 0), 'constant', 0,)[ + x = torch.nn.functional.pad(x, (0, 0, 0, 0, pad, 0), 'constant', 0)[ :ns ] # [ ns, bs, d ] diff --git a/megatron/core/models/retro/decoder_spec.py b/megatron/core/models/retro/decoder_spec.py index 0c16ccc8cb..d9cc69eacd 100644 --- a/megatron/core/models/retro/decoder_spec.py +++ b/megatron/core/models/retro/decoder_spec.py @@ -73,9 +73,7 @@ def get_retro_decoder_layer_te_spec( spec.submodules.pre_cross_attn_layernorm = TENorm spec.submodules.cross_attention = ModuleSpec( module=RetroDecoderCrossAttention, - params={ - "encoder_block_spec": encoder_block_spec, - }, + params={"encoder_block_spec": encoder_block_spec}, submodules=CrossAttentionSubmodules( linear_q=TEColumnParallelLinear, linear_kv=TEColumnParallelLinear, @@ -108,9 +106,7 @@ def get_retro_decoder_layer_local_spec( spec.submodules.pre_cross_attn_layernorm = LNImpl spec.submodules.cross_attention = ModuleSpec( module=RetroDecoderCrossAttention, - params={ - "encoder_block_spec": encoder_block_spec, - }, + params={"encoder_block_spec": encoder_block_spec}, submodules=CrossAttentionSubmodules( linear_q=ColumnParallelLinear, linear_kv=ColumnParallelLinear, diff --git a/megatron/core/models/retro/encoder_attention.py b/megatron/core/models/retro/encoder_attention.py index a2226c08da..76625abe33 100644 --- a/megatron/core/models/retro/encoder_attention.py +++ b/megatron/core/models/retro/encoder_attention.py @@ -17,7 +17,6 @@ class RetroEncoderCrossAttention(BaseRetroCrossAttention): - """Retro encoder's cross attention operator. See this paper for more details: https://arxiv.org/abs/2112.04426. @@ -96,14 +95,13 @@ def forward( residual = chunked_output # Collect tensors. - attention_output_tuples.append((attention_output, attention_bias, residual,)) + attention_output_tuples.append((attention_output, attention_bias, residual)) # Output. (List[Tuple[( [ r, bs*l, d ], [ d ] )]]) return attention_output_tuples class RetroEncoderBiasDropoutAdd(MegatronModule): - """Retro encoder's bias-dropout-add operator. This operator applies bias-dropout-add individually on each neighboring @@ -113,9 +111,7 @@ class RetroEncoderBiasDropoutAdd(MegatronModule): config (RetroConfig): Retro config. """ - def __init__( - self, config: RetroConfig, - ): + def __init__(self, config: RetroConfig): super().__init__(config=config) self.retro_num_neighbors = config.retro_num_neighbors @@ -186,7 +182,6 @@ def forward(self, training: bool, fused: bool) -> partial: class RetroEncoderLayerNorm(MegatronModule): - """Retro encoder's layernorm operator. This operator applies layernorm individually on each neighboring chunk that @@ -198,9 +193,7 @@ class RetroEncoderLayerNorm(MegatronModule): submodules (Type): Layer norm class. (Named 'submodules' to fit external interface.) """ - def __init__( - self, config: RetroConfig, submodules: Type, **kwargs: dict, - ): + def __init__(self, config: RetroConfig, submodules: Type, **kwargs: dict): super().__init__(config=config) norm_class = submodules self.norm = norm_class(config=config, **kwargs) @@ -211,7 +204,7 @@ def forward(self, input: Tensor) -> Tensor: Args: input (Tensor): Input chunks, concatenated into a single tensor. - + Returns: Output of the layer norm. """ diff --git a/megatron/core/models/retro/encoder_spec.py b/megatron/core/models/retro/encoder_spec.py index ac0eb15598..777b5324d8 100644 --- a/megatron/core/models/retro/encoder_spec.py +++ b/megatron/core/models/retro/encoder_spec.py @@ -63,9 +63,7 @@ def get_retro_encoder_layer_te_spec() -> ModuleSpec: spec.submodules.pre_cross_attn_layernorm = TENorm spec.submodules.cross_attention = ModuleSpec( module=RetroEncoderCrossAttention, - params={ - "attn_mask_type": AttnMaskType.padding, - }, + params={"attn_mask_type": AttnMaskType.padding}, submodules=CrossAttentionSubmodules( linear_q=TEColumnParallelLinear, linear_kv=TEColumnParallelLinear, @@ -74,16 +72,10 @@ def get_retro_encoder_layer_te_spec() -> ModuleSpec: ), ) spec.submodules.cross_attn_bda = ModuleSpec(module=RetroEncoderBiasDropoutAdd) - spec.submodules.pre_mlp_layernorm = ModuleSpec( - module=RetroEncoderLayerNorm, - submodules=TENorm, - ) + spec.submodules.pre_mlp_layernorm = ModuleSpec(module=RetroEncoderLayerNorm, submodules=TENorm) spec.submodules.mlp = ModuleSpec( module=MLP, - submodules=MLPSubmodules( - linear_fc1=TEColumnParallelLinear, - linear_fc2=TERowParallelLinear, - ), + submodules=MLPSubmodules(linear_fc1=TEColumnParallelLinear, linear_fc2=TERowParallelLinear), ) return spec @@ -103,9 +95,7 @@ def get_retro_encoder_layer_local_spec() -> ModuleSpec: spec.submodules.pre_cross_attn_layernorm = LNImpl spec.submodules.cross_attention = ModuleSpec( module=RetroEncoderCrossAttention, - params={ - "attn_mask_type": AttnMaskType.padding, - }, + params={"attn_mask_type": AttnMaskType.padding}, submodules=CrossAttentionSubmodules( linear_q=ColumnParallelLinear, linear_kv=ColumnParallelLinear, @@ -114,19 +104,13 @@ def get_retro_encoder_layer_local_spec() -> ModuleSpec: ), ) spec.submodules.cross_attn_bda = ModuleSpec(module=RetroEncoderBiasDropoutAdd) - spec.submodules.pre_mlp_layernorm = ModuleSpec( - module=RetroEncoderLayerNorm, - submodules=LNImpl, - ) + spec.submodules.pre_mlp_layernorm = ModuleSpec(module=RetroEncoderLayerNorm, submodules=LNImpl) spec.submodules.mlp = ModuleSpec( module=MLP, - submodules=MLPSubmodules( - linear_fc1=ColumnParallelLinear, - linear_fc2=RowParallelLinear, - ), + submodules=MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear), ) spec.submodules.sharded_state_dict_keys_map = { - 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_', + 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_' } # pre_mlp_layernorm doesn't need remapping return spec @@ -168,9 +152,7 @@ def get_retro_encoder_block_spec( spec.submodules.self_attention.params["attn_mask_type"] = AttnMaskType.padding spec.submodules.self_attention.submodules.core_attention = ModuleSpec( module=TEDotProductAttention if use_transformer_engine else DotProductAttention, - params={ - "attention_dropout": config.retro_encoder_attention_dropout, - }, + params={"attention_dropout": config.retro_encoder_attention_dropout}, ) layer_specs = [] diff --git a/megatron/core/models/retro/model.py b/megatron/core/models/retro/model.py index 32c6d26a62..8142c91f7a 100644 --- a/megatron/core/models/retro/model.py +++ b/megatron/core/models/retro/model.py @@ -11,7 +11,6 @@ class RetroModel(GPTModel): - """Retro Model. A Retro model mostly re-uses the GPTModel interface, with the only difference @@ -79,7 +78,7 @@ def forward( decoder_input=decoder_input, labels=labels, inference_params=inference_params, - extra_block_kwargs={"context": context, "context_mask": context_mask,}, + extra_block_kwargs={"context": context, "context_mask": context_mask}, ) def sharded_state_dict( diff --git a/megatron/core/models/vision/multimodal_projector.py b/megatron/core/models/vision/multimodal_projector.py index a5363ac45d..18e62c68a5 100644 --- a/megatron/core/models/vision/multimodal_projector.py +++ b/megatron/core/models/vision/multimodal_projector.py @@ -61,9 +61,7 @@ def forward(self, hidden_states): # deallocate_output_tensor() throwing an error, so a viewless tensor is # created to prevent this. encoder_output = make_viewless_tensor( - inp=encoder_output, - requires_grad=True, - keep_graph=True, + inp=encoder_output, requires_grad=True, keep_graph=True ) return encoder_output diff --git a/megatron/core/models/vision/vit_layer_specs.py b/megatron/core/models/vision/vit_layer_specs.py index a879d25398..876c14dce4 100644 --- a/megatron/core/models/vision/vit_layer_specs.py +++ b/megatron/core/models/vision/vit_layer_specs.py @@ -80,9 +80,7 @@ def get_vit_layer_with_local_spec() -> ModuleSpec: # Helper function to get module spec for MLP/MoE -def _get_mlp_module_spec( - use_te: bool = True, -) -> ModuleSpec: +def _get_mlp_module_spec(use_te: bool = True) -> ModuleSpec: # Dense MLP w/ or w/o TE modules. return ModuleSpec( module=MLP, diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index 04bffc8ff5..65f72ec8c8 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -247,12 +247,7 @@ def init_state_fn(opt): hysteresis=config.hysteresis, ) - optimizer_args = [ - optimizer, - config, - grad_scaler, - init_state_fn, - ] + optimizer_args = [optimizer, config, grad_scaler, init_state_fn] if config.use_distributed_optimizer: optimizer = DistributedOptimizer( *optimizer_args, @@ -266,11 +261,7 @@ def init_state_fn(opt): setattr(optimizer, 'model_parallel_group', model_parallel_group) else: # FP32 optimizer. - optimizer = FP32Optimizer( - optimizer, - config, - init_state_fn, - ) + optimizer = FP32Optimizer(optimizer, config, init_state_fn) setattr(optimizer, 'model_parallel_group', model_parallel_group) return optimizer diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index ee5551d616..8eee169c7b 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -168,9 +168,7 @@ def _build_model_gbuf_range(cls, param_and_grad_buffer: ParamAndGradBuffer, buck ) # Group into dict. - data = { - "param_map": param_range_map, - } + data = {"param_map": param_range_map} return data @@ -417,12 +415,7 @@ def __init__( HAVE_APEX_OR_TE ), f'Please install Apex or Transformer Engine to use DistributedOptimizer.' - super().__init__( - optimizer, - config, - grad_scaler, - init_state_fn, - ) + super().__init__(optimizer, config, grad_scaler, init_state_fn) assert isinstance( optimizer, Adam @@ -464,10 +457,9 @@ def __init__( self.model_param_gbuf_map = self._build_model_param_gbuf_map(self.gbuf_ranges) # Optimizer ranges. - ( - self.model_param_group_index_map, - self.opt_group_ranges, - ) = self._build_optimizer_group_ranges(self.optimizer.param_groups, self.gbuf_ranges) + (self.model_param_group_index_map, self.opt_group_ranges) = ( + self._build_optimizer_group_ranges(self.optimizer.param_groups, self.gbuf_ranges) + ) # Allocate main param shards. ( @@ -626,10 +618,7 @@ def load_state_dict(self, state_dict): # list. inner_state_dict = self.optimizer.state_dict() state_dict_param_groups = [ - { - **group, - "params": list(inner_state_dict["param_groups"][idx]["params"]), - } + {**group, "params": list(inner_state_dict["param_groups"][idx]["params"])} for idx, group in enumerate(state_dict["optimizer"]["param_groups"]) ] @@ -655,13 +644,7 @@ def load_state_dict(self, state_dict): ) state_dict_state.append( - ( - state_order, - { - "exp_avg": init_shard(), - "exp_avg_sq": init_shard(), - }, - ) + (state_order, {"exp_avg": init_shard(), "exp_avg_sq": init_shard()}) ) # Sort by state order (see method docstring for details). @@ -680,10 +663,7 @@ def load_state_dict(self, state_dict): # Optimizer. self.optimizer.load_state_dict( - { - "state": state_dict_state, - "param_groups": state_dict_param_groups, - } + {"state": state_dict_state, "param_groups": state_dict_param_groups} ) # Grad scaler. @@ -776,9 +756,7 @@ def get_parameter_state_dp_zero(self): ) # Collect param states. - state = { - "buckets_coalesced": True, - } + state = {"buckets_coalesced": True} for gbuf_idx, gbuf_range_maps in enumerate(self.gbuf_ranges): # Iterate grad buffers (by data type). @@ -822,10 +800,7 @@ def get_parameter_state_dp_zero(self): main_param = self.optimizer.param_groups[group_index]["params"][group_order] optim_state = self.optimizer.state[main_param] - tensors = { - "param": main_param, - **optim_state, - } + tensors = {"param": main_param, **optim_state} # Copy states into contiguous shard. gbuf_local_start = param_range_map["gbuf_local"].start @@ -1012,9 +987,7 @@ def sharded_param_state_fs_bucket_space( if next_param_start != cur_param_end: pad_tensors = { k: torch.empty( - next_param_start - cur_param_end, - dtype=v.dtype, - device=v.device, + next_param_start - cur_param_end, dtype=v.dtype, device=v.device ) for k, v in bucket_state[i].items() if isinstance(v, torch.Tensor) @@ -1112,10 +1085,7 @@ def sharded_param_state_fs_model_space( main_param = self.optimizer.param_groups[group_index]["params"][group_order] optim_state = self.optimizer.state[main_param] - tensors = { - "fp32_param": main_param, - **optim_state, - } + tensors = {"fp32_param": main_param, **optim_state} # Match optimizer parameter with model ShardedTensor (or ShardedTensorFactory) try: sharded_metadata = param_to_sharded_metadata[model_param] @@ -1188,10 +1158,7 @@ def load_parameter_state_from_fs_bucket_space(self, state_dict): main_param = self.optimizer.param_groups[group_index]["params"][group_order] optim_state = self.optimizer.state[main_param] - dst_tensors = { - "param": main_param, - **optim_state, - } + dst_tensors = {"param": main_param, **optim_state} for key in dst_tensors: dst_tensors[key].copy_(src_tensors[key]) @@ -1211,10 +1178,7 @@ def load_parameter_state_from_fs_model_space(self, state_dict): optim_state = self.optimizer.state[main_param] src_tensors = state_dict[param_idx] - dst_tensors = { - "fp32_param": main_param, - **optim_state, - } + dst_tensors = {"fp32_param": main_param, **optim_state} for key in dst_tensors: dst_tensors[key].copy_(src_tensors[key]) @@ -1561,10 +1525,7 @@ def _dispatch_gather_model_params(self, all_gather_handle_index: int, force_sync ] assert all_gather_handle_index < len(self.all_gather_handles) all_gather_handle = torch.distributed._all_gather_base( - pbuf, - pbuf_views[data_parallel_rank], - group=data_parallel_group, - async_op=async_op, + pbuf, pbuf_views[data_parallel_rank], group=data_parallel_group, async_op=async_op ) self.all_gather_handles[all_gather_handle_index] = all_gather_handle assert self.all_gather_handle_index_to_bucket_index_map[all_gather_handle_index] == ( diff --git a/megatron/core/optimizer/optimizer.py b/megatron/core/optimizer/optimizer.py index 3d6142d207..2a48c12d37 100644 --- a/megatron/core/optimizer/optimizer.py +++ b/megatron/core/optimizer/optimizer.py @@ -156,8 +156,7 @@ def step_with_ready_grads(self) -> bool: def get_grad_norm(self): grads_for_norm = self.get_main_grads_for_grad_norm() total_norm = get_grad_norm_fp32( - grads_for_norm, - model_parallel_group=self.get_model_parallel_group(), + grads_for_norm, model_parallel_group=self.get_model_parallel_group() ) return total_norm @@ -301,11 +300,7 @@ def __init__( if has_config_logger_enabled(config): log_config_to_disk(config, locals(), prefix=type(self).__name__) - super().__init__( - optimizer, - config, - init_state_fn, - ) + super().__init__(optimizer, config, init_state_fn) self.grad_scaler = grad_scaler # None grad scaler is only supported for bf16. @@ -477,12 +472,7 @@ def __init__( init_state_fn: Callable, ): - super().__init__( - optimizer, - config, - grad_scaler, - init_state_fn, - ) + super().__init__(optimizer, config, grad_scaler, init_state_fn) # Handle main parameters. @@ -713,19 +703,12 @@ class FP32Optimizer(MegatronOptimizer): """ def __init__( - self, - optimizer: torch.optim.Optimizer, - config: OptimizerConfig, - init_state_fn: Callable, + self, optimizer: torch.optim.Optimizer, config: OptimizerConfig, init_state_fn: Callable ): if has_config_logger_enabled(config): log_config_to_disk(config, locals(), prefix=type(self).__name__) - super(FP32Optimizer, self).__init__( - optimizer, - config, - init_state_fn, - ) + super(FP32Optimizer, self).__init__(optimizer, config, init_state_fn) self._scale = torch.tensor([1.0], dtype=torch.float, device='cuda') diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index d271fab225..19c19ff5a1 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -118,9 +118,7 @@ def get_nccl_options(pg_name, nccl_comm_cfgs): def generate_masked_orthogonal_rank_groups( - world_size: int, - parallel_size: List[int], - mask: List[bool], + world_size: int, parallel_size: List[int], mask: List[bool] ) -> List[List[int]]: """Generate orthogonal parallel groups based on the parallel size and mask. @@ -748,9 +746,7 @@ def generator_wrapper(group_type, **kwargs): embedding_ranks = get_embedding_ranks(ranks) group = torch.distributed.new_group( - embedding_ranks, - timeout=timeout, - pg_options=get_nccl_options('embd', nccl_comm_cfgs), + embedding_ranks, timeout=timeout, pg_options=get_nccl_options('embd', nccl_comm_cfgs) ) if rank in embedding_ranks: _EMBEDDING_GROUP = group @@ -871,10 +867,7 @@ def is_unitialized() -> bool: Deprecated. Use is_initialized instead. """ - warnings.warn( - "is_unitialized is deprecated, use is_initialized instead", - DeprecationWarning, - ) + warnings.warn("is_unitialized is deprecated, use is_initialized instead", DeprecationWarning) return not is_initialized() diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py index 137929a13e..3e33e7c2f8 100644 --- a/megatron/core/pipeline_parallel/p2p_communication.py +++ b/megatron/core/pipeline_parallel/p2p_communication.py @@ -131,34 +131,22 @@ def _batched_p2p_ops( ops = [] if tensor_send_prev is not None: send_prev_op = torch.distributed.P2POp( - torch.distributed.isend, - tensor_send_prev, - prev_pipeline_rank, - group, + torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group ) ops.append(send_prev_op) if tensor_recv_prev is not None: recv_prev_op = torch.distributed.P2POp( - torch.distributed.irecv, - tensor_recv_prev, - prev_pipeline_rank, - group, + torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group ) ops.append(recv_prev_op) if tensor_send_next is not None: send_next_op = torch.distributed.P2POp( - torch.distributed.isend, - tensor_send_next, - next_pipeline_rank, - group, + torch.distributed.isend, tensor_send_next, next_pipeline_rank, group ) ops.append(send_next_op) if tensor_recv_next is not None: recv_next_op = torch.distributed.P2POp( - torch.distributed.irecv, - tensor_recv_next, - next_pipeline_rank, - group, + torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group ) ops.append(recv_next_op) if len(ops) > 0: @@ -193,66 +181,50 @@ def _p2p_ops( if get_pipeline_model_parallel_rank() % 2 == 0: if tensor_send_next is not None: send_next_req = torch.distributed.isend( - tensor=tensor_send_next, - dst=next_pipeline_rank, - group=even_send_odd_recv_group, + tensor=tensor_send_next, dst=next_pipeline_rank, group=even_send_odd_recv_group ) reqs.append(send_next_req) if tensor_recv_prev is not None: recv_prev_req = torch.distributed.irecv( - tensor=tensor_recv_prev, - src=prev_pipeline_rank, - group=even_recv_odd_send_group, + tensor=tensor_recv_prev, src=prev_pipeline_rank, group=even_recv_odd_send_group ) reqs.append(recv_prev_req) if tensor_send_prev is not None: send_prev_req = torch.distributed.isend( - tensor=tensor_send_prev, - dst=prev_pipeline_rank, - group=even_send_odd_recv_group, + tensor=tensor_send_prev, dst=prev_pipeline_rank, group=even_send_odd_recv_group ) reqs.append(send_prev_req) if tensor_recv_next is not None: recv_next_req = torch.distributed.irecv( - tensor=tensor_recv_next, - src=next_pipeline_rank, - group=even_recv_odd_send_group, + tensor=tensor_recv_next, src=next_pipeline_rank, group=even_recv_odd_send_group ) reqs.append(recv_next_req) else: if tensor_recv_prev is not None: recv_prev_req = torch.distributed.irecv( - tensor=tensor_recv_prev, - src=prev_pipeline_rank, - group=even_send_odd_recv_group, + tensor=tensor_recv_prev, src=prev_pipeline_rank, group=even_send_odd_recv_group ) reqs.append(recv_prev_req) if tensor_send_next is not None: send_next_req = torch.distributed.isend( - tensor=tensor_send_next, - dst=next_pipeline_rank, - group=even_recv_odd_send_group, + tensor=tensor_send_next, dst=next_pipeline_rank, group=even_recv_odd_send_group ) reqs.append(send_next_req) if tensor_recv_next is not None: recv_next_req = torch.distributed.irecv( - tensor=tensor_recv_next, - src=next_pipeline_rank, - group=even_send_odd_recv_group, + tensor=tensor_recv_next, src=next_pipeline_rank, group=even_send_odd_recv_group ) reqs.append(recv_next_req) if tensor_send_prev is not None: send_prev_req = torch.distributed.isend( - tensor=tensor_send_prev, - dst=prev_pipeline_rank, - group=even_recv_odd_send_group, + tensor=tensor_send_prev, dst=prev_pipeline_rank, group=even_recv_odd_send_group ) reqs.append(send_prev_req) return reqs diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 432420f63e..b7669ccb45 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -121,11 +121,7 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): return assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ assert out._base is None, "counter-productive to free a view of another tensor." - out.data = torch.empty( - (1,), - device=out.device, - dtype=out.dtype, - ) + out.data = torch.empty((1,), device=out.device, dtype=out.dtype) def custom_backward(output, grad_output): @@ -146,10 +142,7 @@ def custom_backward(output, grad_output): # Handle scalar output if grad_output is None: assert output.numel() == 1, "implicit grad requires scalar output." - grad_output = torch.ones_like( - output, - memory_format=torch.preserve_format, - ) + grad_output = torch.ones_like(output, memory_format=torch.preserve_format) # Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ] Variable._execution_engine.run_backward( @@ -752,9 +745,7 @@ def forward_step_helper(microbatch_id, current_microbatch, checkpoint_activation collect_non_loss_data, checkpoint_activations_microbatch, check_first_val_step( - first_val_step, - forward_only, - is_first_microbatch_for_model_chunk(microbatch_id), + first_val_step, forward_only, is_first_microbatch_for_model_chunk(microbatch_id) ), current_microbatch=current_microbatch, ) @@ -863,16 +854,15 @@ def backward_step_helper(microbatch_id): recv_next = True if parallel_state.is_pipeline_last_stage(ignore_virtual=True): recv_next = False - ( - input_tensor, - output_tensor_grad, - ) = p2p_communication.send_forward_backward_recv_forward_backward( - output_tensor, - input_tensor_grad, - recv_prev=recv_prev, - recv_next=recv_next, - tensor_shape=tensor_shape, - config=config, + (input_tensor, output_tensor_grad) = ( + p2p_communication.send_forward_backward_recv_forward_backward( + output_tensor, + input_tensor_grad, + recv_prev=recv_prev, + recv_next=recv_next, + tensor_shape=tensor_shape, + config=config, + ) ) output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) else: @@ -899,15 +889,14 @@ def backward_step_helper(microbatch_id): if parallel_state.is_pipeline_last_stage(ignore_virtual=True): recv_next = False - ( - output_tensor_grad, - bwd_wait_handles, - ) = p2p_communication.send_backward_recv_backward( - input_tensor_grad, - recv_next=recv_next, - tensor_shape=tensor_shape, - config=config, - overlap_p2p_comm=True, + (output_tensor_grad, bwd_wait_handles) = ( + p2p_communication.send_backward_recv_backward( + input_tensor_grad, + recv_next=recv_next, + tensor_shape=tensor_shape, + config=config, + overlap_p2p_comm=True, + ) ) output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) @@ -1073,16 +1062,15 @@ def backward_step_helper(microbatch_id): recv_prev = False # Communicate tensors. - ( - input_tensor, - output_tensor_grad, - ) = p2p_communication.send_forward_backward_recv_forward_backward( - output_tensor, - input_tensor_grad, - recv_prev=recv_prev, - recv_next=recv_next, - tensor_shape=tensor_shape, - config=config, + (input_tensor, output_tensor_grad) = ( + p2p_communication.send_forward_backward_recv_forward_backward( + output_tensor, + input_tensor_grad, + recv_prev=recv_prev, + recv_next=recv_next, + tensor_shape=tensor_shape, + config=config, + ) ) deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) diff --git a/megatron/core/ssm/mamba_block.py b/megatron/core/ssm/mamba_block.py index ef444e8d2c..0bb9acce8d 100644 --- a/megatron/core/ssm/mamba_block.py +++ b/megatron/core/ssm/mamba_block.py @@ -146,12 +146,7 @@ def __init__( eps=self.config.layernorm_epsilon, ) - self.apply( - partial( - _init_weights, - n_layer=self.config.num_layers, - ) - ) + self.apply(partial(_init_weights, n_layer=self.config.num_layers)) def _select_layers_for_pipeline_parallel(self, layer_type_list): pipeline_rank = parallel_state.get_pipeline_model_parallel_rank() diff --git a/megatron/core/tensor_parallel/cross_entropy.py b/megatron/core/tensor_parallel/cross_entropy.py index 45fa07515d..0066d126fd 100644 --- a/megatron/core/tensor_parallel/cross_entropy.py +++ b/megatron/core/tensor_parallel/cross_entropy.py @@ -80,8 +80,7 @@ def calculate_cross_entropy_loss( @staticmethod def prepare_gradient_calculation_operands( - softmax: torch.Tensor, - target_mask: torch.Tensor, + softmax: torch.Tensor, target_mask: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # All the inputs have softmax as thier gradient. @@ -133,14 +132,10 @@ def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0): world_size = get_tensor_model_parallel_world_size() vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size) - ( - target_mask, - masked_target_1d, - predicted_logits, - sum_exp_logits, - exp_logits, - ) = VocabParallelCrossEntropy.calculate_predicted_logits( - vocab_parallel_logits, target, logits_max, vocab_start_index, vocab_end_index + (target_mask, masked_target_1d, predicted_logits, sum_exp_logits, exp_logits) = ( + VocabParallelCrossEntropy.calculate_predicted_logits( + vocab_parallel_logits, target, logits_max, vocab_start_index, vocab_end_index + ) ) # All reduce is needed to get the chunks from other GPUs. @@ -193,12 +188,9 @@ def backward(ctx, grad_output): softmax, target_mask, masked_target_1d = ctx.saved_tensors label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size - ( - grad_2d, - arange_1d, - softmax_update, - grad_input, - ) = VocabParallelCrossEntropy.prepare_gradient_calculation_operands(softmax, target_mask) + (grad_2d, arange_1d, softmax_update, grad_input) = ( + VocabParallelCrossEntropy.prepare_gradient_calculation_operands(softmax, target_mask) + ) if label_smoothing > 0: smoothing = label_smoothing * vocab_size / (vocab_size - 1) diff --git a/megatron/core/tensor_parallel/data.py b/megatron/core/tensor_parallel/data.py index 01dd90de51..c549f74d73 100644 --- a/megatron/core/tensor_parallel/data.py +++ b/megatron/core/tensor_parallel/data.py @@ -14,9 +14,10 @@ def _check_data_types(keys, data, target_dtype): """Check that all the keys have the same target data type.""" for key in keys: - assert data[key].dtype == target_dtype, ( - '{} has data type {} which ' - 'is different than {}'.format(key, data[key].dtype, target_dtype) + assert ( + data[key].dtype == target_dtype + ), '{} has data type {} which ' 'is different than {}'.format( + key, data[key].dtype, target_dtype ) diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index d644eb89ef..5707a0b529 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -179,11 +179,12 @@ def __init__( self.reduce_scatter_embeddings = reduce_scatter_embeddings self.tensor_model_parallel_size = get_tensor_model_parallel_world_size() # Divide the weight matrix along the vocaburaly dimension. - ( - self.vocab_start_index, - self.vocab_end_index, - ) = VocabUtility.vocab_range_from_global_vocab_size( - self.num_embeddings, get_tensor_model_parallel_rank(), self.tensor_model_parallel_size + (self.vocab_start_index, self.vocab_end_index) = ( + VocabUtility.vocab_range_from_global_vocab_size( + self.num_embeddings, + get_tensor_model_parallel_rank(), + self.tensor_model_parallel_size, + ) ) self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index self.deterministic_mode = config.deterministic_mode @@ -276,13 +277,7 @@ class LinearWithFrozenWeight(torch.autograd.Function): @staticmethod @custom_fwd - def forward( - ctx, - input, - weight, - bias, - allreduce_dgrad, - ): + def forward(ctx, input, weight, bias, allreduce_dgrad): ctx.save_for_backward(weight) ctx.allreduce_dgrad = allreduce_dgrad output = torch.matmul(input, weight.t()) @@ -372,12 +367,7 @@ def linear_with_frozen_weight( ) allreduce_dgrad = async_grad_allreduce - args = [ - input, - weight, - bias, - allreduce_dgrad, - ] + args = [input, weight, bias, allreduce_dgrad] return LinearWithFrozenWeight.apply(*args) diff --git a/megatron/core/tensor_parallel/mappings.py b/megatron/core/tensor_parallel/mappings.py index 88e77541d1..3eed700ceb 100644 --- a/megatron/core/tensor_parallel/mappings.py +++ b/megatron/core/tensor_parallel/mappings.py @@ -368,7 +368,7 @@ def symbolic(graph, input_): @staticmethod def forward(ctx, input_): - return _gather_along_last_dim(input_,) + return _gather_along_last_dim(input_) @staticmethod def backward(ctx, grad_output): @@ -384,7 +384,7 @@ def symbolic(graph, input_): @staticmethod def forward(ctx, input_): - return _reduce_scatter_along_last_dim(input_,) + return _reduce_scatter_along_last_dim(input_) @staticmethod def backward(ctx, grad_output): @@ -514,7 +514,7 @@ def all_to_all_hp2sp(input_): Args: input_ (torch.Tensor): The input tensor which has been distributed along the hidden dimension. - + Returns: torch.Tensor: The output tensor with shape [num_tokens/TP, H]. """ diff --git a/megatron/core/tensor_parallel/utils.py b/megatron/core/tensor_parallel/utils.py index 53f0d60de0..d7c191b411 100644 --- a/megatron/core/tensor_parallel/utils.py +++ b/megatron/core/tensor_parallel/utils.py @@ -14,18 +14,18 @@ def split_tensor_along_last_dim( - tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False, + tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False ) -> List[torch.Tensor]: - """ Split a tensor along its last dimension. + """Split a tensor along its last dimension. - Args: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. + Args: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. - Returns: - A list of Tensors + Returns: + A list of Tensors """ # Get the size and dimension. last_dim = tensor.dim() - 1 @@ -40,17 +40,17 @@ def split_tensor_along_last_dim( def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): - """ Break a tensor into equal 1D chunks across tensor parallel ranks. + """Break a tensor into equal 1D chunks across tensor parallel ranks. - Returns a Tensor or View with this rank's portion of the data. + Returns a Tensor or View with this rank's portion of the data. - Args: - tensor: The tensor to split + Args: + tensor: The tensor to split - Keyword Args: - new_buffer (bool): If True, returns a new Tensor. - If False, returns a view into the existing Tensor. - Default is False + Keyword Args: + new_buffer (bool): If True, returns a new Tensor. + If False, returns a view into the existing Tensor. + Default is False """ partition_size = torch.numel(tensor) // parallel_state.get_tensor_model_parallel_world_size() @@ -70,13 +70,13 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): def gather_split_1d_tensor(tensor): - """ Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor - model parallel ranks. + """Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor + model parallel ranks. - Returns a new Tensor with the gathered data. + Returns a new Tensor with the gathered data. - Args: - tensor: A Tensor or view of this rank's portion of the data. + Args: + tensor: A Tensor or view of this rank's portion of the data. """ numel_gathered = torch.numel(tensor) * parallel_state.get_tensor_model_parallel_world_size() gathered = torch.empty( @@ -94,9 +94,9 @@ def gather_split_1d_tensor(tensor): class VocabUtility: - """ Split the vocabulary into `world_size` chunks and return the first - and last index of the vocabulary belonging to the `rank` - partition: Note that indices in [fist, last) + """Split the vocabulary into `world_size` chunks and return the first + and last index of the vocabulary belonging to the `rank` + partition: Note that indices in [fist, last) """ diff --git a/megatron/core/timers.py b/megatron/core/timers.py index b61eb4ed22..e7070e37d8 100644 --- a/megatron/core/timers.py +++ b/megatron/core/timers.py @@ -110,8 +110,7 @@ def stop(self, barrier=False): self._started = False def reset(self): - """Reset timer. - """ + """Reset timer.""" # Don't reset _active_time self._elapsed = 0.0 self._started = False @@ -145,14 +144,13 @@ def active_time(self): class Timers: - """Class for a group of Timers. - """ + """Class for a group of Timers.""" def __init__(self, log_level, log_option): """Initialize group of timers. Args: - log_level (int): Log level to control what timers are enabled. + log_level (int): Log level to control what timers are enabled. log_option (str): Setting for logging statistics over ranks for all the timers. Allowed: ['max', 'minmax', 'all']. """ self._log_level = log_level @@ -351,7 +349,7 @@ def log( barrier: bool = False, ): """logs the timers passed in names to stdout. Example usage is to log average per step value for timer 'foo', - this function can be called with normalizer factor set to logging interval. + this function can be called with normalizer factor set to logging interval. Args: names (List[str]): Names of the timers to log. diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 96c19d0fca..43eacf03f9 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -149,14 +149,7 @@ def custom_forward(*inputs): attn_mask_type = self.attn_mask_type attn_mask_type = torch.tensor([attn_mask_type.value], dtype=torch.int) hidden_states = tensor_parallel.checkpoint( - custom_forward, - False, - query, - key, - value, - attention_mask, - rotary_pos_emb, - attn_mask_type, + custom_forward, False, query, key, value, attention_mask, rotary_pos_emb, attn_mask_type ) return hidden_states @@ -289,17 +282,9 @@ def forward( else: cu_seqlens_q = cu_seqlens_kv = None query = apply_rotary_pos_emb( - query, - q_pos_emb, - config=self.config, - cu_seqlens=cu_seqlens_q, - ) - key = apply_rotary_pos_emb( - key, - k_pos_emb, - config=self.config, - cu_seqlens=cu_seqlens_kv, + query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q ) + key = apply_rotary_pos_emb(key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv) # TODO, can apply positional embedding to value_layer so it has # absolute positional embedding. @@ -499,19 +484,11 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None): if SplitAlongDim is not None: # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] - (query, key, value) = SplitAlongDim( - mixed_qkv, - 3, - split_arg_list, - ) + (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list) else: # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] - (query, key, value) = torch.split( - mixed_qkv, - split_arg_list, - dim=3, - ) + (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3) # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) diff --git a/megatron/core/transformer/custom_layers/transformer_engine.py b/megatron/core/transformer/custom_layers/transformer_engine.py index 879547fc1b..4d73995bbd 100644 --- a/megatron/core/transformer/custom_layers/transformer_engine.py +++ b/megatron/core/transformer/custom_layers/transformer_engine.py @@ -39,9 +39,7 @@ def get_te_version_str(): def _get_extra_te_kwargs(config: TransformerConfig): - extra_transformer_engine_kwargs = { - "params_dtype": config.params_dtype, - } + extra_transformer_engine_kwargs = {"params_dtype": config.params_dtype} if _te_version >= packaging.version.Version("0.12.0"): if config.use_cpu_initialization: @@ -62,12 +60,7 @@ class TENorm: """ # TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm? - def __new__( - cls, - config: TransformerConfig, - hidden_size: int, - eps: float = 1e-5, - ): + def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5): if config.normalization == "LayerNorm": instance = te.pytorch.LayerNorm( hidden_size=hidden_size, @@ -559,13 +552,7 @@ def forward( **packed_seq_kwargs, ) else: - core_attn_out = super().forward( - query, - key, - value, - attention_mask, - **packed_seq_kwargs, - ) + core_attn_out = super().forward(query, key, value, attention_mask, **packed_seq_kwargs) if self.config.apply_rope_fusion and qkv_format == 'bshd': return core_attn_out.transpose(0, 1) @@ -767,12 +754,7 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): """ tp_axis_map = {} for gemm_idx in range(self.num_gemms): - tp_axis_map.update( - { - f'{gemm_idx}.weight': 0, - f'{gemm_idx}.bias': 0, - } - ) + tp_axis_map.update({f'{gemm_idx}.weight': 0, f'{gemm_idx}.bias': 0}) return super()._sharded_state_dict_grouped( tp_axis_map, prefix, sharded_offsets, metadata ) diff --git a/megatron/core/transformer/dot_product_attention.py b/megatron/core/transformer/dot_product_attention.py index 967d0ce8d8..7c28c153bc 100644 --- a/megatron/core/transformer/dot_product_attention.py +++ b/megatron/core/transformer/dot_product_attention.py @@ -120,12 +120,7 @@ def forward( ) # [b, np, sq, sk] - output_size = ( - query.size(1), - query.size(2), - query.size(0), - key.size(0), - ) + output_size = (query.size(1), query.size(2), query.size(0), key.size(0)) # [sq, b, np, hn] -> [sq, b * np, hn] # This will be a simple view when doing normal attention, but in group query attention @@ -137,7 +132,7 @@ def forward( # preallocting input tensor: [b * np, sq, sk] matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor( - (output_size[0] * output_size[1], output_size[2], output_size[3]), query.dtype, "mpu", + (output_size[0] * output_size[1], output_size[2], output_size[3]), query.dtype, "mpu" ) # Raw attention scores. [b * np, sq, sk] @@ -176,12 +171,7 @@ def forward( # [sk, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] - output_size = ( - value.size(1), - value.size(2), - query.size(0), - value.size(3), - ) + output_size = (value.size(1), value.size(2), query.size(0), value.size(3)) # change view [sk, b * np, hn] value = value.view(value.size(0), output_size[0] * output_size[1], -1) diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index e11adf9447..d19ff6a234 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -94,9 +94,7 @@ def glu(x): ) self.weight2 = Parameter( torch.empty( - fc2_input_size_per_partition, - self.config.hidden_size, - dtype=config.params_dtype, + fc2_input_size_per_partition, self.config.hidden_size, dtype=config.params_dtype ) ) if config.perform_initialization: diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index c0c10a2c58..da3bde82f5 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -270,9 +270,7 @@ def unpermute_with_padded_tokens( # Prepare a tensor of zeros with the desired output shape empty_tokens = torch.zeros( - restore_shape, - dtype=combined_output.dtype, - device=combined_output.device, + restore_shape, dtype=combined_output.dtype, device=combined_output.device ) # Scatter the combined tokens back to their original positions @@ -325,9 +323,7 @@ def topk_softmax_with_capacity( else: # TopK with capacity expert_capacity = get_capacity( - num_tokens=num_tokens * topk, - num_experts=num_experts, - capacity_factor=capacity_factor, + num_tokens=num_tokens * topk, num_experts=num_experts, capacity_factor=capacity_factor ) # TopK selection, Maskout unused experts topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs) @@ -418,9 +414,7 @@ def reduce_aux_losses_tracker_across_ranks(): torch.distributed.all_reduce(values, group=tracker[name].get('reduce_group')) if tracker[name].get('avg_group') is not None: torch.distributed.all_reduce( - values, - group=tracker[name]['avg_group'], - op=torch.distributed.ReduceOp.AVG, + values, group=tracker[name]['avg_group'], op=torch.distributed.ReduceOp.AVG ) diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index a98959b710..817bfc0bdb 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -40,10 +40,7 @@ def __init__(self, config: TransformerConfig) -> None: # Initialize the gate weights. self.weight = torch.nn.Parameter( - torch.empty( - (self.config.num_moe_experts, self.config.hidden_size), - dtype=torch.float32, - ) + torch.empty((self.config.num_moe_experts, self.config.hidden_size), dtype=torch.float32) ) if config.perform_initialization: if get_cuda_rng_tracker().is_initialized(): @@ -99,10 +96,7 @@ def set_layer_number(self, layer_number: int): class TopKRouter(Router): """Route each token to the top-k experts.""" - def __init__( - self, - config: TransformerConfig, - ) -> None: + def __init__(self, config: TransformerConfig) -> None: """Initialize the zero token dropping router. Args: @@ -228,10 +222,7 @@ def apply_z_loss(self, logits): z_loss = z_loss_func(logits, moe_z_loss_coeff) logits = MoEAuxLossAutoScaler.apply(logits, z_loss) save_to_aux_losses_tracker( - "z_loss", - z_loss / moe_z_loss_coeff, - self.layer_number, - self.config.num_layers, + "z_loss", z_loss / moe_z_loss_coeff, self.layer_number, self.config.num_layers ) return logits diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 377403a5d7..c76ca6541e 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -23,11 +23,7 @@ def __init__(self, config: TransformerConfig) -> None: self.config = config @abstractmethod - def token_permutation( - self, - tokens: torch.Tensor, - indices: torch.Tensor, - ): + def token_permutation(self, tokens: torch.Tensor, indices: torch.Tensor): """Dispatch tokens to experts. Args: @@ -41,10 +37,7 @@ def token_permutation( @abstractmethod def token_unpermutation( - self, - expert_output: torch.Tensor, - probs: torch.Tensor, - indices: torch.Tensor, + self, expert_output: torch.Tensor, probs: torch.Tensor, indices: torch.Tensor ): """Restores the expert output to its original ordering. @@ -65,10 +58,7 @@ class MoEAllGatherTokenDispatcher(MoETokenDispatcher): """ def __init__( - self, - num_local_experts: int, - local_expert_indices: List[int], - config: TransformerConfig, + self, num_local_experts: int, local_expert_indices: List[int], config: TransformerConfig ) -> None: """ Initialize the zero token dropping router. @@ -163,8 +153,7 @@ def token_permutation( # The indices of local_indices that give its sorted order along dim 0. self.indices = torch.argsort(local_indices, dim=0) tokens_per_expert = torch.bincount( - local_indices.view(-1), - minlength=self.config.num_moe_experts, + local_indices.view(-1), minlength=self.config.num_moe_experts ) if self.num_local_experts < self.config.num_moe_experts: tokens_per_expert = tokens_per_expert[ @@ -179,16 +168,9 @@ def token_permutation( permuted_local_hidden_states = moe_gather.apply(local_hidden_states, self.indices) else: permuted_local_hidden_states = local_hidden_states - return ( - permuted_local_hidden_states, - tokens_per_expert, - ) + return (permuted_local_hidden_states, tokens_per_expert) - def token_unpermutation( - self, - hidden_states: torch.Tensor, - bias: torch.Tensor = None, - ): + def token_unpermutation(self, hidden_states: torch.Tensor, bias: torch.Tensor = None): """ Reverse process of `dispatch()` which permutes the ouput of local experts locallay and across expert parallel rank into the original order to @@ -299,10 +281,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): """ def __init__( - self, - num_local_experts: int, - local_expert_indices: List[int], - config: TransformerConfig, + self, num_local_experts: int, local_expert_indices: List[int], config: TransformerConfig ) -> None: """ Initialize the AlltoAll token dispatcher. @@ -442,10 +421,7 @@ def preprocess(self, indices: torch.Tensor) -> torch.Tensor: return num_tokens_per_local_expert def token_permutation( - self, - hidden_states: torch.Tensor, - probs: torch.Tensor, - indices: torch.Tensor, + self, hidden_states: torch.Tensor, probs: torch.Tensor, indices: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Dispatch tokens to local experts using AlltoAll communication. @@ -522,9 +498,7 @@ def token_permutation( return global_input_tokens, tokens_per_expert def token_unpermutation( - self, - hidden_states: torch.Tensor, - bias: torch.Tensor = None, + self, hidden_states: torch.Tensor, bias: torch.Tensor = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Reverse the token permutation to restore the original order. @@ -551,8 +525,7 @@ def token_unpermutation( if self.num_local_experts > 1: if not self.drop_and_pad: hidden_states = unpermute( - hidden_states, - self.reversed_global_input_permutation_mapping, + hidden_states, self.reversed_global_input_permutation_mapping ) else: hidden_states = hidden_states.reshape( diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index 8904e4b86f..1e90099a21 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -90,8 +90,7 @@ class TransformerBlockSubmodules: def _get_block_submodules( - config: TransformerConfig, - spec: Union[TransformerBlockSubmodules, ModuleSpec], + config: TransformerConfig, spec: Union[TransformerBlockSubmodules, ModuleSpec] ) -> TransformerBlockSubmodules: # Transformer block submodules. @@ -107,8 +106,7 @@ def _get_block_submodules( elif issubclass(spec.module, BaseTransformerLayer): num_layers = get_num_layers_to_build(config) return TransformerBlockSubmodules( - layer_specs=[spec] * num_layers, - layer_norm=LayerNormImpl, + layer_specs=[spec] * num_layers, layer_norm=LayerNormImpl ) else: raise Exception(f"specialize for {spec.module.__name__}.") @@ -146,15 +144,14 @@ def __init__( self.checkpoint_core_attention = self.config.recompute_granularity == 'selective' if get_cpu_offload_context is not None: - ( - self.offload_context, - self.group_prefetch_offload_commit_async, - ) = get_cpu_offload_context( - self.config.cpu_offloading, - self.config.cpu_offloading_num_layers, - self.config.num_layers, - self.config.cpu_offloading_activations, - self.config.cpu_offloading_weights, + (self.offload_context, self.group_prefetch_offload_commit_async) = ( + get_cpu_offload_context( + self.config.cpu_offloading, + self.config.cpu_offloading_num_layers, + self.config.num_layers, + self.config.cpu_offloading_activations, + self.config.cpu_offloading_weights, + ) ) self.config._cpu_offloading_context = ( self.offload_context if self.config.cpu_offloading else None @@ -178,11 +175,7 @@ def _build_layers(self): # coeff = self.layer_number # self.norm_factor *= coeff def build_layer(layer_spec, layer_number): - return build_module( - layer_spec, - config=self.config, - layer_number=layer_number, - ) + return build_module(layer_spec, config=self.config, layer_number=layer_number) # offset is implicit in TransformerLayer self.layers = torch.nn.ModuleList( @@ -235,11 +228,7 @@ def _checkpointed_forward( def custom(start: int, end: int): def custom_forward( - hidden_states, - attention_mask, - context, - context_mask, - rotary_pos_emb, + hidden_states, attention_mask, context, context_mask, rotary_pos_emb ): for index in range(start, end): layer = self._get_layer(index) @@ -310,11 +299,7 @@ def checkpoint_handler(forward_func): hidden_states, context = checkpoint_handler(custom(l, l + 1)) else: hidden_states, context = custom(l, l + 1)( - hidden_states, - attention_mask, - context, - context_mask, - rotary_pos_emb, + hidden_states, attention_mask, context, context_mask, rotary_pos_emb ) else: raise ValueError("Invalid activation recompute method.") @@ -363,11 +348,7 @@ def forward( # likely redundant, since p2p_communication.py (likely originator) # already creates viewless tensors. That said, make_viewless_tensor() # is called here to be future-proof and corner-case-proof. - hidden_states = make_viewless_tensor( - inp=hidden_states, - requires_grad=True, - keep_graph=True, - ) + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) if self.config.sequence_parallel: rng_context = tensor_parallel.get_cuda_rng_tracker().fork() @@ -437,8 +418,7 @@ def forward( self.current_microbatch < len(self.cuda_graphs[l_no]) ) hidden_states = self.cuda_graphs[l_no][self.current_microbatch]( - hidden_states, - is_first_microbatch=(self.current_microbatch == 0), + hidden_states, is_first_microbatch=(self.current_microbatch == 0) ) if ( @@ -455,9 +435,7 @@ def forward( # deallocate_output_tensor() throwing an error, so a viewless tensor is # created to prevent this. hidden_states = make_viewless_tensor( - inp=hidden_states, - requires_grad=True, - keep_graph=True, + inp=hidden_states, requires_grad=True, keep_graph=True ) return hidden_states diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 631179ed08..703a291e83 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -36,7 +36,7 @@ class TransformerLayerSubmodules: class BaseTransformerLayer(ABC): - """ A common parent class for `TransformerLayer` like implementations. + """A common parent class for `TransformerLayer` like implementations. A dummy class that is subclassed by similar `TransformerLayer`s e.g. the `TransformerLayer` in this file and possibly other `TransformerLayer` @@ -82,7 +82,7 @@ def __init__( ## [Module 2: SelfAttention] self.self_attention = build_module( - submodules.self_attention, config=self.config, layer_number=layer_number, + submodules.self_attention, config=self.config, layer_number=layer_number ) ## [Module 3: BiasDropoutFusion] @@ -98,11 +98,11 @@ def __init__( ## [Module 5: CrossAttention] self.cross_attention = build_module( - submodules.cross_attention, config=self.config, layer_number=layer_number, + submodules.cross_attention, config=self.config, layer_number=layer_number ) ## [Module 6: BiasDropoutFusion] - self.cross_attn_bda = build_module(submodules.cross_attn_bda, config=self.config,) + self.cross_attn_bda = build_module(submodules.cross_attn_bda, config=self.config) ## [Module 7: Pre MLP] Optional Layernorm before MLP self.pre_mlp_layernorm = build_module( diff --git a/megatron/core/transformer/utils.py b/megatron/core/transformer/utils.py index 025f7c2b1e..4781b68d2a 100644 --- a/megatron/core/transformer/utils.py +++ b/megatron/core/transformer/utils.py @@ -97,12 +97,12 @@ def make_sharded_tensors_for_checkpoint( elif layer_name in tensor_parallel_layers_axis_map: tp_axis = tensor_parallel_layers_axis_map[layer_name] sharded_state_dict[layer_key] = make_tp_sharded_tensor_for_checkpoint( - tensor, layer_key, tp_axis, prepend_offsets=sharded_offsets, + tensor, layer_key, tp_axis, prepend_offsets=sharded_offsets ) else: sharded_state_dict[layer_key] = make_sharded_tensor_for_checkpoint( - tensor, layer_key, prepend_offsets=sharded_offsets, + tensor, layer_key, prepend_offsets=sharded_offsets ) return sharded_state_dict @@ -115,7 +115,7 @@ def make_sharded_object_for_checkpoint( replica_id: Union[None, int, Tuple[int, ...]] = None, **kwargs, ): - """ Helper for instantiating a non-sharded ShardedObject (replicated across TP and DP group). + """Helper for instantiating a non-sharded ShardedObject (replicated across TP and DP group). Args: obj (object): any object to be sharded @@ -138,7 +138,7 @@ def make_sharded_object_for_checkpoint( def _get_extra_state_offsets( sharded_offsets: Iterable[Tuple[int, int, int]] ) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: - """ Turns ShardedTensor offsets into offsets suitable for ShardedObject. """ + """Turns ShardedTensor offsets into offsets suitable for ShardedObject.""" if sharded_offsets: sharded_offsets = sorted(sharded_offsets, key=itemgetter(0)) # sort by axis axis, extra_state_offset, extra_state_shape = zip(*sharded_offsets) @@ -183,6 +183,6 @@ def sharded_state_dict_default( else: module_sd = module.state_dict(prefix='', keep_vars=True) module_sharded_sd = make_sharded_tensors_for_checkpoint( - module_sd, prefix, {}, sharded_offsets, + module_sd, prefix, {}, sharded_offsets ) return module_sharded_sd diff --git a/megatron/core/utils.py b/megatron/core/utils.py index a777770617..062372d97d 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -111,12 +111,7 @@ def _kernel_make_viewless_tensor(inp, requires_grad): data, without linking the viewed tensor, referenced via the '._base' field. ''' - out = torch.empty( - (1,), - dtype=inp.dtype, - device=inp.device, - requires_grad=requires_grad, - ) + out = torch.empty((1,), dtype=inp.dtype, device=inp.device, requires_grad=requires_grad) out.data = inp.data return out @@ -908,13 +903,7 @@ def report(self, total_flops: float = 0.0, log_interval: int = 0) -> bool: et_flops = apir_flops / self.amp # Estimated TFLOPs, not tracing backward o_dt = self._min_max( - ptime, - btime, - float(temp), - float(power), - float(util), - float(clock), - et_flops, + ptime, btime, float(temp), float(power), float(util), float(clock), et_flops ) if self.rank == 0 and o_dt is not None and o_dt.aflops is not None: now = f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]" diff --git a/pyproject.toml b/pyproject.toml index 934745ec68..c707686a83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ requires = [ [tool.isort] profile = "black" # black-compatible line_length = 100 # should match black parameters -py_version = 38 # python 3.8 as a target version +py_version = 310 # python 3.8 as a target version known_first_party = ["megatron"] # FIRSTPARTY section known_third_party = ["transformer_engine"] # THIRDPARTY section sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] diff --git a/tests/functional_tests/python_test_utils/common.py b/tests/functional_tests/python_test_utils/common.py index 3ce43f095f..3a9fd359a6 100644 --- a/tests/functional_tests/python_test_utils/common.py +++ b/tests/functional_tests/python_test_utils/common.py @@ -10,10 +10,7 @@ # Since we expect every step to be there when we do our comparisons, we explicitly # set the size guidance to 0 so that we load everything. It's okay given our tests # are small/short. -SIZE_GUIDANCE = { - event_accumulator.TENSORS: 0, - event_accumulator.SCALARS: 0, -} +SIZE_GUIDANCE = {event_accumulator.TENSORS: 0, event_accumulator.SCALARS: 0} logger = logging.getLogger() diff --git a/tests/functional_tests/python_test_utils/get_test_results_from_tensorboard_logs.py b/tests/functional_tests/python_test_utils/get_test_results_from_tensorboard_logs.py index ba3d43f9c5..e93fd2046e 100644 --- a/tests/functional_tests/python_test_utils/get_test_results_from_tensorboard_logs.py +++ b/tests/functional_tests/python_test_utils/get_test_results_from_tensorboard_logs.py @@ -9,12 +9,7 @@ @click.command() -@click.option( - "--logs-dir", - required=True, - type=str, - help="Path to Tensorboard logs", -) +@click.option("--logs-dir", required=True, type=str, help="Path to Tensorboard logs") @click.option( "--output-path", required=False, diff --git a/tests/functional_tests/python_test_utils/test_resume_checkpoint_pipeline.py b/tests/functional_tests/python_test_utils/test_resume_checkpoint_pipeline.py index bf14f8ef75..f0375dfb3d 100644 --- a/tests/functional_tests/python_test_utils/test_resume_checkpoint_pipeline.py +++ b/tests/functional_tests/python_test_utils/test_resume_checkpoint_pipeline.py @@ -16,9 +16,7 @@ def collect_train_test_metrics(logs_dir, index): train_loss_list = read_tb_logs_as_list(logs_dir, index)["lm loss"] train_loss_list = [round(elem, 3) for elem in train_loss_list] - train_metrics = { - "lm loss": train_loss_list[0 : len(train_loss_list) : STEP_INTERVAL], - } + train_metrics = {"lm loss": train_loss_list[0 : len(train_loss_list) : STEP_INTERVAL]} str_train_metrics = str(train_metrics).replace("'", '"') print("\n ----------- The following are the metrics for ----------") print(f"\n {str_train_metrics}", flush=True) diff --git a/tests/unit_tests/__init__.py b/tests/unit_tests/__init__.py index 1d3c586a5d..38a9977640 100644 --- a/tests/unit_tests/__init__.py +++ b/tests/unit_tests/__init__.py @@ -1,2 +1,3 @@ import torch._dynamo -torch._dynamo.config.suppress_errors = True \ No newline at end of file + +torch._dynamo.config.suppress_errors = True diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index fb5cfc3ba4..787dd48c7a 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -13,9 +13,10 @@ @pytest.fixture(scope="session") def tmp_path_dist_ckpt(tmp_path_factory) -> Path: - """ Common directory for saving the checkpoint. + """Common directory for saving the checkpoint. - Can't use pytest `tmp_path_factory` directly because directory must be shared between processes. """ + Can't use pytest `tmp_path_factory` directly because directory must be shared between processes. + """ tmp_dir = tmp_path_factory.mktemp('ignored', numbered=False) tmp_dir = tmp_dir.parent.parent / 'tmp_dist_ckpt' diff --git a/tests/unit_tests/data/test_builder.py b/tests/unit_tests/data/test_builder.py index 8f149dcffb..7f4caaa0f6 100644 --- a/tests/unit_tests/data/test_builder.py +++ b/tests/unit_tests/data/test_builder.py @@ -110,11 +110,7 @@ def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]: config = BlendedMegatronDatasetConfig( random_seed=1234, sequence_length=_SEQUENCE_LENGTH, - blend_per_split=[ - blends[Split.train], - None, - None, - ], + blend_per_split=[blends[Split.train], None, None], ) try: datasets = BlendedMegatronDatasetBuilder( @@ -127,11 +123,7 @@ def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]: config = BlendedMegatronDatasetConfig( random_seed=1234, sequence_length=_SEQUENCE_LENGTH, - blend_per_split=[ - get_blend_from_list([paths[Split.train][0]]), - None, - None, - ], + blend_per_split=[get_blend_from_list([paths[Split.train][0]]), None, None], ) datasets = BlendedMegatronDatasetBuilder( TestDataset, [1000, None, None], lambda: True, config @@ -187,11 +179,7 @@ def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]: config = BlendedMegatronDatasetConfig( random_seed=1234, sequence_length=_SEQUENCE_LENGTH, - blend_per_split=[ - blends_unweighted[Split.train], - None, - None, - ], + blend_per_split=[blends_unweighted[Split.train], None, None], ) datasets = BlendedMegatronDatasetBuilder( TestDataset, [1000, None, None], lambda: True, config @@ -245,11 +233,7 @@ def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]: config = BlendedMegatronDatasetConfig( random_seed=1234, sequence_length=_SEQUENCE_LENGTH, - blend_per_split=[ - blends[Split.train], - blends[Split.valid], - blends[Split.test], - ], + blend_per_split=[blends[Split.train], blends[Split.valid], blends[Split.test]], ) datasets = BlendedMegatronDatasetBuilder( TestDataset, [100, 100, 100], lambda: True, config diff --git a/tests/unit_tests/data/test_gpt_dataset.py b/tests/unit_tests/data/test_gpt_dataset.py index 906a5728de..f10be883bf 100644 --- a/tests/unit_tests/data/test_gpt_dataset.py +++ b/tests/unit_tests/data/test_gpt_dataset.py @@ -96,7 +96,7 @@ def test_mock_gpt_dataset(): assert torch.all(sample['labels'][argmax + 1 :] == 0) assert not torch.any( sample['loss_mask'][ - torch.logical_and(sample['labels'] == tokenizer.eod, sample['labels'] == 0,) + torch.logical_and(sample['labels'] == tokenizer.eod, sample['labels'] == 0) ] ) diff --git a/tests/unit_tests/data/test_multimodal_dataset.py b/tests/unit_tests/data/test_multimodal_dataset.py index ef5430c2da..a9a30c02ec 100644 --- a/tests/unit_tests/data/test_multimodal_dataset.py +++ b/tests/unit_tests/data/test_multimodal_dataset.py @@ -25,7 +25,7 @@ def test_mock_multimodal_dataset(): torch.distributed.barrier() else: compile_helpers() - + config = MultimodalDatasetConfig( random_seed=1234, sequence_length=1024, diff --git a/tests/unit_tests/data/test_preprocess_data.py b/tests/unit_tests/data/test_preprocess_data.py index 8d35e4c5c0..0b460f51a9 100644 --- a/tests/unit_tests/data/test_preprocess_data.py +++ b/tests/unit_tests/data/test_preprocess_data.py @@ -82,14 +82,12 @@ def do_test_preprocess_data(temp_dir, extra_args=[]): dummy_jsonl(path_to_raws) # build the datasets - build_datasets( - path_to_raws, path_to_data, extra_args=extra_args, - ) + build_datasets(path_to_raws, path_to_data, extra_args=extra_args) # merge the datasets merge_datasets(path_to_data) - sys.argv = [sys.argv[0], "--input", None, "--output-prefix", None,] + extra_args + sys.argv = [sys.argv[0], "--input", None, "--output-prefix", None] + extra_args encoder = Encoder(build_args()) encoder.initializer() @@ -184,6 +182,7 @@ def gpt2_merge(odir): writer.write(requests.get(PRETRAINED_MERGES_ARCHIVE_MAP['gpt2']).content) return path + @pytest.mark.skip(reason="Tests are flaky and need to be debugged") def test_preprocess_data_gpt(): with tempfile.TemporaryDirectory() as temp_dir: @@ -214,6 +213,7 @@ def bert_vocab(odir): writer.write(requests.get(__HUGGINGFACE_BERT_BASE_UNCASED_VOCAB).content) return path + @pytest.mark.skip(reason="Tests are flaky and need to be debugged") def test_preprocess_data_bert(): with tempfile.TemporaryDirectory() as temp_dir: @@ -239,4 +239,4 @@ def test_preprocess_data_bert(): if __name__ == "__main__": test_preprocess_data_gpt() - test_preprocess_data_bert() \ No newline at end of file + test_preprocess_data_bert() diff --git a/tests/unit_tests/data/test_preprocess_mmdata.py b/tests/unit_tests/data/test_preprocess_mmdata.py index 8aab96e64a..d6ad4eddc7 100644 --- a/tests/unit_tests/data/test_preprocess_mmdata.py +++ b/tests/unit_tests/data/test_preprocess_mmdata.py @@ -74,9 +74,7 @@ def do_test_preprocess_mmdata(temp_dir, extra_args=[]): dummy_img(path_to_raws_txt, path_to_raws_img) # build the datasets - build_datasets( - path_to_raws_txt, path_to_raws_img, path_to_data, extra_args=extra_args, - ) + build_datasets(path_to_raws_txt, path_to_raws_img, path_to_data, extra_args=extra_args) # merge the datasets merge_datasets(path_to_data) diff --git a/tests/unit_tests/dist_checkpointing/__init__.py b/tests/unit_tests/dist_checkpointing/__init__.py index 3b4a7896d7..d6c2701891 100644 --- a/tests/unit_tests/dist_checkpointing/__init__.py +++ b/tests/unit_tests/dist_checkpointing/__init__.py @@ -3,15 +3,15 @@ from pathlib import Path from shutil import rmtree from tempfile import TemporaryDirectory -from typing import Union, Optional +from typing import Optional, Union -from tests.unit_tests.test_utilities import Utils from tests.unit_tests.dist_checkpointing.utils import ( - setup_model_and_optimizer, init_basic_mock_args, init_checkpointing_mock_args, initialize_gpt_model, + setup_model_and_optimizer, ) +from tests.unit_tests.test_utilities import Utils def empty_dir(path: Path): @@ -25,23 +25,23 @@ def empty_dir(path: Path): class TempNamedDir(TemporaryDirectory): - """ TemporaryDirectory with a fully named directory. Empties the dir if not empty. """ - def __init__(self, name: Union[str, Path], sync=True, - ignore_cleanup_errors=False) -> None: + """TemporaryDirectory with a fully named directory. Empties the dir if not empty.""" + + def __init__(self, name: Union[str, Path], sync=True, ignore_cleanup_errors=False) -> None: self.name = str(name) if Utils.rank == 0: os.makedirs(name, exist_ok=True) empty_dir(Path(name)) if sync: import torch + torch.distributed.barrier() else: os.makedirs(name, exist_ok=True) self._ignore_cleanup_errors = ignore_cleanup_errors self._finalizer = weakref.finalize( - self, self._cleanup, self.name, - warn_message="Implicitly cleaning up {!r}".format(self) + self, self._cleanup, self.name, warn_message="Implicitly cleaning up {!r}".format(self) ) self.sync = sync @@ -49,6 +49,7 @@ def cleanup(self, override_sync: Optional[bool] = None) -> None: sync = self.sync if override_sync is None else override_sync if sync: import torch + torch.distributed.barrier() if Utils.rank == 0: @@ -58,6 +59,7 @@ def __enter__(self): path = Path(super().__enter__()) if self.sync: import torch + torch.distributed.barrier() return path diff --git a/tests/unit_tests/dist_checkpointing/conftest.py b/tests/unit_tests/dist_checkpointing/conftest.py index 655550d632..fed9cdb482 100644 --- a/tests/unit_tests/dist_checkpointing/conftest.py +++ b/tests/unit_tests/dist_checkpointing/conftest.py @@ -18,4 +18,3 @@ def get_pyt_dist_save_sharded_strategy(): new=get_pyt_dist_save_sharded_strategy, ) as _fixture: yield _fixture - diff --git a/tests/unit_tests/dist_checkpointing/models/common.py b/tests/unit_tests/dist_checkpointing/models/common.py index 4159a2a90c..4b908ba3fc 100644 --- a/tests/unit_tests/dist_checkpointing/models/common.py +++ b/tests/unit_tests/dist_checkpointing/models/common.py @@ -3,34 +3,45 @@ import torch -from megatron.core.dist_checkpointing import save, load, load_plain_tensors from megatron.core import parallel_state +from megatron.core.dist_checkpointing import load, load_plain_tensors, save from megatron.core.dist_checkpointing.dict_utils import diff -from megatron.core.dist_checkpointing.serialization import \ - get_default_save_sharded_strategy, get_default_load_sharded_strategy -from megatron.core.dist_checkpointing.strategies.fully_parallel import \ - FullyParallelSaveStrategyWrapper, FullyParallelLoadStrategyWrapper +from megatron.core.dist_checkpointing.serialization import ( + get_default_load_sharded_strategy, + get_default_save_sharded_strategy, +) +from megatron.core.dist_checkpointing.strategies.fully_parallel import ( + FullyParallelLoadStrategyWrapper, + FullyParallelSaveStrategyWrapper, +) from megatron.core.dist_checkpointing.validation import StrictHandling from tests.unit_tests.dist_checkpointing import TempNamedDir from tests.unit_tests.test_utilities import Utils def common_test_simple_sharded_state_dict_save_load( - initialize_model_fn, tmp_path_dist_ckpt, src_layer_spec_fn, dst_layer_spec_fn): - """ Simple save and load sanity check, without any equality tests. """ + initialize_model_fn, tmp_path_dist_ckpt, src_layer_spec_fn, dst_layer_spec_fn +): + """Simple save and load sanity check, without any equality tests.""" tp = 2 pp = 4 Utils.initialize_model_parallel(tp, pp) - gpt_model = initialize_model_fn(1, src_layer_spec_fn, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp) + gpt_model = initialize_model_fn( + 1, src_layer_spec_fn, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp + ) with TempNamedDir(tmp_path_dist_ckpt / 'test_gpt_model') as ckpt_dir: # Save sharded_state_dict = gpt_model.sharded_state_dict() save(sharded_state_dict, ckpt_dir) # Load - gpt_model = initialize_model_fn(2, dst_layer_spec_fn, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp) + gpt_model = initialize_model_fn( + 2, dst_layer_spec_fn, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp + ) sharded_state_dict = gpt_model.sharded_state_dict() - state_dict, missing_keys, unexpected_keys = load(sharded_state_dict, ckpt_dir, strict=StrictHandling.RETURN_ALL) + state_dict, missing_keys, unexpected_keys = load( + sharded_state_dict, ckpt_dir, strict=StrictHandling.RETURN_ALL + ) # Potential mismatch is because of extra states which is ok assert all('_extra_state' in k for k in missing_keys) assert all('_extra_state' in k for k in unexpected_keys) @@ -38,21 +49,37 @@ def common_test_simple_sharded_state_dict_save_load( Utils.destroy_model_parallel() -def common_test_parallel_reconfiguration_e2e(initialize_model_fn, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, - src_layer_spec_fn, dst_layer_spec_fn, use_fpsl, - load_order="tp-dp-pp", store_order="tp-dp-pp"): - """ Test model saving and loading with different TP/PP """ - with TempNamedDir(tmp_path_dist_ckpt / 'test_gpt_model_reconfiguration_model_A') as ckpt_dir_A, \ - TempNamedDir(tmp_path_dist_ckpt / 'test_gpt_model_reconfiguration_model_B') as ckpt_dir_B: +def common_test_parallel_reconfiguration_e2e( + initialize_model_fn, + tmp_path_dist_ckpt, + src_tp_pp, + dest_tp_pp, + src_layer_spec_fn, + dst_layer_spec_fn, + use_fpsl, + load_order="tp-dp-pp", + store_order="tp-dp-pp", +): + """Test model saving and loading with different TP/PP""" + with TempNamedDir( + tmp_path_dist_ckpt / 'test_gpt_model_reconfiguration_model_A' + ) as ckpt_dir_A, TempNamedDir( + tmp_path_dist_ckpt / 'test_gpt_model_reconfiguration_model_B' + ) as ckpt_dir_B: # Save checkpoint A Utils.initialize_model_parallel(*src_tp_pp, order=load_order) - gpt_model_A = initialize_model_fn(1, src_layer_spec_fn, tensor_model_parallel_size=src_tp_pp[0], pipeline_model_parallel_size=src_tp_pp[1]) + gpt_model_A = initialize_model_fn( + 1, + src_layer_spec_fn, + tensor_model_parallel_size=src_tp_pp[0], + pipeline_model_parallel_size=src_tp_pp[1], + ) save_strategy = get_default_save_sharded_strategy() if use_fpsl: save_strategy = FullyParallelSaveStrategyWrapper( save_strategy, parallel_state.get_data_parallel_group(with_context_parallel=True), - True + True, ) save(gpt_model_A.sharded_state_dict(), ckpt_dir_A, save_strategy) regular_state_dict_A = gpt_model_A.state_dict() @@ -61,13 +88,23 @@ def common_test_parallel_reconfiguration_e2e(initialize_model_fn, tmp_path_dist_ # Load checkpoint A with different TP/PP and save as checkpoint B # No FPS this time, only FPL Utils.initialize_model_parallel(*dest_tp_pp, order=store_order) - gpt_model_B = initialize_model_fn(2, dst_layer_spec_fn, tensor_model_parallel_size=dest_tp_pp[0], pipeline_model_parallel_size=dest_tp_pp[1]) + gpt_model_B = initialize_model_fn( + 2, + dst_layer_spec_fn, + tensor_model_parallel_size=dest_tp_pp[0], + pipeline_model_parallel_size=dest_tp_pp[1], + ) if use_fpsl: load_strategy = get_default_load_sharded_strategy(ckpt_dir_A) load_strategy = FullyParallelLoadStrategyWrapper(load_strategy) else: load_strategy = None - state_dict, missing_keys, unexpected_keys = load(gpt_model_B.sharded_state_dict(), ckpt_dir_A, load_strategy, strict=StrictHandling.RETURN_ALL) + state_dict, missing_keys, unexpected_keys = load( + gpt_model_B.sharded_state_dict(), + ckpt_dir_A, + load_strategy, + strict=StrictHandling.RETURN_ALL, + ) # Potential mismatch is because of extra states which is ok assert all('_extra_state' in k for k in missing_keys) assert all('_extra_state' in k for k in unexpected_keys) @@ -84,10 +121,12 @@ def common_test_parallel_reconfiguration_e2e(initialize_model_fn, tmp_path_dist_ assert not any(map(bool, diffs)), diffs # Test both regular state dicts are equal, turning FP8 states to bytes first - regular_state_dict_A = {k: v for k, v in regular_state_dict_A.items() - if not k.endswith('_extra_state')} - regular_state_dict_B = {k: v for k, v in regular_state_dict_B.items() - if not k.endswith('_extra_state')} + regular_state_dict_A = { + k: v for k, v in regular_state_dict_A.items() if not k.endswith('_extra_state') + } + regular_state_dict_B = { + k: v for k, v in regular_state_dict_B.items() if not k.endswith('_extra_state') + } diffs = diff(regular_state_dict_A, regular_state_dict_B) assert not any(map(bool, diffs)), diffs Utils.destroy_model_parallel() @@ -97,11 +136,18 @@ def common_test_state_dict_comparison(initialize_model_fn, tmp_path_dist_ckpt): tp = 2 pp = 4 Utils.initialize_model_parallel(tp, pp) - with TempNamedDir(tmp_path_dist_ckpt / 'test_state_dict_comparison_A') as ckpt_dir_A, \ - TempNamedDir(tmp_path_dist_ckpt / 'test_state_dict_comparison_B') as ckpt_dir_B: - gpt_model_A = initialize_model_fn(1, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp) + with TempNamedDir( + tmp_path_dist_ckpt / 'test_state_dict_comparison_A' + ) as ckpt_dir_A, TempNamedDir( + tmp_path_dist_ckpt / 'test_state_dict_comparison_B' + ) as ckpt_dir_B: + gpt_model_A = initialize_model_fn( + 1, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp + ) save(gpt_model_A.sharded_state_dict(), ckpt_dir_A) - gpt_model_B = initialize_model_fn(2, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp) + gpt_model_B = initialize_model_fn( + 2, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp + ) save(gpt_model_B.sharded_state_dict(), ckpt_dir_B) state_dict_A = load_plain_tensors(ckpt_dir_A) @@ -114,13 +160,16 @@ def common_test_state_dict_comparison(initialize_model_fn, tmp_path_dist_ckpt): # Test that A *keys* match B *keys*, but the tensors content is different only_left, only_right, mismatch = diff(state_dict_A, state_dict_B) - assert (not only_left and not only_right), (only_left, only_right) + assert not only_left and not only_right, (only_left, only_right) assert len(mismatch) == len(state_dict_A), (len(mismatch), (len(state_dict_A))) Utils.destroy_model_parallel() -def common_test_vocab_size_padding_change(initialize_model_fn, tmp_path_dist_ckpt, vocab_size_base, src_tp_pp, dest_tp_pp): - """ Test model loading with different vocab size (caused by TP padding). """ +def common_test_vocab_size_padding_change( + initialize_model_fn, tmp_path_dist_ckpt, vocab_size_base, src_tp_pp, dest_tp_pp +): + """Test model loading with different vocab size (caused by TP padding).""" + def get_test_vocab_size(make_divisible_by=128): divisor = make_divisible_by * parallel_state.get_tensor_model_parallel_world_size() return int(math.ceil(vocab_size_base / divisor)) * divisor @@ -131,17 +180,30 @@ def get_test_vocab_size(make_divisible_by=128): 'embedding.word_embeddings.weight', } - with TempNamedDir(tmp_path_dist_ckpt / 'test_vocab_size_padding_change_A') as ckpt_dir_A, \ - TempNamedDir(tmp_path_dist_ckpt / 'test_vocab_size_padding_change_B') as ckpt_dir_B: + with TempNamedDir( + tmp_path_dist_ckpt / 'test_vocab_size_padding_change_A' + ) as ckpt_dir_A, TempNamedDir( + tmp_path_dist_ckpt / 'test_vocab_size_padding_change_B' + ) as ckpt_dir_B: # Save checkpoint A Utils.initialize_model_parallel(*src_tp_pp) - gpt_model_A = initialize_model_fn(1, tensor_model_parallel_size=src_tp_pp[0], pipeline_model_parallel_size=src_tp_pp[1], vocab_size=get_test_vocab_size()) + gpt_model_A = initialize_model_fn( + 1, + tensor_model_parallel_size=src_tp_pp[0], + pipeline_model_parallel_size=src_tp_pp[1], + vocab_size=get_test_vocab_size(), + ) save(gpt_model_A.sharded_state_dict(), ckpt_dir_A) Utils.destroy_model_parallel() # Load checkpoint A with different TP/PP and save as checkpoint B Utils.initialize_model_parallel(*dest_tp_pp) - gpt_model_B = initialize_model_fn(2, tensor_model_parallel_size=dest_tp_pp[0], pipeline_model_parallel_size=dest_tp_pp[1], vocab_size=get_test_vocab_size()) + gpt_model_B = initialize_model_fn( + 2, + tensor_model_parallel_size=dest_tp_pp[0], + pipeline_model_parallel_size=dest_tp_pp[1], + vocab_size=get_test_vocab_size(), + ) state_dict = load(gpt_model_B.sharded_state_dict(), ckpt_dir_A) gpt_model_B.load_state_dict(state_dict) save(gpt_model_B.sharded_state_dict(), ckpt_dir_B) @@ -156,7 +218,9 @@ def get_test_vocab_size(make_divisible_by=128): if vocab_layer_key in plain_state_dict_A: ten_A = plain_state_dict_A.pop(vocab_layer_key) ten_B = plain_state_dict_B.pop(vocab_layer_key) - assert torch.all(ten_A[:vocab_size_base] == ten_B[:vocab_size_base]), vocab_layer_key + assert torch.all( + ten_A[:vocab_size_base] == ten_B[:vocab_size_base] + ), vocab_layer_key # Test other tensors are equal diffs = diff(plain_state_dict_A, plain_state_dict_B) diff --git a/tests/unit_tests/dist_checkpointing/models/test_bert_model.py b/tests/unit_tests/dist_checkpointing/models/test_bert_model.py index 74af0bc674..e4838faa3d 100644 --- a/tests/unit_tests/dist_checkpointing/models/test_bert_model.py +++ b/tests/unit_tests/dist_checkpointing/models/test_bert_model.py @@ -22,20 +22,35 @@ from tests.unit_tests.test_utilities import Utils -def initialize_bert_model(seed, layer_spec_fn=bert_layer_with_transformer_engine_spec, vocab_size=128, **config_kwargs): +def initialize_bert_model( + seed, layer_spec_fn=bert_layer_with_transformer_engine_spec, vocab_size=128, **config_kwargs +): os.environ['NVTE_ALLOW_NONDETERMINISTIC_ALGO'] = '0' torch.manual_seed(seed) model_parallel_cuda_manual_seed(seed) layer_spec = layer_spec_fn() if callable(layer_spec_fn) else layer_spec_fn - default_config_kwargs=dict(num_layers=8, hidden_size=16, num_attention_heads=8, use_cpu_initialization=True, pipeline_dtype=torch.bfloat16) + default_config_kwargs = dict( + num_layers=8, + hidden_size=16, + num_attention_heads=8, + use_cpu_initialization=True, + pipeline_dtype=torch.bfloat16, + ) default_config_kwargs.update(**config_kwargs) transformer_config = TransformerConfig(**default_config_kwargs) pre_process = ps.is_pipeline_first_stage() post_process = ps.is_pipeline_last_stage() - model = BertModel(config=transformer_config, transformer_layer_spec=layer_spec, vocab_size=vocab_size, max_sequence_length=4, - pre_process=pre_process, post_process=post_process, num_tokentypes=0) + model = BertModel( + config=transformer_config, + transformer_layer_spec=layer_spec, + vocab_size=vocab_size, + max_sequence_length=4, + pre_process=pre_process, + post_process=post_process, + num_tokentypes=0, + ) with torch.no_grad(): for p in model.parameters(): @@ -44,53 +59,95 @@ def initialize_bert_model(seed, layer_spec_fn=bert_layer_with_transformer_engine class TestBertModel: - @pytest.mark.parametrize('src_layer_spec', [bert_layer_with_transformer_engine_spec, bert_layer_local_spec]) - @pytest.mark.parametrize('dst_layer_spec', [bert_layer_with_transformer_engine_spec, bert_layer_local_spec]) - def test_sharded_state_dict_save_load(self, tmp_path_dist_ckpt, - src_layer_spec, dst_layer_spec): - common_test_simple_sharded_state_dict_save_load(initialize_bert_model, tmp_path_dist_ckpt, - src_layer_spec, dst_layer_spec) + @pytest.mark.parametrize( + 'src_layer_spec', [bert_layer_with_transformer_engine_spec, bert_layer_local_spec] + ) + @pytest.mark.parametrize( + 'dst_layer_spec', [bert_layer_with_transformer_engine_spec, bert_layer_local_spec] + ) + def test_sharded_state_dict_save_load(self, tmp_path_dist_ckpt, src_layer_spec, dst_layer_spec): + common_test_simple_sharded_state_dict_save_load( + initialize_bert_model, tmp_path_dist_ckpt, src_layer_spec, dst_layer_spec + ) class TestBERTModelReconfiguration: def setup_method(self, method): pass - + def teardown_method(self, method): Utils.destroy_model_parallel() - + @pytest.mark.parametrize( ('use_fpsl', 'src_tp_pp', 'dest_tp_pp', 'src_layer_spec', 'dst_layer_spec'), [ - (False, (2, 4), (4, 2), bert_layer_with_transformer_engine_spec, bert_layer_with_transformer_engine_spec), - (False, (1, 8), (8, 1), bert_layer_with_transformer_engine_spec, bert_layer_with_transformer_engine_spec), - (True, (2, 1), (1, 8), bert_layer_with_transformer_engine_spec, bert_layer_with_transformer_engine_spec), - (False, (1, 1), (2, 2), bert_layer_with_transformer_engine_spec, bert_layer_with_transformer_engine_spec), + ( + False, + (2, 4), + (4, 2), + bert_layer_with_transformer_engine_spec, + bert_layer_with_transformer_engine_spec, + ), + ( + False, + (1, 8), + (8, 1), + bert_layer_with_transformer_engine_spec, + bert_layer_with_transformer_engine_spec, + ), + ( + True, + (2, 1), + (1, 8), + bert_layer_with_transformer_engine_spec, + bert_layer_with_transformer_engine_spec, + ), + ( + False, + (1, 1), + (2, 2), + bert_layer_with_transformer_engine_spec, + bert_layer_with_transformer_engine_spec, + ), (True, (2, 1), (1, 8), bert_layer_local_spec, bert_layer_local_spec), (True, (1, 1), (2, 4), bert_layer_with_transformer_engine_spec, bert_layer_local_spec), (False, (1, 8), (2, 1), bert_layer_local_spec, bert_layer_with_transformer_engine_spec), - ] + ], ) - def test_parallel_reconfiguration_e2e(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, - src_layer_spec, dst_layer_spec, use_fpsl): - """ Test model saving and loading with different TP/PP """ + def test_parallel_reconfiguration_e2e( + self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, src_layer_spec, dst_layer_spec, use_fpsl + ): + """Test model saving and loading with different TP/PP""" Utils.initialize_model_parallel(src_tp_pp[0], src_tp_pp[1]) - - common_test_parallel_reconfiguration_e2e(initialize_bert_model, tmp_path_dist_ckpt, src_tp_pp, - dest_tp_pp, src_layer_spec, dst_layer_spec, use_fpsl) + + common_test_parallel_reconfiguration_e2e( + initialize_bert_model, + tmp_path_dist_ckpt, + src_tp_pp, + dest_tp_pp, + src_layer_spec, + dst_layer_spec, + use_fpsl, + ) def test_state_dict_comparison(self, tmp_path_dist_ckpt): common_test_state_dict_comparison(initialize_bert_model, tmp_path_dist_ckpt) - @pytest.mark.parametrize("vocab_size_base,src_tp_pp,dest_tp_pp", [ - (128, (2, 4), (4, 2)), - (17, (1, 8), (8, 1)), - (127, (1, 8), (8, 1)), - (31123, (1, 1), (1, 8)), - (17, (1, 1), (1, 8)), - ]) - def test_vocab_size_padding_change(self, tmp_path_dist_ckpt, vocab_size_base, src_tp_pp, dest_tp_pp): - """ Test model loading with different vocab size (caused by TP padding). """ + @pytest.mark.parametrize( + "vocab_size_base,src_tp_pp,dest_tp_pp", + [ + (128, (2, 4), (4, 2)), + (17, (1, 8), (8, 1)), + (127, (1, 8), (8, 1)), + (31123, (1, 1), (1, 8)), + (17, (1, 1), (1, 8)), + ], + ) + def test_vocab_size_padding_change( + self, tmp_path_dist_ckpt, vocab_size_base, src_tp_pp, dest_tp_pp + ): + """Test model loading with different vocab size (caused by TP padding).""" Utils.initialize_model_parallel(src_tp_pp[0], src_tp_pp[1]) - common_test_vocab_size_padding_change(initialize_bert_model, tmp_path_dist_ckpt, vocab_size_base, - src_tp_pp, dest_tp_pp) + common_test_vocab_size_padding_change( + initialize_bert_model, tmp_path_dist_ckpt, vocab_size_base, src_tp_pp, dest_tp_pp + ) diff --git a/tests/unit_tests/dist_checkpointing/models/test_gpt_model.py b/tests/unit_tests/dist_checkpointing/models/test_gpt_model.py index b044ff15c7..20699d4500 100644 --- a/tests/unit_tests/dist_checkpointing/models/test_gpt_model.py +++ b/tests/unit_tests/dist_checkpointing/models/test_gpt_model.py @@ -23,13 +23,25 @@ def initialize_gpt_model(seed, layer_spec_fn=gpt_te_spec, vocab_size=128, **conf torch.manual_seed(seed) model_parallel_cuda_manual_seed(seed) - default_config_kwargs=dict(num_layers=8, hidden_size=16, num_attention_heads=8, use_cpu_initialization=True, pipeline_dtype=torch.bfloat16) + default_config_kwargs = dict( + num_layers=8, + hidden_size=16, + num_attention_heads=8, + use_cpu_initialization=True, + pipeline_dtype=torch.bfloat16, + ) default_config_kwargs.update(**config_kwargs) transformer_config = TransformerConfig(**default_config_kwargs) pre_process = ps.is_pipeline_first_stage() post_process = ps.is_pipeline_last_stage() - model = GPTModel(config=transformer_config, transformer_layer_spec=layer_spec_fn(), vocab_size=vocab_size, max_sequence_length=4, - pre_process=pre_process, post_process=post_process) + model = GPTModel( + config=transformer_config, + transformer_layer_spec=layer_spec_fn(), + vocab_size=vocab_size, + max_sequence_length=4, + pre_process=pre_process, + post_process=post_process, + ) with torch.no_grad(): for p in model.parameters(): @@ -40,53 +52,86 @@ def initialize_gpt_model(seed, layer_spec_fn=gpt_te_spec, vocab_size=128, **conf class TestGPTModel: @pytest.mark.parametrize('src_layer_spec_fn', [gpt_te_spec, gpt_local_spec]) @pytest.mark.parametrize('dst_layer_spec_fn', [gpt_te_spec, gpt_local_spec]) - def test_sharded_state_dict_save_load(self, tmp_path_dist_ckpt, - src_layer_spec_fn, dst_layer_spec_fn): - common_test_simple_sharded_state_dict_save_load(initialize_gpt_model, tmp_path_dist_ckpt, - src_layer_spec_fn, dst_layer_spec_fn) + def test_sharded_state_dict_save_load( + self, tmp_path_dist_ckpt, src_layer_spec_fn, dst_layer_spec_fn + ): + common_test_simple_sharded_state_dict_save_load( + initialize_gpt_model, tmp_path_dist_ckpt, src_layer_spec_fn, dst_layer_spec_fn + ) class TestGPTModelReconfiguration: def setup_method(self, method): pass - + def teardown_method(self, method): Utils.destroy_model_parallel() @pytest.mark.parametrize( - ('use_fpsl', 'load_order', 'store_order', 'src_tp_pp', 'dest_tp_pp', 'src_layer_spec_fn', 'dst_layer_spec_fn'), + ( + 'use_fpsl', + 'load_order', + 'store_order', + 'src_tp_pp', + 'dest_tp_pp', + 'src_layer_spec_fn', + 'dst_layer_spec_fn', + ), [ (False, 'tp-dp-pp', 'tp-dp-pp', (2, 4), (4, 2), gpt_te_spec, gpt_te_spec), (False, 'tp-pp-dp', 'tp-pp-dp', (1, 8), (8, 1), gpt_te_spec, gpt_te_spec), - (True, 'tp-dp-pp', 'tp-pp-dp', (2, 1), (1, 8), gpt_te_spec, gpt_te_spec), + (True, 'tp-dp-pp', 'tp-pp-dp', (2, 1), (1, 8), gpt_te_spec, gpt_te_spec), (False, 'tp-dp-pp', 'tp-dp-pp', (1, 1), (2, 2), gpt_te_spec, gpt_te_spec), - (True, 'tp-pp-dp', 'tp-pp-dp', (2, 1), (1, 8), gpt_local_spec, gpt_local_spec), + (True, 'tp-pp-dp', 'tp-pp-dp', (2, 1), (1, 8), gpt_local_spec, gpt_local_spec), (False, 'tp-dp-pp', 'tp-pp-dp', (1, 1), (2, 4), gpt_te_spec, gpt_local_spec), - (True, 'tp-dp-pp', 'tp-dp-pp', (2, 4), (4, 2), gpt_local_spec, gpt_te_spec), + (True, 'tp-dp-pp', 'tp-dp-pp', (2, 4), (4, 2), gpt_local_spec, gpt_te_spec), (False, 'tp-pp-dp', 'tp-pp-dp', (2, 1), (1, 8), gpt_te_spec, gpt_local_spec), (False, 'tp-dp-pp', 'tp-pp-dp', (2, 4), (2, 4), gpt_local_spec, gpt_local_spec), - ] + ], ) - def test_parallel_reconfiguration_e2e(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, - src_layer_spec_fn, dst_layer_spec_fn, use_fpsl, load_order, store_order): - """ Test model saving and loading with different TP/PP """ + def test_parallel_reconfiguration_e2e( + self, + tmp_path_dist_ckpt, + src_tp_pp, + dest_tp_pp, + src_layer_spec_fn, + dst_layer_spec_fn, + use_fpsl, + load_order, + store_order, + ): + """Test model saving and loading with different TP/PP""" Utils.initialize_model_parallel(src_tp_pp[0], src_tp_pp[1]) - common_test_parallel_reconfiguration_e2e(initialize_gpt_model, tmp_path_dist_ckpt, src_tp_pp, - dest_tp_pp, src_layer_spec_fn, dst_layer_spec_fn, use_fpsl, load_order, store_order) - + common_test_parallel_reconfiguration_e2e( + initialize_gpt_model, + tmp_path_dist_ckpt, + src_tp_pp, + dest_tp_pp, + src_layer_spec_fn, + dst_layer_spec_fn, + use_fpsl, + load_order, + store_order, + ) def test_state_dict_comparison(self, tmp_path_dist_ckpt): common_test_state_dict_comparison(initialize_gpt_model, tmp_path_dist_ckpt) - @pytest.mark.parametrize("vocab_size_base,src_tp_pp,dest_tp_pp", [ - (128, (2, 4), (4, 2)), - (17, (1, 8), (8, 1)), - (127, (1, 8), (8, 1)), - (31123, (1, 1), (1, 8)), - (17, (1, 1), (1, 8)), - ]) - def test_vocab_size_padding_change(self, tmp_path_dist_ckpt, vocab_size_base, src_tp_pp, dest_tp_pp): - """ Test model loading with different vocab size (caused by TP padding). """ + @pytest.mark.parametrize( + "vocab_size_base,src_tp_pp,dest_tp_pp", + [ + (128, (2, 4), (4, 2)), + (17, (1, 8), (8, 1)), + (127, (1, 8), (8, 1)), + (31123, (1, 1), (1, 8)), + (17, (1, 1), (1, 8)), + ], + ) + def test_vocab_size_padding_change( + self, tmp_path_dist_ckpt, vocab_size_base, src_tp_pp, dest_tp_pp + ): + """Test model loading with different vocab size (caused by TP padding).""" Utils.initialize_model_parallel(src_tp_pp[0], src_tp_pp[1]) - common_test_vocab_size_padding_change(initialize_gpt_model, tmp_path_dist_ckpt, vocab_size_base, - src_tp_pp, dest_tp_pp) + common_test_vocab_size_padding_change( + initialize_gpt_model, tmp_path_dist_ckpt, vocab_size_base, src_tp_pp, dest_tp_pp + ) diff --git a/tests/unit_tests/dist_checkpointing/models/test_grouped_mlp.py b/tests/unit_tests/dist_checkpointing/models/test_grouped_mlp.py index df0005e1a3..1bab7ce54b 100644 --- a/tests/unit_tests/dist_checkpointing/models/test_grouped_mlp.py +++ b/tests/unit_tests/dist_checkpointing/models/test_grouped_mlp.py @@ -30,8 +30,15 @@ def initialize_grouped_mlp(seed, glu=True, **config_kwargs): pp_size = parallel_state.get_pipeline_model_parallel_world_size() num_moe_experts = 8 num_local_experts = num_moe_experts // parallel_state.get_expert_model_parallel_world_size() - default_config_kwargs = dict(num_layers=pp_size, hidden_size=12, num_attention_heads=4, num_moe_experts=num_moe_experts, use_cpu_initialization=True, - gated_linear_unit=glu, add_bias_linear=False) + default_config_kwargs = dict( + num_layers=pp_size, + hidden_size=12, + num_attention_heads=4, + num_moe_experts=num_moe_experts, + use_cpu_initialization=True, + gated_linear_unit=glu, + add_bias_linear=False, + ) default_config_kwargs.update(**config_kwargs) transformer_config = TransformerConfig(**default_config_kwargs) model = GroupedMLP(num_local_experts, transformer_config) @@ -47,36 +54,44 @@ def get_pp_offsets(): class TestGroupedMLPReconfiguration: def setup_method(self, method): pass - + def teardown_method(self, method): Utils.destroy_model_parallel() - @pytest.mark.parametrize("use_fpsl,src_tp_pp_exp,dest_tp_pp_exp,use_glu", [ - # changing PP is impossible because the number of layers must be the same - (False, (2, 4, 1), (2, 4, 1), False), - (True, (2, 4, 1), (2, 4, 1), False), - (False, (1, 1, 1), (1, 1, 1), False), - (True, (1, 1, 1), (1, 1, 4), False), - (False, (1, 1, 8), (1, 1, 2), False), - (False, (2, 2, 2), (4, 2, 1), False), - (True, (1, 1, 4), (8, 1, 1), False), - (False, (1, 8, 1), (1, 8, 1), False), - (False, (1, 1, 4), (2, 1, 1), False), - (False, (1, 1, 1), (1, 1, 1), True), - (False, (1, 1, 1), (1, 1, 4), True), - (True, (1, 1, 1), (2, 1, 1), True), - (False, (1, 1, 4), (8, 1, 1), True), - (True, (2, 1, 4), (1, 1, 8), True), - (False, (2, 1, 4), (1, 1, 8), True), - ]) - def test_parallel_reconfiguration_e2e(self, tmp_path_dist_ckpt, src_tp_pp_exp, dest_tp_pp_exp, use_glu, use_fpsl): - """ Test model saving and loading with different TP/PP/expert parallelism """ + @pytest.mark.parametrize( + "use_fpsl,src_tp_pp_exp,dest_tp_pp_exp,use_glu", + [ + # changing PP is impossible because the number of layers must be the same + (False, (2, 4, 1), (2, 4, 1), False), + (True, (2, 4, 1), (2, 4, 1), False), + (False, (1, 1, 1), (1, 1, 1), False), + (True, (1, 1, 1), (1, 1, 4), False), + (False, (1, 1, 8), (1, 1, 2), False), + (False, (2, 2, 2), (4, 2, 1), False), + (True, (1, 1, 4), (8, 1, 1), False), + (False, (1, 8, 1), (1, 8, 1), False), + (False, (1, 1, 4), (2, 1, 1), False), + (False, (1, 1, 1), (1, 1, 1), True), + (False, (1, 1, 1), (1, 1, 4), True), + (True, (1, 1, 1), (2, 1, 1), True), + (False, (1, 1, 4), (8, 1, 1), True), + (True, (2, 1, 4), (1, 1, 8), True), + (False, (2, 1, 4), (1, 1, 8), True), + ], + ) + def test_parallel_reconfiguration_e2e( + self, tmp_path_dist_ckpt, src_tp_pp_exp, dest_tp_pp_exp, use_glu, use_fpsl + ): + """Test model saving and loading with different TP/PP/expert parallelism""" src_tp, src_pp, src_exp = src_tp_pp_exp dest_tp, dest_pp, dest_exp = dest_tp_pp_exp Utils.initialize_model_parallel(src_tp, src_pp, expert_model_parallel_size=src_exp) - - with TempNamedDir(tmp_path_dist_ckpt / 'test_grouped_mlp_reconfiguration_model_A') as ckpt_dir_A, \ - TempNamedDir(tmp_path_dist_ckpt / 'test_grouped_mlp_reconfiguration_model_B') as ckpt_dir_B: + + with TempNamedDir( + tmp_path_dist_ckpt / 'test_grouped_mlp_reconfiguration_model_A' + ) as ckpt_dir_A, TempNamedDir( + tmp_path_dist_ckpt / 'test_grouped_mlp_reconfiguration_model_B' + ) as ckpt_dir_B: # Save checkpoint A model_A = initialize_grouped_mlp(1, use_glu) sharded_state_dict = model_A.sharded_state_dict(sharded_offsets=get_pp_offsets()) @@ -86,7 +101,7 @@ def test_parallel_reconfiguration_e2e(self, tmp_path_dist_ckpt, src_tp_pp_exp, d save_strategy = FullyParallelSaveStrategyWrapper( save_strategy, parallel_state.get_data_parallel_group(with_context_parallel=True), - True + True, ) save(sharded_state_dict, ckpt_dir_A, save_strategy) Utils.destroy_model_parallel() @@ -97,11 +112,17 @@ def test_parallel_reconfiguration_e2e(self, tmp_path_dist_ckpt, src_tp_pp_exp, d model_B = initialize_grouped_mlp(2, use_glu) if use_fpsl: load_strategy = get_default_load_sharded_strategy(ckpt_dir_A) - load_strategy = FullyParallelLoadStrategyWrapper(load_strategy, - parallel_state.get_data_parallel_group(with_context_parallel=True)) + load_strategy = FullyParallelLoadStrategyWrapper( + load_strategy, + parallel_state.get_data_parallel_group(with_context_parallel=True), + ) else: load_strategy = None - state_dict = load(model_B.sharded_state_dict(sharded_offsets=get_pp_offsets()), ckpt_dir_A, load_strategy) + state_dict = load( + model_B.sharded_state_dict(sharded_offsets=get_pp_offsets()), + ckpt_dir_A, + load_strategy, + ) model_B.load_state_dict(state_dict) save(model_B.sharded_state_dict(sharded_offsets=get_pp_offsets()), ckpt_dir_B) Utils.destroy_model_parallel() @@ -114,41 +135,51 @@ def test_parallel_reconfiguration_e2e(self, tmp_path_dist_ckpt, src_tp_pp_exp, d assert not any(map(bool, diffs)), diffs Utils.destroy_model_parallel() - @pytest.mark.parametrize("src_module,src_tp_pp_exp,dest_tp_pp_exp,use_glu", [ - # changing PP is impossible because the number of layers must be the same - ('sequential', (2, 4, 1), (2, 4, 1), False), - ('sequential', (1, 1, 1), (1, 1, 4), False), - ('sequential', (2, 2, 2), (4, 2, 1), False), - ('sequential', (1, 1, 4), (8, 1, 1), False), - ('sequential', (2, 1, 4), (1, 1, 8), False), - ('sequential', (2, 4, 1), (2, 4, 1), True), - ('sequential', (1, 1, 1), (1, 1, 4), True), - ('sequential', (2, 2, 2), (4, 2, 1), True), - ('sequential', (1, 1, 4), (8, 1, 1), True), - ('sequential', (2, 1, 4), (1, 1, 8), True), - ('grouped', (2, 4, 1), (2, 4, 1), False), - ('grouped', (1, 1, 1), (1, 1, 4), False), - ('grouped', (2, 2, 2), (4, 2, 1), False), - ('grouped', (1, 1, 4), (8, 1, 1), False), - ('grouped', (2, 1, 4), (1, 1, 8), False), - ('grouped', (2, 4, 1), (2, 4, 1), True), - ('grouped', (1, 1, 1), (1, 1, 4), True), - ('grouped', (2, 2, 2), (4, 2, 1), True), - ('grouped', (1, 1, 4), (8, 1, 1), True), - ('grouped', (2, 1, 4), (1, 1, 8), True), - ]) - def test_sequential_grouped_mlp_interchangeable(self, tmp_path_dist_ckpt, src_tp_pp_exp, dest_tp_pp_exp, use_glu, src_module): - """ Test model saving and loading with different TP/PP/expert parallelism """ + @pytest.mark.parametrize( + "src_module,src_tp_pp_exp,dest_tp_pp_exp,use_glu", + [ + # changing PP is impossible because the number of layers must be the same + ('sequential', (2, 4, 1), (2, 4, 1), False), + ('sequential', (1, 1, 1), (1, 1, 4), False), + ('sequential', (2, 2, 2), (4, 2, 1), False), + ('sequential', (1, 1, 4), (8, 1, 1), False), + ('sequential', (2, 1, 4), (1, 1, 8), False), + ('sequential', (2, 4, 1), (2, 4, 1), True), + ('sequential', (1, 1, 1), (1, 1, 4), True), + ('sequential', (2, 2, 2), (4, 2, 1), True), + ('sequential', (1, 1, 4), (8, 1, 1), True), + ('sequential', (2, 1, 4), (1, 1, 8), True), + ('grouped', (2, 4, 1), (2, 4, 1), False), + ('grouped', (1, 1, 1), (1, 1, 4), False), + ('grouped', (2, 2, 2), (4, 2, 1), False), + ('grouped', (1, 1, 4), (8, 1, 1), False), + ('grouped', (2, 1, 4), (1, 1, 8), False), + ('grouped', (2, 4, 1), (2, 4, 1), True), + ('grouped', (1, 1, 1), (1, 1, 4), True), + ('grouped', (2, 2, 2), (4, 2, 1), True), + ('grouped', (1, 1, 4), (8, 1, 1), True), + ('grouped', (2, 1, 4), (1, 1, 8), True), + ], + ) + def test_sequential_grouped_mlp_interchangeable( + self, tmp_path_dist_ckpt, src_tp_pp_exp, dest_tp_pp_exp, use_glu, src_module + ): + """Test model saving and loading with different TP/PP/expert parallelism""" src_tp, src_pp, src_exp = src_tp_pp_exp dest_tp, dest_pp, dest_exp = dest_tp_pp_exp Utils.initialize_model_parallel(src_tp, src_pp, expert_model_parallel_size=src_exp) - with TempNamedDir(tmp_path_dist_ckpt / 'test_sequential_grouped_mlp_interchangeable_model_A') as ckpt_dir_A, \ - TempNamedDir(tmp_path_dist_ckpt / 'test_sequential_grouped_mlp_interchangeable_model_B') as ckpt_dir_B: + with TempNamedDir( + tmp_path_dist_ckpt / 'test_sequential_grouped_mlp_interchangeable_model_A' + ) as ckpt_dir_A, TempNamedDir( + tmp_path_dist_ckpt / 'test_sequential_grouped_mlp_interchangeable_model_B' + ) as ckpt_dir_B: # Save checkpoint A - + if src_module == 'sequential': - model_A = initialize_expert_layer(1, use_glu, add_bias_linear=False, moe_grouped_gemm=False) + model_A = initialize_expert_layer( + 1, use_glu, add_bias_linear=False, moe_grouped_gemm=False + ) else: model_A = initialize_grouped_mlp(1, use_glu) sharded_state_dict = model_A.sharded_state_dict(sharded_offsets=get_pp_offsets()) @@ -161,9 +192,15 @@ def test_sequential_grouped_mlp_interchangeable(self, tmp_path_dist_ckpt, src_tp if src_module == 'sequential': model_B = initialize_grouped_mlp(1, use_glu) else: - model_B = initialize_expert_layer(1, use_glu, add_bias_linear=False, moe_grouped_gemm=False) + model_B = initialize_expert_layer( + 1, use_glu, add_bias_linear=False, moe_grouped_gemm=False + ) load_strategy = None - state_dict = load(model_B.sharded_state_dict(sharded_offsets=get_pp_offsets()), ckpt_dir_A, load_strategy) + state_dict = load( + model_B.sharded_state_dict(sharded_offsets=get_pp_offsets()), + ckpt_dir_A, + load_strategy, + ) model_B.load_state_dict(state_dict) save(model_B.sharded_state_dict(sharded_offsets=get_pp_offsets()), ckpt_dir_B) Utils.destroy_model_parallel() @@ -174,4 +211,4 @@ def test_sequential_grouped_mlp_interchangeable(self, tmp_path_dist_ckpt, src_tp state_dict_B = load_plain_tensors(ckpt_dir_B) diffs = diff(state_dict_A, state_dict_B) assert not any(map(bool, diffs)), diffs - Utils.destroy_model_parallel() \ No newline at end of file + Utils.destroy_model_parallel() diff --git a/tests/unit_tests/dist_checkpointing/models/test_mlp_glu.py b/tests/unit_tests/dist_checkpointing/models/test_mlp_glu.py index 04148a44d4..1a0851039a 100644 --- a/tests/unit_tests/dist_checkpointing/models/test_mlp_glu.py +++ b/tests/unit_tests/dist_checkpointing/models/test_mlp_glu.py @@ -22,9 +22,16 @@ def initialize_mlp(glu=True): model_parallel_cuda_manual_seed(123) pp_size = parallel_state.get_pipeline_model_parallel_world_size() - transformer_config = TransformerConfig(num_layers=pp_size, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True, - gated_linear_unit=glu) - return MLP(transformer_config, get_gpt_layer_with_transformer_engine_spec().submodules.mlp.submodules) + transformer_config = TransformerConfig( + num_layers=pp_size, + hidden_size=12, + num_attention_heads=4, + use_cpu_initialization=True, + gated_linear_unit=glu, + ) + return MLP( + transformer_config, get_gpt_layer_with_transformer_engine_spec().submodules.mlp.submodules + ) def get_pp_offsets(): @@ -36,23 +43,29 @@ def get_pp_offsets(): class TestParallelMLPWithGLU: def setup_method(self, method): pass - + def teardown_method(self, method): Utils.destroy_model_parallel() - - @pytest.mark.parametrize("src_tp_pp,dest_tp_pp", [ - # changing PP is impossible because the number of layers must be the same - ((2, 2), (4, 2)), - ((1, 1), (8, 1)), - ((1, 8), (1, 8)), - ((1, 1), (2, 1)), - ]) + + @pytest.mark.parametrize( + "src_tp_pp,dest_tp_pp", + [ + # changing PP is impossible because the number of layers must be the same + ((2, 2), (4, 2)), + ((1, 1), (8, 1)), + ((1, 8), (1, 8)), + ((1, 1), (2, 1)), + ], + ) def test_parallel_reconfiguration_e2e(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp): - """ Test module saving and loading with different TP/PP """ + """Test module saving and loading with different TP/PP""" Utils.initialize_model_parallel(*src_tp_pp) - - with TempNamedDir(tmp_path_dist_ckpt / 'test_mlp_glu_reconfiguration_model_A') as ckpt_dir_A, \ - TempNamedDir(tmp_path_dist_ckpt / 'test_mlp_glu_reconfiguration_model_B') as ckpt_dir_B: + + with TempNamedDir( + tmp_path_dist_ckpt / 'test_mlp_glu_reconfiguration_model_A' + ) as ckpt_dir_A, TempNamedDir( + tmp_path_dist_ckpt / 'test_mlp_glu_reconfiguration_model_B' + ) as ckpt_dir_B: # Save checkpoint A mlp_A = initialize_mlp() save(mlp_A.sharded_state_dict(sharded_offsets=get_pp_offsets()), ckpt_dir_A) @@ -61,7 +74,9 @@ def test_parallel_reconfiguration_e2e(self, tmp_path_dist_ckpt, src_tp_pp, dest_ # Load checkpoint A with different TP/PP and save as checkpoint B Utils.initialize_model_parallel(*dest_tp_pp) mlp_B = initialize_mlp() - state_dict = load(mlp_B.sharded_state_dict(sharded_offsets=get_pp_offsets()), ckpt_dir_A) + state_dict = load( + mlp_B.sharded_state_dict(sharded_offsets=get_pp_offsets()), ckpt_dir_A + ) mlp_B.load_state_dict(state_dict) save(mlp_B.sharded_state_dict(sharded_offsets=get_pp_offsets()), ckpt_dir_B) Utils.destroy_model_parallel() diff --git a/tests/unit_tests/dist_checkpointing/models/test_retro_model.py b/tests/unit_tests/dist_checkpointing/models/test_retro_model.py index 013543def2..cf972f0c53 100644 --- a/tests/unit_tests/dist_checkpointing/models/test_retro_model.py +++ b/tests/unit_tests/dist_checkpointing/models/test_retro_model.py @@ -18,7 +18,7 @@ def initialize_retro_model(seed, decoder_spec_fn, spec_type, num_layers=9, **con torch.manual_seed(seed) model_parallel_cuda_manual_seed(seed) - default_config_kwargs=dict( + default_config_kwargs = dict( num_layers=num_layers, hidden_size=16, num_attention_heads=12, @@ -35,11 +35,17 @@ def initialize_retro_model(seed, decoder_spec_fn, spec_type, num_layers=9, **con pre_process = ps.is_pipeline_first_stage() post_process = ps.is_pipeline_last_stage() - - de_block_spec = decoder_spec_fn(retro_config, use_transformer_engine=True if spec_type=="te" else False) - model = RetroModel(config=retro_config, transformer_layer_spec=de_block_spec, - pre_process=pre_process, post_process=post_process, - vocab_size=29184, max_sequence_length=4) + de_block_spec = decoder_spec_fn( + retro_config, use_transformer_engine=True if spec_type == "te" else False + ) + model = RetroModel( + config=retro_config, + transformer_layer_spec=de_block_spec, + pre_process=pre_process, + post_process=post_process, + vocab_size=29184, + max_sequence_length=4, + ) with torch.no_grad(): for p in model.parameters(): @@ -50,14 +56,16 @@ def initialize_retro_model(seed, decoder_spec_fn, spec_type, num_layers=9, **con class TestRetroModel: def setup_method(self, method): pass - + def teardown_method(self, method): Utils.destroy_model_parallel() - + @pytest.mark.parametrize('src_spec_type', ['te', 'local']) @pytest.mark.parametrize('dst_spec_type', ['te', 'local']) @pytest.mark.parametrize('model_type', ['retro']) - def test_sharded_state_dict_save_load(self, tmp_path_dist_ckpt, src_spec_type, dst_spec_type, model_type): + def test_sharded_state_dict_save_load( + self, tmp_path_dist_ckpt, src_spec_type, dst_spec_type, model_type + ): decoder_spec_fn = get_retro_decoder_block_spec Utils.initialize_model_parallel(1, 1) @@ -71,7 +79,9 @@ def test_sharded_state_dict_save_load(self, tmp_path_dist_ckpt, src_spec_type, d gpt_model = initialize_retro_model(2, decoder_spec_fn, dst_spec_type) sharded_state_dict = gpt_model.sharded_state_dict() - state_dict, missing_keys, unexpected_keys = load(sharded_state_dict, ckpt_dir, strict=StrictHandling.RETURN_ALL) + state_dict, missing_keys, unexpected_keys = load( + sharded_state_dict, ckpt_dir, strict=StrictHandling.RETURN_ALL + ) # Potential mismatch is because of extra states which is ok assert all('_extra_state' in k for k in missing_keys) assert all('_extra_state' in k for k in unexpected_keys) diff --git a/tests/unit_tests/dist_checkpointing/models/test_sequential_mlp.py b/tests/unit_tests/dist_checkpointing/models/test_sequential_mlp.py index 0bc07298a4..111e982a35 100644 --- a/tests/unit_tests/dist_checkpointing/models/test_sequential_mlp.py +++ b/tests/unit_tests/dist_checkpointing/models/test_sequential_mlp.py @@ -26,6 +26,7 @@ _te_version = packaging.version.Version(version("transformer-engine")) + def initialize_expert_layer(seed, glu=True, moe_grouped_gemm=False, **config_kwargs): torch.manual_seed(seed) model_parallel_cuda_manual_seed(seed) @@ -62,17 +63,19 @@ def get_pp_offsets(): pp_size = parallel_state.get_pipeline_model_parallel_world_size() return ((0, pp_rank, pp_size),) + moe_grouped_gemm_options = [False] if _te_version >= packaging.version.Version("1.9.0.dev0"): moe_grouped_gemm_options.append(True) + class TestExpertLayerReconfiguration: def setup_method(self, method): pass - + def teardown_method(self, method): Utils.destroy_model_parallel() - + @pytest.mark.parametrize( "use_fpsl,src_tp_pp_exp,dest_tp_pp_exp,use_glu", [ @@ -96,7 +99,7 @@ def teardown_method(self, method): def test_parallel_reconfiguration_e2e( self, tmp_path_dist_ckpt, src_tp_pp_exp, dest_tp_pp_exp, use_glu, use_fpsl, moe_grouped_gemm ): - """ Test model saving and loading with different TP/PP/expert parallelism """ + """Test model saving and loading with different TP/PP/expert parallelism""" src_tp, src_pp, src_exp = src_tp_pp_exp dest_tp, dest_pp, dest_exp = dest_tp_pp_exp # Save checkpoint A @@ -180,7 +183,7 @@ def test_parallel_reconfiguration_e2e( def test_sequential_grouped_mlp_interchangeable( self, tmp_path_dist_ckpt, src_tp_pp_exp, dest_tp_pp_exp, use_glu, src_module ): - """ Test model saving and loading with different TP/PP/expert parallelism """ + """Test model saving and loading with different TP/PP/expert parallelism""" src_tp, src_pp, src_exp = src_tp_pp_exp dest_tp, dest_pp, dest_exp = dest_tp_pp_exp # Save checkpoint A @@ -190,7 +193,7 @@ def test_sequential_grouped_mlp_interchangeable( ) as ckpt_dir_A, TempNamedDir( tmp_path_dist_ckpt / 'test_sequential_grouped_mlp_interchangeable_model_B' ) as ckpt_dir_B: - + model_A = initialize_expert_layer( 1, use_glu, moe_grouped_gemm=src_module != 'sequential' ) diff --git a/tests/unit_tests/dist_checkpointing/models/test_t5_model.py b/tests/unit_tests/dist_checkpointing/models/test_t5_model.py index da1ae4b093..07c9f8676a 100644 --- a/tests/unit_tests/dist_checkpointing/models/test_t5_model.py +++ b/tests/unit_tests/dist_checkpointing/models/test_t5_model.py @@ -34,9 +34,14 @@ def initialize_t5_model(seed, encoder_spec_fn, decoder_spec_fn, num_layers=2, ** torch.manual_seed(seed) model_parallel_cuda_manual_seed(seed) - default_config_kwargs=dict( - num_layers=num_layers, hidden_size=16, num_attention_heads=12, kv_channels=64, ffn_hidden_size=64, - use_cpu_initialization=True, pipeline_dtype=torch.bfloat16 + default_config_kwargs = dict( + num_layers=num_layers, + hidden_size=16, + num_attention_heads=12, + kv_channels=64, + ffn_hidden_size=64, + use_cpu_initialization=True, + pipeline_dtype=torch.bfloat16, ) default_config_kwargs.update(**config_kwargs) transformer_config = TransformerConfig(**default_config_kwargs) @@ -45,10 +50,16 @@ def initialize_t5_model(seed, encoder_spec_fn, decoder_spec_fn, num_layers=2, ** en_block_spec = TransformerBlockSubmodules([encoder_spec_fn()] * num_layers) de_block_spec = TransformerBlockSubmodules([decoder_spec_fn()] * num_layers) - model = T5Model(encoder_config=transformer_config, config=transformer_config, - transformer_encoder_layer_spec=en_block_spec, transformer_decoder_layer_spec=de_block_spec, - pre_process=False, post_process=False, - vocab_size=29184, max_sequence_length=4) + model = T5Model( + encoder_config=transformer_config, + config=transformer_config, + transformer_encoder_layer_spec=en_block_spec, + transformer_decoder_layer_spec=de_block_spec, + pre_process=False, + post_process=False, + vocab_size=29184, + max_sequence_length=4, + ) with torch.no_grad(): for p in model.parameters(): @@ -59,14 +70,16 @@ def initialize_t5_model(seed, encoder_spec_fn, decoder_spec_fn, num_layers=2, ** class TestT5Model: def setup_method(self, method): pass - + def teardown_method(self, method): Utils.destroy_model_parallel() - + @pytest.mark.parametrize('src_spec_type', ['te', 'local']) @pytest.mark.parametrize('dst_spec_type', ['te', 'local']) @pytest.mark.parametrize('model_type', ['t5']) - def test_sharded_state_dict_save_load(self, tmp_path_dist_ckpt, src_spec_type, dst_spec_type, model_type): + def test_sharded_state_dict_save_load( + self, tmp_path_dist_ckpt, src_spec_type, dst_spec_type, model_type + ): enc_dec_spec_fn = { 'te': { 't5': (t5_encoder_te_spec, t5_decoder_te_spec), @@ -75,7 +88,7 @@ def test_sharded_state_dict_save_load(self, tmp_path_dist_ckpt, src_spec_type, d 'local': { 't5': (t5_encoder_local_spec, t5_decoder_local_spec), 'retro': (get_retro_encoder_layer_local_spec, get_retro_decoder_layer_local_spec), - } + }, } src_encoder_spec_fn, src_decoder_spec_fn = enc_dec_spec_fn[src_spec_type][model_type] dst_encoder_spec_fn, dst_decoder_spec_fn = enc_dec_spec_fn[dst_spec_type][model_type] @@ -91,7 +104,9 @@ def test_sharded_state_dict_save_load(self, tmp_path_dist_ckpt, src_spec_type, d gpt_model = initialize_t5_model(2, dst_encoder_spec_fn, dst_decoder_spec_fn) sharded_state_dict = gpt_model.sharded_state_dict() - state_dict, missing_keys, unexpected_keys = load(sharded_state_dict, ckpt_dir, strict=StrictHandling.RETURN_ALL) + state_dict, missing_keys, unexpected_keys = load( + sharded_state_dict, ckpt_dir, strict=StrictHandling.RETURN_ALL + ) # Potential mismatch is because of extra states which is ok assert all('_extra_state' in k for k in missing_keys) assert all('_extra_state' in k for k in unexpected_keys) diff --git a/tests/unit_tests/dist_checkpointing/test_async_save.py b/tests/unit_tests/dist_checkpointing/test_async_save.py index 9b8fe0044c..d6aa879982 100644 --- a/tests/unit_tests/dist_checkpointing/test_async_save.py +++ b/tests/unit_tests/dist_checkpointing/test_async_save.py @@ -13,7 +13,6 @@ from tests.unit_tests.test_utilities import Utils - def write_data_os_err_mock_fn(local_proc_idx, write_bucket, results_queue, count_queue, use_fsync): """Raises an error on worker #2 during storage save""" try: @@ -32,8 +31,8 @@ def setup_method(self, method): pass def teardown_method(self, method): - Utils.destroy_model_parallel() - + Utils.destroy_model_parallel() + def test_async_is_equivalent_to_sync(self, tmp_path_dist_ckpt): Utils.initialize_model_parallel(2, 4) diff --git a/tests/unit_tests/dist_checkpointing/test_cached_metadata.py b/tests/unit_tests/dist_checkpointing/test_cached_metadata.py index b1286f01f1..2733ea7a1b 100644 --- a/tests/unit_tests/dist_checkpointing/test_cached_metadata.py +++ b/tests/unit_tests/dist_checkpointing/test_cached_metadata.py @@ -2,7 +2,6 @@ import pickle from copy import deepcopy - from dataclasses import fields import torch @@ -20,8 +19,8 @@ def setup_method(self, method): pass def teardown_method(self, method): - Utils.destroy_model_parallel() - + Utils.destroy_model_parallel() + def test_cached_metadata(self, tmp_path_dist_ckpt): Utils.initialize_model_parallel(2, 4) diff --git a/tests/unit_tests/dist_checkpointing/test_flattened_resharding.py b/tests/unit_tests/dist_checkpointing/test_flattened_resharding.py index 0b64f36e64..fa00a20cad 100644 --- a/tests/unit_tests/dist_checkpointing/test_flattened_resharding.py +++ b/tests/unit_tests/dist_checkpointing/test_flattened_resharding.py @@ -27,21 +27,18 @@ def setup_method(self, method): pass def teardown_method(self, method): - Utils.destroy_model_parallel() - + Utils.destroy_model_parallel() + @pytest.mark.parametrize( - ('src_tp_pp', 'dest_tp_pp',), - [ - ((2, 4), (2, 4)), - ((2, 4), (2, 2)), - ((2, 4), (4, 2)), - ((8, 1), (1, 2)), - ] + ('src_tp_pp', 'dest_tp_pp'), + [((2, 4), (2, 4)), ((2, 4), (2, 2)), ((2, 4), (4, 2)), ((8, 1), (1, 2))], ) def test_partition_change_save_load(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp): Utils.initialize_model_parallel(*src_tp_pp) - with TempNamedDir(tmp_path_dist_ckpt / 'test_flattened_partition_change_save_load') as ckpt_dir: - + with TempNamedDir( + tmp_path_dist_ckpt / 'test_flattened_partition_change_save_load' + ) as ckpt_dir: + state_dict = self._build_state_dict() save(state_dict, ckpt_dir) @@ -57,30 +54,32 @@ def test_partition_change_save_load(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp Utils.destroy_model_parallel() - @pytest.mark.parametrize( ('src_tp_pp', 'dest_tp_pp', 'expected_ckpt_offsets_by_rank'), [ - ((2, 4), (2, 2), { - 0: [(0, 0, 0), (0, 0, 10)], # TP 0, DP 0, PP 0 - 1: [(4, 0, 0), (4, 0, 10)], # TP 1, DP 0, PP 0 - 2: [(0, 0, 0), (0, 0, 10)], # TP 0, DP 1, PP 0 - 3: [(4, 0, 0), (4, 0, 10)], # TP 1, DP 1, PP 0 - 4: [(0, 0, 20), (0, 0, 30)], # TP 0, DP 0, PP 1 - 5: [(4, 0, 20), (4, 0, 30)], # TP 1, DP 0, PP 1 - 6: [(0, 0, 20), (0, 0, 30)], # TP 0, DP 1, PP 1 - 7: [(4, 0, 20), (4, 0, 30)], # TP 1, DP 1, PP 1 - }), - ((8, 1), (1, 2), { - rank: [(tp, 0, 0) for tp in range(8)] - for rank in range(8) - }) - ] + ( + (2, 4), + (2, 2), + { + 0: [(0, 0, 0), (0, 0, 10)], # TP 0, DP 0, PP 0 + 1: [(4, 0, 0), (4, 0, 10)], # TP 1, DP 0, PP 0 + 2: [(0, 0, 0), (0, 0, 10)], # TP 0, DP 1, PP 0 + 3: [(4, 0, 0), (4, 0, 10)], # TP 1, DP 1, PP 0 + 4: [(0, 0, 20), (0, 0, 30)], # TP 0, DP 0, PP 1 + 5: [(4, 0, 20), (4, 0, 30)], # TP 1, DP 0, PP 1 + 6: [(0, 0, 20), (0, 0, 30)], # TP 0, DP 1, PP 1 + 7: [(4, 0, 20), (4, 0, 30)], # TP 1, DP 1, PP 1 + }, + ), + ((8, 1), (1, 2), {rank: [(tp, 0, 0) for tp in range(8)] for rank in range(8)}), + ], ) - def test_reformulate_nd_flattened_tensors(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, expected_ckpt_offsets_by_rank): + def test_reformulate_nd_flattened_tensors( + self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, expected_ckpt_offsets_by_rank + ): Utils.initialize_model_parallel(*src_tp_pp, order='tp-dp-pp') with TempNamedDir(tmp_path_dist_ckpt / 'test_reformulate_nd_flattened_tensors') as ckpt_dir: - + state_dict = self._build_state_dict() ckpt_local_shape = state_dict['sd_key_flat'].local_shape @@ -93,36 +92,38 @@ def test_reformulate_nd_flattened_tensors(self, tmp_path_dist_ckpt, src_tp_pp, d load_state_dict = self._build_state_dict(random=True) reformulation_metadata = get_reformulation_metadata(load_state_dict, ckpt_dir) - reformulated_state_dict, formulation_restore_data = apply_nd_flattened_tensors_reformulation(load_state_dict, reformulation_metadata) + reformulated_state_dict, formulation_restore_data = ( + apply_nd_flattened_tensors_reformulation(load_state_dict, reformulation_metadata) + ) assert isinstance(reformulated_state_dict['sd_key_unflat'], ShardedTensor) assert isinstance(reformulated_state_dict['sd_key_flat'], dict) - assert reformulated_state_dict['sd_key_flat'].keys() == set((offset, ckpt_local_shape) for offset in expected_ckpt_offsets_by_rank[Utils.rank]), \ - (reformulated_state_dict['sd_key_flat'].keys(), ckpt_local_shape, expected_ckpt_offsets_by_rank[Utils.rank]) + assert reformulated_state_dict['sd_key_flat'].keys() == set( + (offset, ckpt_local_shape) for offset in expected_ckpt_offsets_by_rank[Utils.rank] + ), ( + reformulated_state_dict['sd_key_flat'].keys(), + ckpt_local_shape, + expected_ckpt_offsets_by_rank[Utils.rank], + ) # We can even load the reformulated state dict with a high-level API - loaded_state_dict = load(reformulated_state_dict, ckpt_dir, validate_access_integrity=False) - loaded_state_dict = restore_nd_flattened_tensors_formulation(loaded_state_dict, formulation_restore_data) + loaded_state_dict = load( + reformulated_state_dict, ckpt_dir, validate_access_integrity=False + ) + loaded_state_dict = restore_nd_flattened_tensors_formulation( + loaded_state_dict, formulation_restore_data + ) expected_state_dict = {k: v.data for k, v in self._build_state_dict().items()} diffs = diff(expected_state_dict, loaded_state_dict) assert not any(diffs), diffs Utils.destroy_model_parallel() - - @pytest.mark.parametrize( - ('src_tp_pp',), - [ - ((2, 4),), - ((8, 1),), - ((1, 1),), - ((1, 4),), - ] - ) + @pytest.mark.parametrize(('src_tp_pp',), [((2, 4),), ((8, 1),), ((1, 1),), ((1, 4),)]) def test_load_tensor_metadata(self, tmp_path_dist_ckpt, src_tp_pp): Utils.initialize_model_parallel(*src_tp_pp, order='tp-dp-pp') with TempNamedDir(tmp_path_dist_ckpt / 'test_reformulate_nd_flattened_tensors') as ckpt_dir: - + state_dict = self._build_state_dict() save(state_dict, ckpt_dir) @@ -141,7 +142,9 @@ def test_load_tensor_metadata(self, tmp_path_dist_ckpt, src_tp_pp): for sh_ten in sharded_metadata.values(): sh_ten.replica_id = Utils.rank loaded_state_dict = load(sharded_metadata, ckpt_dir) - assert torch.all(loaded_state_dict['unflat'] == torch.arange(8 * 5 * 40).reshape(8, 5, 40)) + assert torch.all( + loaded_state_dict['unflat'] == torch.arange(8 * 5 * 40).reshape(8, 5, 40) + ) assert torch.all(loaded_state_dict['flat'] == torch.arange(8 * 5 * 40)) Utils.destroy_model_parallel() @@ -169,7 +172,7 @@ def _build_state_dict(self, random=False): end_jitter = dp_rank + 1 if dp_rank + 1 < dp_size else 0 local_dp_slice = slice( local_ten_size_by_dp * dp_rank + start_jitter, - local_ten_size_by_dp * (dp_rank + 1) + end_jitter + local_ten_size_by_dp * (dp_rank + 1) + end_jitter, ) local_flat_ten = local_ten.flatten()[local_dp_slice] if dp_rank == dp_size - 1: @@ -191,7 +194,7 @@ def _build_state_dict(self, random=False): local_ten.shape, (0, tp_rank, tp_size), (2, pp_rank, pp_size), - flattened_range=local_dp_slice + flattened_range=local_dp_slice, ), } return state_dict diff --git a/tests/unit_tests/dist_checkpointing/test_fully_parallel.py b/tests/unit_tests/dist_checkpointing/test_fully_parallel.py index f357f1b57d..42eda5d549 100644 --- a/tests/unit_tests/dist_checkpointing/test_fully_parallel.py +++ b/tests/unit_tests/dist_checkpointing/test_fully_parallel.py @@ -34,8 +34,11 @@ def __init__(self): self.save_keys = set() def save(self, sharded_state_dict, ckpt_dir): - self.save_keys = {sh_ten.key for sh_ten in nested_values(sharded_state_dict) - if is_main_replica(sh_ten.replica_id)} + self.save_keys = { + sh_ten.key + for sh_ten in nested_values(sharded_state_dict) + if is_main_replica(sh_ten.replica_id) + } class MockLoadStrategy(LoadShardedStrategy): @@ -45,8 +48,11 @@ def __init__(self, device='cpu'): self.load_keys = set() def load(self, sharded_state_dict, ckpt_dir): - self.load_keys = {sh_ten.key for sh_ten in nested_values(sharded_state_dict) - if is_main_replica(sh_ten.replica_id)} + self.load_keys = { + sh_ten.key + for sh_ten in nested_values(sharded_state_dict) + if is_main_replica(sh_ten.replica_id) + } def load_rand(x): assert isinstance(x, ShardedTensor) @@ -71,21 +77,43 @@ def setup_method(self, method): pass def teardown_method(self, method): - Utils.destroy_model_parallel() - + Utils.destroy_model_parallel() + @staticmethod def get_sharded_state_dict(): return { - 'sd_key_tp_repl1': ShardedTensor.from_rank_offsets('key_TP_repl1', torch.ones(10), - (0, parallel_state.get_tensor_model_parallel_rank(), parallel_state.get_tensor_model_parallel_world_size()), - replica_id=parallel_state.get_data_parallel_rank(with_context_parallel=True)), - 'sd_key_tp_repl2': ShardedTensor.from_rank_offsets('key_TP_repl2', torch.ones(10), - (0, parallel_state.get_tensor_model_parallel_rank(), parallel_state.get_tensor_model_parallel_world_size()), - replica_id=parallel_state.get_data_parallel_rank(with_context_parallel=True)), - 'sd_keyB': ShardedTensor.from_rank_offsets('keyB', torch.ones(20), (0, Utils.rank, Utils.world_size)), - 'sd_keyE_no_C': ShardedTensor.from_rank_offsets('keyC', torch.ones(100), replica_id=Utils.rank), - 'sd_keyX_no_D': ShardedTensor.from_rank_offsets('keyD', torch.ones(1000), replica_id=Utils.rank), - 'sd_keyC_no_E': ShardedTensor.from_rank_offsets('keyE', torch.ones(100), replica_id=Utils.rank), + 'sd_key_tp_repl1': ShardedTensor.from_rank_offsets( + 'key_TP_repl1', + torch.ones(10), + ( + 0, + parallel_state.get_tensor_model_parallel_rank(), + parallel_state.get_tensor_model_parallel_world_size(), + ), + replica_id=parallel_state.get_data_parallel_rank(with_context_parallel=True), + ), + 'sd_key_tp_repl2': ShardedTensor.from_rank_offsets( + 'key_TP_repl2', + torch.ones(10), + ( + 0, + parallel_state.get_tensor_model_parallel_rank(), + parallel_state.get_tensor_model_parallel_world_size(), + ), + replica_id=parallel_state.get_data_parallel_rank(with_context_parallel=True), + ), + 'sd_keyB': ShardedTensor.from_rank_offsets( + 'keyB', torch.ones(20), (0, Utils.rank, Utils.world_size) + ), + 'sd_keyE_no_C': ShardedTensor.from_rank_offsets( + 'keyC', torch.ones(100), replica_id=Utils.rank + ), + 'sd_keyX_no_D': ShardedTensor.from_rank_offsets( + 'keyD', torch.ones(1000), replica_id=Utils.rank + ), + 'sd_keyC_no_E': ShardedTensor.from_rank_offsets( + 'keyE', torch.ones(100), replica_id=Utils.rank + ), } @pytest.mark.parametrize("parallelization_along_dp", [False, True]) @@ -99,7 +127,9 @@ def test_save_distribution(self, parallelization_along_dp, tmp_path_dist_ckpt): # 3. Shard id (key) if not parallelization_along_dp: expected_key_to_saving_ranks = { - 'keyB': list(range(Utils.world_size)), # everyone must save (disjoint shards, coverage == 1) + 'keyB': list( + range(Utils.world_size) + ), # everyone must save (disjoint shards, coverage == 1) 'key_TP_repl1': [0, 1], # lowest coverage (4), first TP domain 'key_TP_repl2': [2, 3], # lowest coverage (4), second TP domain 'keyD': [4], # largest tensor @@ -110,7 +140,11 @@ def test_save_distribution(self, parallelization_along_dp, tmp_path_dist_ckpt): if parallel_state.get_tensor_model_parallel_rank() == 0: expected_key_to_saving_ranks = { # everyone must save (disjoint shards, coverage == 1): - 'keyB': list(range(parallel_state.get_data_parallel_world_size(with_context_parallel=True))), + 'keyB': list( + range( + parallel_state.get_data_parallel_world_size(with_context_parallel=True) + ) + ), # this time, TP sharded tensors have the same coverage as fully replicated! 'keyD': [0], # largest tensor 'keyC': [1], # second largest tensor @@ -121,32 +155,59 @@ def test_save_distribution(self, parallelization_along_dp, tmp_path_dist_ckpt): else: expected_key_to_saving_ranks = { # everyone must save (disjoint shards, coverage == 1): - 'keyB': list(range(parallel_state.get_data_parallel_world_size(with_context_parallel=True))), + 'keyB': list( + range( + parallel_state.get_data_parallel_world_size(with_context_parallel=True) + ) + ), # tensors C, D, E are absent in this DP group 'key_TP_repl1': [0], # smallest tensor 'key_TP_repl2': [1], # smallest tensor, last rank is the least occupied } - parallelization_group = parallel_state.get_data_parallel_group(with_context_parallel=True) if parallelization_along_dp else None + parallelization_group = ( + parallel_state.get_data_parallel_group(with_context_parallel=True) + if parallelization_along_dp + else None + ) dp_rank = torch.distributed.get_rank(parallelization_group) - expected_keys_saved_by_current_rank = {k for k, v in expected_key_to_saving_ranks.items() if dp_rank in v} + expected_keys_saved_by_current_rank = { + k for k, v in expected_key_to_saving_ranks.items() if dp_rank in v + } # Run save and tests mock_strategy = MockSaveStrategy() - save_strategy = FullyParallelSaveStrategyWrapper(mock_strategy, - parallelization_group, - do_cache_distribution=True) + save_strategy = FullyParallelSaveStrategyWrapper( + mock_strategy, parallelization_group, do_cache_distribution=True + ) with TempNamedDir(tmp_path_dist_ckpt / 'mock_dir') as ckpt_dir_A: save_strategy.save(state_dict, ckpt_dir_A) - key_to_saving_rank = dict(map_reduce(save_strategy.cached_distribution.main_rank_for_shard.items(), lambda shard_rank: shard_rank[0][0], lambda shard_rank: shard_rank[1])) + key_to_saving_rank = dict( + map_reduce( + save_strategy.cached_distribution.main_rank_for_shard.items(), + lambda shard_rank: shard_rank[0][0], + lambda shard_rank: shard_rank[1], + ) + ) assert expected_key_to_saving_ranks == key_to_saving_rank for k, sh_ten in state_dict.items(): - if _sharded_tensor_shard_id(sh_ten) in save_strategy.cached_distribution.shards_in_this_group: - is_expected_to_be_saved_by_this_rank = dp_rank in expected_key_to_saving_ranks.get(sh_ten.key, []) - assert sh_ten.replica_id == int(not is_expected_to_be_saved_by_this_rank), expected_key_to_saving_ranks - - assert mock_strategy.save_keys == expected_keys_saved_by_current_rank, (Utils.rank, mock_strategy.save_keys, expected_keys_saved_by_current_rank) + if ( + _sharded_tensor_shard_id(sh_ten) + in save_strategy.cached_distribution.shards_in_this_group + ): + is_expected_to_be_saved_by_this_rank = dp_rank in expected_key_to_saving_ranks.get( + sh_ten.key, [] + ) + assert sh_ten.replica_id == int( + not is_expected_to_be_saved_by_this_rank + ), expected_key_to_saving_ranks + + assert mock_strategy.save_keys == expected_keys_saved_by_current_rank, ( + Utils.rank, + mock_strategy.save_keys, + expected_keys_saved_by_current_rank, + ) @pytest.mark.parametrize("parallelization_along_dp", [False, True]) def test_load_distribution(self, parallelization_along_dp, tmp_path_dist_ckpt): @@ -160,7 +221,9 @@ def test_load_distribution(self, parallelization_along_dp, tmp_path_dist_ckpt): # 3. Shard id (key) if not parallelization_along_dp: expected_key_to_saving_ranks = { - 'keyB': list(range(Utils.world_size)), # everyone must save (disjoint shards, coverage == 1) + 'keyB': list( + range(Utils.world_size) + ), # everyone must save (disjoint shards, coverage == 1) 'key_TP_repl1': [0, 1], # lowest coverage (4), first TP domain 'key_TP_repl2': [2, 3], # lowest coverage (4), second TP domain 'keyD': [4], # largest tensor @@ -171,7 +234,9 @@ def test_load_distribution(self, parallelization_along_dp, tmp_path_dist_ckpt): # When loading, expected key distribution is the same across TP, because every replica needs to be loaded expected_key_to_saving_ranks = { # everyone must load (disjoint shards, coverage == 1): - 'keyB': list(range(parallel_state.get_data_parallel_world_size(with_context_parallel=True))), + 'keyB': list( + range(parallel_state.get_data_parallel_world_size(with_context_parallel=True)) + ), # this time, TP sharded tensors have the same coverage as fully replicated! 'keyD': [0], # largest tensor 'keyC': [1], # second largest tensor @@ -180,21 +245,37 @@ def test_load_distribution(self, parallelization_along_dp, tmp_path_dist_ckpt): 'key_TP_repl2': [3], # smallest tensor, last rank is the least occupied } - parallelization_group = parallel_state.get_data_parallel_group(with_context_parallel=True) if parallelization_along_dp else None + parallelization_group = ( + parallel_state.get_data_parallel_group(with_context_parallel=True) + if parallelization_along_dp + else None + ) dp_rank = torch.distributed.get_rank(parallelization_group) - expected_keys_saved_by_current_rank = {k for k, v in expected_key_to_saving_ranks.items() if dp_rank in v} + expected_keys_saved_by_current_rank = { + k for k, v in expected_key_to_saving_ranks.items() if dp_rank in v + } # Run save and tests mock_strategy = MockLoadStrategy() - load_strategy = FullyParallelLoadStrategyWrapper(mock_strategy, - parallelization_group, - do_cache_distribution=True) + load_strategy = FullyParallelLoadStrategyWrapper( + mock_strategy, parallelization_group, do_cache_distribution=True + ) with TempNamedDir(tmp_path_dist_ckpt / 'mock_dir') as ckpt_dir_A: loaded_state_dict = load_strategy.load(state_dict, ckpt_dir_A) - key_to_saving_rank = dict(map_reduce(load_strategy.cached_distribution.main_rank_for_shard.items(), lambda shard_rank: shard_rank[0][0], lambda shard_rank: shard_rank[1])) + key_to_saving_rank = dict( + map_reduce( + load_strategy.cached_distribution.main_rank_for_shard.items(), + lambda shard_rank: shard_rank[0][0], + lambda shard_rank: shard_rank[1], + ) + ) assert expected_key_to_saving_ranks == key_to_saving_rank - assert mock_strategy.load_keys == expected_keys_saved_by_current_rank, (Utils.rank, mock_strategy.load_keys, expected_keys_saved_by_current_rank) + assert mock_strategy.load_keys == expected_keys_saved_by_current_rank, ( + Utils.rank, + mock_strategy.load_keys, + expected_keys_saved_by_current_rank, + ) assert loaded_state_dict.keys() == state_dict.keys() @@ -220,8 +301,11 @@ def _get_empty_tensor_for_exchange(self, *args, **kwargs) -> torch.Tensor: # Each tensor is 4MB, 40MB in total. # We expect extra memory usage peak at ~32MB, not 1GB sharded_state_dict = { - f'ten_{i}': ShardedTensor.from_rank_offsets(f'ten_{i}', torch.rand(megabytes, dtype=torch.float, device=state_dict_device), - (0, Utils.rank, Utils.world_size)) + f'ten_{i}': ShardedTensor.from_rank_offsets( + f'ten_{i}', + torch.rand(megabytes, dtype=torch.float, device=state_dict_device), + (0, Utils.rank, Utils.world_size), + ) for i in range(10) } @@ -233,6 +317,9 @@ def _get_empty_tensor_for_exchange(self, *args, **kwargs) -> torch.Tensor: # Each rank is expected to do 7 * 10 empty allocations assert len(mem_alloc) == 7 * 10 # Peak mem usage should be within 4MB (single tensor) - assert max(mem_alloc) - mem_alloc_start < 4.01 * megabytes, (max(mem_alloc), mem_alloc_start) + assert max(mem_alloc) - mem_alloc_start < 4.01 * megabytes, ( + max(mem_alloc), + mem_alloc_start, + ) - Utils.destroy_model_parallel() \ No newline at end of file + Utils.destroy_model_parallel() diff --git a/tests/unit_tests/dist_checkpointing/test_mapping.py b/tests/unit_tests/dist_checkpointing/test_mapping.py index ebd0d1ed15..2f986ec1c2 100644 --- a/tests/unit_tests/dist_checkpointing/test_mapping.py +++ b/tests/unit_tests/dist_checkpointing/test_mapping.py @@ -1,16 +1,21 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. import pytest - import torch from megatron.core.dist_checkpointing import ShardedTensor from megatron.core.dist_checkpointing.core import CheckpointingException -from megatron.core.dist_checkpointing.mapping import is_main_replica, \ - ShardedTensorFactory, ShardedObject, apply_factories, apply_factory_merges +from megatron.core.dist_checkpointing.mapping import ( + ShardedObject, + ShardedTensorFactory, + apply_factories, + apply_factory_merges, + is_main_replica, +) from megatron.core.transformer.transformer_config import TransformerConfig from tests.unit_tests.test_utilities import Utils + class TestShardedTensor: # def setup_method(self, method): @@ -20,14 +25,11 @@ class TestShardedTensor: # # def teardown_method(self, method): # Utils.destroy_model_parallel() - + def test_from_rank_offsets_constructor(self, dtype=torch.float, device='cuda'): data = torch.ones((1, 3, 7, 9), dtype=dtype, device=device) shape = data.shape - rank_offsets = [ - (0, 0, 10), - (2, 3, 6) - ] + rank_offsets = [(0, 0, 10), (2, 3, 6)] sh_ten = ShardedTensor.from_rank_offsets('keyA', data, *rank_offsets) assert isinstance(sh_ten, ShardedTensor) @@ -40,13 +42,12 @@ def test_from_rank_offsets_constructor(self, dtype=torch.float, device='cuda'): def test_from_rank_offsets_flat_constructor(self, dtype=torch.float, device='cuda'): data = torch.arange(28, dtype=dtype, device=device).reshape((1, 4, 7)) shape = data.shape - rank_offsets = [ - (1, 0, 2), - (2, 3, 5) - ] + rank_offsets = [(1, 0, 2), (2, 3, 5)] flattened_range = slice(4, 9) flat_data = data.flatten()[flattened_range] - sh_ten = ShardedTensor.from_rank_offsets_flat('keyA', flat_data, data.shape, *rank_offsets, flattened_range=flattened_range) + sh_ten = ShardedTensor.from_rank_offsets_flat( + 'keyA', flat_data, data.shape, *rank_offsets, flattened_range=flattened_range + ) # The main attributes properties are unchanged assert isinstance(sh_ten, ShardedTensor) @@ -60,10 +61,7 @@ def test_from_rank_offsets_flat_constructor(self, dtype=torch.float, device='cud def test_metadata_integrity_violation(self): data = torch.ones((1, 3, 7, 9), device='meta') - rank_offsets = [ - (0, 0, 10), - (2, 3, 6) - ] + rank_offsets = [(0, 0, 10), (2, 3, 6)] sh_ten = ShardedTensor.from_rank_offsets('keyA', data, *rank_offsets) sh_ten.validate_metadata_integrity() with pytest.raises(CheckpointingException): @@ -76,32 +74,40 @@ def test_metadata_integrity_violation(self): sh_ten.validate_metadata_integrity() with pytest.raises(CheckpointingException): - sh_ten = ShardedTensor.from_rank_offsets_flat('keyA', data, data.shape, *rank_offsets, - flattened_range=slice(4, 9)) + sh_ten = ShardedTensor.from_rank_offsets_flat( + 'keyA', data, data.shape, *rank_offsets, flattened_range=slice(4, 9) + ) - sh_ten = ShardedTensor.from_rank_offsets_flat('keyA', data.flatten()[4:9], data.shape, *rank_offsets, - flattened_range=slice(4, 9)) + sh_ten = ShardedTensor.from_rank_offsets_flat( + 'keyA', data.flatten()[4:9], data.shape, *rank_offsets, flattened_range=slice(4, 9) + ) assert sh_ten.local_shape == (1, 3, 7, 9) with pytest.raises(CheckpointingException): sh_ten.local_shape = (5,) sh_ten.validate_metadata_integrity() - class TestShardedTensorFactory: def test_build_and_merge(self): def build_fn(key, tensor, replica_id, flattened_range): assert flattened_range is None return { - 'level2_a': ShardedTensor.from_rank_offsets(key + 'part1', tensor + 1, replica_id=replica_id), - 'level2_b': ShardedTensor.from_rank_offsets(key + 'part2', tensor + 2, replica_id=replica_id) + 'level2_a': ShardedTensor.from_rank_offsets( + key + 'part1', tensor + 1, replica_id=replica_id + ), + 'level2_b': ShardedTensor.from_rank_offsets( + key + 'part2', tensor + 2, replica_id=replica_id + ), } # state_dict will be modified in-place def get_state_dict(): return { - 'level1': ShardedTensorFactory('a', torch.arange(3), build_fn, lambda x: x['level2_b']) + 'level1': ShardedTensorFactory( + 'a', torch.arange(3), build_fn, lambda x: x['level2_b'] + ) } + state_dict = get_state_dict() apply_factories(state_dict) assert torch.allclose(state_dict['level1']['level2_a'].data, torch.tensor([1, 2, 3])) diff --git a/tests/unit_tests/dist_checkpointing/test_nonpersistent.py b/tests/unit_tests/dist_checkpointing/test_nonpersistent.py index 667efddff4..d7907ead1f 100644 --- a/tests/unit_tests/dist_checkpointing/test_nonpersistent.py +++ b/tests/unit_tests/dist_checkpointing/test_nonpersistent.py @@ -2,36 +2,33 @@ import filecmp import os -import pytest from types import SimpleNamespace from unittest import mock +import pytest + from megatron.training.checkpointing import ( _NON_PERSISTENT_CKPT_SUBDIR, load_checkpoint, save_checkpoint, ) from tests.unit_tests.dist_checkpointing import ( + TempNamedDir, init_basic_mock_args, init_checkpointing_mock_args, - TempNamedDir, setup_model_and_optimizer, ) from tests.unit_tests.test_utilities import Utils + class TestNonPersistentSaveAndLoad: def setup_method(self, method): pass def teardown_method(self, method): - Utils.destroy_model_parallel() - - @pytest.mark.parametrize( - ('tp,pp'), - [ - (2, 4), - ] - ) + Utils.destroy_model_parallel() + + @pytest.mark.parametrize(('tp,pp'), [(2, 4)]) def test_basic_save_load_scenarios(self, tmp_path_dist_ckpt, tp, pp): Utils.initialize_model_parallel(tp, pp) num_floating_point_operations_so_far = 0 @@ -60,7 +57,7 @@ def test_basic_save_load_scenarios(self, tmp_path_dist_ckpt, tp, pp): non_persistent_ckpt=True, ) save_checkpoint( - 3, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far, {}, + 3, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far, {} ) save_checkpoint( 4, @@ -74,7 +71,7 @@ def test_basic_save_load_scenarios(self, tmp_path_dist_ckpt, tp, pp): iteration, _ = load_checkpoint(model, optimizer, opt_param_scheduler) assert iteration == 4 save_checkpoint( - 6, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far, {}, + 6, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far, {} ) iteration, _ = load_checkpoint(model, optimizer, opt_param_scheduler) assert iteration == 6 @@ -119,12 +116,7 @@ def test_basic_save_load_scenarios(self, tmp_path_dist_ckpt, tp, pp): class TestLegacySaveAndLoad: - @pytest.mark.parametrize( - ('tp,pp'), - [ - (2, 4), - ] - ) + @pytest.mark.parametrize(('tp,pp'), [(2, 4)]) def test_basic_save_load_scenario(self, tmp_path_dist_ckpt, tp, pp): Utils.initialize_model_parallel(tp, pp) num_floating_point_operations_so_far = 0 @@ -139,7 +131,7 @@ def test_basic_save_load_scenario(self, tmp_path_dist_ckpt, tp, pp): init_checkpointing_mock_args(mock_args, legacy_ckpt_dir) save_checkpoint( - 2, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far, {}, + 2, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far, {} ) iteration, _ = load_checkpoint(model, optimizer, opt_param_scheduler) assert iteration == 2 diff --git a/tests/unit_tests/dist_checkpointing/test_optimizer.py b/tests/unit_tests/dist_checkpointing/test_optimizer.py index 87047b92b4..59577c73fa 100644 --- a/tests/unit_tests/dist_checkpointing/test_optimizer.py +++ b/tests/unit_tests/dist_checkpointing/test_optimizer.py @@ -62,20 +62,25 @@ def sharded_state_dict(self): sharded_state_dict = self.state_dict(keep_vars=True) # conv sharded_state_dict['conv.weight'] = ShardedTensor.from_rank_offsets( - 'conv.weight', sharded_state_dict['conv.weight'], - (1, parallel_state.get_tensor_model_parallel_rank(), parallel_state.get_tensor_model_parallel_world_size()) + 'conv.weight', + sharded_state_dict['conv.weight'], + ( + 1, + parallel_state.get_tensor_model_parallel_rank(), + parallel_state.get_tensor_model_parallel_world_size(), + ), ) # bias is non-sharded - sharded_state_dict['conv.bias'] = ShardedTensor.from_rank_offsets('conv.bias', sharded_state_dict['conv.bias']) + sharded_state_dict['conv.bias'] = ShardedTensor.from_rank_offsets( + 'conv.bias', sharded_state_dict['conv.bias'] + ) # proj sharded_state_dict['proj.weight'] = ShardedTensor.from_rank_offsets( - 'proj.weight', sharded_state_dict['proj.weight'], - (0, Utils.rank, Utils.world_size) + 'proj.weight', sharded_state_dict['proj.weight'], (0, Utils.rank, Utils.world_size) ) sharded_state_dict['proj.bias'] = ShardedTensor.from_rank_offsets( - 'proj.bias', sharded_state_dict['proj.bias'], - (0, Utils.rank, Utils.world_size) + 'proj.bias', sharded_state_dict['proj.bias'], (0, Utils.rank, Utils.world_size) ) return sharded_state_dict @@ -83,34 +88,68 @@ def sharded_state_dict(self): class SwigluFactoryModel(torch.nn.Module): def __init__(self): super().__init__() - self.linear = torch.nn.Linear(5, 64 // parallel_state.get_tensor_model_parallel_world_size(), bias=False) + self.linear = torch.nn.Linear( + 5, 64 // parallel_state.get_tensor_model_parallel_world_size(), bias=False + ) self.config = TransformerConfig(hidden_size=8, num_attention_heads=1, num_layers=1) def sharded_state_dict(self): sharded_state_dict = self.state_dict(keep_vars=True) sharded_state_dict['linear.weight'] = ShardedTensor.from_rank_offsets( - 'linear.weight', sharded_state_dict['linear.weight'], - ((0, parallel_state.get_tensor_model_parallel_rank(), parallel_state.get_tensor_model_parallel_world_size())), - replica_id=((parallel_state.get_pipeline_model_parallel_rank(), 0, parallel_state.get_data_parallel_rank(with_context_parallel=True))) + 'linear.weight', + sharded_state_dict['linear.weight'], + ( + ( + 0, + parallel_state.get_tensor_model_parallel_rank(), + parallel_state.get_tensor_model_parallel_world_size(), + ) + ), + replica_id=( + ( + parallel_state.get_pipeline_model_parallel_rank(), + 0, + parallel_state.get_data_parallel_rank(with_context_parallel=True), + ) + ), + ) + sharded_state_dict['linear.weight'] = apply_swiglu_sharded_factory( + sharded_state_dict['linear.weight'], () ) - sharded_state_dict['linear.weight'] = apply_swiglu_sharded_factory(sharded_state_dict['linear.weight'], ()) return sharded_state_dict class SwigluFactoryModel(torch.nn.Module): def __init__(self): super().__init__() - self.linear = torch.nn.Linear(5, 64 // parallel_state.get_tensor_model_parallel_world_size(), bias=False) + self.linear = torch.nn.Linear( + 5, 64 // parallel_state.get_tensor_model_parallel_world_size(), bias=False + ) self.config = TransformerConfig(hidden_size=8, num_attention_heads=1, num_layers=1) def sharded_state_dict(self): sharded_state_dict = self.state_dict(keep_vars=True) sharded_state_dict['linear.weight'] = ShardedTensor.from_rank_offsets( - 'linear.weight', sharded_state_dict['linear.weight'], - ((0, parallel_state.get_tensor_model_parallel_rank(), parallel_state.get_tensor_model_parallel_world_size())), - replica_id=((parallel_state.get_pipeline_model_parallel_rank(), 0, parallel_state.get_data_parallel_rank(with_context_parallel=True))) + 'linear.weight', + sharded_state_dict['linear.weight'], + ( + ( + 0, + parallel_state.get_tensor_model_parallel_rank(), + parallel_state.get_tensor_model_parallel_world_size(), + ) + ), + replica_id=( + ( + parallel_state.get_pipeline_model_parallel_rank(), + 0, + parallel_state.get_data_parallel_rank(with_context_parallel=True), + ) + ), + ) + sharded_state_dict['linear.weight'] = apply_swiglu_sharded_factory( + sharded_state_dict['linear.weight'], () ) - sharded_state_dict['linear.weight'] = apply_swiglu_sharded_factory(sharded_state_dict['linear.weight'], ()) return sharded_state_dict @@ -119,10 +158,10 @@ def setup_method(self, method): pass def teardown_method(self, method): - Utils.destroy_model_parallel() + Utils.destroy_model_parallel() def test_optimizer_params(self, tmp_path_dist_ckpt): - Utils.initialize_model_parallel(1,1) + Utils.initialize_model_parallel(1, 1) model = Model() # Force optimizer state initialization for p in model.parameters(): @@ -131,18 +170,22 @@ def test_optimizer_params(self, tmp_path_dist_ckpt): optim.step() model_state_dict = model.sharded_state_dict() - param_map = get_param_id_to_sharded_param_map(model_state_dict, optim.param_groups[0]['params']) + param_map = get_param_id_to_sharded_param_map( + model_state_dict, optim.param_groups[0]['params'] + ) optim_state_dict = optim.state_dict() optim_state_to_sharding_state(optim_state_dict, param_map, exclude_keys=('step',)) optim_sharded_tensors = nested_values(extract_sharded_tensors(optim_state_dict)[0]) optim_sharded_keys = {sh_ten.key for sh_ten in optim_sharded_tensors} assert len(optim_sharded_keys) == 2 * len(model_state_dict) - assert optim_sharded_keys == set([ - f'optimizer.state.{state_key}.{layer_name}' - for state_key in ['exp_avg', 'exp_avg_sq'] - for layer_name in model_state_dict - ]) + assert optim_sharded_keys == set( + [ + f'optimizer.state.{state_key}.{layer_name}' + for state_key in ['exp_avg', 'exp_avg_sq'] + for layer_name in model_state_dict + ] + ) def initialize_small_model(pre_process=True, post_process=True, seed=0, **config_kwargs): @@ -163,17 +206,20 @@ def setup_method(self, method): pass def teardown_method(self, method): - Utils.destroy_model_parallel() + Utils.destroy_model_parallel() @pytest.mark.parametrize("initialize_fn", [initialize_small_model, initialize_gpt_model]) @pytest.mark.parametrize("use_fpsl", [False, True]) - @pytest.mark.parametrize("tp_pp,src_dp,dest_dp", [ - ((4, 1), 2, 2), - # ((1, 1), 8, 1), # TODO: changing DP doesn't work in unit tests because of NCCL crashes - # ((1, 1), 1, 8), - # ((2, 1), 2, 1), - # ((2, 1), 2, 2), - ]) + @pytest.mark.parametrize( + "tp_pp,src_dp,dest_dp", + [ + ((4, 1), 2, 2), + # ((1, 1), 8, 1), # TODO: changing DP doesn't work in unit tests because of NCCL crashes + # ((1, 1), 1, 8), + # ((2, 1), 2, 1), + # ((2, 1), 2, 2), + ], + ) def test_dp_sharding(self, tmp_path_dist_ckpt, tp_pp, src_dp, dest_dp, use_fpsl, initialize_fn): src_world_size = tp_pp[0] * tp_pp[1] * src_dp dest_world_size = tp_pp[0] * tp_pp[1] * dest_dp @@ -190,16 +236,24 @@ def test_dp_sharding(self, tmp_path_dist_ckpt, tp_pp, src_dp, dest_dp, use_fpsl, Utils.set_world_size(src_world_size) if Utils.rank >= 0: # Save checkpoint A - model, optimizer_A = setup_model_and_optimizer(seed=2, tp=tp_pp[0], pp=tp_pp[1], initialize_fn=initialize_fn) + model, optimizer_A = setup_model_and_optimizer( + seed=2, tp=tp_pp[0], pp=tp_pp[1], initialize_fn=initialize_fn + ) save_strategy = get_default_save_sharded_strategy() if use_fpsl: save_strategy = FullyParallelSaveStrategyWrapper( save_strategy, parallel_state.get_data_parallel_group(with_context_parallel=True), - True + True, ) - save(optimizer_A.sharded_state_dict(model[0].sharded_state_dict(), sharding_type=sharding_type), ckpt_dir, save_strategy) + save( + optimizer_A.sharded_state_dict( + model[0].sharded_state_dict(), sharding_type=sharding_type + ), + ckpt_dir, + save_strategy, + ) optim_param_state_A = optimizer_A.get_parameter_state_dp_zero() Utils.destroy_model_parallel() else: @@ -213,7 +267,9 @@ def test_dp_sharding(self, tmp_path_dist_ckpt, tp_pp, src_dp, dest_dp, use_fpsl, if Utils.rank >= 0: Utils.initialize_model_parallel(*tp_pp) - model, optimizer_B = setup_model_and_optimizer(seed=3, tp=tp_pp[0], pp=tp_pp[1], initialize_fn=initialize_fn) + model, optimizer_B = setup_model_and_optimizer( + seed=3, tp=tp_pp[0], pp=tp_pp[1], initialize_fn=initialize_fn + ) optim_param_state_B = optimizer_B.get_parameter_state_dp_zero() diffs = diff(optim_param_state_A, optim_param_state_B) # Expect a mismatch in values - diffs[2] nonempty @@ -221,9 +277,7 @@ def test_dp_sharding(self, tmp_path_dist_ckpt, tp_pp, src_dp, dest_dp, use_fpsl, assert not diffs[0] and not diffs[1] and diffs[2], diffs sharded_state_dict = optimizer_B.sharded_state_dict( - model[0].sharded_state_dict(), - is_loading=True, - sharding_type=sharding_type, + model[0].sharded_state_dict(), is_loading=True, sharding_type=sharding_type ) optim_state_dict = load(sharded_state_dict, ckpt_dir) optimizer_B.load_state_dict(optim_state_dict) @@ -241,23 +295,26 @@ def test_dp_sharding(self, tmp_path_dist_ckpt, tp_pp, src_dp, dest_dp, use_fpsl, @pytest.mark.parametrize( ('src_tp_pp', 'dest_tp_pp', 'use_glu'), - [ - ((2, 2), (2, 4), False,), - ((1, 8), (4, 1), True), - ((2, 4), (4, 2), False), - ] + [((2, 2), (2, 4), False), ((1, 8), (4, 1), True), ((2, 4), (4, 2), False)], ) - def test_finetune_doesnt_load_optimizer(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, use_glu): + def test_finetune_doesnt_load_optimizer( + self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, use_glu + ): # sync=True to make sure other ranks wait for rank 0 to finish creating directory. Utils.initialize_model_parallel(*src_tp_pp) - with TempNamedDir(tmp_path_dist_ckpt / 'test_finetune_doesnt_load_optimizer', sync=True) as ckpt_dir: + with TempNamedDir( + tmp_path_dist_ckpt / 'test_finetune_doesnt_load_optimizer', sync=True + ) as ckpt_dir: mock_args = SimpleNamespace() with mock.patch('megatron.training.checkpointing.get_args', new=lambda: mock_args): init_basic_mock_args(mock_args, tp=src_tp_pp[0], pp=src_tp_pp[1]) init_checkpointing_mock_args(mock_args, ckpt_dir, False) model, optimizer = setup_model_and_optimizer( - seed=2, tp=src_tp_pp[0], pp=src_tp_pp[1], initialize_fn=partial(initialize_gpt_model, use_glu=use_glu) + seed=2, + tp=src_tp_pp[0], + pp=src_tp_pp[1], + initialize_fn=partial(initialize_gpt_model, use_glu=use_glu), ) save_checkpoint(10, model, optimizer, None, 0) @@ -265,7 +322,10 @@ def test_finetune_doesnt_load_optimizer(self, tmp_path_dist_ckpt, src_tp_pp, des Utils.initialize_model_parallel(*dest_tp_pp) model, optimizer = setup_model_and_optimizer( - seed=3, tp=dest_tp_pp[0], pp=dest_tp_pp[1], initialize_fn=partial(initialize_gpt_model, use_glu=use_glu) + seed=3, + tp=dest_tp_pp[0], + pp=dest_tp_pp[1], + initialize_fn=partial(initialize_gpt_model, use_glu=use_glu), ) model_unloaded_state_dict = deepcopy(model[0].state_dict()) optim_unloaded_state_dict = deepcopy(optimizer.state_dict()) @@ -291,7 +351,10 @@ def test_finetune_doesnt_load_optimizer(self, tmp_path_dist_ckpt, src_tp_pp, des # ... or `no_load_optim` flag model, optimizer = setup_model_and_optimizer( - seed=3, tp=dest_tp_pp[0], pp=dest_tp_pp[1], initialize_fn=partial(initialize_gpt_model, use_glu=use_glu) + seed=3, + tp=dest_tp_pp[0], + pp=dest_tp_pp[1], + initialize_fn=partial(initialize_gpt_model, use_glu=use_glu), ) mock_args.finetune = False mock_args.no_load_optim = True @@ -299,33 +362,43 @@ def test_finetune_doesnt_load_optimizer(self, tmp_path_dist_ckpt, src_tp_pp, des load_checkpoint_no_arg_checks(model, optimizer, None) ## Model weights should be different, but optimizer state is unchanged - diffs = (diff(model[0].state_dict(), model_unloaded_state_dict)) + diffs = diff(model[0].state_dict(), model_unloaded_state_dict) # diffs[0] and diffs[1] is structural diff, diffs[2] is values diff - we expect only values diff assert not diffs[0] and not diffs[1] and diffs[2] assert not any(diff(optimizer.state_dict(), optim_unloaded_state_dict)) - def test_can_load_deprecated_bucket_space_format(self, tmp_path_dist_ckpt): # sync=True to make sure other ranks wait for rank 0 to finish creating directory. tp = 4 pp = 2 Utils.initialize_model_parallel(tp, pp) - with TempNamedDir(tmp_path_dist_ckpt / 'test_can_load_deprecated_bucket_space_format', sync=True) as ckpt_dir: + with TempNamedDir( + tmp_path_dist_ckpt / 'test_can_load_deprecated_bucket_space_format', sync=True + ) as ckpt_dir: mock_args = SimpleNamespace() with mock.patch('megatron.training.checkpointing.get_args', new=lambda: mock_args): - + init_basic_mock_args(mock_args, tp=tp, pp=pp) init_checkpointing_mock_args(mock_args, ckpt_dir, True) - - model, optimizer = setup_model_and_optimizer(seed=2, tp=tp, pp=pp, initialize_fn=initialize_gpt_model) + + model, optimizer = setup_model_and_optimizer( + seed=2, tp=tp, pp=pp, initialize_fn=initialize_gpt_model + ) # Mock optimizer sharded_state_dict so that it ignores the externally passed sharding_type and uses 'fully_sharded_bucket_space' instead orig_optim_sharded_state_dict_fn = optimizer.sharded_state_dict - def sharded_state_dict_bucket_space(self, *args, sharding_type: str = 'fully_sharded_model_space', **kwargs): - return orig_optim_sharded_state_dict_fn(*args, sharding_type='fully_sharded_bucket_space', **kwargs) - optimizer.sharded_state_dict = MethodType(sharded_state_dict_bucket_space, optimizer) + def sharded_state_dict_bucket_space( + self, *args, sharding_type: str = 'fully_sharded_model_space', **kwargs + ): + return orig_optim_sharded_state_dict_fn( + *args, sharding_type='fully_sharded_bucket_space', **kwargs + ) + + optimizer.sharded_state_dict = MethodType( + sharded_state_dict_bucket_space, optimizer + ) save_checkpoint(10, model, optimizer, None, 0) flag = 0 @@ -348,30 +421,32 @@ def sharded_state_dict_bucket_space(self, *args, sharding_type: str = 'fully_sha load_checkpoint_no_arg_checks(model, optimizer, None) - class TestFP32Optimizer: def setup_method(self, method): pass def teardown_method(self, method): - Utils.destroy_model_parallel() + Utils.destroy_model_parallel() @pytest.mark.parametrize( - ('src_tp_pp', 'dest_tp_pp'), - [ - ((2, 4), (2, 4)), - ((2, 4), (4, 2)), - ((8, 1), (1, 2)), - ] + ('src_tp_pp', 'dest_tp_pp'), [((2, 4), (2, 4)), ((2, 4), (4, 2)), ((8, 1), (1, 2))] ) def test_fp32_optimizer_resharding(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp): # sync=True to make sure other ranks wait for rank 0 to finish creating directory. Utils.initialize_model_parallel(*src_tp_pp) - with TempNamedDir(tmp_path_dist_ckpt / 'test_fp32_optimizer_state_dict_A', sync=True) as ckpt_dir_A: - with TempNamedDir(tmp_path_dist_ckpt / 'test_fp32_optimizer_state_dict_B', sync=True) as ckpt_dir_B: - + with TempNamedDir( + tmp_path_dist_ckpt / 'test_fp32_optimizer_state_dict_A', sync=True + ) as ckpt_dir_A: + with TempNamedDir( + tmp_path_dist_ckpt / 'test_fp32_optimizer_state_dict_B', sync=True + ) as ckpt_dir_B: + model_A, optimizer_A = setup_model_and_optimizer( - seed=2, tp=src_tp_pp[0], pp=src_tp_pp[1], initialize_fn=initialize_small_model, bf16=False + seed=2, + tp=src_tp_pp[0], + pp=src_tp_pp[1], + initialize_fn=initialize_small_model, + bf16=False, ) save(optimizer_A.sharded_state_dict(model_A[0].sharded_state_dict()), ckpt_dir_A) @@ -380,9 +455,15 @@ def test_fp32_optimizer_resharding(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_ # Load checkpoint A with different TP/PP and save as checkpoint B Utils.initialize_model_parallel(*dest_tp_pp) model_B, optimizer_B = setup_model_and_optimizer( - seed=3, tp=dest_tp_pp[0], pp=dest_tp_pp[1], initialize_fn=initialize_small_model, bf16=False + seed=3, + tp=dest_tp_pp[0], + pp=dest_tp_pp[1], + initialize_fn=initialize_small_model, + bf16=False, + ) + load_sharded_state_dict = optimizer_B.sharded_state_dict( + model_B[0].sharded_state_dict() ) - load_sharded_state_dict = optimizer_B.sharded_state_dict(model_B[0].sharded_state_dict()) state_dict = load(load_sharded_state_dict, ckpt_dir_A) optimizer_B.load_state_dict(state_dict) @@ -402,40 +483,47 @@ def setup_method(self, method): pass def teardown_method(self, method): - Utils.destroy_model_parallel() - + Utils.destroy_model_parallel() + @pytest.mark.parametrize( ('use_dist_opt', 'bf16'), ( (False, True), # regular BF16 - (True, True), # DistOpt BF16 + (True, True), # DistOpt BF16 # (False, False), # FP32 - ) + ), ) @pytest.mark.parametrize( - ('src_tp_pp', 'dest_tp_pp',), - [ - ((2, 4), (2, 4)), - ((2, 4), (2, 2)), - ((2, 4), (4, 2)), - ((8, 1), (1, 2)), - ] + ('src_tp_pp', 'dest_tp_pp'), + [((2, 4), (2, 4)), ((2, 4), (2, 2)), ((2, 4), (4, 2)), ((8, 1), (1, 2))], ) @pytest.mark.skip(reason="Tests are flaky and need to be debugged") - def test_optimizer_resharding(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, use_dist_opt, bf16): + def test_optimizer_resharding( + self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, use_dist_opt, bf16 + ): Utils.initialize_model_parallel(*src_tp_pp) - with TempNamedDir(tmp_path_dist_ckpt / 'test_fp32_optimizer_state_dict_A', sync=False) as ckpt_dir_A: - with TempNamedDir(tmp_path_dist_ckpt / 'test_fp32_optimizer_state_dict_B', sync=False) as ckpt_dir_B: - - model_A, optimizer_A = setup_model_and_optimizer(seed=2, tp=src_tp_pp[0], pp=src_tp_pp[1], bf16=bf16, dist_opt=use_dist_opt) + with TempNamedDir( + tmp_path_dist_ckpt / 'test_fp32_optimizer_state_dict_A', sync=False + ) as ckpt_dir_A: + with TempNamedDir( + tmp_path_dist_ckpt / 'test_fp32_optimizer_state_dict_B', sync=False + ) as ckpt_dir_B: + + model_A, optimizer_A = setup_model_and_optimizer( + seed=2, tp=src_tp_pp[0], pp=src_tp_pp[1], bf16=bf16, dist_opt=use_dist_opt + ) save(optimizer_A.sharded_state_dict(model_A[0].sharded_state_dict()), ckpt_dir_A) Utils.destroy_model_parallel() # Load checkpoint A with different TP/PP and save as checkpoint B Utils.initialize_model_parallel(*dest_tp_pp) - model_B, optimizer_B = setup_model_and_optimizer(seed=3, tp=dest_tp_pp[0], pp=dest_tp_pp[1], bf16=bf16, dist_opt=use_dist_opt) - load_sharded_state_dict = optimizer_B.sharded_state_dict(model_B[0].sharded_state_dict()) + model_B, optimizer_B = setup_model_and_optimizer( + seed=3, tp=dest_tp_pp[0], pp=dest_tp_pp[1], bf16=bf16, dist_opt=use_dist_opt + ) + load_sharded_state_dict = optimizer_B.sharded_state_dict( + model_B[0].sharded_state_dict() + ) state_dict = load(load_sharded_state_dict, ckpt_dir_A) optimizer_B.load_state_dict(state_dict) diff --git a/tests/unit_tests/dist_checkpointing/test_serialization.py b/tests/unit_tests/dist_checkpointing/test_serialization.py index 6c625f11d3..19e99de553 100644 --- a/tests/unit_tests/dist_checkpointing/test_serialization.py +++ b/tests/unit_tests/dist_checkpointing/test_serialization.py @@ -9,18 +9,16 @@ from torch.distributed.checkpoint import CheckpointException as PyTCheckpointingException from megatron.core import parallel_state -from megatron.core.dist_checkpointing import ShardedTensor, save, load -from megatron.core.dist_checkpointing.core import CheckpointingException, \ - maybe_load_config +from megatron.core.dist_checkpointing import ShardedTensor, load, save +from megatron.core.dist_checkpointing.core import CheckpointingException, maybe_load_config from megatron.core.dist_checkpointing.dict_utils import diff -from megatron.core.dist_checkpointing.mapping import ShardedTensorFactory, \ - ShardedObject -from megatron.core.dist_checkpointing.serialization import \ - load_tensors_metadata, load_sharded_metadata -from megatron.core.dist_checkpointing.strategies.base import StrategyAction, \ - get_default_strategy +from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedTensorFactory +from megatron.core.dist_checkpointing.serialization import ( + load_sharded_metadata, + load_tensors_metadata, +) +from megatron.core.dist_checkpointing.strategies.base import StrategyAction, get_default_strategy from megatron.core.dist_checkpointing.validation import StrictHandling - from tests.unit_tests.dist_checkpointing import TempNamedDir from tests.unit_tests.test_utilities import Utils @@ -30,18 +28,24 @@ def setup_method(self, method): pass def teardown_method(self, method): - Utils.destroy_model_parallel() + Utils.destroy_model_parallel() def test_single_process_save_load(self, tmp_path_dist_ckpt): - Utils.initialize_model_parallel(1,1) + Utils.initialize_model_parallel(1, 1) sharded_state_dict = { - 'sd_keyA': ShardedTensor.from_rank_offsets('keyA', torch.ones(2, 4), replica_id=Utils.rank), - 'sd_keyB': ShardedTensor.from_rank_offsets('keyB', torch.ones(3, 5, 7), replica_id=Utils.rank), + 'sd_keyA': ShardedTensor.from_rank_offsets( + 'keyA', torch.ones(2, 4), replica_id=Utils.rank + ), + 'sd_keyB': ShardedTensor.from_rank_offsets( + 'keyB', torch.ones(3, 5, 7), replica_id=Utils.rank + ), } # sync=True to make sure other ranks wait for rank 0 to finish creating directory. - with TempNamedDir(tmp_path_dist_ckpt / 'test_single_process_save_load', sync=True) as ckpt_dir: + with TempNamedDir( + tmp_path_dist_ckpt / 'test_single_process_save_load', sync=True + ) as ckpt_dir: save(sharded_state_dict, ckpt_dir) torch.distributed.barrier() @@ -53,23 +57,28 @@ def test_single_process_save_load(self, tmp_path_dist_ckpt): assert not (ckpt_dir / 'sd_keyA').is_dir() load_ssd = { - 'load_sd_keyA': ShardedTensor.from_rank_offsets('keyA', torch.ones(2, 4), replica_id=Utils.rank), + 'load_sd_keyA': ShardedTensor.from_rank_offsets( + 'keyA', torch.ones(2, 4), replica_id=Utils.rank + ) } loaded_state_dict = load(load_ssd, ckpt_dir) - + assert set(loaded_state_dict.keys()) == {'load_sd_keyA'} assert isinstance(loaded_state_dict['load_sd_keyA'], torch.Tensor) assert loaded_state_dict['load_sd_keyA'].shape == (2, 4) Utils.destroy_model_parallel() - def test_multi_process_save(self, tmp_path_dist_ckpt): - Utils.initialize_model_parallel(2,4) + Utils.initialize_model_parallel(2, 4) state_dict = { - 'sd_keyA': ShardedTensor.from_rank_offsets('keyA', torch.ones(2, 4), (0, Utils.rank, Utils.world_size)), - 'sd_keyB': ShardedTensor.from_rank_offsets('keyB', torch.ones(3, 5, 7), (2, Utils.rank, Utils.world_size)), + 'sd_keyA': ShardedTensor.from_rank_offsets( + 'keyA', torch.ones(2, 4), (0, Utils.rank, Utils.world_size) + ), + 'sd_keyB': ShardedTensor.from_rank_offsets( + 'keyB', torch.ones(3, 5, 7), (2, Utils.rank, Utils.world_size) + ), } # sync=True to make sure other ranks wait for rank 0 to finish creating directory. @@ -85,13 +94,16 @@ def test_multi_process_save(self, tmp_path_dist_ckpt): Utils.destroy_model_parallel() - def test_partition_change_save_load(self, tmp_path_dist_ckpt, strategy=None): - Utils.initialize_model_parallel(2,4) + Utils.initialize_model_parallel(2, 4) # ten_a: global shape (2, 4): ten_a_global = torch.tensor([[0, 1, 2, 3], [10, 11, 12, 13]]) - ten_a = torch.zeros(1, 1) + 10 * parallel_state.get_tensor_model_parallel_rank() + parallel_state.get_pipeline_model_parallel_rank() + ten_a = ( + torch.zeros(1, 1) + + 10 * parallel_state.get_tensor_model_parallel_rank() + + parallel_state.get_pipeline_model_parallel_rank() + ) assert ten_a.shape == (1, 1) # ten_b: global shape (4, 5, 80), where (x, y, z) is (100x + z) @@ -100,11 +112,24 @@ def test_partition_change_save_load(self, tmp_path_dist_ckpt, strategy=None): assert ten_b.shape == (4, 5, 10) state_dict = { - 'sd_keyA': ShardedTensor.from_rank_offsets('keyA', ten_a, - (0, parallel_state.get_tensor_model_parallel_rank(), parallel_state.get_tensor_model_parallel_world_size()), - (1, parallel_state.get_pipeline_model_parallel_rank(), parallel_state.get_pipeline_model_parallel_world_size()), - replica_id=0), - 'sd_keyB': ShardedTensor.from_rank_offsets('keyB', ten_b, (2, Utils.rank, Utils.world_size)), + 'sd_keyA': ShardedTensor.from_rank_offsets( + 'keyA', + ten_a, + ( + 0, + parallel_state.get_tensor_model_parallel_rank(), + parallel_state.get_tensor_model_parallel_world_size(), + ), + ( + 1, + parallel_state.get_pipeline_model_parallel_rank(), + parallel_state.get_pipeline_model_parallel_world_size(), + ), + replica_id=0, + ), + 'sd_keyB': ShardedTensor.from_rank_offsets( + 'keyB', ten_b, (2, Utils.rank, Utils.world_size) + ), } ten_a_global_shape = ten_a_global.shape @@ -115,19 +140,21 @@ def test_partition_change_save_load(self, tmp_path_dist_ckpt, strategy=None): assert state_dict['sd_keyB'].global_shape == ten_b_global_shape # sync=True to make sure other ranks wait for rank 0 to finish creating directory. - with TempNamedDir(tmp_path_dist_ckpt / 'test_partition_change_save_load', sync=True) as ckpt_dir: + with TempNamedDir( + tmp_path_dist_ckpt / 'test_partition_change_save_load', sync=True + ) as ckpt_dir: save(state_dict, ckpt_dir, strategy) del ten_a, ten_b # without changing TPxPP, load tensors without any sharding load_sd = { - 'sd_keyA': ShardedTensor.from_rank_offsets('keyA', - torch.empty(ten_a_global_shape), - replica_id=Utils.rank), - 'sd_keyB': ShardedTensor.from_rank_offsets('keyB', - torch.empty(ten_b_global_shape), - replica_id=Utils.rank), + 'sd_keyA': ShardedTensor.from_rank_offsets( + 'keyA', torch.empty(ten_a_global_shape), replica_id=Utils.rank + ), + 'sd_keyB': ShardedTensor.from_rank_offsets( + 'keyB', torch.empty(ten_b_global_shape), replica_id=Utils.rank + ), } loaded_state_dict = load(load_sd, ckpt_dir) @@ -139,27 +166,39 @@ def test_partition_change_save_load(self, tmp_path_dist_ckpt, strategy=None): assert isinstance(ten_b, torch.Tensor) assert ten_b.shape == ten_b_global_shape - assert np.all([ - val == 100 * x + z - for x, x_row in enumerate(ten_b) - for y, y_row in enumerate(x_row) - for z, val in enumerate(y_row) - ]) + assert np.all( + [ + val == 100 * x + z + for x, x_row in enumerate(ten_b) + for y, y_row in enumerate(x_row) + for z, val in enumerate(y_row) + ] + ) del ten_a, ten_b # change TPxPP Utils.destroy_model_parallel() - Utils.initialize_model_parallel(1,2) + Utils.initialize_model_parallel(1, 2) load_sd = { - 'sd_keyA': ShardedTensor.from_rank_offsets('keyA', torch.empty(2, 1), - (1, parallel_state.get_data_parallel_rank(), parallel_state.get_data_parallel_world_size()), - replica_id=parallel_state.get_pipeline_model_parallel_rank()), - 'sd_keyB': ShardedTensor.from_rank_offsets('keyB', torch.empty(5, 80), - (0, Utils.rank // 2, 4), - prepend_axis_num=1, - replica_id=Utils.rank % 2), + 'sd_keyA': ShardedTensor.from_rank_offsets( + 'keyA', + torch.empty(2, 1), + ( + 1, + parallel_state.get_data_parallel_rank(), + parallel_state.get_data_parallel_world_size(), + ), + replica_id=parallel_state.get_pipeline_model_parallel_rank(), + ), + 'sd_keyB': ShardedTensor.from_rank_offsets( + 'keyB', + torch.empty(5, 80), + (0, Utils.rank // 2, 4), + prepend_axis_num=1, + replica_id=Utils.rank % 2, + ), } loaded_state_dict = load(load_sd, ckpt_dir) @@ -168,18 +207,26 @@ def test_partition_change_save_load(self, tmp_path_dist_ckpt, strategy=None): assert isinstance(ten_a, torch.Tensor) assert ten_a.shape == (2, 1) - assert torch.all(ten_a[:, 0] == ten_a_global[:, parallel_state.get_data_parallel_rank()]) + assert torch.all( + ten_a[:, 0] == ten_a_global[:, parallel_state.get_data_parallel_rank()] + ) assert isinstance(ten_b, torch.Tensor) assert ten_b.shape == (5, 10 * 8) - assert torch.all(ten_b == torch.arange(80).unsqueeze(0).expand(5, 80) + Utils.rank // 2 * 100) + assert torch.all( + ten_b == torch.arange(80).unsqueeze(0).expand(5, 80) + Utils.rank // 2 * 100 + ) def test_load_tensors_metadata(self, tmp_path_dist_ckpt): - Utils.initialize_model_parallel(2,4) + Utils.initialize_model_parallel(2, 4) state_dict = { - 'sd_keyA': ShardedTensor.from_rank_offsets('keyA', torch.arange(10) + Utils.rank * 10, (0, Utils.rank, Utils.world_size)), - 'sd_keyB': ShardedTensor.from_rank_offsets('keyB', torch.ones(3, 5, 7), (2, Utils.rank, Utils.world_size)), + 'sd_keyA': ShardedTensor.from_rank_offsets( + 'keyA', torch.arange(10) + Utils.rank * 10, (0, Utils.rank, Utils.world_size) + ), + 'sd_keyB': ShardedTensor.from_rank_offsets( + 'keyB', torch.ones(3, 5, 7), (2, Utils.rank, Utils.world_size) + ), } # sync=True to make sure other ranks wait for rank 0 to finish creating directory. @@ -223,15 +270,27 @@ def _build_fn(key, tensor, replica_id, flattened_range): # state dict can be modified by dist_checkpointing.save, so two copies def get_sharded_state_dict(base=0): - return {'all': [ - ShardedTensor.from_rank_offsets('A', torch.arange(2) + base, replica_id=Utils.rank), - ShardedTensor.from_rank_offsets('B', torch.arange(3) + base, replica_id=Utils.rank), - ShardedTensor.from_rank_offsets('C', torch.arange(4) + base, replica_id=Utils.rank), - ShardedTensorFactory('D', torch.arange(5) + base, _build_fn, sum, replica_id=Utils.rank), - ]} + return { + 'all': [ + ShardedTensor.from_rank_offsets( + 'A', torch.arange(2) + base, replica_id=Utils.rank + ), + ShardedTensor.from_rank_offsets( + 'B', torch.arange(3) + base, replica_id=Utils.rank + ), + ShardedTensor.from_rank_offsets( + 'C', torch.arange(4) + base, replica_id=Utils.rank + ), + ShardedTensorFactory( + 'D', torch.arange(5) + base, _build_fn, sum, replica_id=Utils.rank + ), + ] + } # sync=True to make sure other ranks wait for rank 0 to finish creating directory. - with TempNamedDir(tmp_path_dist_ckpt / 'test_can_mix_sharded_tensors_and_factories', sync=True) as ckpt_dir: + with TempNamedDir( + tmp_path_dist_ckpt / 'test_can_mix_sharded_tensors_and_factories', sync=True + ) as ckpt_dir: save(get_sharded_state_dict(0), ckpt_dir) loaded_state_dict = load(get_sharded_state_dict(10), ckpt_dir) @@ -282,16 +341,22 @@ def test_sharded_object_serialization(self, tmp_path_dist_ckpt): state = {'some': 'dict'} state_serialized = io.BytesIO() torch.save(state, state_serialized) - state_dict = {'some_key': ShardedObject('sh_obj_A', state_serialized, (1,), (0,), - replica_id=Utils.rank)} + state_dict = { + 'some_key': ShardedObject( + 'sh_obj_A', state_serialized, (1,), (0,), replica_id=Utils.rank + ) + } save(state_dict, ckpt_dir) del state, state_serialized, state_dict other_state = {'other': 'dictionary'} other_serialized = io.BytesIO() torch.save(other_state, other_serialized) - state_dict = {'other_key': ShardedObject('sh_obj_A', other_serialized, (1,), (0,), - replica_id=Utils.rank)} + state_dict = { + 'other_key': ShardedObject( + 'sh_obj_A', other_serialized, (1,), (0,), replica_id=Utils.rank + ) + } load_state_dict = load(state_dict, ckpt_dir) assert 'other_key' in load_state_dict load_state_dict['other_key'].seek(0) @@ -302,15 +367,18 @@ def test_sharded_object_serialization(self, tmp_path_dist_ckpt): Utils.destroy_model_parallel() def test_tensor_shape_mismatch(self, tmp_path_dist_ckpt): - Utils.initialize_model_parallel(2,4) + Utils.initialize_model_parallel(2, 4) # Global tensor is just a range(32) repeated twice over the first dimension local_tensor = torch.arange(4).unsqueeze(0).expand(2, 4) + Utils.rank * 4 state_dict = { - 'rigid': ShardedTensor.from_rank_offsets('keyA', local_tensor, (1, Utils.rank, Utils.world_size)), - 'flexible': ShardedTensor.from_rank_offsets('keyB', local_tensor, (1, Utils.rank, Utils.world_size), - allow_shape_mismatch=True), + 'rigid': ShardedTensor.from_rank_offsets( + 'keyA', local_tensor, (1, Utils.rank, Utils.world_size) + ), + 'flexible': ShardedTensor.from_rank_offsets( + 'keyB', local_tensor, (1, Utils.rank, Utils.world_size), allow_shape_mismatch=True + ), } assert state_dict['rigid'].global_shape == (2, 32) assert state_dict['flexible'].global_shape == (2, 32) @@ -325,28 +393,45 @@ def test_tensor_shape_mismatch(self, tmp_path_dist_ckpt): # Smaller coverage than expected (28 < 32) state_dict = { - 'rigid': ShardedTensor.from_rank_offsets('keyA', torch.ones(2, 7), (1, pp_rank, pp_size), replica_id=tp_rank), + 'rigid': ShardedTensor.from_rank_offsets( + 'keyA', torch.ones(2, 7), (1, pp_rank, pp_size), replica_id=tp_rank + ) } with pytest.raises((CheckpointingException, PyTCheckpointingException)): load(state_dict, ckpt_dir) state_dict = { - 'flexible': ShardedTensor.from_rank_offsets('keyB', torch.ones(2, 7), (1, pp_rank, pp_size), replica_id=tp_rank, - allow_shape_mismatch=True), + 'flexible': ShardedTensor.from_rank_offsets( + 'keyB', + torch.ones(2, 7), + (1, pp_rank, pp_size), + replica_id=tp_rank, + allow_shape_mismatch=True, + ) } loaded_state_dict = load(state_dict, ckpt_dir) - assert torch.all(loaded_state_dict['flexible'] == torch.arange(7).unsqueeze(0).expand(2, 7) + pp_rank * 7) + assert torch.all( + loaded_state_dict['flexible'] + == torch.arange(7).unsqueeze(0).expand(2, 7) + pp_rank * 7 + ) # Larger coverage than expected (36 > 32) state_dict = { - 'rigid': ShardedTensor.from_rank_offsets('keyA', torch.ones(2, 9), (1, pp_rank, pp_size), replica_id=tp_rank), + 'rigid': ShardedTensor.from_rank_offsets( + 'keyA', torch.ones(2, 9), (1, pp_rank, pp_size), replica_id=tp_rank + ) } with pytest.raises((CheckpointingException, PyTCheckpointingException)): load(state_dict, ckpt_dir) state_dict = { - 'flexible': ShardedTensor.from_rank_offsets('keyB', torch.ones(2, 9), (1, pp_rank, pp_size), replica_id=tp_rank, - allow_shape_mismatch=True), + 'flexible': ShardedTensor.from_rank_offsets( + 'keyB', + torch.ones(2, 9), + (1, pp_rank, pp_size), + replica_id=tp_rank, + allow_shape_mismatch=True, + ) } loaded_state_dict = load(state_dict, ckpt_dir) expected_tensor = torch.arange(9).unsqueeze(0).expand(2, 9) + pp_rank * 9 @@ -369,25 +454,44 @@ def teardown_method(self, method): def _get_base_state_dict(self): return { 'TenA': ShardedTensor.from_rank_offsets('TenA', torch.arange(2), replica_id=Utils.rank), - 'TenB': ShardedTensor.from_rank_offsets('TenB', torch.arange(3), (0, Utils.rank, Utils.world_size), replica_id=0), - 'TenC': ShardedTensor.from_rank_offsets('TenC', torch.arange(3), replica_id=Utils.world_size - Utils.rank - 1), + 'TenB': ShardedTensor.from_rank_offsets( + 'TenB', torch.arange(3), (0, Utils.rank, Utils.world_size), replica_id=0 + ), + 'TenC': ShardedTensor.from_rank_offsets( + 'TenC', torch.arange(3), replica_id=Utils.world_size - Utils.rank - 1 + ), 'ObjA': ShardedObject('ObjA', list(range(10)), (1,), (0,), replica_id=Utils.rank), - 'ObjB': ShardedObject('ObjB', {Utils.rank + 7}, (1, Utils.world_size), (0, Utils.rank), replica_id=0), + 'ObjB': ShardedObject( + 'ObjB', {Utils.rank + 7}, (1, Utils.world_size), (0, Utils.rank), replica_id=0 + ), } @pytest.mark.parametrize('save_format', ['zarr', 'torch_dist']) @pytest.mark.parametrize('validate_integrity', [True, False]) - def test_unexpected_keys_handling_during_validation(self, caplog, tmp_path_dist_ckpt, validate_integrity, save_format): + def test_unexpected_keys_handling_during_validation( + self, caplog, tmp_path_dist_ckpt, validate_integrity, save_format + ): sharded_state_dict = self._get_base_state_dict() - with TempNamedDir(tmp_path_dist_ckpt / 'test_unexpected_keys_raises_error_during_validation') as ckpt_dir: + with TempNamedDir( + tmp_path_dist_ckpt / 'test_unexpected_keys_raises_error_during_validation' + ) as ckpt_dir: save_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, save_format, 1) save(sharded_state_dict, ckpt_dir, save_strategy) def load_with_flag(strict): sharded_state_dict = self._get_base_state_dict() - sharded_state_dict['TenD'] = ShardedTensor.from_rank_offsets('UnexpectedTenD', torch.arange(3), replica_id=Utils.rank) - sharded_state_dict['ObjD'] = ShardedObject('UnexpectedObjD', None, (1,), (0,), replica_id=Utils.rank) - return load(sharded_state_dict, ckpt_dir, validate_access_integrity=validate_integrity, strict=strict) + sharded_state_dict['TenD'] = ShardedTensor.from_rank_offsets( + 'UnexpectedTenD', torch.arange(3), replica_id=Utils.rank + ) + sharded_state_dict['ObjD'] = ShardedObject( + 'UnexpectedObjD', None, (1,), (0,), replica_id=Utils.rank + ) + return load( + sharded_state_dict, + ckpt_dir, + validate_access_integrity=validate_integrity, + strict=strict, + ) def test_error(error_msg): assert 'Unexpected keys' in error_msg @@ -396,7 +500,9 @@ def test_error(error_msg): assert 'Missing keys' not in error_msg # ASSUME_OK_UNEXPECTED results in an exception raised by the underlying strategy - with pytest.raises(PyTCheckpointingException if save_format == 'torch_dist' else CheckpointingException) as exc_info: + with pytest.raises( + PyTCheckpointingException if save_format == 'torch_dist' else CheckpointingException + ) as exc_info: load_with_flag(StrictHandling.ASSUME_OK_UNEXPECTED) # Informative exceptions with `RAISE_*` options: with pytest.raises(CheckpointingException) as exc_info: @@ -417,11 +523,15 @@ def test_error(error_msg): test_error(caplog.text) # Returned mismatches - loaded_state_dict, missing_keys, unexpected_keys = load_with_flag(StrictHandling.RETURN_UNEXPECTED) + loaded_state_dict, missing_keys, unexpected_keys = load_with_flag( + StrictHandling.RETURN_UNEXPECTED + ) assert 'TenA' in loaded_state_dict assert unexpected_keys == {'UnexpectedTenD', 'UnexpectedObjD'} assert missing_keys == set() - loaded_state_dict, missing_keys, unexpected_keys = load_with_flag(StrictHandling.RETURN_ALL) + loaded_state_dict, missing_keys, unexpected_keys = load_with_flag( + StrictHandling.RETURN_ALL + ) assert 'TenA' in loaded_state_dict assert unexpected_keys == {'UnexpectedTenD', 'UnexpectedObjD'} assert missing_keys == set() @@ -432,9 +542,13 @@ def test_error(error_msg): @pytest.mark.parametrize('save_format', ['zarr', 'torch_dist']) @pytest.mark.parametrize('validate_integrity', [True, False]) - def test_missing_keys_raises_error_during_validation(self, caplog, tmp_path_dist_ckpt, validate_integrity, save_format): + def test_missing_keys_raises_error_during_validation( + self, caplog, tmp_path_dist_ckpt, validate_integrity, save_format + ): sharded_state_dict = self._get_base_state_dict() - with TempNamedDir(tmp_path_dist_ckpt / 'test_missing_keys_raises_error_during_validation') as ckpt_dir: + with TempNamedDir( + tmp_path_dist_ckpt / 'test_missing_keys_raises_error_during_validation' + ) as ckpt_dir: save_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, save_format, 1) save(sharded_state_dict, ckpt_dir, save_strategy) @@ -442,7 +556,12 @@ def load_with_flag(strict): sharded_state_dict = self._get_base_state_dict() del sharded_state_dict['TenA'] del sharded_state_dict['ObjB'] - return load(sharded_state_dict, ckpt_dir, validate_access_integrity=validate_integrity, strict=strict) + return load( + sharded_state_dict, + ckpt_dir, + validate_access_integrity=validate_integrity, + strict=strict, + ) def test_error(error_msg): assert 'Unexpected keys' not in error_msg @@ -459,10 +578,15 @@ def test_error(error_msg): with caplog.at_level(logging.WARNING): loaded_state_dict = load_with_flag(StrictHandling.LOG_UNEXPECTED) - assert caplog.text == '' or '`zarr` distributed checkpoint backend is deprecated' in caplog.text + assert ( + caplog.text == '' + or '`zarr` distributed checkpoint backend is deprecated' in caplog.text + ) assert 'TenB' in loaded_state_dict - loaded_state_dict, missing_keys, unexpected_keys = load_with_flag(StrictHandling.RETURN_UNEXPECTED) + loaded_state_dict, missing_keys, unexpected_keys = load_with_flag( + StrictHandling.RETURN_UNEXPECTED + ) assert 'TenB' in loaded_state_dict assert missing_keys == set() assert unexpected_keys == set() @@ -482,7 +606,9 @@ def test_error(error_msg): test_error(caplog.text) # Returned mismatches - loaded_state_dict, missing_keys, unexpected_keys = load_with_flag(StrictHandling.RETURN_ALL) + loaded_state_dict, missing_keys, unexpected_keys = load_with_flag( + StrictHandling.RETURN_ALL + ) assert 'TenB' in loaded_state_dict assert unexpected_keys == set() assert missing_keys == {'TenA', 'ObjB'} @@ -497,7 +623,12 @@ def test_exact_load_handling(self, caplog, tmp_path_dist_ckpt, validate_integrit def load_with_flag(strict): sharded_state_dict = self._get_base_state_dict() - return load(sharded_state_dict, ckpt_dir, validate_access_integrity=validate_integrity, strict=strict) + return load( + sharded_state_dict, + ckpt_dir, + validate_access_integrity=validate_integrity, + strict=strict, + ) for strict in ( StrictHandling.ASSUME_OK_UNEXPECTED, @@ -509,17 +640,20 @@ def load_with_flag(strict): ): with caplog.at_level(logging.WARNING): loaded_state_dict = load_with_flag(strict) - assert caplog.text == '' or '`zarr` distributed checkpoint backend is deprecated' in caplog.text + assert ( + caplog.text == '' + or '`zarr` distributed checkpoint backend is deprecated' in caplog.text + ) assert 'TenB' in loaded_state_dict assert 'ObjB' in loaded_state_dict - for strict in ( - StrictHandling.RETURN_UNEXPECTED, - StrictHandling.RETURN_ALL, - ): + for strict in (StrictHandling.RETURN_UNEXPECTED, StrictHandling.RETURN_ALL): with caplog.at_level(logging.WARNING): loaded_state_dict, missing_keys, unexpected_keys = load_with_flag(strict) - assert caplog.text == '' or '`zarr` distributed checkpoint backend is deprecated' in caplog.text + assert ( + caplog.text == '' + or '`zarr` distributed checkpoint backend is deprecated' in caplog.text + ) assert 'TenB' in loaded_state_dict assert 'ObjB' in loaded_state_dict assert missing_keys == set() @@ -534,9 +668,17 @@ def test_sharded_metadata(self, tmp_path_dist_ckpt, save_format): save(sharded_state_dict, ckpt_dir, save_strategy) torch.distributed.barrier() sharded_metadata = load_sharded_metadata(ckpt_dir) - assert set(sh_base.key for sh_base in sharded_metadata.values()) == {'TenA', 'TenB', 'TenC', 'ObjA', 'ObjB'} + assert set(sh_base.key for sh_base in sharded_metadata.values()) == { + 'TenA', + 'TenB', + 'TenC', + 'ObjA', + 'ObjB', + } assert set(sharded_metadata.keys()) == { - 'TenA', 'TenB', 'TenC', + 'TenA', + 'TenB', + 'TenC', 'ObjA/shard_0_1', *(f'ObjB/shard_0.{i}_1.8' for i in range(8)), } diff --git a/tests/unit_tests/dist_checkpointing/utils.py b/tests/unit_tests/dist_checkpointing/utils.py index 5b2b4aa3eb..c4532b7f4a 100644 --- a/tests/unit_tests/dist_checkpointing/utils.py +++ b/tests/unit_tests/dist_checkpointing/utils.py @@ -3,6 +3,7 @@ from unittest import mock import torch + from megatron.core.models.gpt import GPTModel from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer @@ -16,7 +17,9 @@ NUM_ATTENTION_HEADS = 8 -def initialize_gpt_model(pre_process=True, post_process=True, seed=0, use_glu=True, **config_kwargs): +def initialize_gpt_model( + pre_process=True, post_process=True, seed=0, use_glu=True, **config_kwargs +): torch.manual_seed(seed) model_parallel_cuda_manual_seed(seed) @@ -59,6 +62,7 @@ def init_basic_mock_args(args, tp, pp, bf16=True): args.pipeline_model_parallel_size = pp return args + def init_checkpointing_mock_args(args, ckpt_dir, fully_parallel=False): args.non_persistent_global_ckpt_dir = None args.non_persistent_ckpt_type = None @@ -90,15 +94,28 @@ def init_checkpointing_mock_args(args, ckpt_dir, fully_parallel=False): args.hidden_size = HIDDEN_SIZE args.num_attention_heads = NUM_ATTENTION_HEADS -def setup_model_and_optimizer(seed, tp, pp, initialize_fn=initialize_gpt_model, bf16=True, dist_opt=True): + +def setup_model_and_optimizer( + seed, tp, pp, initialize_fn=initialize_gpt_model, bf16=True, dist_opt=True +): mock_args = SimpleNamespace() with mock.patch('megatron.training.training.get_args', new=lambda: mock_args): init_basic_mock_args(mock_args, tp, pp, bf16=bf16) - model = get_model(partial( - initialize_fn, seed=seed, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp, pipeline_dtype=torch.bfloat16 - )) + model = get_model( + partial( + initialize_fn, + seed=seed, + tensor_model_parallel_size=tp, + pipeline_model_parallel_size=pp, + pipeline_dtype=torch.bfloat16, + ) + ) - config = OptimizerConfig(bf16=bf16, params_dtype=torch.bfloat16 if bf16 else torch.float, use_distributed_optimizer=dist_opt) + config = OptimizerConfig( + bf16=bf16, + params_dtype=torch.bfloat16 if bf16 else torch.float, + use_distributed_optimizer=dist_opt, + ) optimizer = get_megatron_optimizer(config, model) torch.manual_seed(seed + 1) diff --git a/tests/unit_tests/distributed/test_param_and_grad_buffer.py b/tests/unit_tests/distributed/test_param_and_grad_buffer.py index 14d3be7071..f070303177 100644 --- a/tests/unit_tests/distributed/test_param_and_grad_buffer.py +++ b/tests/unit_tests/distributed/test_param_and_grad_buffer.py @@ -1,11 +1,12 @@ import contextlib import math + import pytest import torch from megatron.core import parallel_state from megatron.core.distributed import DistributedDataParallelConfig, ParamAndGradBuffer -from tests.unit_tests.test_utilities import Utils, TestModel +from tests.unit_tests.test_utilities import TestModel, Utils def get_model_and_buffers( diff --git a/tests/unit_tests/fusions/test_torch_softmax.py b/tests/unit_tests/fusions/test_torch_softmax.py index 504bb0b48d..63b0bc7b5d 100644 --- a/tests/unit_tests/fusions/test_torch_softmax.py +++ b/tests/unit_tests/fusions/test_torch_softmax.py @@ -19,10 +19,10 @@ def setup_method(self, method): softmax_in_fp32=True, scale=None, ) - + def teardown_method(self): - get_default_causal_mask.cache_clear() - + get_default_causal_mask.cache_clear() + def test_output_shape(self): x = torch.randn(8, 2, 4, 4, device="cuda") y = self.softmax(x, None) diff --git a/tests/unit_tests/inference/engines/test_mcore_engine.py b/tests/unit_tests/inference/engines/test_mcore_engine.py index 1c8568feea..161284ceeb 100644 --- a/tests/unit_tests/inference/engines/test_mcore_engine.py +++ b/tests/unit_tests/inference/engines/test_mcore_engine.py @@ -1,52 +1,72 @@ +import random +import string from typing import List -from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig +from unittest import mock + import torch -import random -import string from megatron.core.inference.common_inference_params import CommonInferenceParams from megatron.core.inference.engines.mcore_engine import MCoreEngine -from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper from megatron.core.inference.inference_request import InferenceRequest, Status -from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import SimpleTextGenerationController +from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( + GPTInferenceWrapper, +) +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( + InferenceWrapperConfig, +) +from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import ( + SimpleTextGenerationController, +) from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.transformer_config import TransformerConfig from tests.unit_tests.test_utilities import Utils -from unittest import mock + class TestMCoreEngine: def setup_method(self, method): - Utils.initialize_model_parallel(tensor_model_parallel_size=1,pipeline_model_parallel_size=1) - model_parallel_cuda_manual_seed(123) + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + model_parallel_cuda_manual_seed(123) self.batch_size = 4 self.hidden_size = 12 self.vocab_size = 100 self.sequence_length = 64 - transformer_config = TransformerConfig(num_layers=4, hidden_size=self.hidden_size, num_attention_heads=4, use_cpu_initialization=True) - + transformer_config = TransformerConfig( + num_layers=4, + hidden_size=self.hidden_size, + num_attention_heads=4, + use_cpu_initialization=True, + ) + gpt_model = GPTModel( - config=transformer_config, - transformer_layer_spec=get_gpt_layer_local_spec(), - vocab_size=self.vocab_size, - max_sequence_length=self.sequence_length, - parallel_output = True).cuda() + config=transformer_config, + transformer_layer_spec=get_gpt_layer_local_spec(), + vocab_size=self.vocab_size, + max_sequence_length=self.sequence_length, + parallel_output=True, + ).cuda() inference_wrapper_config = InferenceWrapperConfig( hidden_size=self.hidden_size, inference_batch_times_seqlen_threshold=400, fp32_residual_connection=False, params_dtype=torch.float, - padded_vocab_size=self.vocab_size + padded_vocab_size=self.vocab_size, ) inference_wrapped_model = GPTInferenceWrapper(gpt_model, inference_wrapper_config) self.mock_tokenizer = mock.Mock() - text_generation_controller = SimpleTextGenerationController(inference_wrapped_model=inference_wrapped_model, tokenizer=self.mock_tokenizer) + text_generation_controller = SimpleTextGenerationController( + inference_wrapped_model=inference_wrapped_model, tokenizer=self.mock_tokenizer + ) + + self.mcore_engine = MCoreEngine( + text_generation_controller=text_generation_controller, max_batch_size=4 + ) - self.mcore_engine = MCoreEngine(text_generation_controller=text_generation_controller, max_batch_size=4) - def teardown_method(self, method): Utils.destroy_model_parallel() @@ -54,14 +74,22 @@ def test_generate(self): self.mock_tokenizer.vocab_size = self.vocab_size self.mock_tokenizer.eod = self.vocab_size - 1 # Generating random length integer prompts - self.mock_tokenizer.tokenize.return_value = [random.randint(0, self.vocab_size -1) for _ in range(random.randint(5,10))] + self.mock_tokenizer.tokenize.return_value = [ + random.randint(0, self.vocab_size - 1) for _ in range(random.randint(5, 10)) + ] # Generates some random string - self.mock_tokenizer.detokenize.return_value = ''.join(random.choices(string.ascii_letters, k=random.randint(4,10))) + self.mock_tokenizer.detokenize.return_value = ''.join( + random.choices(string.ascii_letters, k=random.randint(4, 10)) + ) - prompts = ["sample"*(i+1) for i in range(self.batch_size)] - results : List[InferenceRequest] = self.mcore_engine.generate(prompts, common_inference_params=CommonInferenceParams(num_tokens_to_generate=10)) + prompts = ["sample" * (i + 1) for i in range(self.batch_size)] + results: List[InferenceRequest] = self.mcore_engine.generate( + prompts, common_inference_params=CommonInferenceParams(num_tokens_to_generate=10) + ) for result in results: - assert result.status == Status.COMPLETED, f"Status should be completed but its {result.status}" - assert result.generated_length > 0 , f"Generated length should be greater than zero" - assert result.generated_text is not None , f'Generated text should not be None' + assert ( + result.status == Status.COMPLETED + ), f"Status should be completed but its {result.status}" + assert result.generated_length > 0, f"Generated length should be greater than zero" + assert result.generated_text is not None, f'Generated text should not be None' diff --git a/tests/unit_tests/inference/model_inference_wrappers/gpt/test_gpt_inference_wrapper.py b/tests/unit_tests/inference/model_inference_wrappers/gpt/test_gpt_inference_wrapper.py index 1f7fb478a3..e01c3f4d17 100644 --- a/tests/unit_tests/inference/model_inference_wrappers/gpt/test_gpt_inference_wrapper.py +++ b/tests/unit_tests/inference/model_inference_wrappers/gpt/test_gpt_inference_wrapper.py @@ -1,83 +1,124 @@ from argparse import Namespace -from megatron.core import parallel_state -from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig + import torch -from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec, get_gpt_layer_with_transformer_engine_spec -from megatron.core.transformer.transformer_config import TransformerConfig + +from megatron.core import parallel_state +from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( + GPTInferenceWrapper, +) +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( + InferenceWrapperConfig, +) +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, +) from megatron.core.models.gpt.gpt_model import GPTModel -from tests.unit_tests.test_utilities import Utils from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.transformer_config import TransformerConfig +from tests.unit_tests.test_utilities import Utils + class TestGPTInferenceWrapper: def setup_model(self, tensor_parallel_size, pipeline_parallel_size): - Utils.initialize_model_parallel(tensor_model_parallel_size=tensor_parallel_size,pipeline_model_parallel_size=pipeline_parallel_size) + Utils.initialize_model_parallel( + tensor_model_parallel_size=tensor_parallel_size, + pipeline_model_parallel_size=pipeline_parallel_size, + ) model_parallel_cuda_manual_seed(123) self.vocab_size = 100 self.batch_size = 4 self.sequence_length = 32 hidden_size = 12 - transformer_config = TransformerConfig(num_layers=4, hidden_size=hidden_size, num_attention_heads=4, use_cpu_initialization=True) - + transformer_config = TransformerConfig( + num_layers=4, + hidden_size=hidden_size, + num_attention_heads=4, + use_cpu_initialization=True, + ) + gpt_model = GPTModel( - config=transformer_config, - transformer_layer_spec=get_gpt_layer_local_spec(), - vocab_size=self.vocab_size, - max_sequence_length=self.sequence_length, - parallel_output = True).cuda() + config=transformer_config, + transformer_layer_spec=get_gpt_layer_local_spec(), + vocab_size=self.vocab_size, + max_sequence_length=self.sequence_length, + parallel_output=True, + ).cuda() inference_wrapper_config = InferenceWrapperConfig( hidden_size=hidden_size, inference_batch_times_seqlen_threshold=20, fp32_residual_connection=False, params_dtype=torch.float, - padded_vocab_size=self.vocab_size + padded_vocab_size=self.vocab_size, ) self.inference_wrapped_model = GPTInferenceWrapper(gpt_model, inference_wrapper_config) + def teardown_method(self, method): Utils.destroy_model_parallel() - - # This will call the inference_wrapped_model.forward_pass_with_pipeline_parallel_small_input_batch() + + # This will call the inference_wrapped_model.forward_pass_with_pipeline_parallel_small_input_batch() def test_inference_pipeline_parallel_small_size(self): self.setup_model(tensor_parallel_size=2, pipeline_parallel_size=2) - - batch_prompt_tokens = torch.randint(low = 0, high = self.vocab_size, size=(self.batch_size, self.sequence_length)).int().cuda() + + batch_prompt_tokens = ( + torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.sequence_length)) + .int() + .cuda() + ) self.inference_wrapped_model.prep_model_for_inference(prompts_tokens=batch_prompt_tokens) - + inference_input = self.inference_wrapped_model.get_batch_for_context_window(0, 5) - + logits = self.inference_wrapped_model.run_one_forward_step(inference_input) # Logits are not returned in all ranks in PP if parallel_state.is_pipeline_last_stage(): - assert logits.shape == (self.batch_size, 5, self.vocab_size), f"Shape mismatch . Expected {(self.batch_size, 5, self.vocab_size)}, but got {logits.shape}" - + assert logits.shape == ( + self.batch_size, + 5, + self.vocab_size, + ), f"Shape mismatch . Expected {(self.batch_size, 5, self.vocab_size)}, but got {logits.shape}" # This will call the inference_wrapped_model.forward_pass_with_pipeline_parallel_large_input_batch() def test_inference_pipeline_parallel_large__size(self): self.setup_model(tensor_parallel_size=2, pipeline_parallel_size=2) - - batch_prompt_tokens = torch.randint(low = 0, high = self.vocab_size, size=(self.batch_size, self.sequence_length)).int().cuda() + + batch_prompt_tokens = ( + torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.sequence_length)) + .int() + .cuda() + ) self.inference_wrapped_model.prep_model_for_inference(prompts_tokens=batch_prompt_tokens) inference_input = self.inference_wrapped_model.get_batch_for_context_window(0, 10) - + logits = self.inference_wrapped_model.run_one_forward_step(inference_input) if parallel_state.is_pipeline_last_stage(): - assert logits.shape == (self.batch_size, 10, self.vocab_size), f"Shape mismatch . Expected {(self.batch_size,10, self.vocab_size)}, but got {logits.shape}" - + assert logits.shape == ( + self.batch_size, + 10, + self.vocab_size, + ), f"Shape mismatch . Expected {(self.batch_size,10, self.vocab_size)}, but got {logits.shape}" def test_inference_only_tensor_parallel(self): self.setup_model(tensor_parallel_size=4, pipeline_parallel_size=1) - - batch_prompt_tokens = torch.randint(low = 0, high = self.vocab_size, size=(self.batch_size, self.sequence_length)).int().cuda() + + batch_prompt_tokens = ( + torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.sequence_length)) + .int() + .cuda() + ) self.inference_wrapped_model.prep_model_for_inference(prompts_tokens=batch_prompt_tokens) inference_input = self.inference_wrapped_model.get_batch_for_context_window(0, 5) logits = self.inference_wrapped_model.run_one_forward_step(inference_input) - - assert logits.shape == (self.batch_size, 5, self.vocab_size), f"Shape mismatch . Expected {(self.batch_size, 5, self.vocab_size)}, but got {logits.shape}" + assert logits.shape == ( + self.batch_size, + 5, + self.vocab_size, + ), f"Shape mismatch . Expected {(self.batch_size, 5, self.vocab_size)}, but got {logits.shape}" diff --git a/tests/unit_tests/inference/model_inference_wrappers/test_model_inference_wrapper_config.py b/tests/unit_tests/inference/model_inference_wrappers/test_model_inference_wrapper_config.py index 5c6f4229c0..e3da997cd4 100644 --- a/tests/unit_tests/inference/model_inference_wrappers/test_model_inference_wrapper_config.py +++ b/tests/unit_tests/inference/model_inference_wrappers/test_model_inference_wrapper_config.py @@ -1,5 +1,9 @@ import torch -from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig + +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( + InferenceWrapperConfig, +) + class TestModelInferenceWrapperConfig: @@ -9,7 +13,9 @@ def test_inference_params(self): inference_batch_times_seqlen_threshold=10, padded_vocab_size=10, params_dtype=torch.float, - fp32_residual_connection=False + fp32_residual_connection=False, ) inference_parameters.add_attributes({"abc": 45}) - assert inference_parameters.abc == 45, f"min tokens not set correctly. it is {inference_parameters.min_tokens}" \ No newline at end of file + assert ( + inference_parameters.abc == 45 + ), f"min tokens not set correctly. it is {inference_parameters.min_tokens}" diff --git a/tests/unit_tests/inference/test_common_inference_params.py b/tests/unit_tests/inference/test_common_inference_params.py index c22a72d326..af51e433df 100644 --- a/tests/unit_tests/inference/test_common_inference_params.py +++ b/tests/unit_tests/inference/test_common_inference_params.py @@ -1,8 +1,11 @@ from megatron.core.inference.common_inference_params import CommonInferenceParams + class TestCommonInferenceParams: def test_inference_params(self): inference_parameters = CommonInferenceParams() inference_parameters.add_attributes({"min_tokens": 45}) - assert inference_parameters.min_tokens == 45, f"min tokens not set correctly. it is {inference_parameters.min_tokens}" \ No newline at end of file + assert ( + inference_parameters.min_tokens == 45 + ), f"min tokens not set correctly. it is {inference_parameters.min_tokens}" diff --git a/tests/unit_tests/inference/test_inference_utils.py b/tests/unit_tests/inference/test_inference_utils.py index 7f0061963e..fc4e69018d 100644 --- a/tests/unit_tests/inference/test_inference_utils.py +++ b/tests/unit_tests/inference/test_inference_utils.py @@ -1,5 +1,6 @@ from megatron.core.inference.utils import Counter + class TestInferenceUtils: def test_counter(self): diff --git a/tests/unit_tests/inference/test_modelopt_gpt_model.py b/tests/unit_tests/inference/test_modelopt_gpt_model.py index 953052c732..380ac7fa16 100644 --- a/tests/unit_tests/inference/test_modelopt_gpt_model.py +++ b/tests/unit_tests/inference/test_modelopt_gpt_model.py @@ -7,7 +7,6 @@ from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.transformer_config import TransformerConfig - from tests.unit_tests.test_utilities import Utils @@ -17,10 +16,7 @@ def setup_method(self, method): Utils.initialize_model_parallel(1, 1) model_parallel_cuda_manual_seed(123) transformer_config = TransformerConfig( - num_layers=2, - hidden_size=12, - num_attention_heads=4, - use_cpu_initialization=True, + num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True ) self.gpt_model = GPTModel( config=transformer_config, diff --git a/tests/unit_tests/inference/test_scheduler.py b/tests/unit_tests/inference/test_scheduler.py index 57e08106d3..b1f0ea184e 100644 --- a/tests/unit_tests/inference/test_scheduler.py +++ b/tests/unit_tests/inference/test_scheduler.py @@ -1,17 +1,26 @@ from typing import Dict + import torch + from megatron.core.inference.common_inference_params import CommonInferenceParams from megatron.core.inference.inference_request import InferenceRequest, Status from megatron.core.inference.scheduler import Scheduler + class TestScheduler: def setup_method(self, method): self.max_batch_size = 4 self.scheduler = Scheduler(max_batch_size=self.max_batch_size) - assert len(self.scheduler.active_request_pool) == 0, "Active request pool should be empty on initalization" - assert len(self.scheduler.waiting_request_pool) == 0, "Waiting request pool should be empty on initalization" - assert len(self.scheduler.completed_request_pool) == 0, "Completed request pool should be empty on initalization" + assert ( + len(self.scheduler.active_request_pool) == 0 + ), "Active request pool should be empty on initalization" + assert ( + len(self.scheduler.waiting_request_pool) == 0 + ), "Waiting request pool should be empty on initalization" + assert ( + len(self.scheduler.completed_request_pool) == 0 + ), "Completed request pool should be empty on initalization" def test_scheduler(self): prompt = "sample prompt" @@ -20,15 +29,23 @@ def test_scheduler(self): for i in range(self.max_batch_size): self.scheduler.add_request(prompt, prompt_tokens, inference_parameters) - assert len(self.scheduler.active_request_pool) == i + 1, f"Active request pool should have {i+1} requests, but it has only {len(self.scheduler.active_request_pool)}" + assert ( + len(self.scheduler.active_request_pool) == i + 1 + ), f"Active request pool should have {i+1} requests, but it has only {len(self.scheduler.active_request_pool)}" self.scheduler.add_request(prompt, prompt_tokens, inference_parameters) - assert len(self.scheduler.waiting_request_pool) == 1, f"Waiting request pool should have 1 request but it has {len(self.scheduler.waiting_request_pool)} requests" - + assert ( + len(self.scheduler.waiting_request_pool) == 1 + ), f"Waiting request pool should have 1 request but it has {len(self.scheduler.waiting_request_pool)} requests" + waiting_request: InferenceRequest = list(self.scheduler.waiting_request_pool.values())[0] - assert waiting_request.status == Status.WAITING_IN_QUEUE, f"Status should be WAITING_IN_QUEUE, but its {waiting_request.status} for the waiting request" + assert ( + waiting_request.status == Status.WAITING_IN_QUEUE + ), f"Status should be WAITING_IN_QUEUE, but its {waiting_request.status} for the waiting request" - assert self.scheduler.have_requests_pending(), "Scheduler should have requests pending, but it seems to be having no requests" + assert ( + self.scheduler.have_requests_pending() + ), "Scheduler should have requests pending, but it seems to be having no requests" active_request_dict: Dict[int, InferenceRequest] = self.scheduler.active_request_pool for request_id, request in active_request_dict.items(): @@ -37,11 +54,17 @@ def test_scheduler(self): request.status = Status.COMPLETED self.scheduler.update_requests_pools(active_request_dict) - assert len(self.scheduler.active_request_pool) == 3, f"Active request pool should have 3 requests, but it has {len(self.scheduler.active_request_pool)}" + assert ( + len(self.scheduler.active_request_pool) == 3 + ), f"Active request pool should have 3 requests, but it has {len(self.scheduler.active_request_pool)}" - assert len(self.scheduler.waiting_request_pool) == 0, f"Waiting request pool should be empty but it has {len(self.scheduler.waiting_request_pool)} requests" + assert ( + len(self.scheduler.waiting_request_pool) == 0 + ), f"Waiting request pool should be empty but it has {len(self.scheduler.waiting_request_pool)} requests" - assert len(self.scheduler.completed_request_pool) == 2, f"Completed request pool should have 2 requests but it has {len(self.scheduler.completed_request_pool)} requests " + assert ( + len(self.scheduler.completed_request_pool) == 2 + ), f"Completed request pool should have 2 requests but it has {len(self.scheduler.completed_request_pool)} requests " active_request_dict: Dict[int, InferenceRequest] = self.scheduler.active_request_pool for request_id, request in active_request_dict.items(): @@ -49,15 +72,18 @@ def test_scheduler(self): request.status = Status.COMPLETED self.scheduler.update_requests_pools(active_request_dict) - assert len(self.scheduler.active_request_pool) == 0, f"Active request pool should be empty, but it has {len(self.scheduler.active_request_pool)}" - - assert len(self.scheduler.waiting_request_pool) == 0, f"Waiting request pool should be empty but it has {len(self.scheduler.waiting_request_pool)} requests" - - assert len(self.scheduler.completed_request_pool) == 5, f"Completed request pool should have 5 requests but it has {len(self.scheduler.completed_request_pool)} requests " - - assert self.scheduler.have_requests_pending() == False, "Scheduler should not have any requests pending" + assert ( + len(self.scheduler.active_request_pool) == 0 + ), f"Active request pool should be empty, but it has {len(self.scheduler.active_request_pool)}" + assert ( + len(self.scheduler.waiting_request_pool) == 0 + ), f"Waiting request pool should be empty but it has {len(self.scheduler.waiting_request_pool)} requests" + assert ( + len(self.scheduler.completed_request_pool) == 5 + ), f"Completed request pool should have 5 requests but it has {len(self.scheduler.completed_request_pool)} requests " - - \ No newline at end of file + assert ( + self.scheduler.have_requests_pending() == False + ), "Scheduler should not have any requests pending" diff --git a/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py index 35b820edd6..a9f15faf80 100644 --- a/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py +++ b/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py @@ -1,118 +1,172 @@ - +import random +import string +import time from collections import OrderedDict from typing import Dict -from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig +from unittest import mock + +import pytest import torch -import random -import string + from megatron.core.inference.common_inference_params import CommonInferenceParams -from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper from megatron.core.inference.inference_request import InferenceRequest, Status -from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import SimpleTextGenerationController +from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( + GPTInferenceWrapper, +) +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( + InferenceWrapperConfig, +) +from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import ( + SimpleTextGenerationController, +) from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.transformer_config import TransformerConfig -from unittest import mock -import pytest -import time +from tests.unit_tests.test_utilities import Utils -from tests.unit_tests.test_utilities import Utils class TestTextGenerationController: def setup_method(self, method): - Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=2) - model_parallel_cuda_manual_seed(123) + Utils.initialize_model_parallel( + tensor_model_parallel_size=2, pipeline_model_parallel_size=2 + ) + model_parallel_cuda_manual_seed(123) self.batch_size = 4 self.hidden_size = 12 self.vocab_size = 100 self.sequence_length = 64 - transformer_config = TransformerConfig(num_layers=4, hidden_size=self.hidden_size, num_attention_heads=4, use_cpu_initialization=True) - + transformer_config = TransformerConfig( + num_layers=4, + hidden_size=self.hidden_size, + num_attention_heads=4, + use_cpu_initialization=True, + ) + gpt_model = GPTModel( - config=transformer_config, - transformer_layer_spec=get_gpt_layer_local_spec(), - vocab_size=self.vocab_size, - max_sequence_length=self.sequence_length, - parallel_output = True).cuda() - + config=transformer_config, + transformer_layer_spec=get_gpt_layer_local_spec(), + vocab_size=self.vocab_size, + max_sequence_length=self.sequence_length, + parallel_output=True, + ).cuda() + inference_wrapper_config = InferenceWrapperConfig( hidden_size=self.hidden_size, inference_batch_times_seqlen_threshold=20, fp32_residual_connection=False, params_dtype=torch.float, - padded_vocab_size=self.vocab_size + padded_vocab_size=self.vocab_size, ) inference_wrapped_model = GPTInferenceWrapper(gpt_model, inference_wrapper_config) self.mock_tokenizer = mock.Mock() - self.text_generation_controller = SimpleTextGenerationController(inference_wrapped_model=inference_wrapped_model, tokenizer=self.mock_tokenizer) - + self.text_generation_controller = SimpleTextGenerationController( + inference_wrapped_model=inference_wrapped_model, tokenizer=self.mock_tokenizer + ) + def teardown_method(self, method): Utils.destroy_model_parallel() def test_sample_from_logits(self): with pytest.raises(AssertionError) as aerror: - self.text_generation_controller.sample_from_logits(last_token_logits=None, common_inference_params=CommonInferenceParams(top_k=2, top_p=0.4), vocab_size=self.vocab_size ) + self.text_generation_controller.sample_from_logits( + last_token_logits=None, + common_inference_params=CommonInferenceParams(top_k=2, top_p=0.4), + vocab_size=self.vocab_size, + ) assert str(aerror.value) == 'Cannot have top-p and top-k both greater than zero' with pytest.raises(AssertionError) as aerror: - self.text_generation_controller.sample_from_logits(last_token_logits=None, common_inference_params=CommonInferenceParams(top_p=1.4, top_k=0), vocab_size=self.vocab_size ) + self.text_generation_controller.sample_from_logits( + last_token_logits=None, + common_inference_params=CommonInferenceParams(top_p=1.4, top_k=0), + vocab_size=self.vocab_size, + ) assert str(aerror.value) == 'top-p should be in (0,1]' with pytest.raises(AssertionError) as aerror: - self.text_generation_controller.sample_from_logits(last_token_logits=torch.randn(self.batch_size, 1), common_inference_params=CommonInferenceParams(top_k = self.vocab_size + 10), vocab_size=self.vocab_size) + self.text_generation_controller.sample_from_logits( + last_token_logits=torch.randn(self.batch_size, 1), + common_inference_params=CommonInferenceParams(top_k=self.vocab_size + 10), + vocab_size=self.vocab_size, + ) assert str(aerror.value) == 'top-k is larger than logit size.' - - last_token_logits = torch.arange(0, self.vocab_size).repeat(self.batch_size,1).float().cuda() - sampled_logits = self.text_generation_controller.sample_from_logits(last_token_logits, CommonInferenceParams(top_k=1), self.vocab_size) - assert torch.all(sampled_logits.cpu() == torch.ones(self.batch_size) * self.vocab_size - 1), f"The sampled logits should all be {self.vocab_size} but its {sampled_logits}" + last_token_logits = ( + torch.arange(0, self.vocab_size).repeat(self.batch_size, 1).float().cuda() + ) + sampled_logits = self.text_generation_controller.sample_from_logits( + last_token_logits, CommonInferenceParams(top_k=1), self.vocab_size + ) + assert torch.all( + sampled_logits.cpu() == torch.ones(self.batch_size) * self.vocab_size - 1 + ), f"The sampled logits should all be {self.vocab_size} but its {sampled_logits}" - sampled_logits = self.text_generation_controller.sample_from_logits(last_token_logits, CommonInferenceParams(top_k=2), self.vocab_size) - assert torch.all(sampled_logits >= self.vocab_size - 2), f"The sampled logits should all be greater than {self.vocab_size-2} but its {sampled_logits}" + sampled_logits = self.text_generation_controller.sample_from_logits( + last_token_logits, CommonInferenceParams(top_k=2), self.vocab_size + ) + assert torch.all( + sampled_logits >= self.vocab_size - 2 + ), f"The sampled logits should all be greater than {self.vocab_size-2} but its {sampled_logits}" l = last_token_logits[0] top_p = 0.3 expected_min_value = l[l.softmax(dim=-1).cumsum(dim=-1) > top_p][0].item() - sampled_logits = self.text_generation_controller.sample_from_logits(last_token_logits, CommonInferenceParams(top_p=top_p, top_k=0), self.vocab_size) - assert torch.all(sampled_logits >= expected_min_value), f"The sampled logits should all be greater than {expected_min_value} but its {sampled_logits}" + sampled_logits = self.text_generation_controller.sample_from_logits( + last_token_logits, CommonInferenceParams(top_p=top_p, top_k=0), self.vocab_size + ) + assert torch.all( + sampled_logits >= expected_min_value + ), f"The sampled logits should all be greater than {expected_min_value} but its {sampled_logits}" top_p = 0.95 - temperature=2 + temperature = 2 expected_min_value = l[l.div_(temperature).softmax(dim=-1).cumsum(dim=-1) > top_p][0].item() - sampled_logits = self.text_generation_controller.sample_from_logits(last_token_logits, CommonInferenceParams(top_p=top_p, temperature=temperature, top_k=0), self.vocab_size) - assert torch.all(sampled_logits >= expected_min_value), f"The sampled logits should all be greater than {expected_min_value} but its {sampled_logits}" - + sampled_logits = self.text_generation_controller.sample_from_logits( + last_token_logits, + CommonInferenceParams(top_p=top_p, temperature=temperature, top_k=0), + self.vocab_size, + ) + assert torch.all( + sampled_logits >= expected_min_value + ), f"The sampled logits should all be greater than {expected_min_value} but its {sampled_logits}" + def test_generate_all_output_tokens_static_batch(self): self.mock_tokenizer.vocab_size = self.vocab_size self.mock_tokenizer.eod = self.vocab_size - 1 - self.mock_tokenizer.detokenize.return_value = ''.join(random.choices(string.ascii_letters, k=random.randint(4,10))) + self.mock_tokenizer.detokenize.return_value = ''.join( + random.choices(string.ascii_letters, k=random.randint(4, 10)) + ) active_requests: Dict[int, InferenceRequest] = OrderedDict() for i in range(self.batch_size): - prompt = "sample" * (i+1) - self.mock_tokenizer.tokenize.return_value = torch.randn(self.batch_size, self.vocab_size).cuda() + prompt = "sample" * (i + 1) + self.mock_tokenizer.tokenize.return_value = torch.randn( + self.batch_size, self.vocab_size + ).cuda() inference_request = InferenceRequest( request_id=i, prompt=prompt, inference_parameters=CommonInferenceParams(num_tokens_to_generate=10), arrival_time=time.time(), - prompt_tokens=torch.randint(low=0, high=self.vocab_size - 1, size=(len(prompt),)).tolist(), - status=Status.ACTIVE_BUT_NOT_GENERATING_TOKENS + prompt_tokens=torch.randint( + low=0, high=self.vocab_size - 1, size=(len(prompt),) + ).tolist(), + status=Status.ACTIVE_BUT_NOT_GENERATING_TOKENS, ) active_requests[i] = inference_request - requests = self.text_generation_controller.generate_all_output_tokens_static_batch(active_requests) - + requests = self.text_generation_controller.generate_all_output_tokens_static_batch( + active_requests + ) + for request_id, request in requests.items(): - assert request.status == Status.COMPLETED, f"Status should be completed but its {request.status}" - assert request.generated_length > 0 , f"Generated length should be greater than zero" + assert ( + request.status == Status.COMPLETED + ), f"Status should be completed but its {request.status}" + assert request.generated_length > 0, f"Generated length should be greater than zero" assert request.generated_text is not None, "Generated text should not be None" - - - - \ No newline at end of file diff --git a/tests/unit_tests/models/test_base_embedding.py b/tests/unit_tests/models/test_base_embedding.py index 511b0262fa..0ce18b3843 100644 --- a/tests/unit_tests/models/test_base_embedding.py +++ b/tests/unit_tests/models/test_base_embedding.py @@ -1,11 +1,10 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. import pytest - import torch -from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from megatron.core.transformer.transformer_config import TransformerConfig from tests.unit_tests.test_utilities import Utils @@ -14,17 +13,21 @@ class TestBaseEmbedding: def setup_method(self, method): Utils.initialize_model_parallel(1, 1) transformer_config = TransformerConfig( - num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True) + num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True + ) self.base_embedding = LanguageModelEmbedding( - config=transformer_config, vocab_size=100, max_sequence_length=4, position_embedding_type='learned_absolute') + config=transformer_config, + vocab_size=100, + max_sequence_length=4, + position_embedding_type='learned_absolute', + ) def teardown_method(self, method): Utils.destroy_model_parallel() def test_constructor(self): assert isinstance(self.base_embedding, LanguageModelEmbedding) - num_weights = sum([p.numel() - for p in self.base_embedding.parameters()]) + num_weights = sum([p.numel() for p in self.base_embedding.parameters()]) assert num_weights == 1248 def test_zero_parameters(self): @@ -35,10 +38,8 @@ def test_zero_parameters(self): assert sum_weights == 0 def test_cpu_forward(self): - input_ids = torch.tensor( - [0, 1, 2, 3], dtype=torch.int64).repeat((2, 1)) - position_ids = torch.tensor( - [0, 1, 2, 3], dtype=torch.int64).repeat((2, 1)) + input_ids = torch.tensor([0, 1, 2, 3], dtype=torch.int64).repeat((2, 1)) + position_ids = torch.tensor([0, 1, 2, 3], dtype=torch.int64).repeat((2, 1)) embeddings = self.base_embedding(input_ids, position_ids) assert embeddings.device.type == 'cpu' assert embeddings.shape[0] == self.base_embedding.max_sequence_length @@ -47,10 +48,8 @@ def test_cpu_forward(self): def test_gpu_forward(self): self.base_embedding.cuda() - input_ids = torch.tensor( - [0, 1, 2, 3], dtype=torch.int64).repeat((2, 1)).cuda() - position_ids = torch.tensor( - [0, 1, 2, 3], dtype=torch.int64).repeat((2, 1)).cuda() + input_ids = torch.tensor([0, 1, 2, 3], dtype=torch.int64).repeat((2, 1)).cuda() + position_ids = torch.tensor([0, 1, 2, 3], dtype=torch.int64).repeat((2, 1)).cuda() embeddings = self.base_embedding(input_ids, position_ids) assert embeddings.device.type == 'cuda' assert embeddings.shape[0] == self.base_embedding.max_sequence_length diff --git a/tests/unit_tests/models/test_bert_model.py b/tests/unit_tests/models/test_bert_model.py index f6722f66a3..b1b544698b 100644 --- a/tests/unit_tests/models/test_bert_model.py +++ b/tests/unit_tests/models/test_bert_model.py @@ -1,33 +1,45 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -import pytest +import os +import pytest import torch -import os from pkg_resources import packaging -from megatron.core.transformer.transformer_config import TransformerConfig +from pytest_mock import mocker + +from megatron.core.models.bert.bert_layer_specs import bert_layer_with_transformer_engine_spec from megatron.core.models.bert.bert_model import BertModel -from tests.unit_tests.test_utilities import Utils from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.models.bert.bert_layer_specs import bert_layer_with_transformer_engine_spec -from pytest_mock import mocker +from megatron.core.transformer.transformer_config import TransformerConfig +from tests.unit_tests.test_utilities import Utils + class TestBertModel: def setup_method(self, method): - os.environ['NVTE_ALLOW_NONDETERMINISTIC_ALGO'] = '0' #Bert does not support flash attention + os.environ['NVTE_ALLOW_NONDETERMINISTIC_ALGO'] = ( + '0' # Bert does not support flash attention + ) tp = 1 pp = 1 Utils.initialize_model_parallel(tp, pp) model_parallel_cuda_manual_seed(123) transformer_config = TransformerConfig( - num_layers=2, hidden_size=12, num_attention_heads=4, - use_cpu_initialization=True, perform_initialization=True, - tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp, pipeline_dtype=torch.bfloat16 + num_layers=2, + hidden_size=12, + num_attention_heads=4, + use_cpu_initialization=True, + perform_initialization=True, + tensor_model_parallel_size=tp, + pipeline_model_parallel_size=pp, + pipeline_dtype=torch.bfloat16, ) self.bert_model = BertModel( - config=transformer_config, num_tokentypes=0, - transformer_layer_spec=bert_layer_with_transformer_engine_spec, vocab_size=100, max_sequence_length=4 + config=transformer_config, + num_tokentypes=0, + transformer_layer_spec=bert_layer_with_transformer_engine_spec, + vocab_size=100, + max_sequence_length=4, ) def teardown_method(self, method): @@ -77,66 +89,105 @@ def test_post_process_forward(self): class TestBertModelAssertions: def test_te_assertions_te_less_than_1_7(self, mocker): - os.environ.pop('NVTE_ALLOW_NONDETERMINISTIC_ALGO',None) - os.environ.pop('NVTE_FLASH_ATTN',None) - os.environ.pop('NVTE_FUSED_ATTN',None) + os.environ.pop('NVTE_ALLOW_NONDETERMINISTIC_ALGO', None) + os.environ.pop('NVTE_FLASH_ATTN', None) + os.environ.pop('NVTE_FUSED_ATTN', None) tp = 1 pp = 1 - Utils.initialize_model_parallel(tp, pp) + Utils.initialize_model_parallel(tp, pp) model_parallel_cuda_manual_seed(123) transformer_config = TransformerConfig( - num_layers=2, hidden_size=12, num_attention_heads=4, - use_cpu_initialization=True, perform_initialization=True, - tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp, pipeline_dtype=torch.bfloat16 + num_layers=2, + hidden_size=12, + num_attention_heads=4, + use_cpu_initialization=True, + perform_initialization=True, + tensor_model_parallel_size=tp, + pipeline_model_parallel_size=pp, + pipeline_dtype=torch.bfloat16, ) with pytest.raises(Exception) as exc_info: - mocker.patch("megatron.core.models.bert.bert_model.get_te_version", return_value = packaging.version.Version("1.4")) + mocker.patch( + "megatron.core.models.bert.bert_model.get_te_version", + return_value=packaging.version.Version("1.4"), + ) self.bert_model = BertModel( - config=transformer_config, num_tokentypes=0, - transformer_layer_spec=bert_layer_with_transformer_engine_spec, vocab_size=100, max_sequence_length=4 + config=transformer_config, + num_tokentypes=0, + transformer_layer_spec=bert_layer_with_transformer_engine_spec, + vocab_size=100, + max_sequence_length=4, ) - assert str(exc_info.value) == "Flash and fused attention is not supported with transformer engine version < 1.7. Set NVTE_FLASH_ATTN=0 and NVTE_FUSED_ATTN=0 or upgrade transformer engine >= 1.7 or set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0" + assert ( + str(exc_info.value) + == "Flash and fused attention is not supported with transformer engine version < 1.7. Set NVTE_FLASH_ATTN=0 and NVTE_FUSED_ATTN=0 or upgrade transformer engine >= 1.7 or set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0" + ) def test_te_assertions_te_equal_to_1_7_exception(self, mocker): - os.environ.pop('NVTE_ALLOW_NONDETERMINISTIC_ALGO',None) + os.environ.pop('NVTE_ALLOW_NONDETERMINISTIC_ALGO', None) os.environ['NVTE_FLASH_ATTN'] = '0' os.environ['NVTE_FUSED_ATTN'] = '0' tp = 1 pp = 1 - Utils.initialize_model_parallel(tp, pp) + Utils.initialize_model_parallel(tp, pp) model_parallel_cuda_manual_seed(123) transformer_config = TransformerConfig( - num_layers=2, hidden_size=12, num_attention_heads=4, - use_cpu_initialization=True, perform_initialization=True, - tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp, pipeline_dtype=torch.bfloat16 + num_layers=2, + hidden_size=12, + num_attention_heads=4, + use_cpu_initialization=True, + perform_initialization=True, + tensor_model_parallel_size=tp, + pipeline_model_parallel_size=pp, + pipeline_dtype=torch.bfloat16, ) with pytest.raises(Exception) as exc_info: - mocker.patch("megatron.core.models.bert.bert_model.get_te_version", return_value = packaging.version.Version("1.7")) + mocker.patch( + "megatron.core.models.bert.bert_model.get_te_version", + return_value=packaging.version.Version("1.7"), + ) self.bert_model = BertModel( - config=transformer_config, num_tokentypes=0, - transformer_layer_spec=bert_layer_with_transformer_engine_spec, vocab_size=100, max_sequence_length=4 + config=transformer_config, + num_tokentypes=0, + transformer_layer_spec=bert_layer_with_transformer_engine_spec, + vocab_size=100, + max_sequence_length=4, ) - assert str(exc_info.value) == "Set env variable NVTE_FLASH_ATTN to 1 or NVTE_FUSED_ATTN to 1 to use a more optimized attention kernal. Currently using unfused attention path. If you want to proceed with this path set AttnMaskType in module spec to be arbitrary" + assert ( + str(exc_info.value) + == "Set env variable NVTE_FLASH_ATTN to 1 or NVTE_FUSED_ATTN to 1 to use a more optimized attention kernal. Currently using unfused attention path. If you want to proceed with this path set AttnMaskType in module spec to be arbitrary" + ) def test_te_assertions_te_equal_to_1_7_no_exception(self, mocker): - os.environ.pop('NVTE_ALLOW_NONDETERMINISTIC_ALGO',None) - os.environ.pop('NVTE_FLASH_ATTN',None) - os.environ.pop('NVTE_FUSED_ATTN',None) + os.environ.pop('NVTE_ALLOW_NONDETERMINISTIC_ALGO', None) + os.environ.pop('NVTE_FLASH_ATTN', None) + os.environ.pop('NVTE_FUSED_ATTN', None) tp = 1 pp = 1 - Utils.initialize_model_parallel(tp, pp) + Utils.initialize_model_parallel(tp, pp) model_parallel_cuda_manual_seed(123) transformer_config = TransformerConfig( - num_layers=2, hidden_size=12, num_attention_heads=4, - use_cpu_initialization=True, perform_initialization=True, - tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp, pipeline_dtype=torch.bfloat16 + num_layers=2, + hidden_size=12, + num_attention_heads=4, + use_cpu_initialization=True, + perform_initialization=True, + tensor_model_parallel_size=tp, + pipeline_model_parallel_size=pp, + pipeline_dtype=torch.bfloat16, ) - mocker.patch("megatron.core.models.bert.bert_model.get_te_version", return_value = packaging.version.Version("1.7")) + mocker.patch( + "megatron.core.models.bert.bert_model.get_te_version", + return_value=packaging.version.Version("1.7"), + ) self.bert_model = BertModel( - config=transformer_config, num_tokentypes=0, - transformer_layer_spec=bert_layer_with_transformer_engine_spec, vocab_size=100, max_sequence_length=4 + config=transformer_config, + num_tokentypes=0, + transformer_layer_spec=bert_layer_with_transformer_engine_spec, + vocab_size=100, + max_sequence_length=4, ) - Utils.destroy_model_parallel() \ No newline at end of file + Utils.destroy_model_parallel() diff --git a/tests/unit_tests/models/test_clip_vit_model.py b/tests/unit_tests/models/test_clip_vit_model.py index bc29f943af..fcbf2ad440 100644 --- a/tests/unit_tests/models/test_clip_vit_model.py +++ b/tests/unit_tests/models/test_clip_vit_model.py @@ -16,12 +16,11 @@ def setup_method(self, method): Utils.initialize_model_parallel(1, 1) model_parallel_cuda_manual_seed(123) transformer_config = TransformerConfig( - num_layers=2, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True, + num_layers=2, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True ) transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec() self.model = CLIPViTModel( - transformer_config, transformer_layer_spec, - img_h=336, img_w=336, patch_dim=14, + transformer_config, transformer_layer_spec, img_h=336, img_w=336, patch_dim=14 ) def teardown_method(self, method): diff --git a/tests/unit_tests/models/test_llava_model.py b/tests/unit_tests/models/test_llava_model.py index f5681fc154..c65f2d3b87 100644 --- a/tests/unit_tests/models/test_llava_model.py +++ b/tests/unit_tests/models/test_llava_model.py @@ -21,7 +21,7 @@ def setup_method(self, method): num_layers=3, hidden_size=128, num_attention_heads=8, use_cpu_initialization=True ) vision_config = TransformerConfig( - num_layers=2, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True, + num_layers=2, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True ) vision_projection_config = TransformerConfig( num_layers=2, @@ -101,7 +101,7 @@ def test_forward(self): kv_dict = inference_params.key_value_memory_dict assert kv_dict["image_tokens_count"] == 577 - for layer_no in range(1, 4): # 3 layers in the model. + for layer_no in range(1, 4): # 3 layers in the model. layer_kv = kv_dict[layer_no] # Expected shape is [sequence_len, batch_size, num_heads, hidden_size_per_head] assert layer_kv[0].shape == layer_kv[1].shape == torch.Size((1601, 2, 8, 16)) diff --git a/tests/unit_tests/models/test_mamba_model.py b/tests/unit_tests/models/test_mamba_model.py index db9277f028..913adb538c 100644 --- a/tests/unit_tests/models/test_mamba_model.py +++ b/tests/unit_tests/models/test_mamba_model.py @@ -71,9 +71,7 @@ def test_forward(self): ).cuda() logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, + input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask ) assert logits.shape[0] == micro_batch_size diff --git a/tests/unit_tests/models/test_multimodal_projector.py b/tests/unit_tests/models/test_multimodal_projector.py index f5ef29c6e8..976dc489da 100644 --- a/tests/unit_tests/models/test_multimodal_projector.py +++ b/tests/unit_tests/models/test_multimodal_projector.py @@ -1,32 +1,40 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. import pytest - import torch -from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.models.gpt.gpt_layer_specs import _get_mlp_module_spec from megatron.core.models.vision.multimodal_projector import MultimodalProjector -from tests.unit_tests.test_utilities import Utils +from megatron.core.tensor_parallel.layers import ColumnParallelLinear from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.models.gpt.gpt_layer_specs import _get_mlp_module_spec from megatron.core.transformer.mlp import MLPSubmodules -from megatron.core.tensor_parallel.layers import ColumnParallelLinear +from megatron.core.transformer.transformer_config import TransformerConfig +from tests.unit_tests.test_utilities import Utils class TestMultimodalProjector: def setup_method(self, method): - Utils.initialize_model_parallel(1,1) + Utils.initialize_model_parallel(1, 1) model_parallel_cuda_manual_seed(123) - transformer_config = TransformerConfig(num_layers=1, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True) + transformer_config = TransformerConfig( + num_layers=1, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True + ) mlp_layer_spec = _get_mlp_module_spec().submodules - - affine_layer_spec = MLPSubmodules( - linear_fc1=ColumnParallelLinear, - linear_fc2=None, - ) - self.mlp = MultimodalProjector(config = transformer_config, submodules = mlp_layer_spec, projector_type = "mlp", input_size = 1024) - self.affine = MultimodalProjector(config = transformer_config, submodules = affine_layer_spec, projector_type = "affine", input_size = 1024) + + affine_layer_spec = MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=None) + self.mlp = MultimodalProjector( + config=transformer_config, + submodules=mlp_layer_spec, + projector_type="mlp", + input_size=1024, + ) + self.affine = MultimodalProjector( + config=transformer_config, + submodules=affine_layer_spec, + projector_type="affine", + input_size=1024, + ) def teardown_method(self, method): Utils.destroy_model_parallel() @@ -65,4 +73,3 @@ def test_save_load(self, tmp_path): torch.save(self.affine.state_dict(), path) self.affine.load_state_dict(torch.load(path)) - diff --git a/tests/unit_tests/models/test_t5_model.py b/tests/unit_tests/models/test_t5_model.py index 75d2286960..efe12b78f4 100644 --- a/tests/unit_tests/models/test_t5_model.py +++ b/tests/unit_tests/models/test_t5_model.py @@ -1,19 +1,22 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. from copy import deepcopy -import pytest +import pytest import torch -import megatron.core.parallel_state as ps -from megatron.core.transformer.transformer_config import TransformerConfig +import megatron.core.parallel_state as ps from megatron.core.models.T5.t5_model import T5Model -from tests.unit_tests.test_utilities import Utils +from megatron.core.models.T5.t5_spec import ( + get_t5_decoder_with_local_block_spec, + get_t5_decoder_with_transformer_engine_block_spec, + get_t5_encoder_with_local_block_spec, + get_t5_encoder_with_transformer_engine_block_spec, +) from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.models.T5.t5_spec import (get_t5_encoder_with_transformer_engine_block_spec, - get_t5_decoder_with_transformer_engine_block_spec, - get_t5_encoder_with_local_block_spec, - get_t5_decoder_with_local_block_spec) +from megatron.core.transformer.transformer_config import TransformerConfig +from tests.unit_tests.test_utilities import Utils + class TestT5Model: @@ -27,9 +30,15 @@ def setup_method(self, method): ) model_parallel_cuda_manual_seed(123) transformer_config = TransformerConfig( - num_layers=12, hidden_size=768, num_attention_heads=12, kv_channels=64, ffn_hidden_size=3072, - use_cpu_initialization=True, pipeline_dtype=torch.bfloat16, - tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp, + num_layers=12, + hidden_size=768, + num_attention_heads=12, + kv_channels=64, + ffn_hidden_size=3072, + use_cpu_initialization=True, + pipeline_dtype=torch.bfloat16, + tensor_model_parallel_size=tp, + pipeline_model_parallel_size=pp, ) rank = ps.get_pipeline_model_parallel_rank() world_size = ps.get_pipeline_model_parallel_world_size() @@ -38,15 +47,21 @@ def setup_method(self, method): first_decoder_rank = pp pre_process = rank == 0 or rank == first_decoder_rank - post_process = (rank == (first_decoder_rank - 1)) or (rank == (world_size-1)) + post_process = (rank == (first_decoder_rank - 1)) or (rank == (world_size - 1)) add_encoder = ps.is_inside_encoder(rank) add_decoder = ps.is_inside_decoder(rank) self.t5_model = T5Model( - encoder_config=transformer_config, config=transformer_config, transformer_encoder_layer_spec=en_block_spec, - transformer_decoder_layer_spec=de_block_spec, vocab_size=29184, max_sequence_length=4, - pre_process=pre_process, post_process=post_process, - add_encoder=add_encoder, add_decoder=add_decoder, + encoder_config=transformer_config, + config=transformer_config, + transformer_encoder_layer_spec=en_block_spec, + transformer_decoder_layer_spec=de_block_spec, + vocab_size=29184, + max_sequence_length=4, + pre_process=pre_process, + post_process=post_process, + add_encoder=add_encoder, + add_decoder=add_decoder, ) def teardown_method(self, method): @@ -96,14 +111,22 @@ def test_post_process_forward(self): self.t5_model.cuda() data = list(range(sequence_length)) - encoder_input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - decoder_input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() + encoder_input_ids = ( + torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() + ) + decoder_input_ids = ( + torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() + ) encoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda() decoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda() - encoder_decoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda() + encoder_decoder_attn_mask = torch.ones( + (1, sequence_length, sequence_length), dtype=bool + ).cuda() if self.t5_model.add_decoder: - encoder_hidden_states = torch.zeros((sequence_length, micro_batch_size, config.hidden_size), dtype=torch.float32).cuda() + encoder_hidden_states = torch.zeros( + (sequence_length, micro_batch_size, config.hidden_size), dtype=torch.float32 + ).cuda() else: encoder_hidden_states = None @@ -113,20 +136,22 @@ def test_post_process_forward(self): encoder_attn_mask=encoder_attn_mask, decoder_attn_mask=decoder_attn_mask, encoder_decoder_attn_mask=encoder_decoder_attn_mask, - encoder_hidden_states=encoder_hidden_states + encoder_hidden_states=encoder_hidden_states, ) if self.t5_model.add_decoder: logits = output assert logits.shape[0] == micro_batch_size assert logits.shape[1] == sequence_length - assert logits.shape[2] == self.t5_model.vocab_size // ps.get_tensor_model_parallel_world_size() + assert ( + logits.shape[2] + == self.t5_model.vocab_size // ps.get_tensor_model_parallel_world_size() + ) else: encoder_hidden_states = output assert encoder_hidden_states.shape[0] == sequence_length assert encoder_hidden_states.shape[1] == micro_batch_size assert encoder_hidden_states.shape[2] == config.hidden_size - def test_forward_output_encoder_hidden_only(self): config: TransformerConfig = self.t5_model.config sequence_length = self.t5_model.max_sequence_length @@ -135,11 +160,17 @@ def test_forward_output_encoder_hidden_only(self): self.t5_model.cuda() data = list(range(sequence_length)) - encoder_input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - decoder_input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() + encoder_input_ids = ( + torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() + ) + decoder_input_ids = ( + torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() + ) encoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda() decoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda() - encoder_decoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda() + encoder_decoder_attn_mask = torch.ones( + (1, sequence_length, sequence_length), dtype=bool + ).cuda() encoder_hidden_states = self.t5_model.forward( encoder_input_ids=encoder_input_ids, @@ -147,7 +178,7 @@ def test_forward_output_encoder_hidden_only(self): encoder_attn_mask=encoder_attn_mask, decoder_attn_mask=decoder_attn_mask, encoder_decoder_attn_mask=encoder_decoder_attn_mask, - output_encoder_hidden_only=True + output_encoder_hidden_only=True, ) if self.t5_model.add_decoder: assert encoder_hidden_states is None @@ -164,12 +195,20 @@ def test_forward_with_encoder_hidden_states(self): self.t5_model.cuda() data = list(range(sequence_length)) - encoder_input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() - decoder_input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() + encoder_input_ids = ( + torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() + ) + decoder_input_ids = ( + torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda() + ) encoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda() decoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda() - encoder_decoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda() - encoder_hidden_states = torch.zeros((sequence_length, micro_batch_size, config.hidden_size), dtype=torch.float32).cuda() + encoder_decoder_attn_mask = torch.ones( + (1, sequence_length, sequence_length), dtype=bool + ).cuda() + encoder_hidden_states = torch.zeros( + (sequence_length, micro_batch_size, config.hidden_size), dtype=torch.float32 + ).cuda() output = self.t5_model.forward( encoder_input_ids=None, @@ -177,13 +216,16 @@ def test_forward_with_encoder_hidden_states(self): encoder_attn_mask=encoder_attn_mask, decoder_attn_mask=decoder_attn_mask, encoder_decoder_attn_mask=encoder_decoder_attn_mask, - encoder_hidden_states=encoder_hidden_states + encoder_hidden_states=encoder_hidden_states, ) if self.t5_model.add_decoder: logits = output assert logits.shape[0] == micro_batch_size assert logits.shape[1] == sequence_length - assert logits.shape[2] == self.t5_model.vocab_size // ps.get_tensor_model_parallel_world_size() + assert ( + logits.shape[2] + == self.t5_model.vocab_size // ps.get_tensor_model_parallel_world_size() + ) else: encoder_hidden_states = output assert encoder_hidden_states.shape[0] == sequence_length @@ -201,4 +243,3 @@ def test_state_dict_for_save_checkpoint(self): def test_load_state_dict(self): pass - diff --git a/tests/unit_tests/pipeline_parallel/test_schedules.py b/tests/unit_tests/pipeline_parallel/test_schedules.py index 5dd6605d68..06994094fc 100644 --- a/tests/unit_tests/pipeline_parallel/test_schedules.py +++ b/tests/unit_tests/pipeline_parallel/test_schedules.py @@ -1,30 +1,51 @@ +import pytest import torch -from tests.unit_tests.test_utilities import Utils -from megatron.core import ModelParallelConfig +from pytest_mock import mocker + import megatron.core.pipeline_parallel.schedules as schedule -from pytest_mock import mocker -import pytest +from megatron.core import ModelParallelConfig +from tests.unit_tests.test_utilities import Utils rank = Utils.rank - + + def test_get_forward_backward_func(): Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=1) - assert(schedule.get_forward_backward_func() == schedule.forward_backward_no_pipelining) + assert schedule.get_forward_backward_func() == schedule.forward_backward_no_pipelining Utils.destroy_model_parallel() Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4) - assert(schedule.get_forward_backward_func() == schedule.forward_backward_pipelining_without_interleaving) + assert ( + schedule.get_forward_backward_func() + == schedule.forward_backward_pipelining_without_interleaving + ) Utils.destroy_model_parallel() - Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4, virtual_pipeline_model_parallel_size=2) - assert(schedule.get_forward_backward_func() == schedule.forward_backward_pipelining_with_interleaving) + Utils.initialize_model_parallel( + tensor_model_parallel_size=2, + pipeline_model_parallel_size=4, + virtual_pipeline_model_parallel_size=2, + ) + assert ( + schedule.get_forward_backward_func() + == schedule.forward_backward_pipelining_with_interleaving + ) Utils.destroy_model_parallel() - Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=2, virtual_pipeline_model_parallel_size=4) - assert(schedule.get_forward_backward_func() == schedule.forward_backward_pipelining_with_interleaving) + Utils.initialize_model_parallel( + tensor_model_parallel_size=2, + pipeline_model_parallel_size=2, + virtual_pipeline_model_parallel_size=4, + ) + assert ( + schedule.get_forward_backward_func() + == schedule.forward_backward_pipelining_with_interleaving + ) Utils.destroy_model_parallel() + def test_deallocate_output_tensor(): out = torch.tensor([[1, 2, 3], [4, 5, 6]]) schedule.deallocate_output_tensor(out) - assert(out.nelement() == 6) + assert out.nelement() == 6 + def test_forward_backward_func_without_pipeline_parallel(mocker): from megatron.core.pipeline_parallel import get_forward_backward_func @@ -33,43 +54,51 @@ def test_forward_backward_func_without_pipeline_parallel(mocker): def forward_step_func(data_iterator, model): import os + rank = int(os.environ['LOCAL_RANK']) - dummy_data = torch.ones(1,4) + dummy_data = torch.ones(1, 4) + def loss_func(output_tensor): - return rank, {'loss_reduced':rank} + return rank, {'loss_reduced': rank} + return model(dummy_data), loss_func - model = torch.nn.Linear(4,1) + model = torch.nn.Linear(4, 1) model.model_type = 'unit-test' + def set_input_tensor(input_tensor): return None + model.set_input_tensor = set_input_tensor forward_backward_func = get_forward_backward_func() - assert(schedule.get_forward_backward_func() == schedule.forward_backward_no_pipelining) + assert schedule.get_forward_backward_func() == schedule.forward_backward_no_pipelining mocker.patch("megatron.core.pipeline_parallel.schedules.custom_backward", return_value=2) - config = ModelParallelConfig( - pipeline_model_parallel_size = 1 - ) + config = ModelParallelConfig(pipeline_model_parallel_size=1) model.config = config losses_reduced = forward_backward_func( forward_step_func=forward_step_func, - data_iterator=range(0,100), + data_iterator=range(0, 100), model=[model], num_microbatches=4, seq_length=None, micro_batch_size=None, - forward_only=True) - + forward_only=True, + ) + + loss_reduced_expected = [ + {'loss_reduced': rank}, + {'loss_reduced': rank}, + {'loss_reduced': rank}, + {'loss_reduced': rank}, + ] - loss_reduced_expected = [{'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}] - - for i,j in zip(losses_reduced, loss_reduced_expected): + for i, j in zip(losses_reduced, loss_reduced_expected): print(losses_reduced) - assert(i['loss_reduced'] == j['loss_reduced']) - Utils.destroy_model_parallel() + assert i['loss_reduced'] == j['loss_reduced'] + Utils.destroy_model_parallel() def test_forward_backward_func_with_pipeline_parallel(mocker): @@ -79,77 +108,99 @@ def test_forward_backward_func_with_pipeline_parallel(mocker): def forward_step_func(data_iterator, model): import os + rank = int(os.environ['LOCAL_RANK']) + def loss_func(output_tensor): - return rank, {'loss_reduced':rank} - return torch.rand(512,8,256).cuda(), loss_func + return rank, {'loss_reduced': rank} - model = torch.nn.Linear(4,1) + return torch.rand(512, 8, 256).cuda(), loss_func + + model = torch.nn.Linear(4, 1) model.model_type = 'unit-test' + def set_input_tensor(input_tensor): return None + model.set_input_tensor = set_input_tensor forward_backward_func = get_forward_backward_func() - assert(schedule.get_forward_backward_func() == schedule.forward_backward_pipelining_without_interleaving) + assert ( + schedule.get_forward_backward_func() + == schedule.forward_backward_pipelining_without_interleaving + ) sequence_length = 512 micro_batch_size = 8 hidden_size = 256 config = ModelParallelConfig( - pipeline_model_parallel_size = 4, - sequence_parallel = False, - pipeline_dtype=torch.float, + pipeline_model_parallel_size=4, sequence_parallel=False, pipeline_dtype=torch.float ) config.hidden_size = hidden_size model.config = config - + losses_reduced = forward_backward_func( forward_step_func=forward_step_func, data_iterator=None, model=[model], - num_microbatches= micro_batch_size, + num_microbatches=micro_batch_size, seq_length=sequence_length, micro_batch_size=micro_batch_size, - forward_only=True) - - loss_reduced_expected = [{'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}] - for i,j in zip(losses_reduced, loss_reduced_expected): + forward_only=True, + ) + + loss_reduced_expected = [ + {'loss_reduced': rank}, + {'loss_reduced': rank}, + {'loss_reduced': rank}, + {'loss_reduced': rank}, + ] + for i, j in zip(losses_reduced, loss_reduced_expected): print(losses_reduced) - assert(i['loss_reduced'] == j['loss_reduced']) - Utils.destroy_model_parallel() + assert i['loss_reduced'] == j['loss_reduced'] + Utils.destroy_model_parallel() def test_forward_backward_func_with_interleaving(mocker): - from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.enums import ModelType + from megatron.core.pipeline_parallel import get_forward_backward_func - Utils.initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=4, virtual_pipeline_model_parallel_size=2) + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=4, + virtual_pipeline_model_parallel_size=2, + ) def forward_step_func(data_iterator, model): import os + rank = int(os.environ['LOCAL_RANK']) + def loss_func(output_tensor): - return rank, {'loss_reduced':rank} - return torch.rand(512,8,256).cuda(), loss_func + return rank, {'loss_reduced': rank} + + return torch.rand(512, 8, 256).cuda(), loss_func + + model = torch.nn.Linear(4, 1) - model = torch.nn.Linear(4,1) def set_input_tensor(input_tensor): return None + model.set_input_tensor = set_input_tensor forward_backward_func = get_forward_backward_func() - assert(schedule.get_forward_backward_func() == schedule.forward_backward_pipelining_with_interleaving) + assert ( + schedule.get_forward_backward_func() + == schedule.forward_backward_pipelining_with_interleaving + ) sequence_length = 512 micro_batch_size = 8 hidden_size = 256 config = ModelParallelConfig( - pipeline_model_parallel_size = 4, - sequence_parallel = False, - pipeline_dtype=torch.float, + pipeline_model_parallel_size=4, sequence_parallel=False, pipeline_dtype=torch.float ) config.hidden_size = hidden_size model.config = config @@ -160,53 +211,61 @@ def set_input_tensor(input_tensor): model.model_type = ModelType.encoder_and_decoder forward_backward_func( forward_step_func=forward_step_func, - data_iterator=[range(0,100)], + data_iterator=[range(0, 100)], model=[model, model], - num_microbatches= micro_batch_size, + num_microbatches=micro_batch_size, seq_length=sequence_length, - micro_batch_size=micro_batch_size, + micro_batch_size=micro_batch_size, decoder_seq_length=sequence_length, - forward_only=True) - + forward_only=True, + ) + with pytest.raises(RuntimeError): model.model_type = ModelType.encoder_or_decoder forward_backward_func( forward_step_func=forward_step_func, - data_iterator=[range(0,100)], + data_iterator=[range(0, 100)], model=[model, model], - num_microbatches= micro_batch_size, + num_microbatches=micro_batch_size, seq_length=sequence_length, - micro_batch_size=micro_batch_size, + micro_batch_size=micro_batch_size, decoder_seq_length=256, - forward_only=True) - + forward_only=True, + ) + with pytest.raises(RuntimeError): model.model_type = ModelType.encoder_or_decoder forward_backward_func( forward_step_func=forward_step_func, - data_iterator=[range(0,100)], + data_iterator=[range(0, 100)], model=[model, model], - num_microbatches= 7, + num_microbatches=7, seq_length=sequence_length, - micro_batch_size=micro_batch_size, + micro_batch_size=micro_batch_size, decoder_seq_length=512, - forward_only=True) + forward_only=True, + ) - model.model_type = ModelType.encoder_or_decoder losses_reduced = forward_backward_func( forward_step_func=forward_step_func, - data_iterator=[range(0,100), range(0,100)], + data_iterator=[range(0, 100), range(0, 100)], model=[model, model], - num_microbatches= micro_batch_size, + num_microbatches=micro_batch_size, seq_length=sequence_length, - micro_batch_size=micro_batch_size, + micro_batch_size=micro_batch_size, decoder_seq_length=sequence_length, - forward_only=True) - - loss_reduced_expected = [{'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}] - for i,j in zip(losses_reduced, loss_reduced_expected): + forward_only=True, + ) + + loss_reduced_expected = [ + {'loss_reduced': rank}, + {'loss_reduced': rank}, + {'loss_reduced': rank}, + {'loss_reduced': rank}, + ] + for i, j in zip(losses_reduced, loss_reduced_expected): print(losses_reduced) - assert(i['loss_reduced'] == j['loss_reduced']) + assert i['loss_reduced'] == j['loss_reduced'] - Utils.destroy_model_parallel() + Utils.destroy_model_parallel() diff --git a/tests/unit_tests/tensor_parallel/test_cross_entropy.py b/tests/unit_tests/tensor_parallel/test_cross_entropy.py index a29365ee43..66982fd234 100644 --- a/tests/unit_tests/tensor_parallel/test_cross_entropy.py +++ b/tests/unit_tests/tensor_parallel/test_cross_entropy.py @@ -1,14 +1,34 @@ -from megatron.core.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy +import numpy as np import torch + +from megatron.core.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy from tests.unit_tests.test_utilities import Utils -import numpy as np + def test_vocab_parallel_cross_entropy(): - Utils.initialize_model_parallel(4,2) - vocab_parallel_logits = torch.range(0,7).repeat(16,4).cuda() - target = torch.arange(0,32,2).cuda() + Utils.initialize_model_parallel(4, 2) + vocab_parallel_logits = torch.range(0, 7).repeat(16, 4).cuda() + target = torch.arange(0, 32, 2).cuda() output = vocab_parallel_cross_entropy(vocab_parallel_logits, target) - expected_output = torch.tensor([10.2309, 8.2309, 6.2309, 4.2309, 10.2309, 8.2309, 6.2309, 4.2309, - 10.2309, 8.2309, 6.2309, 4.2309, 10.2309, 8.2309, 6.2309, 4.2309]).cuda() - assert(torch.equal(torch.round(expected_output), torch.round(output))) - Utils.destroy_model_parallel() \ No newline at end of file + expected_output = torch.tensor( + [ + 10.2309, + 8.2309, + 6.2309, + 4.2309, + 10.2309, + 8.2309, + 6.2309, + 4.2309, + 10.2309, + 8.2309, + 6.2309, + 4.2309, + 10.2309, + 8.2309, + 6.2309, + 4.2309, + ] + ).cuda() + assert torch.equal(torch.round(expected_output), torch.round(output)) + Utils.destroy_model_parallel() diff --git a/tests/unit_tests/tensor_parallel/test_data.py b/tests/unit_tests/tensor_parallel/test_data.py index 38a39ce37f..211d48b4fd 100644 --- a/tests/unit_tests/tensor_parallel/test_data.py +++ b/tests/unit_tests/tensor_parallel/test_data.py @@ -1,21 +1,23 @@ -from megatron.core.tensor_parallel.data import broadcast_data import torch + +from megatron.core.tensor_parallel.data import broadcast_data from tests.unit_tests.test_utilities import Utils + def test_broadcast_data(): - Utils.initialize_model_parallel(2,4) + Utils.initialize_model_parallel(2, 4) input_data = { - 0 : torch.ones((8,8)).cuda() * 0.0, - 1 : torch.ones((8,8)).cuda() * 1.0, - 2 : torch.ones((8,8)).cuda() * 2.0, - 3 : torch.ones((8,8)).cuda() * 3.0, - 4 : torch.ones((8,8)).cuda() * 4.0, - 5 : torch.ones((8,8)).cuda() * 5.0, - 6 : torch.ones((8,8)).cuda() * 6.0, - 7 : torch.ones((8,8)).cuda() * 7.0 - } + 0: torch.ones((8, 8)).cuda() * 0.0, + 1: torch.ones((8, 8)).cuda() * 1.0, + 2: torch.ones((8, 8)).cuda() * 2.0, + 3: torch.ones((8, 8)).cuda() * 3.0, + 4: torch.ones((8, 8)).cuda() * 4.0, + 5: torch.ones((8, 8)).cuda() * 5.0, + 6: torch.ones((8, 8)).cuda() * 6.0, + 7: torch.ones((8, 8)).cuda() * 7.0, + } dtype = torch.float32 - actual_output = broadcast_data([0,1],input_data, dtype) - assert(torch.equal(actual_output[0], input_data[0])) - assert(torch.equal(actual_output[1], input_data[1])) - Utils.destroy_model_parallel() \ No newline at end of file + actual_output = broadcast_data([0, 1], input_data, dtype) + assert torch.equal(actual_output[0], input_data[0]) + assert torch.equal(actual_output[1], input_data[1]) + Utils.destroy_model_parallel() diff --git a/tests/unit_tests/tensor_parallel/test_initialization.py b/tests/unit_tests/tensor_parallel/test_initialization.py index 346ae241e0..9fcc38c259 100644 --- a/tests/unit_tests/tensor_parallel/test_initialization.py +++ b/tests/unit_tests/tensor_parallel/test_initialization.py @@ -1,20 +1,25 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. import pytest - import torch import megatron.core.parallel_state as ps -from megatron.core.tensor_parallel.layers import VocabParallelEmbedding, RowParallelLinear, ColumnParallelLinear -from tests.unit_tests.test_utilities import Utils +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec +from megatron.core.tensor_parallel.layers import ( + ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding, +) from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec +from tests.unit_tests.test_utilities import Utils + class Test: - transformer_config = TransformerConfig(num_layers=1, hidden_size=12, - num_attention_heads=4, use_cpu_initialization=True) + transformer_config = TransformerConfig( + num_layers=1, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True + ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_embedding_init(self): @@ -23,22 +28,27 @@ def test_embedding_init(self): torch.manual_seed(42) model_parallel_cuda_manual_seed(42) - - tp1 = VocabParallelEmbedding(num_embeddings=16, embedding_dim=4, - init_method=self.transformer_config.init_method, - config=self.transformer_config).weight + tp1 = VocabParallelEmbedding( + num_embeddings=16, + embedding_dim=4, + init_method=self.transformer_config.init_method, + config=self.transformer_config, + ).weight Utils.destroy_model_parallel() Utils.initialize_model_parallel(4, 1) torch.manual_seed(42) model_parallel_cuda_manual_seed(41) # intentionally different. - tp4 = VocabParallelEmbedding(num_embeddings=16, embedding_dim=4, - init_method=self.transformer_config.init_method, - config=self.transformer_config).weight + tp4 = VocabParallelEmbedding( + num_embeddings=16, + embedding_dim=4, + init_method=self.transformer_config.init_method, + config=self.transformer_config, + ).weight rank = ps.get_tensor_model_parallel_rank() assert tp4.shape[0] * 4 == tp1.shape[0] - assert torch.equal(tp1[rank*4:(rank+1)*4], tp4) + assert torch.equal(tp1[rank * 4 : (rank + 1) * 4], tp4) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_row_init(self): @@ -47,26 +57,33 @@ def test_row_init(self): torch.manual_seed(42) model_parallel_cuda_manual_seed(42) - tp1 = RowParallelLinear(input_size=16, output_size=16, - init_method=self.transformer_config.init_method, - bias=True, input_is_parallel=False, - config=self.transformer_config, - skip_bias_add=False).weight + tp1 = RowParallelLinear( + input_size=16, + output_size=16, + init_method=self.transformer_config.init_method, + bias=True, + input_is_parallel=False, + config=self.transformer_config, + skip_bias_add=False, + ).weight Utils.destroy_model_parallel() Utils.initialize_model_parallel(4, 1) torch.manual_seed(42) model_parallel_cuda_manual_seed(41) # intentionally different. - tp4 = RowParallelLinear(input_size=16, output_size=16, - init_method=self.transformer_config.init_method, - bias=True, - input_is_parallel=False, - config=self.transformer_config, - skip_bias_add=False).weight + tp4 = RowParallelLinear( + input_size=16, + output_size=16, + init_method=self.transformer_config.init_method, + bias=True, + input_is_parallel=False, + config=self.transformer_config, + skip_bias_add=False, + ).weight rank = ps.get_tensor_model_parallel_rank() assert tp4.shape[1] * 4 == tp1.shape[1] - assert torch.equal(tp1[:, rank*4:(rank+1)*4], tp4) + assert torch.equal(tp1[:, rank * 4 : (rank + 1) * 4], tp4) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_col_init(self): @@ -75,20 +92,28 @@ def test_col_init(self): torch.manual_seed(42) model_parallel_cuda_manual_seed(42) - tp1 = ColumnParallelLinear(input_size=16, output_size=16, - init_method=self.transformer_config.init_method, - bias=True, config=self.transformer_config, - skip_bias_add=False).weight + tp1 = ColumnParallelLinear( + input_size=16, + output_size=16, + init_method=self.transformer_config.init_method, + bias=True, + config=self.transformer_config, + skip_bias_add=False, + ).weight Utils.destroy_model_parallel() Utils.initialize_model_parallel(4, 1) torch.manual_seed(42) model_parallel_cuda_manual_seed(41) # intentionally different. - tp4 = ColumnParallelLinear(input_size=16, output_size=16, - init_method=self.transformer_config.init_method, - bias=True, config=self.transformer_config, - skip_bias_add=False).weight + tp4 = ColumnParallelLinear( + input_size=16, + output_size=16, + init_method=self.transformer_config.init_method, + bias=True, + config=self.transformer_config, + skip_bias_add=False, + ).weight rank = ps.get_tensor_model_parallel_rank() assert tp4.shape[0] * 4 == tp1.shape[0] - assert torch.equal(tp1[rank*4:(rank+1)*4], tp4) + assert torch.equal(tp1[rank * 4 : (rank + 1) * 4], tp4) diff --git a/tests/unit_tests/tensor_parallel/test_mappings.py b/tests/unit_tests/tensor_parallel/test_mappings.py index 6be486ef3c..c6a789410c 100644 --- a/tests/unit_tests/tensor_parallel/test_mappings.py +++ b/tests/unit_tests/tensor_parallel/test_mappings.py @@ -1,135 +1,139 @@ +import torch + from megatron.core.tensor_parallel import mappings from tests.unit_tests.test_utilities import Utils -import torch + def test_CopyToModelParallelRegion(): - Utils.initialize_model_parallel(4,2) - input_data = torch.ones((1)).cuda()*Utils.rank + Utils.initialize_model_parallel(4, 2) + input_data = torch.ones((1)).cuda() * Utils.rank output_data = mappings._CopyToModelParallelRegion.backward(None, input_data) result = torch.ones(1).cuda() result = result * 22 if Utils.rank >= 4 else result * 6 - assert(torch.equal(output_data, result)) - assert(torch.equal(input_data, mappings.copy_to_tensor_model_parallel_region(input_data))) - assert(torch.equal(input_data, mappings._CopyToModelParallelRegion.symbolic(None, input_data))) + assert torch.equal(output_data, result) + assert torch.equal(input_data, mappings.copy_to_tensor_model_parallel_region(input_data)) + assert torch.equal(input_data, mappings._CopyToModelParallelRegion.symbolic(None, input_data)) Utils.destroy_model_parallel() + def test_ReduceFromModelParallelRegion(): - Utils.initialize_model_parallel(4,2) - input_data = torch.ones((1)).cuda()*Utils.rank + Utils.initialize_model_parallel(4, 2) + input_data = torch.ones((1)).cuda() * Utils.rank output_data = mappings._ReduceFromModelParallelRegion.symbolic(None, input_data) result = torch.ones(1).cuda() result = result * 22 if Utils.rank >= 4 else result * 6 - assert(torch.equal(output_data, result)) - input_data = torch.ones((1)).cuda()*Utils.rank - assert(torch.equal(mappings.reduce_from_tensor_model_parallel_region(input_data), result)) - assert(torch.equal(input_data, mappings._ReduceFromModelParallelRegion.backward(None, input_data))) + assert torch.equal(output_data, result) + input_data = torch.ones((1)).cuda() * Utils.rank + assert torch.equal(mappings.reduce_from_tensor_model_parallel_region(input_data), result) + assert torch.equal( + input_data, mappings._ReduceFromModelParallelRegion.backward(None, input_data) + ) Utils.destroy_model_parallel() + def test_ScatterToModelParallelRegion(): - Utils.initialize_model_parallel(4,2) - input_data = torch.rand((8,4)).cuda() + Utils.initialize_model_parallel(4, 2) + input_data = torch.rand((8, 4)).cuda() output_data = mappings.scatter_to_tensor_model_parallel_region(input_data) - req_dim = int(Utils.rank%(Utils.world_size/2)) - assert(torch.equal(output_data, input_data[:,req_dim].reshape((8,1)))) + req_dim = int(Utils.rank % (Utils.world_size / 2)) + assert torch.equal(output_data, input_data[:, req_dim].reshape((8, 1))) output_data = mappings._ScatterToModelParallelRegion.symbolic(None, input_data) - assert(torch.equal(output_data, input_data[:, req_dim].reshape((8,1)))) + assert torch.equal(output_data, input_data[:, req_dim].reshape((8, 1))) input_data = torch.ones(8).cuda() * Utils.rank actual_output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data) - expected_output = torch.cat(( - torch.ones(8)*0, - torch.ones(8)*1, - torch.ones(8)*2, - torch.ones(8)*3)).cuda() - if (Utils.rank >= 4): + expected_output = torch.cat( + (torch.ones(8) * 0, torch.ones(8) * 1, torch.ones(8) * 2, torch.ones(8) * 3) + ).cuda() + if Utils.rank >= 4: expected_output = expected_output + 4 - assert(torch.equal(actual_output_data, expected_output)) + assert torch.equal(actual_output_data, expected_output) Utils.destroy_model_parallel() + def test_GatherFromModelParallelRegion(): - Utils.initialize_model_parallel(4,2) - input_data = torch.rand((8,4)).cuda() - req_dim = int(Utils.rank%(Utils.world_size/2)) + Utils.initialize_model_parallel(4, 2) + input_data = torch.rand((8, 4)).cuda() + req_dim = int(Utils.rank % (Utils.world_size / 2)) output_data = mappings._GatherFromModelParallelRegion.backward(None, input_data) - assert(torch.equal(output_data, input_data[:, req_dim].reshape((8,1)))) + assert torch.equal(output_data, input_data[:, req_dim].reshape((8, 1))) input_data = torch.ones(8).cuda() * Utils.rank actual_output_data = mappings.gather_from_tensor_model_parallel_region(input_data) - expected_output = torch.cat(( - torch.ones(8)*0, - torch.ones(8)*1, - torch.ones(8)*2, - torch.ones(8)*3)).cuda() - if (Utils.rank >= 4): + expected_output = torch.cat( + (torch.ones(8) * 0, torch.ones(8) * 1, torch.ones(8) * 2, torch.ones(8) * 3) + ).cuda() + if Utils.rank >= 4: expected_output = expected_output + 4 - assert(torch.equal(actual_output_data, expected_output)) - assert(torch.equal(mappings._GatherFromModelParallelRegion.symbolic(None, input_data), expected_output)) + assert torch.equal(actual_output_data, expected_output) + assert torch.equal( + mappings._GatherFromModelParallelRegion.symbolic(None, input_data), expected_output + ) Utils.destroy_model_parallel() - + + def test_ScatterToSequenceParallelRegion(): - Utils.initialize_model_parallel(4,2) - input_data = torch.rand((8,4)).cuda() - req_dim = int(Utils.rank%(Utils.world_size/2))*2 + Utils.initialize_model_parallel(4, 2) + input_data = torch.rand((8, 4)).cuda() + req_dim = int(Utils.rank % (Utils.world_size / 2)) * 2 output_data = mappings._ScatterToSequenceParallelRegion.symbolic(None, input_data) - assert(torch.equal(output_data, input_data[req_dim:req_dim+2, :])) + assert torch.equal(output_data, input_data[req_dim : req_dim + 2, :]) output_data = mappings.scatter_to_sequence_parallel_region(input_data) - assert(torch.equal(output_data, input_data[req_dim:req_dim+2, :])) + assert torch.equal(output_data, input_data[req_dim : req_dim + 2, :]) input_data = torch.ones(4).cuda() * Utils.rank output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data) - expected_output = torch.concat(( - torch.ones(4)*0, - torch.ones(4)*1, - torch.ones(4)*2, - torch.ones(4)*3)).cuda() - if (Utils.rank >= 4): + expected_output = torch.concat( + (torch.ones(4) * 0, torch.ones(4) * 1, torch.ones(4) * 2, torch.ones(4) * 3) + ).cuda() + if Utils.rank >= 4: expected_output = expected_output + 4 - assert(torch.equal(output_data, expected_output)) + assert torch.equal(output_data, expected_output) Utils.destroy_model_parallel() + def test_GatherFromSequenceParallelRegion(): - Utils.initialize_model_parallel(4,2) + Utils.initialize_model_parallel(4, 2) input_data = torch.ones(4).cuda() * Utils.rank output_data = mappings.gather_from_sequence_parallel_region(input_data) - expected_output = torch.concat(( - torch.ones(4)*0, - torch.ones(4)*1, - torch.ones(4)*2, - torch.ones(4)*3)).cuda() - if (Utils.rank >= 4): + expected_output = torch.concat( + (torch.ones(4) * 0, torch.ones(4) * 1, torch.ones(4) * 2, torch.ones(4) * 3) + ).cuda() + if Utils.rank >= 4: expected_output = expected_output + 4 - assert(torch.equal(output_data, expected_output)) - assert(torch.equal(mappings._GatherFromSequenceParallelRegion.symbolic(None, input_data), expected_output)) - input_data = torch.vstack(( - torch.ones(4)*0, - torch.ones(4)*1, - torch.ones(4)*2, - torch.ones(4)*3)).cuda() + assert torch.equal(output_data, expected_output) + assert torch.equal( + mappings._GatherFromSequenceParallelRegion.symbolic(None, input_data), expected_output + ) + input_data = torch.vstack( + (torch.ones(4) * 0, torch.ones(4) * 1, torch.ones(4) * 2, torch.ones(4) * 3) + ).cuda() + class Ctx: tensor_parallel_output_grad = True + output_data = mappings._GatherFromSequenceParallelRegion.backward(Ctx(), input_data) - expected_output = torch.ones((1,4)).cuda() * 4 * int(Utils.rank % 4) - assert(torch.equal(output_data[0], expected_output)) + expected_output = torch.ones((1, 4)).cuda() * 4 * int(Utils.rank % 4) + assert torch.equal(output_data[0], expected_output) Utils.destroy_model_parallel() + def test_ReduceScatterToSequenceParallelRegion(): - Utils.initialize_model_parallel(4,2) - input_data = torch.vstack(( - torch.ones(4)*0, - torch.ones(4)*1, - torch.ones(4)*2, - torch.ones(4)*3)).cuda() + Utils.initialize_model_parallel(4, 2) + input_data = torch.vstack( + (torch.ones(4) * 0, torch.ones(4) * 1, torch.ones(4) * 2, torch.ones(4) * 3) + ).cuda() output_data = mappings.reduce_scatter_to_sequence_parallel_region(input_data) expected_output = torch.ones(4).cuda() * 4 * int(Utils.rank % 4) - assert(torch.equal(output_data[0], expected_output)) - assert(torch.equal(mappings._ReduceScatterToSequenceParallelRegion.symbolic(None, input_data) , expected_output.reshape((1,4)))) + assert torch.equal(output_data[0], expected_output) + assert torch.equal( + mappings._ReduceScatterToSequenceParallelRegion.symbolic(None, input_data), + expected_output.reshape((1, 4)), + ) input_data = torch.ones(4).cuda() * Utils.rank - output_data = mappings._ReduceScatterToSequenceParallelRegion.backward(None,input_data) - expected_output = torch.concat(( - torch.ones(4)*0, - torch.ones(4)*1, - torch.ones(4)*2, - torch.ones(4)*3)).cuda() - if (Utils.rank >= 4): + output_data = mappings._ReduceScatterToSequenceParallelRegion.backward(None, input_data) + expected_output = torch.concat( + (torch.ones(4) * 0, torch.ones(4) * 1, torch.ones(4) * 2, torch.ones(4) * 3) + ).cuda() + if Utils.rank >= 4: expected_output = expected_output + 4 - assert(torch.equal(output_data, expected_output)) + assert torch.equal(output_data, expected_output) Utils.destroy_model_parallel() - diff --git a/tests/unit_tests/tensor_parallel/test_random.py b/tests/unit_tests/tensor_parallel/test_random.py index e2f35cf341..ace500839d 100644 --- a/tests/unit_tests/tensor_parallel/test_random.py +++ b/tests/unit_tests/tensor_parallel/test_random.py @@ -1,44 +1,54 @@ -from megatron.core.tensor_parallel.random import CudaRNGStatesTracker -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed,get_cuda_rng_tracker -from megatron.core.tensor_parallel.random import checkpoint -from tests.unit_tests.test_utilities import Utils import pytest import torch +from megatron.core.tensor_parallel.random import ( + CudaRNGStatesTracker, + checkpoint, + get_cuda_rng_tracker, + model_parallel_cuda_manual_seed, +) +from tests.unit_tests.test_utilities import Utils + + def test_cuda_rng_states_tracker(): rng_tracker = CudaRNGStatesTracker() - rng_tracker.set_states({"state1":1234}) - assert(rng_tracker.get_states()["state1"] == 1234) + rng_tracker.set_states({"state1": 1234}) + assert rng_tracker.get_states()["state1"] == 1234 rng_tracker.reset() - assert(rng_tracker.get_states() == {}) + assert rng_tracker.get_states() == {} seed = 1111 - rng_tracker.add("state2",seed) + rng_tracker.add("state2", seed) with pytest.raises(Exception): - assert(rng_tracker.add("state3",seed)) + assert rng_tracker.add("state3", seed) with pytest.raises(Exception): - assert(rng_tracker.add("state2",111)) - assert(rng_tracker.get_states()['state2'] is not None) + assert rng_tracker.add("state2", 111) + assert rng_tracker.get_states()['state2'] is not None with pytest.raises(Exception): - assert() - + assert () + rng_tracker.fork("state2") torch.cuda.manual_seed(seed) rng_state = torch.cuda.get_rng_state() assert torch.equal(rng_tracker.get_states()['state2'], rng_state) + def test_model_parallel_cuda_manual_seed(): - Utils.initialize_model_parallel(4,2) + Utils.initialize_model_parallel(4, 2) model_parallel_cuda_manual_seed(0) rng_tracker = get_cuda_rng_tracker() - assert(rng_tracker.get_states()['model-parallel-rng'] is not None) + assert rng_tracker.get_states()['model-parallel-rng'] is not None Utils.destroy_model_parallel() + def test_checkpoint(): def test_forward(*input): - return input[0]+input[1] - assert(torch.equal(torch.ones(16)*3,checkpoint(test_forward, None, torch.ones(16), torch.ones(16)*2))) + return input[0] + input[1] + + assert torch.equal( + torch.ones(16) * 3, checkpoint(test_forward, None, torch.ones(16), torch.ones(16) * 2) + ) Utils.initialize_model_parallel() - input1 = torch.ones((4,4)) - checkpoint(test_forward, True, input1, torch.ones((4,4))*2) - assert(torch.equal(torch.ones(input1.numel()).cuda(), input1)) + input1 = torch.ones((4, 4)) + checkpoint(test_forward, True, input1, torch.ones((4, 4)) * 2) + assert torch.equal(torch.ones(input1.numel()).cuda(), input1) Utils.destroy_model_parallel() diff --git a/tests/unit_tests/tensor_parallel/test_tensor_parallel_utils.py b/tests/unit_tests/tensor_parallel/test_tensor_parallel_utils.py index f82e5fa693..5df774e5ff 100644 --- a/tests/unit_tests/tensor_parallel/test_tensor_parallel_utils.py +++ b/tests/unit_tests/tensor_parallel/test_tensor_parallel_utils.py @@ -1,43 +1,55 @@ import torch -import megatron.core.tensor_parallel.utils as util + import megatron.core.parallel_state as ps +import megatron.core.tensor_parallel.utils as util from tests.unit_tests.test_utilities import Utils rank = Utils.rank + def test_split_tensor_along_last_dim(): - input_tensor = torch.rand((3,4)) - torch.equal(input_tensor[0:2,0:2], util.split_tensor_along_last_dim(input_tensor,2)[0]) - torch.equal(input_tensor[2:,2:], util.split_tensor_along_last_dim(input_tensor,2)[1]) + input_tensor = torch.rand((3, 4)) + torch.equal(input_tensor[0:2, 0:2], util.split_tensor_along_last_dim(input_tensor, 2)[0]) + torch.equal(input_tensor[2:, 2:], util.split_tensor_along_last_dim(input_tensor, 2)[1]) + def test_split_tensor_into_1d_equal_chunks(): Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4) - input_tensor = torch.rand((3,4)) + input_tensor = torch.rand((3, 4)) output_tensor = util.split_tensor_into_1d_equal_chunks(input_tensor) - if rank % 2 == 0 : + if rank % 2 == 0: start = 0 - end = int(input_tensor.numel()/2) - else : - start = int(input_tensor.numel()/2) + end = int(input_tensor.numel() / 2) + else: + start = int(input_tensor.numel() / 2) end = input_tensor.numel() - + assert torch.equal(output_tensor, input_tensor.flatten()[start:end]) Utils.destroy_model_parallel() + def test_gather_split_1d_tensor(): Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4) - input_tensor = torch.ones((2,4)).cuda() * rank + input_tensor = torch.ones((2, 4)).cuda() * rank actual_output_tensor = util.gather_split_1d_tensor(input_tensor) - if rank %2 == 0: + if rank % 2 == 0: expected_output_tensor = torch.concat((input_tensor.flatten(), input_tensor.flatten() + 1)) - else : + else: expected_output_tensor = torch.concat((input_tensor.flatten() - 1, input_tensor.flatten())) - assert(torch.equal(actual_output_tensor, expected_output_tensor)) + assert torch.equal(actual_output_tensor, expected_output_tensor) Utils.destroy_model_parallel() + def test_vocab(): global_vocab_size = 1600 per_partition_vocab_size = 1600 / Utils.world_size - assert((rank * per_partition_vocab_size, (rank + 1)* per_partition_vocab_size) == (util.VocabUtility.vocab_range_from_per_partition_vocab_size(global_vocab_size // Utils.world_size, rank, Utils.world_size))) - assert((rank * per_partition_vocab_size, (rank + 1)* per_partition_vocab_size) == (util.VocabUtility.vocab_range_from_global_vocab_size(global_vocab_size, rank, Utils.world_size))) - \ No newline at end of file + assert (rank * per_partition_vocab_size, (rank + 1) * per_partition_vocab_size) == ( + util.VocabUtility.vocab_range_from_per_partition_vocab_size( + global_vocab_size // Utils.world_size, rank, Utils.world_size + ) + ) + assert (rank * per_partition_vocab_size, (rank + 1) * per_partition_vocab_size) == ( + util.VocabUtility.vocab_range_from_global_vocab_size( + global_vocab_size, rank, Utils.world_size + ) + ) diff --git a/tests/unit_tests/test_basic.py b/tests/unit_tests/test_basic.py index 915d2c1001..d2a60f92c8 100644 --- a/tests/unit_tests/test_basic.py +++ b/tests/unit_tests/test_basic.py @@ -1,3 +1,2 @@ def test_import(): import megatron - diff --git a/tests/unit_tests/test_imports.py b/tests/unit_tests/test_imports.py index 49e7c77b55..bad67cd8d5 100644 --- a/tests/unit_tests/test_imports.py +++ b/tests/unit_tests/test_imports.py @@ -81,8 +81,7 @@ def _test_domain_module_imports(module, subdomains: list): if error is None: for imp in dir(module): - class_, result, error = _get_class_from_path( - subdomains, imp) + class_, result, error = _get_class_from_path(subdomains, imp) if result is not None: module_list.append(class_) @@ -99,7 +98,8 @@ def _test_domain_module_imports(module, subdomains: list): print() for module in failed_list: print( - "Module did not match a valid signature of Megatron core Model (hence ignored):", module) + "Module did not match a valid signature of Megatron core Model (hence ignored):", module + ) print() if len(error_list) > 0: @@ -125,29 +125,21 @@ def _test_domain_module_imports(module, subdomains: list): def test_domain_mcore(): import megatron.core as mcore - all_passed = _test_domain_module_imports( - mcore, subdomains=['models']) + all_passed = _test_domain_module_imports(mcore, subdomains=['models']) - all_passed = _test_domain_module_imports( - mcore, subdomains=['pipeline_parallel']) + all_passed = _test_domain_module_imports(mcore, subdomains=['pipeline_parallel']) - all_passed = _test_domain_module_imports( - mcore, subdomains=['tensor_parallel']) + all_passed = _test_domain_module_imports(mcore, subdomains=['tensor_parallel']) - all_passed = _test_domain_module_imports( - mcore, subdomains=['transformer']) + all_passed = _test_domain_module_imports(mcore, subdomains=['transformer']) - all_passed = _test_domain_module_imports( - mcore, subdomains=['fusions']) + all_passed = _test_domain_module_imports(mcore, subdomains=['fusions']) - all_passed = _test_domain_module_imports( - mcore, subdomains=['distributed']) + all_passed = _test_domain_module_imports(mcore, subdomains=['distributed']) - all_passed = _test_domain_module_imports( - mcore, subdomains=['datasets']) + all_passed = _test_domain_module_imports(mcore, subdomains=['datasets']) - all_passed = _test_domain_module_imports( - mcore, subdomains=['dist_checkpointing']) + all_passed = _test_domain_module_imports(mcore, subdomains=['dist_checkpointing']) if not all_passed: exit(1) diff --git a/tests/unit_tests/test_local_multi_tensor_fns.py b/tests/unit_tests/test_local_multi_tensor_fns.py index f47d549f98..086de6f6d0 100644 --- a/tests/unit_tests/test_local_multi_tensor_fns.py +++ b/tests/unit_tests/test_local_multi_tensor_fns.py @@ -1,11 +1,14 @@ import copy + +import pytest +import torch + from megatron.core.utils import ( local_multi_tensor_applier, local_multi_tensor_l2_norm, - local_multi_tensor_scale + local_multi_tensor_scale, ) -import pytest -import torch + def test_local_multi_tensor_l2_norm_and_scale(): amp_C = pytest.importorskip("amp_C") @@ -13,24 +16,55 @@ def test_local_multi_tensor_l2_norm_and_scale(): torch.manual_seed(42) - tensor_list = [torch.rand(5,5).cuda() for _ in range(10)] + tensor_list = [torch.rand(5, 5).cuda() for _ in range(10)] tensor_list_copy = copy.deepcopy(tensor_list) - norm_apex, _ = multi_tensor_apply.multi_tensor_applier(amp_C.multi_tensor_l2norm, torch.tensor([0], dtype=torch.int, device='cuda'), [tensor_list], False) - norm_local, _ = multi_tensor_apply.multi_tensor_applier(local_multi_tensor_l2_norm, torch.tensor([0], dtype=torch.int, device='cuda'), [tensor_list_copy], False) + norm_apex, _ = multi_tensor_apply.multi_tensor_applier( + amp_C.multi_tensor_l2norm, + torch.tensor([0], dtype=torch.int, device='cuda'), + [tensor_list], + False, + ) + norm_local, _ = multi_tensor_apply.multi_tensor_applier( + local_multi_tensor_l2_norm, + torch.tensor([0], dtype=torch.int, device='cuda'), + [tensor_list_copy], + False, + ) torch.testing.assert_close(norm_apex, norm_local) clip_coeff = 0.05 - multi_tensor_apply.multi_tensor_applier(amp_C.multi_tensor_scale, torch.tensor([0], dtype=torch.int, device='cuda'), [tensor_list, tensor_list], clip_coeff) - multi_tensor_apply.multi_tensor_applier(local_multi_tensor_scale, torch.tensor([0], dtype=torch.int, device='cuda'), [tensor_list_copy, tensor_list_copy], clip_coeff) + multi_tensor_apply.multi_tensor_applier( + amp_C.multi_tensor_scale, + torch.tensor([0], dtype=torch.int, device='cuda'), + [tensor_list, tensor_list], + clip_coeff, + ) + multi_tensor_apply.multi_tensor_applier( + local_multi_tensor_scale, + torch.tensor([0], dtype=torch.int, device='cuda'), + [tensor_list_copy, tensor_list_copy], + clip_coeff, + ) torch.testing.assert_close(tensor_list, tensor_list_copy) + def test_local_multi_tensor_apply(): amp_C = pytest.importorskip("amp_C") multi_tensor_apply = pytest.importorskip("apex.multi_tensor_apply") - tensor_list = [torch.rand(5,5).cuda() for _ in range(10)] + tensor_list = [torch.rand(5, 5).cuda() for _ in range(10)] - norm_apex, _ = multi_tensor_apply.multi_tensor_applier(amp_C.multi_tensor_l2norm, torch.tensor([0], dtype=torch.int, device='cuda'), [tensor_list], False) - norm_local, _ = local_multi_tensor_applier(amp_C.multi_tensor_l2norm, torch.tensor([0], dtype=torch.int, device='cuda'), [tensor_list], False) + norm_apex, _ = multi_tensor_apply.multi_tensor_applier( + amp_C.multi_tensor_l2norm, + torch.tensor([0], dtype=torch.int, device='cuda'), + [tensor_list], + False, + ) + norm_local, _ = local_multi_tensor_applier( + amp_C.multi_tensor_l2norm, + torch.tensor([0], dtype=torch.int, device='cuda'), + [tensor_list], + False, + ) torch.testing.assert_close(norm_apex, norm_local) diff --git a/tests/unit_tests/test_optimizer.py b/tests/unit_tests/test_optimizer.py index 247da4aeb9..732a68cfa6 100644 --- a/tests/unit_tests/test_optimizer.py +++ b/tests/unit_tests/test_optimizer.py @@ -28,8 +28,8 @@ def forward(self, x): def test_chained_optimizer(): net = Net() - optimizer_1 = Adam(list(net.parameters())[:2], lr=0.01,) - optimizer_2 = SGD(list(net.parameters())[2:], lr=0.1, momentum=0.9,) + optimizer_1 = Adam(list(net.parameters())[:2], lr=0.01) + optimizer_2 = SGD(list(net.parameters())[2:], lr=0.1, momentum=0.9) chained_optimizer = ChainedOptimizer([optimizer_1, optimizer_2]) # Test the chained optimizer's param groups is a reference of the underlying optimizers' param groups diff --git a/tests/unit_tests/test_parallel_state.py b/tests/unit_tests/test_parallel_state.py index af58872ac0..abe3ea3d2e 100644 --- a/tests/unit_tests/test_parallel_state.py +++ b/tests/unit_tests/test_parallel_state.py @@ -1,114 +1,132 @@ +import os + +import pytest import torch + import megatron.core.parallel_state as ps -import pytest from tests.unit_tests.test_utilities import Utils -import os rank = Utils.rank world_size = Utils.world_size test_parallel_order = ['tp-cp-ep-dp-pp', 'tp-cp-pp-ep-dp'] + @pytest.mark.parametrize('order', test_parallel_order) def test_initialize_and_destroy_model_parallel(order): with pytest.raises(AssertionError): - assert(ps.initialize_model_parallel(order=order)) + assert ps.initialize_model_parallel(order=order) Utils.initialize_distributed() with pytest.raises(RuntimeError): - assert(ps.initialize_model_parallel(tensor_model_parallel_size=2*world_size, order=order)) + assert ps.initialize_model_parallel(tensor_model_parallel_size=2 * world_size, order=order) with pytest.raises(RuntimeError): - assert(ps.initialize_model_parallel(pipeline_model_parallel_size=2*world_size, order=order)) + assert ps.initialize_model_parallel( + pipeline_model_parallel_size=2 * world_size, order=order + ) with pytest.raises(RuntimeError): - assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size, order=order)) + assert ps.initialize_model_parallel( + pipeline_model_parallel_size=world_size, + tensor_model_parallel_size=world_size, + order=order, + ) with pytest.raises(RuntimeError): - assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2, order=order)) - Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4, order=order) - - assert(ps.model_parallel_is_initialized()) - assert(ps.get_model_parallel_group() is not None) - assert(ps.get_tensor_model_parallel_group() is not None) - assert(ps.get_pipeline_model_parallel_group() is not None) - assert(ps.get_data_parallel_group() is not None) + assert ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2, order=order) + Utils.initialize_model_parallel( + tensor_model_parallel_size=2, pipeline_model_parallel_size=4, order=order + ) + + assert ps.model_parallel_is_initialized() + assert ps.get_model_parallel_group() is not None + assert ps.get_tensor_model_parallel_group() is not None + assert ps.get_pipeline_model_parallel_group() is not None + assert ps.get_data_parallel_group() is not None Utils.destroy_model_parallel() - assert(ps._MODEL_PARALLEL_GROUP is None) + assert ps._MODEL_PARALLEL_GROUP is None + @pytest.mark.parametrize('order', test_parallel_order) def test_pipeline_parallel_initializations(order): - Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4, order=order) - assert(ps.get_pipeline_model_parallel_first_rank() == rank % 2 ) - assert(ps.get_data_parallel_src_rank() == rank) - assert(ps.get_pipeline_model_parallel_next_rank() == ((rank + 2) % world_size)) - assert(ps.get_pipeline_model_parallel_prev_rank() == ((rank - 2) % world_size)) + Utils.initialize_model_parallel( + tensor_model_parallel_size=2, pipeline_model_parallel_size=4, order=order + ) + assert ps.get_pipeline_model_parallel_first_rank() == rank % 2 + assert ps.get_data_parallel_src_rank() == rank + assert ps.get_pipeline_model_parallel_next_rank() == ((rank + 2) % world_size) + assert ps.get_pipeline_model_parallel_prev_rank() == ((rank - 2) % world_size) Utils.destroy_model_parallel() + @pytest.mark.parametrize('order', test_parallel_order) def test_data_parallel_initializations(order): Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order) - assert(ps.get_data_parallel_src_rank() == rank) - assert(ps.get_data_parallel_world_size() == 1) - assert(ps.get_data_parallel_rank() == 0) + assert ps.get_data_parallel_src_rank() == rank + assert ps.get_data_parallel_world_size() == 1 + assert ps.get_data_parallel_rank() == 0 Utils.destroy_model_parallel() + @pytest.mark.parametrize('order', test_parallel_order) def test_tensor_model_parellel_world_size(order): Utils.initialize_model_parallel(tensor_model_parallel_size=world_size, order=order) - assert(ps.get_tensor_model_parallel_world_size() == world_size) + assert ps.get_tensor_model_parallel_world_size() == world_size ps.set_tensor_model_parallel_world_size(None) - assert(ps.get_tensor_model_parallel_world_size() == world_size) + assert ps.get_tensor_model_parallel_world_size() == world_size Utils.destroy_model_parallel() @pytest.mark.parametrize('order', test_parallel_order) def test_pipeline_model_parallel_world_size(order): Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order) - assert(ps.get_pipeline_model_parallel_world_size() == world_size) + assert ps.get_pipeline_model_parallel_world_size() == world_size ps.set_pipeline_model_parallel_world_size(None) - assert(ps.get_pipeline_model_parallel_world_size() == world_size) + assert ps.get_pipeline_model_parallel_world_size() == world_size Utils.destroy_model_parallel() @pytest.mark.parametrize('order', test_parallel_order) def test_tensor_model_parallel_rank(order): Utils.initialize_model_parallel(tensor_model_parallel_size=world_size, order=order) - assert(ps.get_tensor_model_parallel_rank() == rank) + assert ps.get_tensor_model_parallel_rank() == rank ps.set_tensor_model_parallel_rank(None) - assert(ps.get_tensor_model_parallel_rank() == rank) + assert ps.get_tensor_model_parallel_rank() == rank Utils.destroy_model_parallel() @pytest.mark.parametrize('order', test_parallel_order) def test_pipeline_model_parallel_rank(order): Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order) - assert(ps.get_pipeline_model_parallel_rank() == rank) + assert ps.get_pipeline_model_parallel_rank() == rank ps.set_pipeline_model_parallel_rank(None) - assert(ps.get_pipeline_model_parallel_rank() == rank) + assert ps.get_pipeline_model_parallel_rank() == rank Utils.destroy_model_parallel() + def test_context_parallel_rank(): Utils.initialize_model_parallel(context_parallel_size=world_size) - assert(ps.get_context_parallel_rank() == rank) + assert ps.get_context_parallel_rank() == rank Utils.destroy_model_parallel() + def test_expert_model_parallel_rank(): Utils.initialize_model_parallel(expert_model_parallel_size=world_size) - assert(ps.get_expert_model_parallel_rank() == rank) + assert ps.get_expert_model_parallel_rank() == rank ps.set_expert_model_parallel_rank(None) - assert(ps.get_expert_model_parallel_rank() == rank) + assert ps.get_expert_model_parallel_rank() == rank Utils.destroy_model_parallel() @pytest.mark.parametrize('order', test_parallel_order) def test_is_pipeline_first_stage(order): Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order) - assert(ps.is_pipeline_first_stage(ignore_virtual=True) == (rank == 0)) - assert(ps.is_pipeline_first_stage() == (rank == 0)) + assert ps.is_pipeline_first_stage(ignore_virtual=True) == (rank == 0) + assert ps.is_pipeline_first_stage() == (rank == 0) Utils.destroy_model_parallel() @pytest.mark.parametrize('order', test_parallel_order) def test_is_pipeline_last_stage(order): Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order) - assert(ps.is_pipeline_last_stage(ignore_virtual=True) == (rank == world_size-1)) - assert(ps.is_pipeline_last_stage() == (rank == world_size-1)) + assert ps.is_pipeline_last_stage(ignore_virtual=True) == (rank == world_size - 1) + assert ps.is_pipeline_last_stage() == (rank == world_size - 1) Utils.destroy_model_parallel() @@ -116,14 +134,14 @@ def test_is_pipeline_last_stage(order): def test_virtual_pipeline_model_parallel_rank(order): Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order) ps.set_virtual_pipeline_model_parallel_rank(rank) - assert(ps.get_virtual_pipeline_model_parallel_rank() == rank) + assert ps.get_virtual_pipeline_model_parallel_rank() == rank Utils.destroy_model_parallel() @pytest.mark.parametrize('order', test_parallel_order) def test_get_tensor_model_parallel_src_rank(order): Utils.initialize_model_parallel(tensor_model_parallel_size=world_size, order=order) - assert(ps.get_tensor_model_parallel_src_rank() == ((rank // world_size) * world_size)) + assert ps.get_tensor_model_parallel_src_rank() == ((rank // world_size) * world_size) Utils.destroy_model_parallel() @@ -215,7 +233,7 @@ def test_different_initialize_order_consistency(src_tp_pp, ep_size): @pytest.mark.parametrize( 'src_tp_pp, ep_size', - [((1, 2), 1), ((1, 4), 1), ((2, 2), 1), ((1, 2), 2), ((1, 4), 2), ((2, 2), 2),], + [((1, 2), 1), ((1, 4), 1), ((2, 2), 1), ((1, 2), 2), ((1, 4), 2), ((2, 2), 2)], ) def test_different_initialize_order_unconsistency(src_tp_pp, ep_size): Utils.initialize_model_parallel( @@ -350,7 +368,9 @@ def golden_rank_result_from_past_code( tp_dp_group = [] tp_dp_cp_group = [] - tensor_and_data_group_size_with_cp: int = tensor_model_parallel_size * data_parallel_size * context_parallel_size + tensor_and_data_group_size_with_cp: int = ( + tensor_model_parallel_size * data_parallel_size * context_parallel_size + ) num_tensor_and_data_groups_with_cp: int = world_size // tensor_and_data_group_size_with_cp for i in range(num_tensor_and_data_groups_with_cp): start_rank = i * tensor_and_data_group_size_with_cp @@ -374,16 +394,20 @@ def golden_rank_result_from_past_code( dp_no_ep_group = [] dp_no_ep_group_with_cp = [] - all_ranks = torch.arange(world_size).reshape(( - pipeline_model_parallel_size, - data_parallel_size // expert_model_parallel_size, - expert_model_parallel_size, - context_parallel_size, - tensor_model_parallel_size - )) + all_ranks = torch.arange(world_size).reshape( + ( + pipeline_model_parallel_size, + data_parallel_size // expert_model_parallel_size, + expert_model_parallel_size, + context_parallel_size, + tensor_model_parallel_size, + ) + ) # 'pp edp ep cp tp -> (pp edp cp) (ep tp)' tp_ep_rearrange = torch.transpose(all_ranks, 2, 3) - tp_ep_rearrange = torch.reshape(tp_ep_rearrange, (-1, expert_model_parallel_size * tensor_model_parallel_size)) + tp_ep_rearrange = torch.reshape( + tp_ep_rearrange, (-1, expert_model_parallel_size * tensor_model_parallel_size) + ) tp_ep_rearrange = tp_ep_rearrange.tolist() tp_ep_rearrange.sort() for tensor_and_expert_parallel_ranks in tp_ep_rearrange: @@ -392,7 +416,9 @@ def golden_rank_result_from_past_code( tp_ep_group.append(tensor_and_expert_parallel_ranks) # 'pp edp ep cp tp -> (pp ep cp tp) edp' edp_rearrange = torch.transpose(all_ranks, 1, 4) - edp_rearrange = torch.reshape(edp_rearrange, (-1, data_parallel_size // expert_model_parallel_size)) + edp_rearrange = torch.reshape( + edp_rearrange, (-1, data_parallel_size // expert_model_parallel_size) + ) edp_rearrange = edp_rearrange.tolist() edp_rearrange.sort() for expert_data_parallel_ranks in edp_rearrange: @@ -404,7 +430,7 @@ def golden_rank_result_from_past_code( edp_cp_rearrange = torch.transpose(edp_cp_rearrange, 2, 4) edp_cp_rearrange = torch.reshape( edp_cp_rearrange, - (-1, context_parallel_size * data_parallel_size // expert_model_parallel_size) + (-1, context_parallel_size * data_parallel_size // expert_model_parallel_size), ) edp_cp_rearrange = edp_cp_rearrange.tolist() edp_cp_rearrange.sort() @@ -452,7 +478,7 @@ def golden_rank_result_from_past_code( context_parallel_size=cp, expert_model_parallel_size=ep, ) - rank_generator = ps.RankGenerator(tp=tp, ep=ep, dp=dp, pp=pp, cp=cp, order="tp-cp-ep-dp-pp",) + rank_generator = ps.RankGenerator(tp=tp, ep=ep, dp=dp, pp=pp, cp=cp, order="tp-cp-ep-dp-pp") assert dp_groups == rank_generator.get_ranks( "dp" ), f"{dp_groups} != {rank_generator.get_ranks('dp')}" diff --git a/tests/unit_tests/test_training.py b/tests/unit_tests/test_training.py index 7ac6ff360a..a23496f981 100644 --- a/tests/unit_tests/test_training.py +++ b/tests/unit_tests/test_training.py @@ -1,8 +1,8 @@ from types import SimpleNamespace from megatron.training.global_vars import set_args -from megatron.training.training import build_train_valid_test_data_iterators from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding +from megatron.training.training import build_train_valid_test_data_iterators from tests.unit_tests.test_utilities import Utils @@ -40,7 +40,6 @@ def test_build_train_valid_test_data_iterators(self): assert (train_iter, valid_iter, test_iter) == (1, 2, 3) - def test_closed_formula_vocab_size_with_padding(self): def old_round_impl(after, multiple): while (after % multiple) != 0: @@ -54,12 +53,16 @@ def old_round_impl(after, multiple): for vocab in range(1, 600000, 1000): for mult in [1, 17, 32, 64, 128]: args.make_vocab_size_divisible_by = mult - assert old_round_impl(vocab, mult) == _vocab_size_with_padding(vocab, args, False), (vocab, mult) + assert old_round_impl(vocab, mult) == _vocab_size_with_padding( + vocab, args, False + ), (vocab, mult) for vocab in range(1, 10_000, 500): - for mult in range(1, 1024+1): + for mult in range(1, 1024 + 1): args.make_vocab_size_divisible_by = mult - assert old_round_impl(vocab, mult) == _vocab_size_with_padding(vocab, args, False), (vocab, mult) + assert old_round_impl(vocab, mult) == _vocab_size_with_padding( + vocab, args, False + ), (vocab, mult) def teardown_method(self, method): Utils.destroy_model_parallel() diff --git a/tests/unit_tests/test_utilities.py b/tests/unit_tests/test_utilities.py index 1de1fbe9f9..27e87378ba 100644 --- a/tests/unit_tests/test_utilities.py +++ b/tests/unit_tests/test_utilities.py @@ -47,10 +47,7 @@ def initialize_distributed(): Utils.store = store torch.distributed.init_process_group( - backend='nccl', - world_size=Utils.world_size, - rank=Utils.rank, - store=store, + backend='nccl', world_size=Utils.world_size, rank=Utils.rank, store=store ) torch.distributed.barrier() diff --git a/tests/unit_tests/test_utils.py b/tests/unit_tests/test_utils.py index e0a0c2d07d..b2095e3506 100644 --- a/tests/unit_tests/test_utils.py +++ b/tests/unit_tests/test_utils.py @@ -11,36 +11,42 @@ def test_divide_properly(): - assert util.divide(4,2) == 2 + assert util.divide(4, 2) == 2 + def test_divide_improperly(): with pytest.raises(AssertionError): - util.divide(4,5) + util.divide(4, 5) + def test_global_memory_buffer(): global_memory_buffer = util.GlobalMemoryBuffer() - obtained_tensor = global_memory_buffer.get_tensor((3,2), torch.float32, "test_tensor") - expected_tensor = torch.empty((3,2), dtype=torch.float32, device=torch.cuda.current_device()) + obtained_tensor = global_memory_buffer.get_tensor((3, 2), torch.float32, "test_tensor") + expected_tensor = torch.empty((3, 2), dtype=torch.float32, device=torch.cuda.current_device()) assert obtained_tensor.shape == expected_tensor.shape + def test_make_viewless_tensor(): - inp = torch.rand((3,4)) - assert(torch.equal(inp, util.make_viewless_tensor(inp, True, True))) - assert(torch.equal(inp, util.make_viewless_tensor(inp, True, False))) + inp = torch.rand((3, 4)) + assert torch.equal(inp, util.make_viewless_tensor(inp, True, True)) + assert torch.equal(inp, util.make_viewless_tensor(inp, True, False)) + def test_safely_set_viewless_tensor_data(): - tensor = torch.zeros((3,4)) - new_data_tensor = torch.tensor(np.random.rand(3,4)) + tensor = torch.zeros((3, 4)) + new_data_tensor = torch.tensor(np.random.rand(3, 4)) util.safely_set_viewless_tensor_data(tensor, new_data_tensor) - assert(torch.equal(tensor, new_data_tensor)) + assert torch.equal(tensor, new_data_tensor) + def test_assert_viewless_tensor(): - tensor = torch.rand((3,4)) - assert(torch.equal(util.assert_viewless_tensor(tensor), tensor)) - input_tensor_list=[tensor,tensor,tensor] + tensor = torch.rand((3, 4)) + assert torch.equal(util.assert_viewless_tensor(tensor), tensor) + input_tensor_list = [tensor, tensor, tensor] output_tensor_list = util.assert_viewless_tensor(input_tensor_list) - for inp,out in zip(input_tensor_list, output_tensor_list): - assert(torch.equal(inp,out)) + for inp, out in zip(input_tensor_list, output_tensor_list): + assert torch.equal(inp, out) + # Initialize torch.distributed; do not call init_process_group here, call # Utils.initialize_distributed() instead. @@ -51,12 +57,14 @@ def _init_distributed(world, rank): assert torch.cuda.device_count() == world torch.distributed.barrier() + # Deinitialization and cleanup. # Do not call torch.distributed.destroy_process_group, may be needed by other tests. def _deinit_distributed(): assert torch.distributed.is_initialized() == True torch.distributed.barrier() + def test_check_param_hashes_across_dp_replicas(): world = int(os.getenv('WORLD_SIZE', '1')) rank = int(os.getenv('RANK', '0')) @@ -74,7 +82,7 @@ def test_check_param_hashes_across_dp_replicas(): if rank == 0: model.weight.data.fill_(0.0) param_hashes_match = util.check_param_hashes_across_dp_replicas([model]) - expected_param_hashes_match = (rank == 0) + expected_param_hashes_match = rank == 0 assert param_hashes_match == expected_param_hashes_match # Teardown. @@ -117,7 +125,7 @@ def straggler_detector_timeit(): # GEMM. with stimer: res = torch.matmul(mat1, mat2) - delta, batch_delta, _, _, _, _, = stimer.elapsed() + delta, batch_delta, _, _, _, _ = stimer.elapsed() assert delta > 0.0 assert batch_delta >= s diff --git a/tests/unit_tests/transformer/moe/test_a2a_token_dispatcher.py b/tests/unit_tests/transformer/moe/test_a2a_token_dispatcher.py index 38eb9aa15e..68b12b36f5 100644 --- a/tests/unit_tests/transformer/moe/test_a2a_token_dispatcher.py +++ b/tests/unit_tests/transformer/moe/test_a2a_token_dispatcher.py @@ -7,6 +7,7 @@ from tests.unit_tests.test_utilities import Utils from tests.unit_tests.transformer.moe.test_token_dispatcher import MoEModelTestContainer + class TestAlltoAllDispatcher: def setup_method(self, method): pass @@ -16,12 +17,7 @@ def teardown_method(self, method): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.timeout(120) - @pytest.mark.parametrize("tp_size,ep_size", [ - (1, 8), - (8, 1), - (4, 2), - (1, 1), - ]) + @pytest.mark.parametrize("tp_size,ep_size", [(1, 8), (8, 1), (4, 2), (1, 1)]) def test_forward_backward(self, tp_size, ep_size): container = MoEModelTestContainer( tp_size=tp_size, @@ -36,12 +32,7 @@ def test_forward_backward(self, tp_size, ep_size): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.timeout(120) - @pytest.mark.parametrize("tp_size,ep_size", [ - (1, 8), - (8, 1), - (4, 2), - (1, 1), - ]) + @pytest.mark.parametrize("tp_size,ep_size", [(1, 8), (8, 1), (4, 2), (1, 1)]) def test_capacity_forward_backward(self, tp_size, ep_size): container = MoEModelTestContainer( tp_size=tp_size, @@ -59,14 +50,10 @@ def test_capacity_forward_backward(self, tp_size, ep_size): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.timeout(120) - @pytest.mark.parametrize("tp_size,ep_size", [ - (1, 8), - (8, 1), - (4, 2), - (1, 1) - ]) + @pytest.mark.parametrize("tp_size,ep_size", [(1, 8), (8, 1), (4, 2), (1, 1)]) def test_capacity_padding_forward_backward(self, tp_size, ep_size): import time + time.sleep(5) container = MoEModelTestContainer( tp_size=tp_size, @@ -81,4 +68,3 @@ def test_capacity_padding_forward_backward(self, tp_size, ep_size): moe_pad_expert_input_to_capacity=True, ) container.dispatcher_drop_and_pad_test() - diff --git a/tests/unit_tests/transformer/moe/test_aux_loss.py b/tests/unit_tests/transformer/moe/test_aux_loss.py index 217a0a2711..2e26f01551 100644 --- a/tests/unit_tests/transformer/moe/test_aux_loss.py +++ b/tests/unit_tests/transformer/moe/test_aux_loss.py @@ -2,15 +2,18 @@ import pytest import torch -from megatron.core.transformer.moe.moe_utils import clear_aux_losses_tracker +from megatron.core import parallel_state +from megatron.core.transformer.moe.moe_utils import clear_aux_losses_tracker from tests.unit_tests.test_utilities import Utils from tests.unit_tests.transformer.moe.test_token_dispatcher import MoEModelTestContainer -from megatron.core import parallel_state + class AuxlossTestContainer(MoEModelTestContainer): def partition_input(self, input): - partitioned_input = input.chunk(parallel_state.get_tensor_and_context_parallel_world_size(), dim=1)[parallel_state.get_tensor_and_context_parallel_rank()] + partitioned_input = input.chunk( + parallel_state.get_tensor_and_context_parallel_world_size(), dim=1 + )[parallel_state.get_tensor_and_context_parallel_rank()] output = partitioned_input.clone().detach() output.requires_grad = True return output @@ -27,6 +30,7 @@ def aux_loss_test(self, input, baseline_grad): loss = parallel_state.get_moe_layer_wise_logging_tracker()['load_balancing_loss'] clear_aux_losses_tracker() + class TestAuxLoss: def setup_method(self, method): baseline_container = AuxlossTestContainer( @@ -44,7 +48,7 @@ def setup_method(self, method): self.input = torch.randn((32, 8, moe_layer.config.hidden_size)).cuda() self.input.requires_grad = True probs, indices = moe_layer.router(self.input) - probs.sum().mul_(0).backward() # zero out the main gradients + probs.sum().mul_(0).backward() # zero out the main gradients self.baseline_grad = self.input.grad self.input.grad = None clear_aux_losses_tracker() @@ -53,13 +57,9 @@ def teardown_method(self, method): Utils.destroy_model_parallel() @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.parametrize("tp_size,ep_size,cp_size", [ - (8, 1, 1), - (4, 2, 1), - (1, 1, 8), - (2, 1, 4), - (2, 2, 2), - ]) + @pytest.mark.parametrize( + "tp_size,ep_size,cp_size", [(8, 1, 1), (4, 2, 1), (1, 1, 8), (2, 1, 4), (2, 2, 2)] + ) def test_allgather_dispatcher(self, tp_size, ep_size, cp_size): container = AuxlossTestContainer( tp_size=tp_size, @@ -75,13 +75,9 @@ def test_allgather_dispatcher(self, tp_size, ep_size, cp_size): container.aux_loss_test(self.input, self.baseline_grad) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.parametrize("tp_size,ep_size,cp_size", [ - (8, 1, 1), - (4, 2, 1), - (1, 1, 8), - (2, 1, 4), - (2, 2, 2), - ]) + @pytest.mark.parametrize( + "tp_size,ep_size,cp_size", [(8, 1, 1), (4, 2, 1), (1, 1, 8), (2, 1, 4), (2, 2, 2)] + ) def test_a2a_dispatcher(self, tp_size, ep_size, cp_size): container = AuxlossTestContainer( tp_size=tp_size, diff --git a/tests/unit_tests/transformer/moe/test_grouped_mlp.py b/tests/unit_tests/transformer/moe/test_grouped_mlp.py index b86edde68d..757be59232 100644 --- a/tests/unit_tests/transformer/moe/test_grouped_mlp.py +++ b/tests/unit_tests/transformer/moe/test_grouped_mlp.py @@ -1,20 +1,20 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -import pytest -from pkg_resources import packaging from importlib.metadata import version +import pytest import torch import torch.nn.functional as F +from pkg_resources import packaging -from megatron.training.arguments import parse_args from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.transformer.moe import grouped_gemm_util as gg -from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.moe.experts import TEGroupedMLP +from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.training.initialize import _set_random_seed from megatron.legacy.model import Float16Module +from megatron.training.arguments import parse_args +from megatron.training.initialize import _set_random_seed from tests.unit_tests.test_utilities import Utils DEVICE_CAPABILITY = None @@ -28,23 +28,37 @@ class TestParallelGroupedMLP: def setup_method(self, method, use_cpu_initialization=False, swiglu=True): print("============") - print("Test for use_cpu_initilization={} and swiglu={}.".format(use_cpu_initialization, swiglu)) + print( + "Test for use_cpu_initilization={} and swiglu={}.".format( + use_cpu_initialization, swiglu + ) + ) print("============") - Utils.initialize_model_parallel(1,1) - num_layers = 1 # 2 - self.hidden_size = 16 # must be an multiple of 16, otherwise trigger CUTLASS misaligned issue + Utils.initialize_model_parallel(1, 1) + num_layers = 1 # 2 + self.hidden_size = ( + 16 # must be an multiple of 16, otherwise trigger CUTLASS misaligned issue + ) self.num_experts = 2 self.gated_linear_unit = swiglu self.activation_func = F.silu if swiglu else F.gelu self.use_cpu_initialization = use_cpu_initialization tf_config = TransformerConfig( - num_layers=num_layers, hidden_size=self.hidden_size, num_attention_heads=4, - num_moe_experts=self.num_experts, use_cpu_initialization=self.use_cpu_initialization, - add_bias_linear=False, gated_linear_unit=self.gated_linear_unit, + num_layers=num_layers, + hidden_size=self.hidden_size, + num_attention_heads=4, + num_moe_experts=self.num_experts, + use_cpu_initialization=self.use_cpu_initialization, + add_bias_linear=False, + gated_linear_unit=self.gated_linear_unit, activation_func=self.activation_func, bias_activation_fusion=False, - bf16=True, params_dtype=torch.bfloat16, moe_router_load_balancing_type="sinkhorn", moe_router_topk=1) + bf16=True, + params_dtype=torch.bfloat16, + moe_router_load_balancing_type="sinkhorn", + moe_router_topk=1, + ) self.fc1_ffn_hidden_size = tf_config.ffn_hidden_size self.fc2_ffn_hidden_size = tf_config.ffn_hidden_size @@ -56,15 +70,15 @@ def setup_method(self, method, use_cpu_initialization=False, swiglu=True): # Set random seed for reproducability _set_random_seed(seed_=123, data_parallel_random_init=False) transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - self.num_experts, moe_grouped_gemm=False) - self.sequential_mlp = MoELayer(tf_config, - transformer_layer_spec.submodules.mlp.submodules) + self.num_experts, moe_grouped_gemm=False + ) + self.sequential_mlp = MoELayer(tf_config, transformer_layer_spec.submodules.mlp.submodules) self.args = parse_args(ignore_unknown_args=True) - self.args.bf16=True + self.args.bf16 = True # Bias is not supported in grouped gemm currently, thus we disable the # bias in the linear layer. - self.args.add_bias_linear=False + self.args.add_bias_linear = False self.sequential_mlp = Float16Module(self.sequential_mlp, self.args).module print("done intializing for sequential gemm") @@ -89,9 +103,12 @@ def test_constructor(self): # GroupedGEMM and sequential GEMMs should hold the same number of parms. assert num_weights_smm == num_weights_gmm # expected num weights: router linear weights+bias + MLP weights(no bias) of all experts - expected_num_weights = \ - self.hidden_size * self.num_experts + \ - self.hidden_size * (self.fc1_ffn_hidden_size + self.fc2_ffn_hidden_size) * self.num_experts + expected_num_weights = ( + self.hidden_size * self.num_experts + + self.hidden_size + * (self.fc1_ffn_hidden_size + self.fc2_ffn_hidden_size) + * self.num_experts + ) assert num_weights_smm == expected_num_weights assert torch.equal(self.sequential_mlp.router.weight, self.grouped_mlp.router.weight) @@ -99,12 +116,19 @@ def test_constructor(self): # weight1: [h, num_experts*4h] # weight2: [num_experts*4h, h] assert self.grouped_mlp.experts.weight1.shape[0] == self.hidden_size - assert self.grouped_mlp.experts.weight1.shape[1] == self.num_experts * self.fc1_ffn_hidden_size + assert ( + self.grouped_mlp.experts.weight1.shape[1] == self.num_experts * self.fc1_ffn_hidden_size + ) if self.gated_linear_unit: - assert self.grouped_mlp.experts.weight2.shape[0] == self.num_experts * self.fc2_ffn_hidden_size + assert ( + self.grouped_mlp.experts.weight2.shape[0] + == self.num_experts * self.fc2_ffn_hidden_size + ) assert self.grouped_mlp.experts.weight2.shape[1] == self.hidden_size else: - assert self.grouped_mlp.experts.weight1.shape == self.grouped_mlp.experts.weight2.t().shape + assert ( + self.grouped_mlp.experts.weight1.shape == self.grouped_mlp.experts.weight2.t().shape + ) def test_weight_init_value_the_same(self): gmm_w1 = self.grouped_mlp.experts.weight1.view(self.num_experts, -1, self.hidden_size) @@ -130,17 +154,18 @@ def test_weight_init_value_the_same(self): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, reason='GroupedGEMM kernels are not supported on this device.' + not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, + reason='GroupedGEMM kernels are not supported on this device.', ) def test_gpu_forward(self): self.sequential_mlp.cuda() self.grouped_mlp.cuda() # [sequence length, batch size, hidden size] - seq_len = 3 #32 + seq_len = 3 # 32 batch_size = 2 hidden_states = torch.rand( - (seq_len, batch_size, self.sequential_mlp.config.hidden_size), - dtype=torch.bfloat16) + (seq_len, batch_size, self.sequential_mlp.config.hidden_size), dtype=torch.bfloat16 + ) hidden_states = hidden_states.cuda() output_smm, _ = self.sequential_mlp(hidden_states) output_gmm, _ = self.grouped_mlp(hidden_states) @@ -151,7 +176,8 @@ def test_gpu_forward(self): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, reason='GroupedGEMM kernels are not supported on this device.' + not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, + reason='GroupedGEMM kernels are not supported on this device.', ) def test_gpu_forward_with_no_tokens_allocated(self): """Test the case when no token is allocated for groupedGEMM kernels.""" @@ -168,7 +194,8 @@ def test_gpu_forward_with_no_tokens_allocated(self): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, reason='GroupedGEMM kernels are not supported on this device.' + not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, + reason='GroupedGEMM kernels are not supported on this device.', ) def test_gradient_with_no_tokens_allocated(self): """Test that when no token is passed in, the parameters of the grouped MLP will also have gradients.""" @@ -177,10 +204,7 @@ def test_gradient_with_no_tokens_allocated(self): tokens_per_expert = torch.zeros(self.num_experts) hidden_states = torch.rand((num_allocated_tokens, self.hidden_size), dtype=torch.bfloat16) hidden_states = hidden_states.cuda() - output_gmm, _ = self.grouped_mlp.experts( - hidden_states, - tokens_per_expert=tokens_per_expert, - ) + output_gmm, _ = self.grouped_mlp.experts(hidden_states, tokens_per_expert=tokens_per_expert) output_gmm.mean().backward() assert self.grouped_mlp.experts.weight1.grad is not None @@ -193,7 +217,7 @@ class TestTEGroupedMLP: def setup_method(self, method, use_cpu_initialization=False, swiglu=True): Utils.initialize_model_parallel(1, 1) - num_layers = 1 + num_layers = 1 self.hidden_size = 16 self.num_experts = 2 self.gated_linear_unit = swiglu @@ -348,9 +372,8 @@ def test_gpu_forward_backward_with_no_tokens_allocated(self): for swiglu in [True, False]: GMLP_test = TestParallelGroupedMLP() GMLP_test.setup_method( - method=None, - use_cpu_initialization=use_cpu_unitilization, - swiglu=swiglu) + method=None, use_cpu_initialization=use_cpu_unitilization, swiglu=swiglu + ) GMLP_test.test_constructor() GMLP_test.test_weight_init_value_the_same() GMLP_test.test_gpu_forward() diff --git a/tests/unit_tests/transformer/moe/test_routers.py b/tests/unit_tests/transformer/moe/test_routers.py index fbeb744f1e..ef4c9d4aed 100644 --- a/tests/unit_tests/transformer/moe/test_routers.py +++ b/tests/unit_tests/transformer/moe/test_routers.py @@ -1,15 +1,14 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. import pytest - import torch +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.moe.router import Router +from megatron.core.transformer.transformer_config import TransformerConfig from megatron.training.initialize import _set_random_seed from tests.unit_tests.test_utilities import Utils -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.moe.moe_layer import MoELayer -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec class TestTop2Router: @@ -46,10 +45,7 @@ def test_constructor(self): assert num_weights == 12 * 4, num_weights @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.parametrize("moe_router_pre_softmax", [ - (True), - (False), - ]) + @pytest.mark.parametrize("moe_router_pre_softmax", [(True), (False)]) def test_router_forward(self, moe_router_pre_softmax): with torch.no_grad(): self.router = self.router.cuda() @@ -62,30 +58,33 @@ def test_router_forward(self, moe_router_pre_softmax): assert scores.shape == (64, 2) assert indices.shape == (64, 2) print( - (indices == 0).sum(), (indices == 1).sum(), (indices == 2).sum(), (indices == 3).sum() + (indices == 0).sum(), + (indices == 1).sum(), + (indices == 2).sum(), + (indices == 3).sum(), ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_aux_loss(self): self.sequential_mlp = self.sequential_mlp.cuda() - + # Without aux loss hidden_states = torch.randn((32, 2, self.router.config.hidden_size)) hidden_states = hidden_states.cuda() out = self.sequential_mlp(hidden_states)[0] out.sum().mul_(0).backward() assert self.sequential_mlp.router.weight.grad.abs().sum() == 0 - + # With aux loss self.transformer_config.moe_aux_loss_coeff = 1 out = self.sequential_mlp(hidden_states)[0] out.sum().mul_(0).backward() assert self.sequential_mlp.router.weight.grad.abs().sum() > 0 - + # With Z loss self.transformer_config.moe_aux_loss_coeff = 0 self.transformer_config.moe_z_loss_coeff = 1 self.sequential_mlp.router.weight.grad.fill_(0) out = self.sequential_mlp(hidden_states)[0] out.sum().mul_(0).backward() - assert self.sequential_mlp.router.weight.grad.abs().sum() > 0 \ No newline at end of file + assert self.sequential_mlp.router.weight.grad.abs().sum() > 0 diff --git a/tests/unit_tests/transformer/moe/test_sequential_mlp.py b/tests/unit_tests/transformer/moe/test_sequential_mlp.py index 0ebb85333e..21fcc23ca2 100644 --- a/tests/unit_tests/transformer/moe/test_sequential_mlp.py +++ b/tests/unit_tests/transformer/moe/test_sequential_mlp.py @@ -1,19 +1,19 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. import pytest - import torch -from megatron.core.transformer.moe.moe_layer import MoELayer -from tests.unit_tests.test_utilities import Utils +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from tests.unit_tests.test_utilities import Utils + class TestParallelSequentialMLP: def setup_method(self, method): - Utils.initialize_model_parallel(1,1) + Utils.initialize_model_parallel(1, 1) model_parallel_cuda_manual_seed(123) print("done intializing") num_moe_experts = 2 @@ -27,11 +27,14 @@ def setup_method(self, method): gated_linear_unit=True, bias_activation_fusion=True, moe_router_load_balancing_type="sinkhorn", - moe_router_topk=1 + moe_router_topk=1, ) transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=num_moe_experts, moe_grouped_gemm=False) - self.sequential_mlp = MoELayer(transformer_config, transformer_layer_spec.submodules.mlp.submodules) + num_experts=num_moe_experts, moe_grouped_gemm=False + ) + self.sequential_mlp = MoELayer( + transformer_config, transformer_layer_spec.submodules.mlp.submodules + ) def teardown_method(self, method): Utils.destroy_model_parallel() @@ -42,7 +45,6 @@ def test_constructor(self): num_weights = sum([p.numel() for p in self.sequential_mlp.parameters()]) assert num_weights == 3696 - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_gpu_forward(self): sequential_mlp = self.sequential_mlp @@ -58,4 +60,3 @@ def test_gpu_forward(self): assert output.dtype == torch.float32 assert output.device.type == 'cuda' assert output_bias.device.type == 'cuda' - diff --git a/tests/unit_tests/transformer/moe/test_token_dispatcher.py b/tests/unit_tests/transformer/moe/test_token_dispatcher.py index f5384143ce..f2c6d3c307 100644 --- a/tests/unit_tests/transformer/moe/test_token_dispatcher.py +++ b/tests/unit_tests/transformer/moe/test_token_dispatcher.py @@ -2,8 +2,8 @@ import pytest import torch -from megatron.core import parallel_state +from megatron.core import parallel_state from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.moe.moe_utils import permute, unpermute @@ -34,7 +34,7 @@ def __init__( tensor_model_parallel_size=tp_size, pipeline_model_parallel_size=pp_size, expert_model_parallel_size=ep_size, - context_parallel_size=cp_size + context_parallel_size=cp_size, ) _set_random_seed(seed_=123, data_parallel_random_init=data_parallel_random_init) local_expert_indices_offset = ( @@ -74,7 +74,7 @@ def __init__( self.config, transformer_layer_spec.submodules.mlp.submodules ).cuda() self.moe_layer.set_layer_number(0) - + def __del__(self): torch.distributed.barrier() torch.cuda.synchronize() @@ -96,11 +96,8 @@ def dispatcher_dropless_test(self): # indices = torch.ones_like(indices) * torch.distributed.get_rank() # print(permuted_local_hidden_states) - ( - permuted_local_hidden_states, - tokens_per_expert, - ) = moe_layer.token_dispatcher.token_permutation( - hidden_states, probs, indices + (permuted_local_hidden_states, tokens_per_expert) = ( + moe_layer.token_dispatcher.token_permutation(hidden_states, probs, indices) ) permuted_local_hidden_states /= moe_layer.config.tensor_model_parallel_size @@ -136,11 +133,8 @@ def dispacher_capacity_test(self): ] restored_hidden_states_answer = hidden_states * local_probss.sum(dim=1).unsqueeze(1) - ( - permuted_local_hidden_states, - tokens_per_expert, - ) = moe_layer.token_dispatcher.token_permutation( - hidden_states, probs, indices + (permuted_local_hidden_states, tokens_per_expert) = ( + moe_layer.token_dispatcher.token_permutation(hidden_states, probs, indices) ) print(f"Dispatched tokens per expert: {tokens_per_expert}") @@ -181,7 +175,7 @@ def dispatcher_drop_and_pad_test(self): # num_local_tokens_per_expert = torch.tensor([2, 2, 2, 2, 2, 2, 2, 2]).cuda() probs_1, indices_1 = moe_layer.router(hidden_states) - (permuted_input_1, tokens_per_expert,) = moe_layer.token_dispatcher.token_permutation( + (permuted_input_1, tokens_per_expert) = moe_layer.token_dispatcher.token_permutation( hidden_states, probs_1, indices_1 ) torch.distributed.barrier() @@ -197,7 +191,7 @@ def dispatcher_drop_and_pad_test(self): # End probs_2, indices_2 = moe_layer.router(hidden_states) - (permuted_input_2, tokens_per_expert,) = moe_layer.token_dispatcher.token_permutation( + (permuted_input_2, tokens_per_expert) = moe_layer.token_dispatcher.token_permutation( hidden_states, probs_2, indices_2 ) restored_hidden_states, restored_bias = moe_layer.token_dispatcher.token_unpermutation( @@ -230,9 +224,7 @@ def teardown_method(self, method): Utils.destroy_model_parallel() @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @pytest.mark.parametrize("tp_size,ep_size", [ - (8, 1), - ]) + @pytest.mark.parametrize("tp_size,ep_size", [(8, 1)]) def test_forward_backward(self, tp_size, ep_size): container = MoEModelTestContainer( tp_size=tp_size, @@ -269,13 +261,15 @@ def test_extended_tp_forward_backward(self): assert scores.shape == (256, moe_layer.router.topk), "Scores shape is not correct" assert indices.shape == (256, moe_layer.router.topk), "Indices shape is not correct" scores = torch.ones_like(scores) / 2 - ( - permuted_local_hidden_states, - tokens_per_expert, - ) = moe_layer.token_dispatcher.token_permutation(hidden_states, scores, indices) - permuted_local_hidden_states /= moe_layer.config.tensor_model_parallel_size * moe_layer.config.expert_model_parallel_size + (permuted_local_hidden_states, tokens_per_expert) = ( + moe_layer.token_dispatcher.token_permutation(hidden_states, scores, indices) + ) + permuted_local_hidden_states /= ( + moe_layer.config.tensor_model_parallel_size + * moe_layer.config.expert_model_parallel_size + ) restored_hidden_states, restored_bias = moe_layer.token_dispatcher.token_unpermutation( - permuted_local_hidden_states, bias=torch.zeros_like(permuted_local_hidden_states), + permuted_local_hidden_states, bias=torch.zeros_like(permuted_local_hidden_states) ) assert torch.allclose( diff --git a/tests/unit_tests/transformer/test_attention.py b/tests/unit_tests/transformer/test_attention.py index 4a5680ea05..8c13ff3f8c 100644 --- a/tests/unit_tests/transformer/test_attention.py +++ b/tests/unit_tests/transformer/test_attention.py @@ -1,25 +1,28 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. import pytest - import torch -from megatron.core.transformer.attention import SelfAttention -from tests.unit_tests.test_utilities import Utils +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.attention import SelfAttention from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from tests.unit_tests.test_utilities import Utils + class TestParallelAttention: def setup_method(self, method): - Utils.initialize_model_parallel(1,1) + Utils.initialize_model_parallel(1, 1) model_parallel_cuda_manual_seed(123) - self.transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True) - self.parallel_attention = SelfAttention(self.transformer_config, - get_gpt_layer_with_transformer_engine_spec().submodules.self_attention.submodules, - layer_number=1) - + self.transformer_config = TransformerConfig( + num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True + ) + self.parallel_attention = SelfAttention( + self.transformer_config, + get_gpt_layer_with_transformer_engine_spec().submodules.self_attention.submodules, + layer_number=1, + ) def teardown_method(self, method): Utils.destroy_model_parallel() @@ -44,7 +47,9 @@ def test_gpu_forward(self): self.parallel_attention.cuda() # [sequence length, batch size, hidden size] - hidden_states = torch.ones((sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size)) + hidden_states = torch.ones( + (sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size) + ) hidden_states = hidden_states.cuda() attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() @@ -66,12 +71,18 @@ def test_fused_rope_gpu_forward(self): self.parallel_attention.cuda() # [sequence length, batch size, hidden size] - hidden_states = torch.ones((sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size)) + hidden_states = torch.ones( + (sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size) + ) hidden_states = hidden_states.cuda() attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - rotary_pos_emb = torch.ones(sequence_length, 1, 1, self.parallel_attention.config.kv_channels).cuda() - output, bias = self.parallel_attention(hidden_states, attention_mask, rotary_pos_emb=rotary_pos_emb) + rotary_pos_emb = torch.ones( + sequence_length, 1, 1, self.parallel_attention.config.kv_channels + ).cuda() + output, bias = self.parallel_attention( + hidden_states, attention_mask, rotary_pos_emb=rotary_pos_emb + ) assert config.recompute_granularity is None assert output.shape[0] == sequence_length @@ -80,13 +91,14 @@ def test_fused_rope_gpu_forward(self): assert bias.shape[0] == config.hidden_size self.parallel_attention.config.apply_rope_fusion = False - def test_checkpointed_gpu_forward(self): transformer_config = self.transformer_config - transformer_config.recompute_granularity='selective' - checkpointed_parallel_attention = SelfAttention(transformer_config, - get_gpt_layer_with_transformer_engine_spec().submodules.self_attention.submodules, - layer_number=1) + transformer_config.recompute_granularity = 'selective' + checkpointed_parallel_attention = SelfAttention( + transformer_config, + get_gpt_layer_with_transformer_engine_spec().submodules.self_attention.submodules, + layer_number=1, + ) config = checkpointed_parallel_attention.config sequence_length = 32 diff --git a/tests/unit_tests/transformer/test_attention_packed_seq.py b/tests/unit_tests/transformer/test_attention_packed_seq.py index c8be7dba3d..54c8787579 100644 --- a/tests/unit_tests/transformer/test_attention_packed_seq.py +++ b/tests/unit_tests/transformer/test_attention_packed_seq.py @@ -1,16 +1,15 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. import pytest - import torch +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.attention import SelfAttention from megatron.core.transformer.enums import AttnMaskType -from tests.unit_tests.test_utilities import Utils -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from tests.unit_tests.test_utilities import Utils # Note: this test requires TE >= 0.13 as well as Flash Attention to run # FIXME this unit test doesn't work in the current test container. to be fixed soon @@ -128,4 +127,4 @@ def test_checkpointed_gpu_forward(self): assert output.shape[1] == micro_batch_size assert output.shape[2] == config.hidden_size assert bias.shape[0] == config.hidden_size -""" \ No newline at end of file +""" diff --git a/tests/unit_tests/transformer/test_core_attention.py b/tests/unit_tests/transformer/test_core_attention.py index 2966b98f89..d8710e2242 100644 --- a/tests/unit_tests/transformer/test_core_attention.py +++ b/tests/unit_tests/transformer/test_core_attention.py @@ -2,10 +2,10 @@ import pytest - import torch from megatron.core.transformer.attention import CrossAttention + """ @pytest.fixture @@ -61,4 +61,4 @@ def test_gpu_forward(self, core_attention): assert context_layer.device.type == 'cuda' assert context_layer.dtype == torch.float32 -""" \ No newline at end of file +""" diff --git a/tests/unit_tests/transformer/test_mlp.py b/tests/unit_tests/transformer/test_mlp.py index 8e3f14688c..d2c25e0cc5 100644 --- a/tests/unit_tests/transformer/test_mlp.py +++ b/tests/unit_tests/transformer/test_mlp.py @@ -1,23 +1,24 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. import pytest - import torch -from megatron.core.transformer.mlp import MLP -from tests.unit_tests.test_utilities import Utils +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.mlp import MLP from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec +from tests.unit_tests.test_utilities import Utils + class TestParallelMLP: def setup_method(self, method): - Utils.initialize_model_parallel(1,1) + Utils.initialize_model_parallel(1, 1) model_parallel_cuda_manual_seed(123) - transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True) - self.mlp = MLP(transformer_config, - get_gpt_layer_local_spec().submodules.mlp.submodules) + transformer_config = TransformerConfig( + num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True + ) + self.mlp = MLP(transformer_config, get_gpt_layer_local_spec().submodules.mlp.submodules) def teardown_method(self, method): Utils.destroy_model_parallel() @@ -55,4 +56,3 @@ def test_gpu_forward(self): assert output.dtype == torch.float32 assert output.device.type == 'cuda' assert output_bias.device.type == 'cuda' - diff --git a/tests/unit_tests/transformer/test_module.py b/tests/unit_tests/transformer/test_module.py index b530709915..64826a0ee5 100644 --- a/tests/unit_tests/transformer/test_module.py +++ b/tests/unit_tests/transformer/test_module.py @@ -1,13 +1,12 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. import pytest - import torch +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.module import Float16Module, MegatronModule from megatron.core.transformer.transformer_config import TransformerConfig from tests.unit_tests.test_utilities import Utils -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed DEVICE_CAPABILITY = None if torch.cuda.is_available(): @@ -24,16 +23,19 @@ def __init__(self, config: TransformerConfig): def forward(self, x): return self.linear(x) + class TestMegatronModule: def setup_method(self, method): - Utils.initialize_model_parallel(1,1) + Utils.initialize_model_parallel(1, 1) model_parallel_cuda_manual_seed(123) - transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True) + transformer_config = TransformerConfig( + num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True + ) self.megatron_module = DummyModule(config=transformer_config).cuda() def teardown_method(self, method): - Utils.destroy_model_parallel() + Utils.destroy_model_parallel() def test_megatron_module(self): megatron_module = self.megatron_module @@ -54,14 +56,16 @@ def test_megatron_module(self): class TestFloat16Module: def setup_method(self, method): - Utils.initialize_model_parallel(1,1) + Utils.initialize_model_parallel(1, 1) model_parallel_cuda_manual_seed(123) - self.transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True) + self.transformer_config = TransformerConfig( + num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True + ) self.megatron_module = DummyModule(config=self.transformer_config).cuda() def teardown_method(self, method): - Utils.destroy_model_parallel() - + Utils.destroy_model_parallel() + def test_fp16_module(self): transformer_config = self.transformer_config megatron_module = self.megatron_module @@ -78,7 +82,8 @@ def test_fp16_module(self): assert fp16_module(x).dtype == torch.float32 pytest.mark.skipif( - not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, reason='bfloat16 is not supported on this device' + not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, + reason='bfloat16 is not supported on this device', ) def test_bf16_module(self): @@ -95,4 +100,3 @@ def test_bf16_module(self): x = torch.ones((2, 2)).cuda() # inputs are converted to bf16 then outputs are converted to fp32 assert bf16_module(x).dtype == torch.float32 - diff --git a/tests/unit_tests/transformer/test_retro_attention.py b/tests/unit_tests/transformer/test_retro_attention.py index 11ec7d5faa..d7c5a5f155 100644 --- a/tests/unit_tests/transformer/test_retro_attention.py +++ b/tests/unit_tests/transformer/test_retro_attention.py @@ -1,16 +1,17 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -import torch import types +import torch + from megatron.core.models.retro import RetroConfig, get_retro_decoder_block_spec from megatron.core.models.retro.decoder_attention import ( - RetroDecoderCrossAttention, RetroDecoderBiasDropoutAdd, + RetroDecoderCrossAttention, ) from megatron.core.models.retro.encoder_attention import ( - RetroEncoderCrossAttention, RetroEncoderBiasDropoutAdd, + RetroEncoderCrossAttention, RetroEncoderLayerNorm, ) from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed @@ -38,33 +39,42 @@ def get_modules(cls, config, use_transformer_engine, use_gpu): # Retro decoder layer. decoder_block_spec = get_retro_decoder_block_spec( - config, use_transformer_engine=use_transformer_engine) + config, use_transformer_engine=use_transformer_engine + ) decoder_block = TransformerBlock(config=config, spec=decoder_block_spec) - decoder_layers = [ layer for layer in decoder_block.layers if isinstance(layer.cross_attention, RetroDecoderCrossAttention) ] + decoder_layers = [ + layer + for layer in decoder_block.layers + if isinstance(layer.cross_attention, RetroDecoderCrossAttention) + ] decoder_layer = decoder_layers[0] # Retro encoder layer. encoder_block = decoder_layer.cross_attention.encoder - encoder_layers = [ layer for layer in encoder_block.layers if isinstance(layer.cross_attention, RetroEncoderCrossAttention) ] + encoder_layers = [ + layer + for layer in encoder_block.layers + if isinstance(layer.cross_attention, RetroEncoderCrossAttention) + ] encoder_layer = encoder_layers[0] # Modules. modules = types.SimpleNamespace( - decoder_attn = decoder_layer.cross_attention, - decoder_bda = decoder_layer.cross_attn_bda, - encoder_attn = encoder_layer.cross_attention, - encoder_bda = encoder_layer.cross_attn_bda, - encoder_norm = encoder_layer.pre_mlp_layernorm, + decoder_attn=decoder_layer.cross_attention, + decoder_bda=decoder_layer.cross_attn_bda, + encoder_attn=encoder_layer.cross_attention, + encoder_bda=encoder_layer.cross_attn_bda, + encoder_norm=encoder_layer.pre_mlp_layernorm, ) # GPU. if use_gpu: - [ m.cuda() for m in vars(modules).values() ] + [m.cuda() for m in vars(modules).values()] return modules def setup_method(self, method): - Utils.initialize_model_parallel(1,1) + Utils.initialize_model_parallel(1, 1) model_parallel_cuda_manual_seed(123) def teardown_method(self, method): @@ -73,11 +83,7 @@ def teardown_method(self, method): def test_constructor(self): config = self.get_config() - modules = self.get_modules( - config, - use_transformer_engine=True, - use_gpu=False, - ) + modules = self.get_modules(config, use_transformer_engine=True, use_gpu=False) assert isinstance(modules.decoder_attn, RetroDecoderCrossAttention) assert isinstance(modules.decoder_bda, RetroDecoderBiasDropoutAdd) @@ -88,7 +94,7 @@ def test_constructor(self): assert modules.decoder_attn.attn.layer_number == 6 assert modules.encoder_attn.attn.layer_number == 1 - get_nparams = lambda m : sum(p.numel() for p in m.parameters()) + get_nparams = lambda m: sum(p.numel() for p in m.parameters()) assert get_nparams(modules.decoder_attn) == 8768 assert get_nparams(modules.decoder_bda) == 0 assert get_nparams(modules.encoder_attn) == 1088 @@ -110,52 +116,38 @@ def run_gpu_forward(self, recompute_granularity, use_transformer_engine): n_chunks_per_sample = seq_length // config.retro_chunk_length # Init tensors. - hidden_states = torch.ones(( - seq_length, - micro_batch_size, - config.hidden_size, - )).cuda() + hidden_states = torch.ones((seq_length, micro_batch_size, config.hidden_size)).cuda() attention_mask = None - decoder_context = torch.ones(( - config.retro_retrieved_length, - config.retro_num_neighbors * micro_batch_size * n_chunks_per_sample, - config.hidden_size, - )).cuda() - encoder_context = torch.ones(( - config.retro_chunk_length, - micro_batch_size * n_chunks_per_sample, - config.hidden_size, - )).cuda() + decoder_context = torch.ones( + ( + config.retro_retrieved_length, + config.retro_num_neighbors * micro_batch_size * n_chunks_per_sample, + config.hidden_size, + ) + ).cuda() + encoder_context = torch.ones( + (config.retro_chunk_length, micro_batch_size * n_chunks_per_sample, config.hidden_size) + ).cuda() # Forward decoder. - decoder_attn_output = modules.decoder_attn( - hidden_states, - attention_mask, - decoder_context, - ) + decoder_attn_output = modules.decoder_attn(hidden_states, attention_mask, decoder_context) with torch.enable_grad(): decoder_bda_output = modules.decoder_bda(True, True)( - decoder_attn_output, - hidden_states, - config.hidden_dropout, + decoder_attn_output, hidden_states, config.hidden_dropout ) # Forward encoder. - encoder_attn_output_tuples = modules.encoder_attn( - decoder_context, - None, - encoder_context, - ) + encoder_attn_output_tuples = modules.encoder_attn(decoder_context, None, encoder_context) with torch.enable_grad(): encoder_bda_output = modules.encoder_bda(True, True)( - encoder_attn_output_tuples, - decoder_context, - config.retro_encoder_hidden_dropout, + encoder_attn_output_tuples, decoder_context, config.retro_encoder_hidden_dropout ) encoder_norm_output = modules.encoder_norm(encoder_bda_output) # Verify decoder. - assert set(decoder_attn_output.keys()) == set([ "ns", "bs", "d", "l", "pad", "attention_output", "attention_bias", "context"]) + assert set(decoder_attn_output.keys()) == set( + ["ns", "bs", "d", "l", "pad", "attention_output", "attention_bias", "context"] + ) assert decoder_attn_output["ns"] == seq_length assert decoder_attn_output["bs"] == micro_batch_size assert decoder_attn_output["d"] == config.hidden_size @@ -166,9 +158,7 @@ def run_gpu_forward(self, recompute_granularity, use_transformer_engine): micro_batch_size * n_chunks_per_sample, config.hidden_size, ) - assert tuple(decoder_attn_output["attention_bias"].shape) == ( - config.hidden_size, - ) + assert tuple(decoder_attn_output["attention_bias"].shape) == (config.hidden_size,) assert decoder_attn_output["context"].shape == ( config.retro_retrieved_length * config.retro_num_neighbors, micro_batch_size * n_chunks_per_sample, diff --git a/tests/unit_tests/transformer/test_spec_customization.py b/tests/unit_tests/transformer/test_spec_customization.py index f0ee9e79af..e6b1fc04b7 100755 --- a/tests/unit_tests/transformer/test_spec_customization.py +++ b/tests/unit_tests/transformer/test_spec_customization.py @@ -55,7 +55,7 @@ def setup_method(self, method): # specify layernorm spec with module path to test dynamic importing self.layernorm_spec = ModuleSpec( - module=("megatron.core.transformer.custom_layers.transformer_engine", "TENorm"), + module=("megatron.core.transformer.custom_layers.transformer_engine", "TENorm") ) # specify bias dropout add with module path @@ -97,7 +97,7 @@ def test_build_module(self): assert x == random_input # Check SelfAttention - self_attention = build_module(self.attention_spec, config=self.config, layer_number=1,) + self_attention = build_module(self.attention_spec, config=self.config, layer_number=1) assert isinstance(self_attention, SelfAttention) assert self_attention.layer_number == 1 assert self_attention.attn_mask_type == self.attention_spec.params['attn_mask_type'] diff --git a/tests/unit_tests/transformer/test_transformer_block.py b/tests/unit_tests/transformer/test_transformer_block.py index 6a2227b52c..02702a9ff7 100644 --- a/tests/unit_tests/transformer/test_transformer_block.py +++ b/tests/unit_tests/transformer/test_transformer_block.py @@ -1,26 +1,31 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. import os -import pytest +import pytest import torch + from megatron.core import dist_checkpointing +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import TransformerLayer -from megatron.core.transformer.transformer_block import TransformerBlock from tests.unit_tests.test_utilities import Utils -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec + class TestParallelTransformerBlock: def setup_method(self, method): - Utils.initialize_model_parallel(1,1) + Utils.initialize_model_parallel(1, 1) model_parallel_cuda_manual_seed(123) - self.transformer_config = TransformerConfig(num_layers=2, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True) - self.parallel_transformer_block = TransformerBlock(self.transformer_config, - get_gpt_layer_with_transformer_engine_spec()) + self.transformer_config = TransformerConfig( + num_layers=2, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True + ) + self.parallel_transformer_block = TransformerBlock( + self.transformer_config, get_gpt_layer_with_transformer_engine_spec() + ) def teardown_method(self, method): Utils.destroy_model_parallel() @@ -51,7 +56,9 @@ def test_gpu_forward(self): attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - hidden_states = parallel_transformer_block(hidden_states=hidden_states, attention_mask=attention_mask) + hidden_states = parallel_transformer_block( + hidden_states=hidden_states, attention_mask=attention_mask + ) assert hidden_states.shape[0] == sequence_length assert hidden_states.shape[1] == micro_batch_size assert hidden_states.shape[2] == config.hidden_size @@ -75,8 +82,9 @@ def _run_full_checkpoint_test(self, fp8): config.recompute_method = 'block' config.fp8 = fp8 config.recompute_num_layers = config.num_layers - full_transformer_block = TransformerBlock(config, - get_gpt_layer_with_transformer_engine_spec()) + full_transformer_block = TransformerBlock( + config, get_gpt_layer_with_transformer_engine_spec() + ) assert full_transformer_block.config.recompute_granularity == 'full' assert full_transformer_block.config.recompute_method == 'block' assert full_transformer_block.config.fp8 == fp8 @@ -91,7 +99,9 @@ def _run_full_checkpoint_test(self, fp8): attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - hidden_states = full_transformer_block(hidden_states=hidden_states, attention_mask=attention_mask) + hidden_states = full_transformer_block( + hidden_states=hidden_states, attention_mask=attention_mask + ) assert hidden_states.shape[0] == sequence_length assert hidden_states.shape[1] == micro_batch_size assert hidden_states.shape[2] == config.hidden_size @@ -101,8 +111,9 @@ def _run_selective_checkpoint_test(self, fp8): config = transformer_config config.recompute_granularity = 'selective' config.fp8 = fp8 - selective_transformer_block = TransformerBlock(config, - get_gpt_layer_with_transformer_engine_spec()) + selective_transformer_block = TransformerBlock( + config, get_gpt_layer_with_transformer_engine_spec() + ) assert selective_transformer_block.config.recompute_granularity == 'selective' assert selective_transformer_block.checkpoint_core_attention assert selective_transformer_block.config.fp8 == fp8 @@ -117,7 +128,9 @@ def _run_selective_checkpoint_test(self, fp8): attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - hidden_states = selective_transformer_block(hidden_states=hidden_states, attention_mask=attention_mask) + hidden_states = selective_transformer_block( + hidden_states=hidden_states, attention_mask=attention_mask + ) assert hidden_states.shape[0] == sequence_length assert hidden_states.shape[1] == micro_batch_size assert hidden_states.shape[2] == config.hidden_size diff --git a/tests/unit_tests/transformer/test_transformer_layer.py b/tests/unit_tests/transformer/test_transformer_layer.py index 31792dbe5c..ad8d3ea0f2 100644 --- a/tests/unit_tests/transformer/test_transformer_layer.py +++ b/tests/unit_tests/transformer/test_transformer_layer.py @@ -2,26 +2,28 @@ import pytest - import torch from megatron.core import parallel_state from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedTensor -from megatron.core.transformer.transformer_layer import TransformerLayer +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.transformer.transformer_layer import TransformerLayer from tests.unit_tests.test_utilities import Utils class TestParallelTransformerLayer: def setup_method(self, method): - Utils.initialize_model_parallel(1,1) + Utils.initialize_model_parallel(1, 1) model_parallel_cuda_manual_seed(123) - transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True) - self.parallel_transformer_layer = TransformerLayer(transformer_config, - get_gpt_layer_with_transformer_engine_spec().submodules) + transformer_config = TransformerConfig( + num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True + ) + self.parallel_transformer_layer = TransformerLayer( + transformer_config, get_gpt_layer_with_transformer_engine_spec().submodules + ) def teardown_method(self, method): Utils.destroy_model_parallel() @@ -47,7 +49,9 @@ def test_gpu_forward(self): attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() - hidden_states, context = parallel_transformer_layer(hidden_states=hidden_states, attention_mask=attention_mask) + hidden_states, context = parallel_transformer_layer( + hidden_states=hidden_states, attention_mask=attention_mask + ) assert hidden_states.shape[0] == sequence_length assert hidden_states.shape[1] == micro_batch_size assert hidden_states.shape[2] == config.hidden_size @@ -59,14 +63,19 @@ def test_sharded_state_dict(self, tp_pp, order): Utils.initialize_model_parallel(*tp_pp, order=order) model_parallel_cuda_manual_seed(123) - transformer_config = TransformerConfig(num_layers=2, hidden_size=128, num_attention_heads=8, use_cpu_initialization=True) - parallel_transformer_layer = TransformerLayer(transformer_config, - get_gpt_layer_with_transformer_engine_spec().submodules) + transformer_config = TransformerConfig( + num_layers=2, hidden_size=128, num_attention_heads=8, use_cpu_initialization=True + ) + parallel_transformer_layer = TransformerLayer( + transformer_config, get_gpt_layer_with_transformer_engine_spec().submodules + ) sharded_state_dict = parallel_transformer_layer.sharded_state_dict() extra_states = {k: v for k, v in sharded_state_dict.items() if k.endswith('extra_state')} - sharded_tensors = {k: v for k, v in sharded_state_dict.items() if not k.endswith('extra_state')} + sharded_tensors = { + k: v for k, v in sharded_state_dict.items() if not k.endswith('extra_state') + } assert all(isinstance(t, ShardedObject) for t in extra_states.values()) assert all(isinstance(t, ShardedTensor) for t in sharded_tensors.values()) diff --git a/tools/autoformat.sh b/tools/autoformat.sh index 784a7846e2..bb5473bcfa 100755 --- a/tools/autoformat.sh +++ b/tools/autoformat.sh @@ -3,7 +3,7 @@ set -euox pipefail SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) CHECK_ONLY=${CHECK_ONLY:-false} -CHANGED_FILES=$(git diff --name-only --diff-filter=d --merge-base origin/main megatron/core | grep '\.py$' || true) +CHANGED_FILES=$(git diff --name-only --diff-filter=d --merge-base origin/main megatron/core tests/ | grep '\.py$' || true) ADDITIONAL_ARGS="" ADDITIONAL_BLACK_ARGS="" @@ -12,9 +12,8 @@ if [[ $CHECK_ONLY == true ]]; then ADDITIONAL_BLACK_ARGS="--diff" fi -# for now we just format core if [[ -n "$CHANGED_FILES" ]]; then - black $ADDITIONAL_ARGS $ADDITIONAL_BLACK_ARGS --verbose $CHANGED_FILES + black --skip-magic-trailing-comma $ADDITIONAL_ARGS $ADDITIONAL_BLACK_ARGS --verbose $CHANGED_FILES isort $ADDITIONAL_ARGS $CHANGED_FILES else echo Changeset is empty, all good.