From 703cc88a87a3ccfa84e14ae575932cbfa3fa9b7c Mon Sep 17 00:00:00 2001
From: Oliver Koenig <okoenig@nvidia.com>
Date: Thu, 8 Aug 2024 13:25:44 -0700
Subject: [PATCH] ADLR/megatron-lm!1788 - chore: Reformat all documents

---
 .flake8                                       |   4 +
 .pylintrc                                     |   7 +
 Dockerfile.linting                            |   4 +-
 megatron/core/datasets/bert_dataset.py        |  19 +-
 megatron/core/datasets/blended_dataset.py     |   9 +-
 .../blended_megatron_dataset_builder.py       |  21 +-
 megatron/core/datasets/gpt_dataset.py         |   8 +-
 megatron/core/datasets/indexed_dataset.py     |  11 +-
 megatron/core/datasets/masked_dataset.py      |  10 +-
 megatron/core/datasets/retro/db/build.py      |  48 +--
 megatron/core/datasets/retro/db/dataset.py    |   7 +-
 megatron/core/datasets/retro/db/utils.py      |   6 +-
 megatron/core/datasets/retro/external_libs.py |   6 +-
 megatron/core/datasets/retro/index/build.py   |   2 +-
 megatron/core/datasets/retro/index/factory.py |   2 +-
 megatron/core/datasets/retro/index/index.py   |   5 +-
 .../retro/index/indexes/faiss_base.py         |   2 +-
 .../retro/index/indexes/faiss_par_add.py      |   8 +-
 megatron/core/datasets/retro/index/utils.py   |   2 +-
 .../core/datasets/retro/index/validate.py     |   4 +-
 .../datasets/retro/query/gpt_chunk_dataset.py |  25 +-
 megatron/core/datasets/retro/query/query.py   |  25 +-
 .../datasets/retro/query/retro_dataset.py     |  12 +-
 megatron/core/datasets/retro/query/utils.py   |   2 +-
 megatron/core/datasets/retro/utils.py         |   8 +-
 megatron/core/datasets/t5_dataset.py          |  15 +-
 megatron/core/datasets/utils.py               |   5 +-
 megatron/core/dist_checkpointing/core.py      |  10 +-
 .../core/dist_checkpointing/serialization.py  |   3 +-
 .../strategies/async_utils.py                 |  11 +-
 .../strategies/filesystem_async.py            |   6 +-
 .../strategies/fully_parallel.py              |  30 +-
 .../strategies/state_dict_saver.py            |   4 +-
 .../strategies/tensorstore.py                 |   5 +-
 .../dist_checkpointing/strategies/torch.py    |  28 +-
 .../strategies/two_stage.py                   |   7 +-
 megatron/core/dist_checkpointing/utils.py     |   8 +-
 .../core/dist_checkpointing/validation.py     |   8 +-
 .../distributed/distributed_data_parallel.py  |   4 +-
 .../core/distributed/finalize_model_grads.py  |   6 +-
 .../core/distributed/param_and_grad_buffer.py |   6 +-
 megatron/core/fusions/fused_bias_dropout.py   |   4 +-
 megatron/core/fusions/fused_cross_entropy.py  |  32 +-
 .../modelopt_support/gpt/model_specs.py       |   3 +-
 .../modelopt_support/gpt/state_dict_hooks.py  |  16 +-
 megatron/core/inference/scheduler.py          |   7 +-
 .../simple_text_generation_controller.py      |  23 +-
 megatron/core/models/T5/t5_model.py           |  10 +-
 megatron/core/models/T5/t5_spec.py            |  12 +-
 megatron/core/models/bert/bert_layer_specs.py |   8 +-
 megatron/core/models/bert/bert_lm_head.py     |  10 +-
 megatron/core/models/bert/bert_model.py       |   5 +-
 .../common/embeddings/rotary_pos_embedding.py |   5 +-
 .../core/models/mamba/mamba_layer_specs.py    |   6 +-
 megatron/core/models/retro/base_attention.py  |   1 -
 megatron/core/models/retro/config.py          |   2 +-
 .../core/models/retro/decoder_attention.py    |  14 +-
 megatron/core/models/retro/decoder_spec.py    |   8 +-
 .../core/models/retro/encoder_attention.py    |  15 +-
 megatron/core/models/retro/encoder_spec.py    |  34 +-
 megatron/core/models/retro/model.py           |   3 +-
 .../models/vision/multimodal_projector.py     |   4 +-
 .../core/models/vision/vit_layer_specs.py     |   4 +-
 megatron/core/optimizer/__init__.py           |  13 +-
 megatron/core/optimizer/distrib_optimizer.py  |  69 +---
 megatron/core/optimizer/optimizer.py          |  27 +-
 megatron/core/parallel_state.py               |  13 +-
 .../pipeline_parallel/p2p_communication.py    |  52 +--
 megatron/core/pipeline_parallel/schedules.py  |  70 ++--
 megatron/core/ssm/mamba_block.py              |   7 +-
 .../core/tensor_parallel/cross_entropy.py     |  24 +-
 megatron/core/tensor_parallel/data.py         |   7 +-
 megatron/core/tensor_parallel/layers.py       |  26 +-
 megatron/core/tensor_parallel/mappings.py     |   6 +-
 megatron/core/tensor_parallel/utils.py        |  50 +--
 megatron/core/timers.py                       |  10 +-
 megatron/core/transformer/attention.py        |  33 +-
 .../custom_layers/transformer_engine.py       |  26 +-
 .../core/transformer/dot_product_attention.py |  16 +-
 megatron/core/transformer/moe/experts.py      |   4 +-
 megatron/core/transformer/moe/moe_utils.py    |  12 +-
 megatron/core/transformer/moe/router.py       |  15 +-
 .../core/transformer/moe/token_dispatcher.py  |  47 +--
 .../core/transformer/transformer_block.py     |  54 +--
 .../core/transformer/transformer_layer.py     |   8 +-
 megatron/core/transformer/utils.py            |  10 +-
 megatron/core/utils.py                        |  15 +-
 pyproject.toml                                |   2 +-
 .../python_test_utils/common.py               |   5 +-
 .../get_test_results_from_tensorboard_logs.py |   7 +-
 .../test_resume_checkpoint_pipeline.py        |   4 +-
 tests/unit_tests/__init__.py                  |   3 +-
 tests/unit_tests/conftest.py                  |   5 +-
 tests/unit_tests/data/test_builder.py         |  24 +-
 tests/unit_tests/data/test_gpt_dataset.py     |   2 +-
 .../data/test_multimodal_dataset.py           |   2 +-
 tests/unit_tests/data/test_preprocess_data.py |  10 +-
 .../unit_tests/data/test_preprocess_mmdata.py |   4 +-
 .../unit_tests/dist_checkpointing/__init__.py |  18 +-
 .../unit_tests/dist_checkpointing/conftest.py |   1 -
 .../dist_checkpointing/models/common.py       | 136 +++++--
 .../models/test_bert_model.py                 | 125 +++++--
 .../models/test_gpt_model.py                  | 105 ++++--
 .../models/test_grouped_mlp.py                | 161 ++++----
 .../dist_checkpointing/models/test_mlp_glu.py |  49 ++-
 .../models/test_retro_model.py                |  30 +-
 .../models/test_sequential_mlp.py             |  13 +-
 .../models/test_t5_model.py                   |  39 +-
 .../dist_checkpointing/test_async_save.py     |   5 +-
 .../test_cached_metadata.py                   |   5 +-
 .../test_flattened_resharding.py              |  99 ++---
 .../dist_checkpointing/test_fully_parallel.py | 173 ++++++---
 .../dist_checkpointing/test_mapping.py        |  56 +--
 .../dist_checkpointing/test_nonpersistent.py  |  30 +-
 .../dist_checkpointing/test_optimizer.py      | 270 +++++++++-----
 .../dist_checkpointing/test_serialization.py  | 348 ++++++++++++------
 tests/unit_tests/dist_checkpointing/utils.py  |  29 +-
 .../distributed/test_param_and_grad_buffer.py |   3 +-
 .../unit_tests/fusions/test_torch_softmax.py  |   6 +-
 .../inference/engines/test_mcore_engine.py    |  80 ++--
 .../gpt/test_gpt_inference_wrapper.py         | 105 ++++--
 .../test_model_inference_wrapper_config.py    |  12 +-
 .../inference/test_common_inference_params.py |   5 +-
 .../inference/test_inference_utils.py         |   1 +
 .../inference/test_modelopt_gpt_model.py      |   6 +-
 tests/unit_tests/inference/test_scheduler.py  |  66 +++-
 .../test_simple_text_generation_controller.py | 156 +++++---
 .../unit_tests/models/test_base_embedding.py  |  27 +-
 tests/unit_tests/models/test_bert_model.py    | 137 ++++---
 .../unit_tests/models/test_clip_vit_model.py  |   5 +-
 tests/unit_tests/models/test_llava_model.py   |   4 +-
 tests/unit_tests/models/test_mamba_model.py   |   4 +-
 .../models/test_multimodal_projector.py       |  37 +-
 tests/unit_tests/models/test_t5_model.py      | 109 ++++--
 .../pipeline_parallel/test_schedules.py       | 209 +++++++----
 .../tensor_parallel/test_cross_entropy.py     |  38 +-
 tests/unit_tests/tensor_parallel/test_data.py |  32 +-
 .../tensor_parallel/test_initialization.py    |  95 +++--
 .../tensor_parallel/test_mappings.py          | 168 ++++-----
 .../unit_tests/tensor_parallel/test_random.py |  50 ++-
 .../test_tensor_parallel_utils.py             |  46 ++-
 tests/unit_tests/test_basic.py                |   1 -
 tests/unit_tests/test_imports.py              |  30 +-
 .../unit_tests/test_local_multi_tensor_fns.py |  56 ++-
 tests/unit_tests/test_optimizer.py            |   4 +-
 tests/unit_tests/test_parallel_state.py       | 132 ++++---
 tests/unit_tests/test_training.py             |  13 +-
 tests/unit_tests/test_utilities.py            |   5 +-
 tests/unit_tests/test_utils.py                |  42 ++-
 .../moe/test_a2a_token_dispatcher.py          |  24 +-
 .../transformer/moe/test_aux_loss.py          |  32 +-
 .../transformer/moe/test_grouped_mlp.py       |  99 +++--
 .../transformer/moe/test_routers.py           |  25 +-
 .../transformer/moe/test_sequential_mlp.py    |  21 +-
 .../transformer/moe/test_token_dispatcher.py  |  42 +--
 .../unit_tests/transformer/test_attention.py  |  50 ++-
 .../transformer/test_attention_packed_seq.py  |   9 +-
 .../transformer/test_core_attention.py        |   4 +-
 tests/unit_tests/transformer/test_mlp.py      |  18 +-
 tests/unit_tests/transformer/test_module.py   |  26 +-
 .../transformer/test_retro_attention.py       |  98 +++--
 .../transformer/test_spec_customization.py    |   4 +-
 .../transformer/test_transformer_block.py     |  43 ++-
 .../transformer/test_transformer_layer.py     |  33 +-
 tools/autoformat.sh                           |   5 +-
 165 files changed, 2878 insertions(+), 2352 deletions(-)
 create mode 100644 .flake8
 create mode 100644 .pylintrc

diff --git a/.flake8 b/.flake8
new file mode 100644
index 0000000000..261f59bc24
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,4 @@
+[flake8]
+max-line-length = 100
+extend-ignore = E203
+per-file-ignores = __init__.py:F401
\ No newline at end of file
diff --git a/.pylintrc b/.pylintrc
new file mode 100644
index 0000000000..5e550f1703
--- /dev/null
+++ b/.pylintrc
@@ -0,0 +1,7 @@
+[MASTER]
+ignore=tests
+
+[MESSAGES CONTROL]
+disable=all
+
+enable=C0115,C0116
\ No newline at end of file
diff --git a/Dockerfile.linting b/Dockerfile.linting
index 910df314f8..b0670af9d1 100644
--- a/Dockerfile.linting
+++ b/Dockerfile.linting
@@ -10,7 +10,9 @@ RUN sed -i -e 's/^APT/# APT/' -e 's/^DPkg/# DPkg/' \
 
 RUN pip3 install --no-cache-dir \
       black==24.4.2 \
-      isort
+      isort==5.13.2 \
+      flake8==7.1.0 \
+      pylint==3.2.6
 
 COPY . /opt/megatron-lm
 
diff --git a/megatron/core/datasets/bert_dataset.py b/megatron/core/datasets/bert_dataset.py
index 657cc6a78a..78ae2edf62 100644
--- a/megatron/core/datasets/bert_dataset.py
+++ b/megatron/core/datasets/bert_dataset.py
@@ -21,8 +21,7 @@ class BERTMaskedWordPieceDatasetConfig(MaskedWordPieceDatasetConfig):
     """Option to perform the next sequence prediction during sampling"""
 
     def __post_init__(self) -> None:
-        """Do asserts and set fields post init
-        """
+        """Do asserts and set fields post init"""
         super().__post_init__()
 
         assert self.classification_head is not None
@@ -73,22 +72,20 @@ def _key_config_attributes() -> List[str]:
         """
         return super(
             BERTMaskedWordPieceDataset, BERTMaskedWordPieceDataset
-        )._key_config_attributes() + ["classification_head",]
+        )._key_config_attributes() + ["classification_head"]
 
     def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
         """Abstract method implementation
- 
+
         Args:
             idx (int): The index into the dataset
 
         Returns:
-            Dict[str, Union[int, numpy.ndarray]]: The 
+            Dict[str, Union[int, numpy.ndarray]]: The
         """
         idx_beg, idx_end, target_sequence_length = self.sample_index[idx]
         sample = [self.dataset[i] for i in range(idx_beg, idx_end)]
-        numpy_random_state = numpy.random.RandomState(
-            seed=(self.config.random_seed + idx) % 2 ** 32
-        )
+        numpy_random_state = numpy.random.RandomState(seed=(self.config.random_seed + idx) % 2**32)
 
         assert target_sequence_length <= self.config.sequence_length
 
@@ -127,11 +124,7 @@ def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
             truncated = True
 
         # Merge the subsegments and create the token assignment labels
-        tokens = [
-            self.config.tokenizer.cls,
-            *split_A,
-            self.config.tokenizer.sep,
-        ]
+        tokens = [self.config.tokenizer.cls, *split_A, self.config.tokenizer.sep]
         assignments = [0 for _ in range(1 + len(split_A) + 1)]
         if split_B:
             tokens += [*split_B, self.config.tokenizer.sep]
diff --git a/megatron/core/datasets/blended_dataset.py b/megatron/core/datasets/blended_dataset.py
index f7883d9b14..be0b7a4a08 100644
--- a/megatron/core/datasets/blended_dataset.py
+++ b/megatron/core/datasets/blended_dataset.py
@@ -93,10 +93,7 @@ def __len__(self) -> int:
     def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
         dataset_id = self.dataset_index[idx]
         dataset_sample_id = self.dataset_sample_index[idx]
-        return {
-            "dataset_id": dataset_id,
-            **self.datasets[dataset_id][dataset_sample_id],
-        }
+        return {"dataset_id": dataset_id, **self.datasets[dataset_id][dataset_sample_id]}
 
     def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]:
         """Build and optionally cache the dataset index and the dataset sample index
@@ -129,9 +126,7 @@ def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]:
 
         if not path_to_cache or (not cache_hit and torch.distributed.get_rank() == 0):
             log_single_rank(
-                logger,
-                logging.INFO,
-                f"Build and save the {type(self).__name__} indices",
+                logger, logging.INFO, f"Build and save the {type(self).__name__} indices"
             )
             self.built_anew_on_cache_miss = True
 
diff --git a/megatron/core/datasets/blended_megatron_dataset_builder.py b/megatron/core/datasets/blended_megatron_dataset_builder.py
index 0230faf5e0..c9cf4abf63 100644
--- a/megatron/core/datasets/blended_megatron_dataset_builder.py
+++ b/megatron/core/datasets/blended_megatron_dataset_builder.py
@@ -156,9 +156,7 @@ def build(self) -> List[Optional[TopLevelDataset]]:
 
         return datasets
 
-    def _build_blended_dataset_splits(
-        self,
-    ) -> List[Optional[TopLevelDataset]]:
+    def _build_blended_dataset_splits(self) -> List[Optional[TopLevelDataset]]:
         """Build all dataset splits according to the provided blend(s)
 
         See the BlendedMegatronDatasetBuilder.build alias for more information.
@@ -306,10 +304,7 @@ def _build_blended_dataset_splits(
             return blended_datasets
 
     def _build_megatron_datasets_parallel(
-        self,
-        prefixes: List[str],
-        split: List[float],
-        sizes_per_dataset: List[List[int]],
+        self, prefixes: List[str], split: List[float], sizes_per_dataset: List[List[int]]
     ) -> List[List[Optional[MegatronDataset]]]:
         """Build the megatron datasets for a list of prefixes in parallel
 
@@ -369,11 +364,7 @@ def _threading_helper(
                     # i.e. meant for serial build, do not scale up.
                     num_workers *= min(2, max(1, torch.cuda.device_count()))
                 _threading_helper(
-                    megatron_datasets,
-                    num_workers,
-                    prefixes,
-                    split,
-                    sizes_per_dataset,
+                    megatron_datasets, num_workers, prefixes, split, sizes_per_dataset
                 )
 
             torch.distributed.barrier()
@@ -389,11 +380,7 @@ def _threading_helper(
                 )
         else:
             _threading_helper(
-                megatron_datasets,
-                num_dataset_builder_threads,
-                prefixes,
-                split,
-                sizes_per_dataset,
+                megatron_datasets, num_dataset_builder_threads, prefixes, split, sizes_per_dataset
             )
 
         return megatron_datasets
diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py
index c5b2bbe7b4..115727de92 100644
--- a/megatron/core/datasets/gpt_dataset.py
+++ b/megatron/core/datasets/gpt_dataset.py
@@ -108,11 +108,9 @@ def __init__(
         except Exception:
             self._pad_token_id = _PAD_TOKEN_ID
 
-        (
-            self.document_index,
-            self.sample_index,
-            self.shuffle_index,
-        ) = self._build_document_sample_shuffle_indices()
+        (self.document_index, self.sample_index, self.shuffle_index) = (
+            self._build_document_sample_shuffle_indices()
+        )
 
     @staticmethod
     def numel_low_level_dataset(low_level_dataset: IndexedDataset) -> int:
diff --git a/megatron/core/datasets/indexed_dataset.py b/megatron/core/datasets/indexed_dataset.py
index ae05bcbc6a..29975336f1 100644
--- a/megatron/core/datasets/indexed_dataset.py
+++ b/megatron/core/datasets/indexed_dataset.py
@@ -385,12 +385,7 @@ def read(self, dtype: Type[numpy.number], count: int, offset: int) -> numpy.ndar
         Returns:
             numpy.ndarray: An array with `count` items and data-type `dtype` constructed from reading bytes from the data file starting at `offset`.
         """
-        return numpy.frombuffer(
-            self._bin_buffer,
-            dtype=dtype,
-            count=count,
-            offset=offset,
-        )
+        return numpy.frombuffer(self._bin_buffer, dtype=dtype, count=count, offset=offset)
 
     def __del__(self) -> None:
         """Clean up the object."""
@@ -633,9 +628,7 @@ def __getitem__(
         if isinstance(idx, (int, numpy.integer)):
             sequence_pointer, sequence_length, sequence_mode = self.index[idx]
             sequence = self.bin_reader.read(
-                dtype=self.index.dtype,
-                count=sequence_length,
-                offset=sequence_pointer,
+                dtype=self.index.dtype, count=sequence_length, offset=sequence_pointer
             )
             return (sequence, sequence_mode) if sequence_mode is not None else sequence
         elif isinstance(idx, slice):
diff --git a/megatron/core/datasets/masked_dataset.py b/megatron/core/datasets/masked_dataset.py
index 081d58525b..9db6c67eb1 100644
--- a/megatron/core/datasets/masked_dataset.py
+++ b/megatron/core/datasets/masked_dataset.py
@@ -154,15 +154,7 @@ def _build_sample_index(
         )
         path_to_description = get_path_to("description.txt")
         path_to_sample_index = get_path_to("sample_index.npy")
-        cache_hit = all(
-            map(
-                os.path.isfile,
-                [
-                    path_to_description,
-                    path_to_sample_index,
-                ],
-            )
-        )
+        cache_hit = all(map(os.path.isfile, [path_to_description, path_to_sample_index]))
 
         if self.num_samples is not None:
             num_epochs = numpy.iinfo(numpy.int32).max - 1
diff --git a/megatron/core/datasets/retro/db/build.py b/megatron/core/datasets/retro/db/build.py
index 780cc9e503..44b9038230 100644
--- a/megatron/core/datasets/retro/db/build.py
+++ b/megatron/core/datasets/retro/db/build.py
@@ -95,23 +95,13 @@ def build_partial_db(
     if proc_id in progress_proc_ids:
         log_retro_rank_0(
             " > building partial chunk db, proc %d / %d, docs %d:%d / %d."
-            % (
-                proc_id,
-                n_procs,
-                doc_start_id,
-                doc_end_id,
-                n_docs,
-            )
+            % (proc_id, n_procs, doc_start_id, doc_end_id, n_docs)
         )
 
     # Progress bars (snapshot of overall progress).
     doc_id_iter = range(doc_start_id, doc_end_id)
     pbar = (
-        tqdm(
-            doc_id_iter,
-            "parse doc chunks",
-            miniters=len(doc_id_iter) // 20,
-        )
+        tqdm(doc_id_iter, "parse doc chunks", miniters=len(doc_id_iter) // 20)
         if proc_id in progress_proc_ids
         else doc_id_iter
     )
@@ -156,9 +146,7 @@ def build_partial_db(
             # Re-tokenize.
             chunk_end_idx = chunk_end_idxs[i]
             gpt_token_ids = indexed_dataset.get(
-                idx=doc_id,
-                offset=chunk_start_idx,
-                length=chunk_end_idx - chunk_start_idx,
+                idx=doc_id, offset=chunk_start_idx, length=chunk_end_idx - chunk_start_idx
             )
             text = config.gpt_detokenize(gpt_token_ids.tolist())
             bert_token_ids = config.bert_tokenize(text)
@@ -169,14 +157,7 @@ def build_partial_db(
             else:
                 _chunk_db = chunk_db_valid
                 doc_size_map[doc_id] += 1
-            _chunk_db.append(
-                (
-                    doc_id,
-                    chunk_start_idx,
-                    chunk_end_idx,
-                    len(bert_token_ids),
-                )
-            )
+            _chunk_db.append((doc_id, chunk_start_idx, chunk_end_idx, len(bert_token_ids)))
 
     return proc_id, chunk_db_valid, chunk_db_invalid, doc_size_map
 
@@ -269,10 +250,7 @@ def build_block_db(
 
 
 def save_block_db(
-    block: dict,
-    chunk_db_valid: np.ndarray,
-    chunk_db_invalid: np.ndarray,
-    doc_offsets: np.ndarray,
+    block: dict, chunk_db_valid: np.ndarray, chunk_db_invalid: np.ndarray, doc_offsets: np.ndarray
 ) -> None:
     """Save block of chunked tokens to disk. These blocks are later used for
     training and adding to the vector index.
@@ -291,10 +269,7 @@ def save_block_db(
 
 
 def build_individual_db(
-    config: RetroPreprocessingConfig,
-    dataset_idx: int,
-    n_datasets: int,
-    dataset_info: dict,
+    config: RetroPreprocessingConfig, dataset_idx: int, n_datasets: int, dataset_info: dict
 ) -> None:
     """Process a single indexed dataset & extract chunks.
 
@@ -395,8 +370,7 @@ def build_individual_db(
 
 
 def build_individual_dbs(
-    config: RetroPreprocessingConfig,
-    indexed_dataset_infos: List[Dict],
+    config: RetroPreprocessingConfig, indexed_dataset_infos: List[Dict]
 ) -> None:
     """Iterate each indexed dataset & process its chunks.
 
@@ -412,11 +386,7 @@ def build_individual_dbs(
         # Progress.
         log_retro_rank_0(
             " > building individual db, dataset %d / %d ... '%s'."
-            % (
-                ds_idx,
-                len(indexed_dataset_infos),
-                ds_info["prefix"],
-            )
+            % (ds_idx, len(indexed_dataset_infos), ds_info["prefix"])
         )
 
         # Process single dataset.
@@ -562,7 +532,7 @@ def merge_dbs(project_dir: str, indexed_dataset_infos: List[Dict], db_type: str)
         for ds_idx, ds_info in enumerate(indexed_dataset_infos):
             log_retro_rank_0(
                 " > merging dbs; '%s', dataset %d / %d ... '%s'."
-                % (db_type, ds_idx, len(indexed_dataset_infos), ds_info["prefix"]),
+                % (db_type, ds_idx, len(indexed_dataset_infos), ds_info["prefix"])
             )
             individual_chunk_db: np.ndarray = get_individual_chunk_db(project_dir, ds_idx, ds_info)
             individual_doc_offsets: np.ndarray = (
diff --git a/megatron/core/datasets/retro/db/dataset.py b/megatron/core/datasets/retro/db/dataset.py
index 1de6e02b10..f9053622ab 100644
--- a/megatron/core/datasets/retro/db/dataset.py
+++ b/megatron/core/datasets/retro/db/dataset.py
@@ -17,7 +17,7 @@
 
 class DBDataset(torch.utils.data.Dataset):
     """Dataset for iterating chunks.
-    
+
     Args:
         db_path (str): Path of HDF5-format chunk database.
         indexed_datasets (List[IndexedDataset]): Indexed datasets used to build database.
@@ -85,10 +85,7 @@ def __getitem__(self, chunk_id: int) -> dict:
             token_ids = token_ids.tolist()
             token_ids += [self.eod_token_id] * (self.max_chunk_length - chunk_length)
 
-        return {
-            "doc_id": doc_id,
-            "text": np.array(token_ids, dtype=np.int64),
-        }
+        return {"doc_id": doc_id, "text": np.array(token_ids, dtype=np.int64)}
 
     def load_doc_tuples(self) -> None:
         """Load the dataset & document ids.
diff --git a/megatron/core/datasets/retro/db/utils.py b/megatron/core/datasets/retro/db/utils.py
index df13089840..e8578a09d5 100644
--- a/megatron/core/datasets/retro/db/utils.py
+++ b/megatron/core/datasets/retro/db/utils.py
@@ -22,7 +22,7 @@ def get_db_dir(project_dir: str) -> str:
 
     Args:
         project_dir (str): Path to Retro project dir.
-    
+
     Returns:
         Path of the DB sub-directory within the project.
     """
@@ -55,9 +55,7 @@ def init_indexed_dataset_infos(config: RetroPreprocessingConfig) -> List[Dict]:
         prefix = data_blend[i + 1]
         path = os.path.join(data_dir, prefix + ".bin")
         assert os.path.exists(path), "couldn't find '%s'." % path
-        infos.append(
-            {"ratio": ratio, "prefix": prefix,}
-        )
+        infos.append({"ratio": ratio, "prefix": prefix})
 
     # Load indexed datasets.
     load_indexed_datasets(config.retro_project_dir, infos)
diff --git a/megatron/core/datasets/retro/external_libs.py b/megatron/core/datasets/retro/external_libs.py
index 98b28728d4..c057eba25c 100644
--- a/megatron/core/datasets/retro/external_libs.py
+++ b/megatron/core/datasets/retro/external_libs.py
@@ -4,11 +4,7 @@
 
 import importlib
 
-required_libs = [
-    "faiss",
-    "h5py",
-    "transformers",  # for huggingface bert
-]
+required_libs = ["faiss", "h5py", "transformers"]  # for huggingface bert
 
 for lib in required_libs:
     try:
diff --git a/megatron/core/datasets/retro/index/build.py b/megatron/core/datasets/retro/index/build.py
index a5659e92db..1f310d89c3 100644
--- a/megatron/core/datasets/retro/index/build.py
+++ b/megatron/core/datasets/retro/index/build.py
@@ -41,7 +41,7 @@ def get_empty_index_path(config: RetroPreprocessingConfig) -> str:
 
     Args:
         config (RetroPreprocessingConfig): Retro preprocessing config.
-    
+
     Returns:
         Path to the empty (trained, but without added samples) vector index.
     """
diff --git a/megatron/core/datasets/retro/index/factory.py b/megatron/core/datasets/retro/index/factory.py
index 293d58c678..f88084ddb1 100644
--- a/megatron/core/datasets/retro/index/factory.py
+++ b/megatron/core/datasets/retro/index/factory.py
@@ -23,7 +23,7 @@ def get_index_class(cls, index_type: str) -> type:
         Returns:
             An `Index` sub-type corresponding to the `index_type`.
         """
-        return {"faiss-base": FaissBaseIndex, "faiss-par-add": FaissParallelAddIndex,}[index_type]
+        return {"faiss-base": FaissBaseIndex, "faiss-par-add": FaissParallelAddIndex}[index_type]
 
     @classmethod
     def get_index(cls, index_type: str) -> Index:
diff --git a/megatron/core/datasets/retro/index/index.py b/megatron/core/datasets/retro/index/index.py
index a8c086fb94..c6bd13fbee 100644
--- a/megatron/core/datasets/retro/index/index.py
+++ b/megatron/core/datasets/retro/index/index.py
@@ -27,7 +27,6 @@
 
 
 class Index(abc.ABC):
-
     """Abstract base class for indexes.
 
     *Note* : While currently only Faiss-based classes are implemented, in the
@@ -60,7 +59,7 @@ def get_empty_index_path(self, config: RetroPreprocessingConfig) -> str:
             File path to empty index (i.e., this index has had index.train() called, but not yet index.add()).
         """
         return os.path.join(
-            get_index_dir(config), "empty_%.3f.faissindex" % config.retro_index_train_load_fraction,
+            get_index_dir(config), "empty_%.3f.faissindex" % config.retro_index_train_load_fraction
         )
 
     def get_empty_index(self, config: RetroPreprocessingConfig) -> faiss.Index:
@@ -86,7 +85,7 @@ def get_added_index_path(self, config: RetroPreprocessingConfig) -> str:
         return os.path.join(
             get_index_dir(config),
             "added_%.3f_%.3f.faissindex"
-            % (config.retro_index_train_load_fraction, config.retro_index_add_load_fraction,),
+            % (config.retro_index_train_load_fraction, config.retro_index_add_load_fraction),
         )
 
     def get_added_index(self, config: RetroPreprocessingConfig) -> faiss.Index:
diff --git a/megatron/core/datasets/retro/index/indexes/faiss_base.py b/megatron/core/datasets/retro/index/indexes/faiss_base.py
index 1ffc72528c..c1daf3f533 100644
--- a/megatron/core/datasets/retro/index/indexes/faiss_base.py
+++ b/megatron/core/datasets/retro/index/indexes/faiss_base.py
@@ -52,7 +52,7 @@ def _train(self, config: RetroPreprocessingConfig) -> None:
 
         # Load data.
         merged_path = get_training_data_merged_path(config)
-        inp = np.memmap(merged_path, dtype="f4", mode="r",).reshape((-1, config.hidden_size))
+        inp = np.memmap(merged_path, dtype="f4", mode="r").reshape((-1, config.hidden_size))
 
         # Init index.
         index = faiss.index_factory(config.hidden_size, config.retro_index_str)
diff --git a/megatron/core/datasets/retro/index/indexes/faiss_par_add.py b/megatron/core/datasets/retro/index/indexes/faiss_par_add.py
index 6d9d68f821..e014217262 100644
--- a/megatron/core/datasets/retro/index/indexes/faiss_par_add.py
+++ b/megatron/core/datasets/retro/index/indexes/faiss_par_add.py
@@ -58,7 +58,7 @@ def encode_block(
         """
 
         # Embed block.
-        embeddings = self.embed_text_dataset_block(embedder, text_dataset, block["range"],)
+        embeddings = self.embed_text_dataset_block(embedder, text_dataset, block["range"])
 
         # Encode block.
         log_retro_rank_0("encode.")
@@ -108,7 +108,7 @@ def validate(f: h5py.File) -> None:
             assert len(f["data"].shape) == 2
 
         blocks = get_blocks_by_rank(
-            codes_dir, len(text_dataset), config.retro_block_size, validate=validate,
+            codes_dir, len(text_dataset), config.retro_block_size, validate=validate
         )
 
         # Encode each block.
@@ -119,7 +119,7 @@ def validate(f: h5py.File) -> None:
                 # Progress.
                 log_retro_rank_0(
                     "encode block %d / %d ... %s."
-                    % (block_index, len(blocks.missing), block["path"],)
+                    % (block_index, len(blocks.missing), block["path"])
                 )
 
                 # Encode and save.
@@ -156,7 +156,7 @@ def add_codes(self, config: RetroPreprocessingConfig) -> None:
         for code_path in pbar:
             pbar.set_description(
                 "add codes, mem %.3f gb, %.1f%%"
-                % (psutil.virtual_memory()[3] / 1024 ** 3, psutil.virtual_memory()[2],)
+                % (psutil.virtual_memory()[3] / 1024**3, psutil.virtual_memory()[2])
             )
             with h5py.File(code_path) as f:
 
diff --git a/megatron/core/datasets/retro/index/utils.py b/megatron/core/datasets/retro/index/utils.py
index 321cd659d8..58229439ae 100644
--- a/megatron/core/datasets/retro/index/utils.py
+++ b/megatron/core/datasets/retro/index/utils.py
@@ -22,7 +22,7 @@ def get_index_dir(config: RetroPreprocessingConfig) -> str:
 
     # Directory path.
     index_dir_path = os.path.join(
-        config.retro_project_dir, "index", config.retro_index_type, config.retro_index_str,
+        config.retro_project_dir, "index", config.retro_index_type, config.retro_index_str
     )
 
     # Make directory.
diff --git a/megatron/core/datasets/retro/index/validate.py b/megatron/core/datasets/retro/index/validate.py
index 6783df6492..57306707c4 100644
--- a/megatron/core/datasets/retro/index/validate.py
+++ b/megatron/core/datasets/retro/index/validate.py
@@ -74,7 +74,7 @@ def validate_training_embeddings(config: RetroPreprocessingConfig) -> None:
             # Progress. (*note*: move world progress to here.)
             log_retro_rank_0(
                 "embed training block %d / %d ... %s."
-                % (block_idx, len(blocks.existing), block["path"],)
+                % (block_idx, len(blocks.existing), block["path"])
             )
 
             # Load existing block embeddings.
@@ -147,7 +147,7 @@ def validate(f: h5py.File) -> None:
 
             # Progress.
             log_retro_rank_0(
-                "encode block %d / %d ... %s." % (block_idx, len(blocks.existing), block["path"],)
+                "encode block %d / %d ... %s." % (block_idx, len(blocks.existing), block["path"])
             )
 
             # Load existing codes.
diff --git a/megatron/core/datasets/retro/query/gpt_chunk_dataset.py b/megatron/core/datasets/retro/query/gpt_chunk_dataset.py
index 34a2ee6c87..6191a30a31 100644
--- a/megatron/core/datasets/retro/query/gpt_chunk_dataset.py
+++ b/megatron/core/datasets/retro/query/gpt_chunk_dataset.py
@@ -73,14 +73,11 @@ def __getitem__(self, idx: int) -> dict:
         chunk_token_ids = sample_token_ids[token_start_idx:token_end_idx]
 
         # Sample.
-        return {
-            "doc_ids": sample_doc_ids,
-            "text": chunk_token_ids,
-        }
+        return {"doc_ids": sample_doc_ids, "text": chunk_token_ids}
 
 
 def build_gpt_chunk_datasets_from_gpt_datasets(
-    project_dir: str, gpt_datasets: dict, sample_length: int, chunk_length: int,
+    project_dir: str, gpt_datasets: dict, sample_length: int, chunk_length: int
 ) -> dict:
     """Get train, valid, test GPT chunk datasets.
 
@@ -96,14 +93,16 @@ def build_gpt_chunk_datasets_from_gpt_datasets(
 
     # GPT chunk datasets.
     chunk_datasets = {
-        key: {
-            "dataset": GPTChunkDataset(sample_ds, sample_length, chunk_length),
-            "neighbor_dir": get_neighbor_dir(project_dir, key, sample_ds),
-            "num_active_chunks": num_active_samples
-            * get_num_chunks_per_sample(sample_length, chunk_length),
-        }
-        if sample_ds
-        else None
+        key: (
+            {
+                "dataset": GPTChunkDataset(sample_ds, sample_length, chunk_length),
+                "neighbor_dir": get_neighbor_dir(project_dir, key, sample_ds),
+                "num_active_chunks": num_active_samples
+                * get_num_chunks_per_sample(sample_length, chunk_length),
+            }
+            if sample_ds
+            else None
+        )
         for key, (sample_ds, num_active_samples) in gpt_datasets.items()
     }
 
diff --git a/megatron/core/datasets/retro/query/query.py b/megatron/core/datasets/retro/query/query.py
index 165792f9a0..9da3381712 100644
--- a/megatron/core/datasets/retro/query/query.py
+++ b/megatron/core/datasets/retro/query/query.py
@@ -39,7 +39,7 @@
 from .gpt_chunk_dataset import build_gpt_chunk_datasets_from_gpt_datasets
 
 
-def get_index(config: RetroPreprocessingConfig, ondisk: bool = False,) -> faiss.Index:
+def get_index(config: RetroPreprocessingConfig, ondisk: bool = False) -> faiss.Index:
     """Read index from disk.
 
     Args:
@@ -67,7 +67,7 @@ def get_index(config: RetroPreprocessingConfig, ondisk: bool = False,) -> faiss.
 
 
 def embed_block(
-    config: RetroPreprocessingConfig, gpt_dataset: GPTChunkDataset, block: dict,
+    config: RetroPreprocessingConfig, gpt_dataset: GPTChunkDataset, block: dict
 ) -> np.ndarray:
     """Embed block of chunks.
 
@@ -80,7 +80,7 @@ def embed_block(
         Embeddings array, with shape (len(block["range"]), dimension(embedder)).
     """
     text_block_dataset = torch.utils.data.Subset(
-        GPTToTextDataset(gpt_dataset, config.retro_tokenizers.gpt), range(*block["range"]),
+        GPTToTextDataset(gpt_dataset, config.retro_tokenizers.gpt), range(*block["range"])
     )
     return config.retro_bert_embedders.mem.embed_text_dataset(text_block_dataset)
 
@@ -248,17 +248,14 @@ def query_block_neighbors(
     sample_map = {}
     for i in sample_ids:
         sample = query_dataset.sample_dataset[i]
-        sample_map[i] = {
-            "dataset_idx": sample["dataset_id"],
-            "doc_ids": sample["document_ids"],
-        }
+        sample_map[i] = {"dataset_idx": sample["dataset_id"], "doc_ids": sample["document_ids"]}
 
     # Embed block.
     embeddings = embed_block(config, query_dataset, block)
 
     # Query embeddings.
     _, filtered_neighbor_ids = query_embedding_block(
-        config, db_dataset, index, embeddings, block["range"], sample_map, n_chunks_per_sample,
+        config, db_dataset, index, embeddings, block["range"], sample_map, n_chunks_per_sample
     )
 
     if config.retro_task_validate is None:
@@ -303,15 +300,17 @@ def validate(f: h5py.File) -> None:
         Args:
             f (h5py.File): File containing save neighbor IDs.
         """
-        assert f["neighbors"].shape[1] == config.retro_query_num_neighbors_save, (
-            "neighbors.shape == %s; num_neighbors_target == %d."
-            % (str(f["neighbors"].shape), config.retro_num_neighbors_target,)
+        assert (
+            f["neighbors"].shape[1] == config.retro_query_num_neighbors_save
+        ), "neighbors.shape == %s; num_neighbors_target == %d." % (
+            str(f["neighbors"].shape),
+            config.retro_num_neighbors_target,
         )
 
     if config.retro_task_validate is None:
         retro_makedir(config, neighbor_dir)
         blocks = get_blocks_by_rank(
-            neighbor_dir, num_active_chunks, config.retro_block_size, validate=validate,
+            neighbor_dir, num_active_chunks, config.retro_block_size, validate=validate
         )
         active_blocks = blocks.missing
     else:
@@ -339,7 +338,7 @@ def validate(f: h5py.File) -> None:
                     block_index,
                     len(active_blocks),
                     os.path.basename(block["path"]),
-                    psutil.virtual_memory()[3] / 1024 ** 3,
+                    psutil.virtual_memory()[3] / 1024**3,
                     psutil.virtual_memory()[2],
                 )
             )
diff --git a/megatron/core/datasets/retro/query/retro_dataset.py b/megatron/core/datasets/retro/query/retro_dataset.py
index 07af161693..6c3b9ae60c 100644
--- a/megatron/core/datasets/retro/query/retro_dataset.py
+++ b/megatron/core/datasets/retro/query/retro_dataset.py
@@ -94,7 +94,7 @@ def __getitem__(self, sample_idx: int) -> dict:
 
         # Sample idx to chunk idxs.
         chunk_idxs = list(
-            range(sample_idx * n_chunks_per_sample, (sample_idx + 1) * n_chunks_per_sample,)
+            range(sample_idx * n_chunks_per_sample, (sample_idx + 1) * n_chunks_per_sample)
         )
 
         # Collect retrieved tokens.
@@ -144,7 +144,7 @@ def __getitem__(self, sample_idx: int) -> dict:
 
 
 def get_retro_datasets(
-    config: RetroConfig, gpt_datasets: dict, sample_length: int, eod_token_id: int,
+    config: RetroConfig, gpt_datasets: dict, sample_length: int, eod_token_id: int
 ) -> Tuple[Optional[RetroDataset], Optional[RetroDataset], Optional[RetroDataset]]:
     """Get train, valid, test retro datasets.
 
@@ -190,7 +190,7 @@ def get_retro_datasets(
         # preprocessing and pretraining.
         chunk_dataset = chunk_ds_info["dataset"]
         chunk_ds_info["neighbor_dir"] = os.path.join(
-            query_dir, config.retro_neighbor_dirs[data_key],
+            query_dir, config.retro_neighbor_dirs[data_key]
         )
         neighbor_dir = chunk_ds_info["neighbor_dir"]
         neighbor_path_map = BlockPathMap.from_dir(
@@ -235,8 +235,4 @@ def get_retro_datasets(
             neighbor_path_map=neighbor_path_map,
         )
 
-    return (
-        retro_dataset_map["train"],
-        retro_dataset_map["valid"],
-        retro_dataset_map["test"],
-    )
+    return (retro_dataset_map["train"], retro_dataset_map["valid"], retro_dataset_map["test"])
diff --git a/megatron/core/datasets/retro/query/utils.py b/megatron/core/datasets/retro/query/utils.py
index f07920d48c..b4e0c67009 100644
--- a/megatron/core/datasets/retro/query/utils.py
+++ b/megatron/core/datasets/retro/query/utils.py
@@ -31,5 +31,5 @@ def get_neighbor_dir(project_dir: str, key: str, dataset: MegatronDataset) -> st
         Path to directory containing this dataset's neighbors within Retro project.
     """
     return os.path.join(
-        get_query_dir(project_dir), os.path.basename(f"{key}_{dataset.unique_description_hash}"),
+        get_query_dir(project_dir), os.path.basename(f"{key}_{dataset.unique_description_hash}")
     )
diff --git a/megatron/core/datasets/retro/utils.py b/megatron/core/datasets/retro/utils.py
index dbef86a38d..31c0be14c8 100644
--- a/megatron/core/datasets/retro/utils.py
+++ b/megatron/core/datasets/retro/utils.py
@@ -110,10 +110,7 @@ def __getitem__(self, idx: int) -> dict:
 
 
 def get_blocks(
-    dirname: str,
-    n_samples: int,
-    block_size: int,
-    validate: Callable = None,
+    dirname: str, n_samples: int, block_size: int, validate: Callable = None
 ) -> SimpleNamespace:
     """Divide range [0, num_samples) to sequence of block ranges.
 
@@ -147,8 +144,7 @@ def get_blocks(
         {
             "range": r,
             "path": os.path.join(
-                dirname,
-                "%s-%s.hdf5" % tuple([str(i).zfill(n_digits) for i in r]),
+                dirname, "%s-%s.hdf5" % tuple([str(i).zfill(n_digits) for i in r])
             ),
         }
         for r in block_ranges
diff --git a/megatron/core/datasets/t5_dataset.py b/megatron/core/datasets/t5_dataset.py
index 33792c8636..b54e4f5315 100644
--- a/megatron/core/datasets/t5_dataset.py
+++ b/megatron/core/datasets/t5_dataset.py
@@ -30,8 +30,7 @@ class T5MaskedWordPieceDatasetConfig(MaskedWordPieceDatasetConfig):
     """The sequence length for the decoder"""
 
     def __post_init__(self) -> None:
-        """Do asserts and set fields post init
-        """
+        """Do asserts and set fields post init"""
         super().__post_init__()
 
         self.sequence_length_encoder = self.sequence_length
@@ -85,23 +84,21 @@ def _key_config_attributes() -> List[str]:
         """
         return super(
             T5MaskedWordPieceDataset, T5MaskedWordPieceDataset
-        )._key_config_attributes() + ["sequence_length_decoder",]
+        )._key_config_attributes() + ["sequence_length_decoder"]
 
     def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
         """Abstract method implementation
- 
+
         Args:
             idx (int): The index into the dataset
 
         Returns:
-            Dict[str, Union[int, numpy.ndarray]]: The 
+            Dict[str, Union[int, numpy.ndarray]]: The
         """
         idx_beg, idx_end, target_sequence_length = self.sample_index[idx]
         sample = [self.dataset[i] for i in range(idx_beg, idx_end)]
 
-        numpy_random_state = numpy.random.RandomState(
-            seed=(self.config.random_seed + idx) % 2 ** 32
-        )
+        numpy_random_state = numpy.random.RandomState(seed=(self.config.random_seed + idx) % 2**32)
 
         assert target_sequence_length <= self.config.sequence_length
 
@@ -113,7 +110,7 @@ def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
         tokens = tokens[:target_sequence_length]
 
         # Masking
-        (tokens, _, _, _, masked_spans,) = self._create_masked_lm_predictions(
+        (tokens, _, _, _, masked_spans) = self._create_masked_lm_predictions(
             tokens, target_sequence_length, numpy_random_state
         )
 
diff --git a/megatron/core/datasets/utils.py b/megatron/core/datasets/utils.py
index 45203c256a..8d887d4a4a 100644
--- a/megatron/core/datasets/utils.py
+++ b/megatron/core/datasets/utils.py
@@ -19,8 +19,7 @@ class Split(Enum):
 
 
 def compile_helpers():
-    """Compile C++ helper functions at runtime. Make sure this is invoked on a single process.
-    """
+    """Compile C++ helper functions at runtime. Make sure this is invoked on a single process."""
     import os
     import subprocess
 
@@ -51,7 +50,7 @@ def get_blend_from_list(
     blend: Optional[List[str]],
 ) -> Optional[Tuple[List[str], Optional[List[float]]]]:
     """Get the megatron.core.datasets.blended_megatron_dataset_config.BlendedMegatronDatasetConfig blend from the blend list
-    
+
     Args:
         blend (Optional[List[str]]): The blend list, which can be either (1) a list of prefixes, e.g. ["path/to/dataset_1_prefix", "path/to/dataset_2_prefix"], or (2) a flattened, zipped list of weights and prefixes, e.g. ["30", "path/to/dataset_1_prefix", "70", "path/to/dataset_2_prefix"]
 
diff --git a/megatron/core/dist_checkpointing/core.py b/megatron/core/dist_checkpointing/core.py
index 50384e661b..af6ebff6ec 100644
--- a/megatron/core/dist_checkpointing/core.py
+++ b/megatron/core/dist_checkpointing/core.py
@@ -11,14 +11,14 @@
 
 
 class CheckpointingException(Exception):
-    """ Base checkpointing related exception  """
+    """Base checkpointing related exception"""
 
     pass
 
 
 @dataclass
 class CheckpointingConfig:
-    """ Documents backends used in the checkpoint.
+    """Documents backends used in the checkpoint.
 
     Checkpoint config keeps track of formats used for storing the sharded tensors
     (sharded_backend) and other objects (common_backend).
@@ -34,7 +34,7 @@ class CheckpointingConfig:
 
 
 def check_is_distributed_checkpoint(checkpoint_dir):
-    """ Checks if `metadata.json` exists in the checkpoint and is a valid config.
+    """Checks if `metadata.json` exists in the checkpoint and is a valid config.
 
     Args:
         checkpoint_dir: checkpoint directory
@@ -46,7 +46,7 @@ def check_is_distributed_checkpoint(checkpoint_dir):
 
 
 def maybe_load_config(checkpoint_dir: str) -> Optional[CheckpointingConfig]:
-    """ Returns checkpoint config if `checkpoint_dir` is a distributed checkpoint and None otherwise
+    """Returns checkpoint config if `checkpoint_dir` is a distributed checkpoint and None otherwise
 
     Args:
         checkpoint_dir: checkpoint directory
@@ -63,7 +63,7 @@ def maybe_load_config(checkpoint_dir: str) -> Optional[CheckpointingConfig]:
 
 
 def save_config(config: CheckpointingConfig, checkpoint_dir: str):
-    """ Save given config to checkpoint directory.
+    """Save given config to checkpoint directory.
 
     Args:
         config: checkpoint config
diff --git a/megatron/core/dist_checkpointing/serialization.py b/megatron/core/dist_checkpointing/serialization.py
index f37aadc913..43ad3bc49e 100644
--- a/megatron/core/dist_checkpointing/serialization.py
+++ b/megatron/core/dist_checkpointing/serialization.py
@@ -182,8 +182,7 @@ def load_common_state_dict(checkpoint_dir: Path) -> StateDict:
 
 
 def load_tensors_metadata(
-    checkpoint_dir: str,
-    sharded_strategy: Union[LoadShardedStrategy, None] = None,
+    checkpoint_dir: str, sharded_strategy: Union[LoadShardedStrategy, None] = None
 ) -> CkptShardedMetadata:
     """Load tensors metadata from the checkpoint.
 
diff --git a/megatron/core/dist_checkpointing/strategies/async_utils.py b/megatron/core/dist_checkpointing/strategies/async_utils.py
index 24ee43d7e0..7cdda8ac32 100644
--- a/megatron/core/dist_checkpointing/strategies/async_utils.py
+++ b/megatron/core/dist_checkpointing/strategies/async_utils.py
@@ -76,11 +76,7 @@ def __init__(self):
         self.process: Optional[mp.Process] = None
         self.start_time: Optional[float] = None
 
-    def schedule_async_call(
-        self,
-        async_fn: Optional[Callable],
-        save_args: Tuple,
-    ) -> None:
+    def schedule_async_call(self, async_fn: Optional[Callable], save_args: Tuple) -> None:
         """Spawn a process with `async_fn` as the target.
 
         This method must be called on all ranks.
@@ -101,10 +97,7 @@ def schedule_async_call(
 
         ctx = mp.get_context('fork')
         self.start_time = time()
-        self.process = ctx.Process(
-            target=async_fn,
-            args=save_args,
-        )
+        self.process = ctx.Process(target=async_fn, args=save_args)
         self.process.start()
         init_time = time()
         logger.debug(
diff --git a/megatron/core/dist_checkpointing/strategies/filesystem_async.py b/megatron/core/dist_checkpointing/strategies/filesystem_async.py
index bfa609128a..9d0be4d6e7 100644
--- a/megatron/core/dist_checkpointing/strategies/filesystem_async.py
+++ b/megatron/core/dist_checkpointing/strategies/filesystem_async.py
@@ -284,11 +284,7 @@ def write_preloaded_data(
             f"{local_proc_idx} consumed: {mem_after - mem_before}, before: {mem_before}, after: {mem_after}"
         )
 
-    def write_data(
-        self,
-        plan: SavePlan,
-        planner: SavePlanner,
-    ) -> Future[List[WriteResult]]:
+    def write_data(self, plan: SavePlan, planner: SavePlanner) -> Future[List[WriteResult]]:
         raise NotImplementedError('write_data not implemented for FileSystemWriterAsync')
 
     def retrieve_write_results(self) -> List[WriteResult]:
diff --git a/megatron/core/dist_checkpointing/strategies/fully_parallel.py b/megatron/core/dist_checkpointing/strategies/fully_parallel.py
index 0b004e2bce..238c381378 100644
--- a/megatron/core/dist_checkpointing/strategies/fully_parallel.py
+++ b/megatron/core/dist_checkpointing/strategies/fully_parallel.py
@@ -97,11 +97,7 @@ def __init__(
 
         self.cached_distribution: Optional[SaveLoadDistribution] = None
 
-    def async_save(
-        self,
-        sharded_state_dict: ShardedStateDict,
-        checkpoint_dir: Path,
-    ):
+    def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
         if not isinstance(self.base_strategy, AsyncSaveShardedStrategy):
             raise CheckpointingException(
                 f'Cannot apply async_save to non-async base strategy {self.base_strategy}'
@@ -109,11 +105,7 @@ def async_save(
         self.apply_saving_parallelization(sharded_state_dict)
         return self.base_strategy.async_save(sharded_state_dict, checkpoint_dir)
 
-    def save(
-        self,
-        sharded_state_dict: ShardedStateDict,
-        checkpoint_dir: Path,
-    ):
+    def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
         self.apply_saving_parallelization(sharded_state_dict)
         return self.base_strategy.save(sharded_state_dict, checkpoint_dir)
 
@@ -248,12 +240,9 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St
         # Step 3: load part of the checkpoint.
         # Load only sharded objects first. ShardedTensors will be loaded separately
         # so that we can keep track of sharded tensors loaded by this rank
-        (
-            sharded_tensors,
-            sharded_state_dict,
-            to_load_shards,
-            unloaded_shards,
-        ) = self._defer_loading_sharded_tensors(sharded_state_dict)
+        (sharded_tensors, sharded_state_dict, to_load_shards, unloaded_shards) = (
+            self._defer_loading_sharded_tensors(sharded_state_dict)
+        )
         loaded_state_dict = self.base_strategy.load(sharded_state_dict, checkpoint_dir)
 
         end = time()
@@ -279,10 +268,7 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St
             raise NotImplementedError(f'Unrecognized gather algorithm: {self.exchange_algo}')
 
         all_loaded_tensors = exchange_fn(
-            loaded_tensors,
-            unloaded_shards,
-            precomputed_distribution,
-            self.parallelization_group,
+            loaded_tensors, unloaded_shards, precomputed_distribution, self.parallelization_group
         )
         if not set(unloaded_shards.keys()).issubset(all_loaded_tensors.keys()):
             missing_shards = set(unloaded_shards.keys()) - all_loaded_tensors.keys()
@@ -300,7 +286,9 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St
         merge(loaded_state_dict, sharded_tensors)
         return loaded_state_dict
 
-    def _defer_loading_sharded_tensors(self, sharded_state_dict: ShardedStateDict) -> Tuple[
+    def _defer_loading_sharded_tensors(
+        self, sharded_state_dict: ShardedStateDict
+    ) -> Tuple[
         ShardedStateDict,
         ShardedStateDict,
         Dict[_ShardId, ShardedTensor],
diff --git a/megatron/core/dist_checkpointing/strategies/state_dict_saver.py b/megatron/core/dist_checkpointing/strategies/state_dict_saver.py
index 092e91d2f8..8e1d2c5523 100644
--- a/megatron/core/dist_checkpointing/strategies/state_dict_saver.py
+++ b/megatron/core/dist_checkpointing/strategies/state_dict_saver.py
@@ -124,9 +124,7 @@ def global_step(all_local_plans):
 
 
 def save_state_dict_async_finalize(
-    storage_writer: 'FileSystemWriterAsync',
-    global_metadata: Metadata,
-    dist_wrapper: _DistWrapper,
+    storage_writer: 'FileSystemWriterAsync', global_metadata: Metadata, dist_wrapper: _DistWrapper
 ) -> None:
     """
     Finalization of save_state_dict_async_plan.
diff --git a/megatron/core/dist_checkpointing/strategies/tensorstore.py b/megatron/core/dist_checkpointing/strategies/tensorstore.py
index 61972ec95b..9b4eeb3185 100644
--- a/megatron/core/dist_checkpointing/strategies/tensorstore.py
+++ b/megatron/core/dist_checkpointing/strategies/tensorstore.py
@@ -115,10 +115,7 @@ def open_ts_array(arr_path: Path):
         arr_path (Path): path to a Zarr (Tensorstore) array
     """
     spec = {'driver': 'zarr', 'metadata_key': '.zarray', 'kvstore': {}}
-    spec['kvstore'] = {
-        'driver': 'file',
-        'path': str(arr_path),
-    }
+    spec['kvstore'] = {'driver': 'file', 'path': str(arr_path)}
     try:
         arr = ts.open(ts.Spec(spec), open=True).result()
     except Exception as e:
diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py
index d42d3ccda0..2fccba1f8d 100644
--- a/megatron/core/dist_checkpointing/strategies/torch.py
+++ b/megatron/core/dist_checkpointing/strategies/torch.py
@@ -524,8 +524,7 @@ def resolve_tensor(self, read_item: ReadItem):
         ):
             self._intermediate_read_item_and_target = (read_item, target_tensor)
             target_tensor = Float8Tensor.make_like(
-                target_tensor,
-                data=target_tensor._data.contiguous(),
+                target_tensor, data=target_tensor._data.contiguous()
             )
         return target_tensor
 
@@ -588,9 +587,7 @@ def __init__(
         self.use_cached_ckpt_structure: bool = cached_metadata
 
     def async_save(
-        self,
-        sharded_state_dict: ShardedStateDict,
-        checkpoint_dir: Path,
+        self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path
     ) -> AsyncRequest:
         """Translates MCore ShardedTensors to PyT ShardedTensors and saves in PyT Distributed format.
 
@@ -601,12 +598,10 @@ def async_save(
         Returns: None
         """
         # Translate the state dict
-        (
-            sharded_state_dict,
-            flat_mapping,
-            rename_mapping,
-        ) = _replace_state_dict_keys_with_sharded_keys(
-            sharded_state_dict, self.keep_only_main_replica
+        (sharded_state_dict, flat_mapping, rename_mapping) = (
+            _replace_state_dict_keys_with_sharded_keys(
+                sharded_state_dict, self.keep_only_main_replica
+            )
         )
         pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, False)
         # Use PyT saving mechanism
@@ -716,11 +711,9 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St
 
         orig_sharded_state_dict = sharded_state_dict
         # MCore state dict to PyT Distributed compatible
-        (
-            sharded_state_dict,
-            flat_mapping,
-            rename_mapping,
-        ) = _replace_state_dict_keys_with_sharded_keys(sharded_state_dict)
+        (sharded_state_dict, flat_mapping, rename_mapping) = (
+            _replace_state_dict_keys_with_sharded_keys(sharded_state_dict)
+        )
         pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, True)
         # Load PyT Distributed format
         checkpoint.load_state_dict(
@@ -764,8 +757,7 @@ def load_tensors_metadata(self, checkpoint_dir: Path, metadata: Metadata = None)
             if nd_orig_global_shape is None:
                 # Regular tensor
                 sharded_metadata[k] = ShardedTensor.from_rank_offsets(
-                    k,
-                    torch.empty(tp.size, **tp.properties.__dict__, device='meta'),
+                    k, torch.empty(tp.size, **tp.properties.__dict__, device='meta')
                 ).without_data()
             else:
                 # N-D flattened tensor
diff --git a/megatron/core/dist_checkpointing/strategies/two_stage.py b/megatron/core/dist_checkpointing/strategies/two_stage.py
index 8d20c32bbb..72e60bc79b 100644
--- a/megatron/core/dist_checkpointing/strategies/two_stage.py
+++ b/megatron/core/dist_checkpointing/strategies/two_stage.py
@@ -59,10 +59,7 @@ class _ShardedTensorMetadata:
 
 
 def sharded_tensor_chunk_id(sharded_tensor: ShardedTensor):
-    return (
-        sharded_tensor.key,
-        sharded_tensor.global_offset,
-    )
+    return (sharded_tensor.key, sharded_tensor.global_offset)
 
 
 class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
@@ -177,7 +174,7 @@ def _build_load_plan(
 
     @timed()
     def deduplicate_chunks(self, ten_metas: List[_ShardedTensorMetadata]):
-        """ Group tensors by chunk and then pick the tensor with the lowest rank.
+        """Group tensors by chunk and then pick the tensor with the lowest rank.
 
         NOTE: with proper loading overlap, loading from randomized ranks
          (instead of the smallest one) could be beneficial here.
diff --git a/megatron/core/dist_checkpointing/utils.py b/megatron/core/dist_checkpointing/utils.py
index 98ce01dd37..ff12b32662 100644
--- a/megatron/core/dist_checkpointing/utils.py
+++ b/megatron/core/dist_checkpointing/utils.py
@@ -73,18 +73,14 @@ def extract_sharded_tensors_or_nonpersistent(
 def extract_sharded_base(
     sharded_state_dict: ShardedStateDict,
 ) -> Tuple[ShardedStateDict, StateDict]:
-    return extract_matching_values(
-        sharded_state_dict,
-        lambda v: isinstance(v, ShardedBase),
-    )
+    return extract_matching_values(sharded_state_dict, lambda v: isinstance(v, ShardedBase))
 
 
 def extract_nonpersistent(
     sharded_state_dict: ShardedStateDict,
 ) -> Tuple[ShardedStateDict, StateDict]:
     return extract_matching_values(
-        sharded_state_dict,
-        lambda v: isinstance(v, LocalNonpersistentObject),
+        sharded_state_dict, lambda v: isinstance(v, LocalNonpersistentObject)
     )
 
 
diff --git a/megatron/core/dist_checkpointing/validation.py b/megatron/core/dist_checkpointing/validation.py
index c45245b2e5..cd11b82ed6 100644
--- a/megatron/core/dist_checkpointing/validation.py
+++ b/megatron/core/dist_checkpointing/validation.py
@@ -100,10 +100,7 @@ def requires_global_app_metadata(val: 'StrictHandling') -> bool:
     @staticmethod
     def requires_returning_mismatch_keys(val: 'StrictHandling') -> bool:
         """Whether a given strict option results in extra return value from the `load` function."""
-        return val in (
-            StrictHandling.RETURN_UNEXPECTED,
-            StrictHandling.RETURN_ALL,
-        )
+        return val in (StrictHandling.RETURN_UNEXPECTED, StrictHandling.RETURN_ALL)
 
 
 def parse_strict_flag(strict: Union[str, StrictHandling]) -> StrictHandling:
@@ -253,8 +250,7 @@ def verify_checkpoint_and_load_strategy(
 
 
 def adjust_non_strict_load(
-    sharded_state_dict: ShardedStateDict,
-    sharded_keys_to_remove: Set[str],
+    sharded_state_dict: ShardedStateDict, sharded_keys_to_remove: Set[str]
 ) -> ShardedStateDict:
     """Adjusts sharded state dict removing keys not existing in the checkpoint.
 
diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py
index 2c02e5f7d1..0451a6e4fb 100644
--- a/megatron/core/distributed/distributed_data_parallel.py
+++ b/megatron/core/distributed/distributed_data_parallel.py
@@ -97,9 +97,7 @@ def __init__(
                 expert_parallel_params.append(param)
 
         def allocate_buffers_for_parameters(
-            input_params,
-            data_parallel_group,
-            gradient_scaling_factor,
+            input_params, data_parallel_group, gradient_scaling_factor
         ):
             param_and_grad_dtype_to_params = {}
 
diff --git a/megatron/core/distributed/finalize_model_grads.py b/megatron/core/distributed/finalize_model_grads.py
index f1a1c2b88c..ff5046afa5 100644
--- a/megatron/core/distributed/finalize_model_grads.py
+++ b/megatron/core/distributed/finalize_model_grads.py
@@ -150,11 +150,7 @@ def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torc
         # need to do a broadcast for every pp group, even though num_tokens should be the same.
         num_tokens_list = []
         for lr, group in zip(last_rank, pp_group):
-            torch.distributed.broadcast(
-                num_tokens,
-                src=lr,
-                group=group,
-            )
+            torch.distributed.broadcast(num_tokens, src=lr, group=group)
             num_tokens_list.append(torch.clone(num_tokens))
         assert all(x.item() == num_tokens_list[0] for x in num_tokens_list)
 
diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py
index efed47c5ba..65c8eeb1be 100644
--- a/megatron/core/distributed/param_and_grad_buffer.py
+++ b/megatron/core/distributed/param_and_grad_buffer.py
@@ -324,11 +324,7 @@ def _does_param_require_new_bucket(param):
                     assert data_start_index % self.data_parallel_world_size == 0
                 _create_new_bucket(data_start_index)
 
-            self.param_index_map[param] = (
-                data_start_index,
-                data_end_index,
-                bucket_id,
-            )
+            self.param_index_map[param] = (data_start_index, data_end_index, bucket_id)
             bucket_params.add(param)
 
             # If we have enough elements already or the current param is part of the shared embedding
diff --git a/megatron/core/fusions/fused_bias_dropout.py b/megatron/core/fusions/fused_bias_dropout.py
index 08af02b099..c7fa8419a0 100644
--- a/megatron/core/fusions/fused_bias_dropout.py
+++ b/megatron/core/fusions/fused_bias_dropout.py
@@ -47,14 +47,14 @@ def _bias_dropout_add(x_with_bias, residual, prob):
 
 @jit_fuser
 def bias_dropout_add_fused_train(
-    x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float,
+    x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float
 ) -> torch.Tensor:
     return _bias_dropout_add_func(x_with_bias, residual, prob, True)
 
 
 @jit_fuser
 def bias_dropout_add_fused_inference(
-    x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float,
+    x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float
 ) -> torch.Tensor:
     return _bias_dropout_add_func(x_with_bias, residual, prob, False)
 
diff --git a/megatron/core/fusions/fused_cross_entropy.py b/megatron/core/fusions/fused_cross_entropy.py
index e10c04c23b..909cc403cf 100644
--- a/megatron/core/fusions/fused_cross_entropy.py
+++ b/megatron/core/fusions/fused_cross_entropy.py
@@ -33,14 +33,10 @@ def calculate_predicted_logits(
     vocab_end_index: int,
 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
 
-    (
-        target_mask,
-        masked_target_1d,
-        predicted_logits,
-        sum_exp_logits,
-        exp_logits,
-    ) = VocabParallelCrossEntropy.calculate_predicted_logits(
-        vocab_parallel_logits, target, logits_max, vocab_start_index, vocab_end_index
+    (target_mask, masked_target_1d, predicted_logits, sum_exp_logits, exp_logits) = (
+        VocabParallelCrossEntropy.calculate_predicted_logits(
+            vocab_parallel_logits, target, logits_max, vocab_start_index, vocab_end_index
+        )
     )
 
     predicted_logits_sum_exp_logits = torch.cat((predicted_logits, sum_exp_logits))
@@ -71,12 +67,9 @@ def calculate_gradients(
     masked_target_1d: torch.Tensor,
 ) -> torch.Tensor:
 
-    (
-        grad_2d,
-        arange_1d,
-        softmax_update,
-        grad_input,
-    ) = VocabParallelCrossEntropy.prepare_gradient_calculation_operands(softmax, target_mask)
+    (grad_2d, arange_1d, softmax_update, grad_input) = (
+        VocabParallelCrossEntropy.prepare_gradient_calculation_operands(softmax, target_mask)
+    )
 
     grad_input = VocabParallelCrossEntropy.calculate_gradients(
         grad_2d, arange_1d, masked_target_1d, softmax_update, grad_input, grad_output
@@ -103,13 +96,10 @@ def forward(ctx, vocab_parallel_logits, target):
         world_size = get_tensor_model_parallel_world_size()
         vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size)
 
-        (
-            target_mask,
-            masked_target_1d,
-            predicted_logits_sum_exp_logits,
-            exp_logits,
-        ) = calculate_predicted_logits(
-            vocab_parallel_logits, target, logits_max, vocab_start_index, vocab_end_index
+        (target_mask, masked_target_1d, predicted_logits_sum_exp_logits, exp_logits) = (
+            calculate_predicted_logits(
+                vocab_parallel_logits, target, logits_max, vocab_start_index, vocab_end_index
+            )
         )
 
         # All reduce is needed to get the chunks from other GPUs.
diff --git a/megatron/core/inference/modelopt_support/gpt/model_specs.py b/megatron/core/inference/modelopt_support/gpt/model_specs.py
index e3d8e08d30..50415ac006 100644
--- a/megatron/core/inference/modelopt_support/gpt/model_specs.py
+++ b/megatron/core/inference/modelopt_support/gpt/model_specs.py
@@ -47,8 +47,7 @@ def get_gpt_layer_modelopt_spec(
             mlp=ModuleSpec(
                 module=MLP,
                 submodules=MLPSubmodules(
-                    linear_fc1=ColumnParallelLinear,
-                    linear_fc2=RowParallelLinear,
+                    linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear
                 ),
             ),
             mlp_bda=get_bias_dropout_add,
diff --git a/megatron/core/inference/modelopt_support/gpt/state_dict_hooks.py b/megatron/core/inference/modelopt_support/gpt/state_dict_hooks.py
index f81c4f5e03..15c3527c94 100644
--- a/megatron/core/inference/modelopt_support/gpt/state_dict_hooks.py
+++ b/megatron/core/inference/modelopt_support/gpt/state_dict_hooks.py
@@ -8,13 +8,7 @@
 
 
 def mcore_gpt_load_legacy_state_dict_pre_hook(
-    state_dict,
-    prefix,
-    local_metadata,
-    strict,
-    missing_keys,
-    unexpected_keys,
-    error_msgs,
+    state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
 ):
     """Register a pre-hook to fix the state_dict key difference.
 
@@ -87,13 +81,7 @@ def mcore_gpt_load_legacy_state_dict_pre_hook(
 
 
 def mcore_gpt_load_te_state_dict_pre_hook(
-    state_dict,
-    prefix,
-    local_metadata,
-    strict,
-    missing_keys,
-    unexpected_keys,
-    error_msgs,
+    state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
 ):
     """Register a pre-hook to fix the state_dict key difference of.
 
diff --git a/megatron/core/inference/scheduler.py b/megatron/core/inference/scheduler.py
index 35efb935f0..abcb325185 100644
--- a/megatron/core/inference/scheduler.py
+++ b/megatron/core/inference/scheduler.py
@@ -85,10 +85,9 @@ def add_earliest_waiting_request_to_active_pool(self):
             len(self.active_request_pool) < self.max_batch_size
         ), "Active request pool is already full. Cant add any more requests"
         if len(self.waiting_request_pool) > 0:
-            (
-                earliest_waiting_request_request_id,
-                earliest_waiting_request,
-            ) = self.waiting_request_pool.popitem(last=False)
+            (earliest_waiting_request_request_id, earliest_waiting_request) = (
+                self.waiting_request_pool.popitem(last=False)
+            )
             earliest_waiting_request.status = Status.ACTIVE_BUT_NOT_GENERATING_TOKENS
             self.active_request_pool[earliest_waiting_request_request_id] = earliest_waiting_request
 
diff --git a/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py b/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py
index b5eed123bc..e4db83f6b3 100644
--- a/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py
+++ b/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py
@@ -189,8 +189,7 @@ def pad_input_prompt_tokens(
         return torch.tensor(batch_prompt_tokens_list).cuda()
 
     def generate_output_tokens_dynamic_batch(
-        self,
-        active_requests: OrderedDict[int, InferenceRequest],
+        self, active_requests: OrderedDict[int, InferenceRequest]
     ) -> OrderedDict[int, InferenceRequest]:
         """Utility to generate the output tokens and probabilities for the prompts
 
@@ -205,8 +204,7 @@ def generate_output_tokens_dynamic_batch(
         raise Exception("Not implemented yet")
 
     def generate_all_output_tokens_static_batch(
-        self,
-        active_requests: OrderedDict[int, InferenceRequest],
+        self, active_requests: OrderedDict[int, InferenceRequest]
     ) -> OrderedDict[int, InferenceRequest]:
         """Utility to generate the all the output tokens and probabilities for the prompts .
 
@@ -305,15 +303,14 @@ def generate_all_output_tokens_static_batch(
                 context_start_position = context_end_position
 
                 # Check end of generation status for each tensor and update generated sequence lengths
-                (
-                    is_generation_done_tensor,
-                    generated_sequence_lengths,
-                ) = self.update_generation_status(
-                    updated_prompts_tokens=batch_prompt_tokens,
-                    generation_started=generation_started,
-                    current_context_end_position=context_end_position,
-                    is_generation_done_tensor=is_generation_done_tensor,
-                    generated_sequence_lengths=generated_sequence_lengths,
+                (is_generation_done_tensor, generated_sequence_lengths) = (
+                    self.update_generation_status(
+                        updated_prompts_tokens=batch_prompt_tokens,
+                        generation_started=generation_started,
+                        current_context_end_position=context_end_position,
+                        is_generation_done_tensor=is_generation_done_tensor,
+                        generated_sequence_lengths=generated_sequence_lengths,
+                    )
                 )
                 # Boolean flag indicating if all prompts are finished
                 all_prompts_done = torch.all(is_generation_done_tensor)
diff --git a/megatron/core/models/T5/t5_model.py b/megatron/core/models/T5/t5_model.py
index 37a395ea47..8266757433 100644
--- a/megatron/core/models/T5/t5_model.py
+++ b/megatron/core/models/T5/t5_model.py
@@ -247,12 +247,10 @@ def forward(
             Tensor: loss tensor
         """
 
-        (
-            encoder_attn_mask,
-            decoder_attn_mask,
-            encoder_decoder_attn_mask,
-        ) = t5_extended_attention_mask(
-            [encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask]
+        (encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask) = (
+            t5_extended_attention_mask(
+                [encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask]
+            )
         )
 
         ## Encoder forward
diff --git a/megatron/core/models/T5/t5_spec.py b/megatron/core/models/T5/t5_spec.py
index f195dcac35..520c3c5c8a 100644
--- a/megatron/core/models/T5/t5_spec.py
+++ b/megatron/core/models/T5/t5_spec.py
@@ -69,8 +69,7 @@ def encoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
             mlp=ModuleSpec(
                 module=MLP,
                 submodules=MLPSubmodules(
-                    linear_fc1=TELayerNormColumnParallelLinear,
-                    linear_fc2=TERowParallelLinear,
+                    linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear
                 ),
             ),
             mlp_bda=get_bias_dropout_add,
@@ -110,8 +109,7 @@ def decoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
             mlp=ModuleSpec(
                 module=MLP,
                 submodules=MLPSubmodules(
-                    linear_fc1=TELayerNormColumnParallelLinear,
-                    linear_fc2=TERowParallelLinear,
+                    linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear
                 ),
             ),
             mlp_bda=get_bias_dropout_add,
@@ -142,8 +140,7 @@ def encoder_model_with_local_spec() -> ModuleSpec:
             mlp=ModuleSpec(
                 module=MLP,
                 submodules=MLPSubmodules(
-                    linear_fc1=ColumnParallelLinear,
-                    linear_fc2=RowParallelLinear,
+                    linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear
                 ),
             ),
             mlp_bda=get_bias_dropout_add,
@@ -189,8 +186,7 @@ def decoder_model_with_local_spec() -> ModuleSpec:
             mlp=ModuleSpec(
                 module=MLP,
                 submodules=MLPSubmodules(
-                    linear_fc1=ColumnParallelLinear,
-                    linear_fc2=RowParallelLinear,
+                    linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear
                 ),
             ),
             mlp_bda=get_bias_dropout_add,
diff --git a/megatron/core/models/bert/bert_layer_specs.py b/megatron/core/models/bert/bert_layer_specs.py
index 1eb965c299..b5b117b498 100644
--- a/megatron/core/models/bert/bert_layer_specs.py
+++ b/megatron/core/models/bert/bert_layer_specs.py
@@ -54,8 +54,7 @@
         mlp=ModuleSpec(
             module=MLP,
             submodules=MLPSubmodules(
-                linear_fc1=TELayerNormColumnParallelLinear,
-                linear_fc2=TERowParallelLinear,
+                linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear
             ),
         ),
         mlp_bda=get_bias_dropout_add,
@@ -82,10 +81,7 @@
         pre_mlp_layernorm=LNImpl,
         mlp=ModuleSpec(
             module=MLP,
-            submodules=MLPSubmodules(
-                linear_fc1=ColumnParallelLinear,
-                linear_fc2=RowParallelLinear,
-            ),
+            submodules=MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear),
         ),
         mlp_bda=get_bias_dropout_add,
         sharded_state_dict_keys_map={
diff --git a/megatron/core/models/bert/bert_lm_head.py b/megatron/core/models/bert/bert_lm_head.py
index ff0411dc59..fd26ebd16f 100644
--- a/megatron/core/models/bert/bert_lm_head.py
+++ b/megatron/core/models/bert/bert_lm_head.py
@@ -30,11 +30,7 @@ class BertLMHead(MegatronModule):
         config (TransformerConfig): TransformerConfig object
     """
 
-    def __init__(
-        self,
-        hidden_size: int,
-        config: TransformerConfig,
-    ):
+    def __init__(self, hidden_size: int, config: TransformerConfig):
         super().__init__(config=config)
 
         # TODO: Should switch this to TE ?
@@ -46,9 +42,7 @@ def __init__(
         setattr(self.dense.bias, 'sequence_parallel', config.sequence_parallel)
 
         self.layer_norm = LNImpl(
-            config=config,
-            hidden_size=hidden_size,
-            eps=config.layernorm_epsilon,
+            config=config, hidden_size=hidden_size, eps=config.layernorm_epsilon
         )
 
         self.gelu = torch.nn.functional.gelu
diff --git a/megatron/core/models/bert/bert_model.py b/megatron/core/models/bert/bert_model.py
index eb94ebbb9f..0b571ca68d 100644
--- a/megatron/core/models/bert/bert_model.py
+++ b/megatron/core/models/bert/bert_model.py
@@ -122,10 +122,7 @@ def __init__(
         # Output
         if post_process:
             # TODO: Make sure you are passing in the mpu_vocab_size properly
-            self.lm_head = BertLMHead(
-                config.hidden_size,
-                config,
-            )
+            self.lm_head = BertLMHead(config.hidden_size, config)
 
             self.output_layer = tensor_parallel.ColumnParallelLinear(
                 config.hidden_size,
diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py
index 207706d0be..0a4e5bf6de 100644
--- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py
+++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py
@@ -223,10 +223,7 @@ def apply_rotary_pos_emb_thd(
 
 
 def apply_rotary_pos_emb(
-    t: Tensor,
-    freqs: Tensor,
-    config: TransformerConfig,
-    cu_seqlens: Optional[Tensor] = None,
+    t: Tensor, freqs: Tensor, config: TransformerConfig, cu_seqlens: Optional[Tensor] = None
 ):
     """
     Reroute to the appropriate apply_rotary_pos_emb function depending on
diff --git a/megatron/core/models/mamba/mamba_layer_specs.py b/megatron/core/models/mamba/mamba_layer_specs.py
index 91224bf6b3..8fcfc424e6 100755
--- a/megatron/core/models/mamba/mamba_layer_specs.py
+++ b/megatron/core/models/mamba/mamba_layer_specs.py
@@ -24,8 +24,7 @@
                 mixer=ModuleSpec(
                     module=MambaMixer,
                     submodules=MambaMixerSubmodules(
-                        in_proj=TELayerNormColumnParallelLinear,
-                        out_proj=TERowParallelLinear,
+                        in_proj=TELayerNormColumnParallelLinear, out_proj=TERowParallelLinear
                     ),
                 ),
                 mamba_bda=get_bias_dropout_add,
@@ -58,8 +57,7 @@
                 mlp=ModuleSpec(
                     module=MLP,
                     submodules=MLPSubmodules(
-                        linear_fc1=TELayerNormColumnParallelLinear,
-                        linear_fc2=TERowParallelLinear,
+                        linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear
                     ),
                 ),
                 mlp_bda=get_bias_dropout_add,
diff --git a/megatron/core/models/retro/base_attention.py b/megatron/core/models/retro/base_attention.py
index 741f712b72..ee8656d96a 100644
--- a/megatron/core/models/retro/base_attention.py
+++ b/megatron/core/models/retro/base_attention.py
@@ -9,7 +9,6 @@
 
 
 class BaseRetroCrossAttention(MegatronModule):
-
     """Base class for Retro cross attention, for both encoder & decoder layers.
 
     This class collects the retro arguments below (i.e., num neighbors, chunk
diff --git a/megatron/core/models/retro/config.py b/megatron/core/models/retro/config.py
index b9a5eb9648..3e3d0b538a 100644
--- a/megatron/core/models/retro/config.py
+++ b/megatron/core/models/retro/config.py
@@ -14,7 +14,7 @@
 
 @dataclass
 class RetroConfig(TransformerConfig):
-    """Configuration object for Retro models. """
+    """Configuration object for Retro models."""
 
     # Retro.
     retro_project_dir: str = None
diff --git a/megatron/core/models/retro/decoder_attention.py b/megatron/core/models/retro/decoder_attention.py
index f459163ccc..6b7a04d884 100644
--- a/megatron/core/models/retro/decoder_attention.py
+++ b/megatron/core/models/retro/decoder_attention.py
@@ -22,7 +22,6 @@
 
 
 class RetroDecoderCrossAttention(BaseRetroCrossAttention):
-
     """Retro decoder's chunked cross attention operator.
 
     See this paper for more details: https://arxiv.org/abs/2112.04426.
@@ -69,7 +68,7 @@ def __init__(
 
         if encoder_block_spec:
             self.encoder = TransformerBlock(
-                config=config, spec=encoder_block_spec, pre_process=True, post_process=False,
+                config=config, spec=encoder_block_spec, pre_process=True, post_process=False
             )
             # self._encoder_key = 'encoder' # ... necessary?
         else:
@@ -124,7 +123,7 @@ def forward(
 
                 # Pad partial chunk with zeros.
                 first_chunk = torch.nn.functional.pad(
-                    first_chunk, (0, 0, 0, 0, 0, self.retro_chunk_length - first_ns), 'constant', 0,
+                    first_chunk, (0, 0, 0, 0, 0, self.retro_chunk_length - first_ns), 'constant', 0
                 )
 
                 # Concatenate padded chunk with remaining chunks.
@@ -169,7 +168,7 @@ def forward(
 
         # Pad attending tokens to sequence length.
         padded_chunks = torch.nn.functional.pad(
-            attending_chunks, (0, 0, 0, 0, 0, self.retro_chunk_length - 1), 'constant', 0,
+            attending_chunks, (0, 0, 0, 0, 0, self.retro_chunk_length - 1), 'constant', 0
         )
 
         # Permute attending chunks.
@@ -210,7 +209,6 @@ def forward(
 
 
 class RetroDecoderBiasDropoutAdd(MegatronModule):
-
     """Retro decoder's bias-dropout-add operator.
 
     This operator takes care of reshaping and permuting the output from the
@@ -220,9 +218,7 @@ class RetroDecoderBiasDropoutAdd(MegatronModule):
         config (RetroConfig): Retro config.
     """
 
-    def __init__(
-        self, config: RetroConfig,
-    ):
+    def __init__(self, config: RetroConfig):
         super().__init__(config=config)
         self.retro_chunk_length = config.retro_chunk_length
 
@@ -282,7 +278,7 @@ def _forward(
             )
 
             # Prepend zeros for non-attending tokens.
-            x = torch.nn.functional.pad(x, (0, 0, 0, 0, pad, 0), 'constant', 0,)[
+            x = torch.nn.functional.pad(x, (0, 0, 0, 0, pad, 0), 'constant', 0)[
                 :ns
             ]  # [ ns, bs, d ]
 
diff --git a/megatron/core/models/retro/decoder_spec.py b/megatron/core/models/retro/decoder_spec.py
index 0c16ccc8cb..d9cc69eacd 100644
--- a/megatron/core/models/retro/decoder_spec.py
+++ b/megatron/core/models/retro/decoder_spec.py
@@ -73,9 +73,7 @@ def get_retro_decoder_layer_te_spec(
     spec.submodules.pre_cross_attn_layernorm = TENorm
     spec.submodules.cross_attention = ModuleSpec(
         module=RetroDecoderCrossAttention,
-        params={
-            "encoder_block_spec": encoder_block_spec,
-        },
+        params={"encoder_block_spec": encoder_block_spec},
         submodules=CrossAttentionSubmodules(
             linear_q=TEColumnParallelLinear,
             linear_kv=TEColumnParallelLinear,
@@ -108,9 +106,7 @@ def get_retro_decoder_layer_local_spec(
     spec.submodules.pre_cross_attn_layernorm = LNImpl
     spec.submodules.cross_attention = ModuleSpec(
         module=RetroDecoderCrossAttention,
-        params={
-            "encoder_block_spec": encoder_block_spec,
-        },
+        params={"encoder_block_spec": encoder_block_spec},
         submodules=CrossAttentionSubmodules(
             linear_q=ColumnParallelLinear,
             linear_kv=ColumnParallelLinear,
diff --git a/megatron/core/models/retro/encoder_attention.py b/megatron/core/models/retro/encoder_attention.py
index a2226c08da..76625abe33 100644
--- a/megatron/core/models/retro/encoder_attention.py
+++ b/megatron/core/models/retro/encoder_attention.py
@@ -17,7 +17,6 @@
 
 
 class RetroEncoderCrossAttention(BaseRetroCrossAttention):
-
     """Retro encoder's cross attention operator.
 
     See this paper for more details: https://arxiv.org/abs/2112.04426.
@@ -96,14 +95,13 @@ def forward(
             residual = chunked_output
 
             # Collect tensors.
-            attention_output_tuples.append((attention_output, attention_bias, residual,))
+            attention_output_tuples.append((attention_output, attention_bias, residual))
 
         # Output. (List[Tuple[( [ r, bs*l, d ], [ d ] )]])
         return attention_output_tuples
 
 
 class RetroEncoderBiasDropoutAdd(MegatronModule):
-
     """Retro encoder's bias-dropout-add operator.
 
     This operator applies bias-dropout-add individually on each neighboring
@@ -113,9 +111,7 @@ class RetroEncoderBiasDropoutAdd(MegatronModule):
         config (RetroConfig): Retro config.
     """
 
-    def __init__(
-        self, config: RetroConfig,
-    ):
+    def __init__(self, config: RetroConfig):
         super().__init__(config=config)
         self.retro_num_neighbors = config.retro_num_neighbors
 
@@ -186,7 +182,6 @@ def forward(self, training: bool, fused: bool) -> partial:
 
 
 class RetroEncoderLayerNorm(MegatronModule):
-
     """Retro encoder's layernorm operator.
 
     This operator applies layernorm individually on each neighboring chunk that
@@ -198,9 +193,7 @@ class RetroEncoderLayerNorm(MegatronModule):
         submodules (Type): Layer norm class. (Named 'submodules' to fit external interface.)
     """
 
-    def __init__(
-        self, config: RetroConfig, submodules: Type, **kwargs: dict,
-    ):
+    def __init__(self, config: RetroConfig, submodules: Type, **kwargs: dict):
         super().__init__(config=config)
         norm_class = submodules
         self.norm = norm_class(config=config, **kwargs)
@@ -211,7 +204,7 @@ def forward(self, input: Tensor) -> Tensor:
 
         Args:
             input (Tensor): Input chunks, concatenated into a single tensor.
-        
+
         Returns:
             Output of the layer norm.
         """
diff --git a/megatron/core/models/retro/encoder_spec.py b/megatron/core/models/retro/encoder_spec.py
index ac0eb15598..777b5324d8 100644
--- a/megatron/core/models/retro/encoder_spec.py
+++ b/megatron/core/models/retro/encoder_spec.py
@@ -63,9 +63,7 @@ def get_retro_encoder_layer_te_spec() -> ModuleSpec:
     spec.submodules.pre_cross_attn_layernorm = TENorm
     spec.submodules.cross_attention = ModuleSpec(
         module=RetroEncoderCrossAttention,
-        params={
-            "attn_mask_type": AttnMaskType.padding,
-        },
+        params={"attn_mask_type": AttnMaskType.padding},
         submodules=CrossAttentionSubmodules(
             linear_q=TEColumnParallelLinear,
             linear_kv=TEColumnParallelLinear,
@@ -74,16 +72,10 @@ def get_retro_encoder_layer_te_spec() -> ModuleSpec:
         ),
     )
     spec.submodules.cross_attn_bda = ModuleSpec(module=RetroEncoderBiasDropoutAdd)
-    spec.submodules.pre_mlp_layernorm = ModuleSpec(
-        module=RetroEncoderLayerNorm,
-        submodules=TENorm,
-    )
+    spec.submodules.pre_mlp_layernorm = ModuleSpec(module=RetroEncoderLayerNorm, submodules=TENorm)
     spec.submodules.mlp = ModuleSpec(
         module=MLP,
-        submodules=MLPSubmodules(
-            linear_fc1=TEColumnParallelLinear,
-            linear_fc2=TERowParallelLinear,
-        ),
+        submodules=MLPSubmodules(linear_fc1=TEColumnParallelLinear, linear_fc2=TERowParallelLinear),
     )
     return spec
 
@@ -103,9 +95,7 @@ def get_retro_encoder_layer_local_spec() -> ModuleSpec:
     spec.submodules.pre_cross_attn_layernorm = LNImpl
     spec.submodules.cross_attention = ModuleSpec(
         module=RetroEncoderCrossAttention,
-        params={
-            "attn_mask_type": AttnMaskType.padding,
-        },
+        params={"attn_mask_type": AttnMaskType.padding},
         submodules=CrossAttentionSubmodules(
             linear_q=ColumnParallelLinear,
             linear_kv=ColumnParallelLinear,
@@ -114,19 +104,13 @@ def get_retro_encoder_layer_local_spec() -> ModuleSpec:
         ),
     )
     spec.submodules.cross_attn_bda = ModuleSpec(module=RetroEncoderBiasDropoutAdd)
-    spec.submodules.pre_mlp_layernorm = ModuleSpec(
-        module=RetroEncoderLayerNorm,
-        submodules=LNImpl,
-    )
+    spec.submodules.pre_mlp_layernorm = ModuleSpec(module=RetroEncoderLayerNorm, submodules=LNImpl)
     spec.submodules.mlp = ModuleSpec(
         module=MLP,
-        submodules=MLPSubmodules(
-            linear_fc1=ColumnParallelLinear,
-            linear_fc2=RowParallelLinear,
-        ),
+        submodules=MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear),
     )
     spec.submodules.sharded_state_dict_keys_map = {
-        'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
+        'input_layernorm.': 'self_attention.linear_qkv.layer_norm_'
     }  # pre_mlp_layernorm doesn't need remapping
     return spec
 
@@ -168,9 +152,7 @@ def get_retro_encoder_block_spec(
         spec.submodules.self_attention.params["attn_mask_type"] = AttnMaskType.padding
         spec.submodules.self_attention.submodules.core_attention = ModuleSpec(
             module=TEDotProductAttention if use_transformer_engine else DotProductAttention,
-            params={
-                "attention_dropout": config.retro_encoder_attention_dropout,
-            },
+            params={"attention_dropout": config.retro_encoder_attention_dropout},
         )
 
     layer_specs = []
diff --git a/megatron/core/models/retro/model.py b/megatron/core/models/retro/model.py
index 32c6d26a62..8142c91f7a 100644
--- a/megatron/core/models/retro/model.py
+++ b/megatron/core/models/retro/model.py
@@ -11,7 +11,6 @@
 
 
 class RetroModel(GPTModel):
-
     """Retro Model.
 
     A Retro model mostly re-uses the GPTModel interface, with the only difference
@@ -79,7 +78,7 @@ def forward(
             decoder_input=decoder_input,
             labels=labels,
             inference_params=inference_params,
-            extra_block_kwargs={"context": context, "context_mask": context_mask,},
+            extra_block_kwargs={"context": context, "context_mask": context_mask},
         )
 
     def sharded_state_dict(
diff --git a/megatron/core/models/vision/multimodal_projector.py b/megatron/core/models/vision/multimodal_projector.py
index a5363ac45d..18e62c68a5 100644
--- a/megatron/core/models/vision/multimodal_projector.py
+++ b/megatron/core/models/vision/multimodal_projector.py
@@ -61,9 +61,7 @@ def forward(self, hidden_states):
         # deallocate_output_tensor() throwing an error, so a viewless tensor is
         # created to prevent this.
         encoder_output = make_viewless_tensor(
-            inp=encoder_output,
-            requires_grad=True,
-            keep_graph=True,
+            inp=encoder_output, requires_grad=True, keep_graph=True
         )
 
         return encoder_output
diff --git a/megatron/core/models/vision/vit_layer_specs.py b/megatron/core/models/vision/vit_layer_specs.py
index a879d25398..876c14dce4 100644
--- a/megatron/core/models/vision/vit_layer_specs.py
+++ b/megatron/core/models/vision/vit_layer_specs.py
@@ -80,9 +80,7 @@ def get_vit_layer_with_local_spec() -> ModuleSpec:
 
 
 # Helper function to get module spec for MLP/MoE
-def _get_mlp_module_spec(
-    use_te: bool = True,
-) -> ModuleSpec:
+def _get_mlp_module_spec(use_te: bool = True) -> ModuleSpec:
     # Dense MLP w/ or w/o TE modules.
     return ModuleSpec(
         module=MLP,
diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py
index 04bffc8ff5..65f72ec8c8 100644
--- a/megatron/core/optimizer/__init__.py
+++ b/megatron/core/optimizer/__init__.py
@@ -247,12 +247,7 @@ def init_state_fn(opt):
                     hysteresis=config.hysteresis,
                 )
 
-        optimizer_args = [
-            optimizer,
-            config,
-            grad_scaler,
-            init_state_fn,
-        ]
+        optimizer_args = [optimizer, config, grad_scaler, init_state_fn]
         if config.use_distributed_optimizer:
             optimizer = DistributedOptimizer(
                 *optimizer_args,
@@ -266,11 +261,7 @@ def init_state_fn(opt):
             setattr(optimizer, 'model_parallel_group', model_parallel_group)
     else:
         # FP32 optimizer.
-        optimizer = FP32Optimizer(
-            optimizer,
-            config,
-            init_state_fn,
-        )
+        optimizer = FP32Optimizer(optimizer, config, init_state_fn)
         setattr(optimizer, 'model_parallel_group', model_parallel_group)
 
     return optimizer
diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py
index ee5551d616..8eee169c7b 100644
--- a/megatron/core/optimizer/distrib_optimizer.py
+++ b/megatron/core/optimizer/distrib_optimizer.py
@@ -168,9 +168,7 @@ def _build_model_gbuf_range(cls, param_and_grad_buffer: ParamAndGradBuffer, buck
         )
 
         # Group into dict.
-        data = {
-            "param_map": param_range_map,
-        }
+        data = {"param_map": param_range_map}
 
         return data
 
@@ -417,12 +415,7 @@ def __init__(
             HAVE_APEX_OR_TE
         ), f'Please install Apex or Transformer Engine to use DistributedOptimizer.'
 
-        super().__init__(
-            optimizer,
-            config,
-            grad_scaler,
-            init_state_fn,
-        )
+        super().__init__(optimizer, config, grad_scaler, init_state_fn)
 
         assert isinstance(
             optimizer, Adam
@@ -464,10 +457,9 @@ def __init__(
         self.model_param_gbuf_map = self._build_model_param_gbuf_map(self.gbuf_ranges)
 
         # Optimizer ranges.
-        (
-            self.model_param_group_index_map,
-            self.opt_group_ranges,
-        ) = self._build_optimizer_group_ranges(self.optimizer.param_groups, self.gbuf_ranges)
+        (self.model_param_group_index_map, self.opt_group_ranges) = (
+            self._build_optimizer_group_ranges(self.optimizer.param_groups, self.gbuf_ranges)
+        )
 
         # Allocate main param shards.
         (
@@ -626,10 +618,7 @@ def load_state_dict(self, state_dict):
         #   list.
         inner_state_dict = self.optimizer.state_dict()
         state_dict_param_groups = [
-            {
-                **group,
-                "params": list(inner_state_dict["param_groups"][idx]["params"]),
-            }
+            {**group, "params": list(inner_state_dict["param_groups"][idx]["params"])}
             for idx, group in enumerate(state_dict["optimizer"]["param_groups"])
         ]
 
@@ -655,13 +644,7 @@ def load_state_dict(self, state_dict):
                         )
 
                         state_dict_state.append(
-                            (
-                                state_order,
-                                {
-                                    "exp_avg": init_shard(),
-                                    "exp_avg_sq": init_shard(),
-                                },
-                            )
+                            (state_order, {"exp_avg": init_shard(), "exp_avg_sq": init_shard()})
                         )
 
         # Sort by state order (see method docstring for details).
@@ -680,10 +663,7 @@ def load_state_dict(self, state_dict):
 
         # Optimizer.
         self.optimizer.load_state_dict(
-            {
-                "state": state_dict_state,
-                "param_groups": state_dict_param_groups,
-            }
+            {"state": state_dict_state, "param_groups": state_dict_param_groups}
         )
 
         # Grad scaler.
@@ -776,9 +756,7 @@ def get_parameter_state_dp_zero(self):
         )
 
         # Collect param states.
-        state = {
-            "buckets_coalesced": True,
-        }
+        state = {"buckets_coalesced": True}
         for gbuf_idx, gbuf_range_maps in enumerate(self.gbuf_ranges):
 
             # Iterate grad buffers (by data type).
@@ -822,10 +800,7 @@ def get_parameter_state_dp_zero(self):
                         main_param = self.optimizer.param_groups[group_index]["params"][group_order]
                         optim_state = self.optimizer.state[main_param]
 
-                        tensors = {
-                            "param": main_param,
-                            **optim_state,
-                        }
+                        tensors = {"param": main_param, **optim_state}
 
                         # Copy states into contiguous shard.
                         gbuf_local_start = param_range_map["gbuf_local"].start
@@ -1012,9 +987,7 @@ def sharded_param_state_fs_bucket_space(
                         if next_param_start != cur_param_end:
                             pad_tensors = {
                                 k: torch.empty(
-                                    next_param_start - cur_param_end,
-                                    dtype=v.dtype,
-                                    device=v.device,
+                                    next_param_start - cur_param_end, dtype=v.dtype, device=v.device
                                 )
                                 for k, v in bucket_state[i].items()
                                 if isinstance(v, torch.Tensor)
@@ -1112,10 +1085,7 @@ def sharded_param_state_fs_model_space(
                         main_param = self.optimizer.param_groups[group_index]["params"][group_order]
                         optim_state = self.optimizer.state[main_param]
 
-                        tensors = {
-                            "fp32_param": main_param,
-                            **optim_state,
-                        }
+                        tensors = {"fp32_param": main_param, **optim_state}
                         # Match optimizer parameter with model ShardedTensor (or ShardedTensorFactory)
                         try:
                             sharded_metadata = param_to_sharded_metadata[model_param]
@@ -1188,10 +1158,7 @@ def load_parameter_state_from_fs_bucket_space(self, state_dict):
                         main_param = self.optimizer.param_groups[group_index]["params"][group_order]
                         optim_state = self.optimizer.state[main_param]
 
-                        dst_tensors = {
-                            "param": main_param,
-                            **optim_state,
-                        }
+                        dst_tensors = {"param": main_param, **optim_state}
                         for key in dst_tensors:
                             dst_tensors[key].copy_(src_tensors[key])
 
@@ -1211,10 +1178,7 @@ def load_parameter_state_from_fs_model_space(self, state_dict):
                         optim_state = self.optimizer.state[main_param]
 
                         src_tensors = state_dict[param_idx]
-                        dst_tensors = {
-                            "fp32_param": main_param,
-                            **optim_state,
-                        }
+                        dst_tensors = {"fp32_param": main_param, **optim_state}
                         for key in dst_tensors:
                             dst_tensors[key].copy_(src_tensors[key])
 
@@ -1561,10 +1525,7 @@ def _dispatch_gather_model_params(self, all_gather_handle_index: int, force_sync
             ]
             assert all_gather_handle_index < len(self.all_gather_handles)
             all_gather_handle = torch.distributed._all_gather_base(
-                pbuf,
-                pbuf_views[data_parallel_rank],
-                group=data_parallel_group,
-                async_op=async_op,
+                pbuf, pbuf_views[data_parallel_rank], group=data_parallel_group, async_op=async_op
             )
             self.all_gather_handles[all_gather_handle_index] = all_gather_handle
             assert self.all_gather_handle_index_to_bucket_index_map[all_gather_handle_index] == (
diff --git a/megatron/core/optimizer/optimizer.py b/megatron/core/optimizer/optimizer.py
index 3d6142d207..2a48c12d37 100644
--- a/megatron/core/optimizer/optimizer.py
+++ b/megatron/core/optimizer/optimizer.py
@@ -156,8 +156,7 @@ def step_with_ready_grads(self) -> bool:
     def get_grad_norm(self):
         grads_for_norm = self.get_main_grads_for_grad_norm()
         total_norm = get_grad_norm_fp32(
-            grads_for_norm,
-            model_parallel_group=self.get_model_parallel_group(),
+            grads_for_norm, model_parallel_group=self.get_model_parallel_group()
         )
         return total_norm
 
@@ -301,11 +300,7 @@ def __init__(
         if has_config_logger_enabled(config):
             log_config_to_disk(config, locals(), prefix=type(self).__name__)
 
-        super().__init__(
-            optimizer,
-            config,
-            init_state_fn,
-        )
+        super().__init__(optimizer, config, init_state_fn)
         self.grad_scaler = grad_scaler
 
         # None grad scaler is only supported for bf16.
@@ -477,12 +472,7 @@ def __init__(
         init_state_fn: Callable,
     ):
 
-        super().__init__(
-            optimizer,
-            config,
-            grad_scaler,
-            init_state_fn,
-        )
+        super().__init__(optimizer, config, grad_scaler, init_state_fn)
 
         # Handle main parameters.
 
@@ -713,19 +703,12 @@ class FP32Optimizer(MegatronOptimizer):
     """
 
     def __init__(
-        self,
-        optimizer: torch.optim.Optimizer,
-        config: OptimizerConfig,
-        init_state_fn: Callable,
+        self, optimizer: torch.optim.Optimizer, config: OptimizerConfig, init_state_fn: Callable
     ):
         if has_config_logger_enabled(config):
             log_config_to_disk(config, locals(), prefix=type(self).__name__)
 
-        super(FP32Optimizer, self).__init__(
-            optimizer,
-            config,
-            init_state_fn,
-        )
+        super(FP32Optimizer, self).__init__(optimizer, config, init_state_fn)
 
         self._scale = torch.tensor([1.0], dtype=torch.float, device='cuda')
 
diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py
index d271fab225..19c19ff5a1 100644
--- a/megatron/core/parallel_state.py
+++ b/megatron/core/parallel_state.py
@@ -118,9 +118,7 @@ def get_nccl_options(pg_name, nccl_comm_cfgs):
 
 
 def generate_masked_orthogonal_rank_groups(
-    world_size: int,
-    parallel_size: List[int],
-    mask: List[bool],
+    world_size: int, parallel_size: List[int], mask: List[bool]
 ) -> List[List[int]]:
     """Generate orthogonal parallel groups based on the parallel size and mask.
 
@@ -748,9 +746,7 @@ def generator_wrapper(group_type, **kwargs):
 
         embedding_ranks = get_embedding_ranks(ranks)
         group = torch.distributed.new_group(
-            embedding_ranks,
-            timeout=timeout,
-            pg_options=get_nccl_options('embd', nccl_comm_cfgs),
+            embedding_ranks, timeout=timeout, pg_options=get_nccl_options('embd', nccl_comm_cfgs)
         )
         if rank in embedding_ranks:
             _EMBEDDING_GROUP = group
@@ -871,10 +867,7 @@ def is_unitialized() -> bool:
     Deprecated. Use is_initialized instead.
 
     """
-    warnings.warn(
-        "is_unitialized is deprecated, use is_initialized instead",
-        DeprecationWarning,
-    )
+    warnings.warn("is_unitialized is deprecated, use is_initialized instead", DeprecationWarning)
     return not is_initialized()
 
 
diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py
index 137929a13e..3e33e7c2f8 100644
--- a/megatron/core/pipeline_parallel/p2p_communication.py
+++ b/megatron/core/pipeline_parallel/p2p_communication.py
@@ -131,34 +131,22 @@ def _batched_p2p_ops(
     ops = []
     if tensor_send_prev is not None:
         send_prev_op = torch.distributed.P2POp(
-            torch.distributed.isend,
-            tensor_send_prev,
-            prev_pipeline_rank,
-            group,
+            torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group
         )
         ops.append(send_prev_op)
     if tensor_recv_prev is not None:
         recv_prev_op = torch.distributed.P2POp(
-            torch.distributed.irecv,
-            tensor_recv_prev,
-            prev_pipeline_rank,
-            group,
+            torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group
         )
         ops.append(recv_prev_op)
     if tensor_send_next is not None:
         send_next_op = torch.distributed.P2POp(
-            torch.distributed.isend,
-            tensor_send_next,
-            next_pipeline_rank,
-            group,
+            torch.distributed.isend, tensor_send_next, next_pipeline_rank, group
         )
         ops.append(send_next_op)
     if tensor_recv_next is not None:
         recv_next_op = torch.distributed.P2POp(
-            torch.distributed.irecv,
-            tensor_recv_next,
-            next_pipeline_rank,
-            group,
+            torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group
         )
         ops.append(recv_next_op)
     if len(ops) > 0:
@@ -193,66 +181,50 @@ def _p2p_ops(
     if get_pipeline_model_parallel_rank() % 2 == 0:
         if tensor_send_next is not None:
             send_next_req = torch.distributed.isend(
-                tensor=tensor_send_next,
-                dst=next_pipeline_rank,
-                group=even_send_odd_recv_group,
+                tensor=tensor_send_next, dst=next_pipeline_rank, group=even_send_odd_recv_group
             )
             reqs.append(send_next_req)
 
         if tensor_recv_prev is not None:
             recv_prev_req = torch.distributed.irecv(
-                tensor=tensor_recv_prev,
-                src=prev_pipeline_rank,
-                group=even_recv_odd_send_group,
+                tensor=tensor_recv_prev, src=prev_pipeline_rank, group=even_recv_odd_send_group
             )
             reqs.append(recv_prev_req)
 
         if tensor_send_prev is not None:
             send_prev_req = torch.distributed.isend(
-                tensor=tensor_send_prev,
-                dst=prev_pipeline_rank,
-                group=even_send_odd_recv_group,
+                tensor=tensor_send_prev, dst=prev_pipeline_rank, group=even_send_odd_recv_group
             )
             reqs.append(send_prev_req)
 
         if tensor_recv_next is not None:
             recv_next_req = torch.distributed.irecv(
-                tensor=tensor_recv_next,
-                src=next_pipeline_rank,
-                group=even_recv_odd_send_group,
+                tensor=tensor_recv_next, src=next_pipeline_rank, group=even_recv_odd_send_group
             )
             reqs.append(recv_next_req)
 
     else:
         if tensor_recv_prev is not None:
             recv_prev_req = torch.distributed.irecv(
-                tensor=tensor_recv_prev,
-                src=prev_pipeline_rank,
-                group=even_send_odd_recv_group,
+                tensor=tensor_recv_prev, src=prev_pipeline_rank, group=even_send_odd_recv_group
             )
             reqs.append(recv_prev_req)
 
         if tensor_send_next is not None:
             send_next_req = torch.distributed.isend(
-                tensor=tensor_send_next,
-                dst=next_pipeline_rank,
-                group=even_recv_odd_send_group,
+                tensor=tensor_send_next, dst=next_pipeline_rank, group=even_recv_odd_send_group
             )
             reqs.append(send_next_req)
 
         if tensor_recv_next is not None:
             recv_next_req = torch.distributed.irecv(
-                tensor=tensor_recv_next,
-                src=next_pipeline_rank,
-                group=even_send_odd_recv_group,
+                tensor=tensor_recv_next, src=next_pipeline_rank, group=even_send_odd_recv_group
             )
             reqs.append(recv_next_req)
 
         if tensor_send_prev is not None:
             send_prev_req = torch.distributed.isend(
-                tensor=tensor_send_prev,
-                dst=prev_pipeline_rank,
-                group=even_recv_odd_send_group,
+                tensor=tensor_send_prev, dst=prev_pipeline_rank, group=even_recv_odd_send_group
             )
             reqs.append(send_prev_req)
     return reqs
diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py
index 432420f63e..b7669ccb45 100644
--- a/megatron/core/pipeline_parallel/schedules.py
+++ b/megatron/core/pipeline_parallel/schedules.py
@@ -121,11 +121,7 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
         return
     assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
     assert out._base is None, "counter-productive to free a view of another tensor."
-    out.data = torch.empty(
-        (1,),
-        device=out.device,
-        dtype=out.dtype,
-    )
+    out.data = torch.empty((1,), device=out.device, dtype=out.dtype)
 
 
 def custom_backward(output, grad_output):
@@ -146,10 +142,7 @@ def custom_backward(output, grad_output):
     # Handle scalar output
     if grad_output is None:
         assert output.numel() == 1, "implicit grad requires scalar output."
-        grad_output = torch.ones_like(
-            output,
-            memory_format=torch.preserve_format,
-        )
+        grad_output = torch.ones_like(output, memory_format=torch.preserve_format)
 
     # Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
     Variable._execution_engine.run_backward(
@@ -752,9 +745,7 @@ def forward_step_helper(microbatch_id, current_microbatch, checkpoint_activation
             collect_non_loss_data,
             checkpoint_activations_microbatch,
             check_first_val_step(
-                first_val_step,
-                forward_only,
-                is_first_microbatch_for_model_chunk(microbatch_id),
+                first_val_step, forward_only, is_first_microbatch_for_model_chunk(microbatch_id)
             ),
             current_microbatch=current_microbatch,
         )
@@ -863,16 +854,15 @@ def backward_step_helper(microbatch_id):
                 recv_next = True
                 if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
                     recv_next = False
-                (
-                    input_tensor,
-                    output_tensor_grad,
-                ) = p2p_communication.send_forward_backward_recv_forward_backward(
-                    output_tensor,
-                    input_tensor_grad,
-                    recv_prev=recv_prev,
-                    recv_next=recv_next,
-                    tensor_shape=tensor_shape,
-                    config=config,
+                (input_tensor, output_tensor_grad) = (
+                    p2p_communication.send_forward_backward_recv_forward_backward(
+                        output_tensor,
+                        input_tensor_grad,
+                        recv_prev=recv_prev,
+                        recv_next=recv_next,
+                        tensor_shape=tensor_shape,
+                        config=config,
+                    )
                 )
                 output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
             else:
@@ -899,15 +889,14 @@ def backward_step_helper(microbatch_id):
                 if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
                     recv_next = False
 
-                (
-                    output_tensor_grad,
-                    bwd_wait_handles,
-                ) = p2p_communication.send_backward_recv_backward(
-                    input_tensor_grad,
-                    recv_next=recv_next,
-                    tensor_shape=tensor_shape,
-                    config=config,
-                    overlap_p2p_comm=True,
+                (output_tensor_grad, bwd_wait_handles) = (
+                    p2p_communication.send_backward_recv_backward(
+                        input_tensor_grad,
+                        recv_next=recv_next,
+                        tensor_shape=tensor_shape,
+                        config=config,
+                        overlap_p2p_comm=True,
+                    )
                 )
 
                 output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
@@ -1073,16 +1062,15 @@ def backward_step_helper(microbatch_id):
                 recv_prev = False
 
             # Communicate tensors.
-            (
-                input_tensor,
-                output_tensor_grad,
-            ) = p2p_communication.send_forward_backward_recv_forward_backward(
-                output_tensor,
-                input_tensor_grad,
-                recv_prev=recv_prev,
-                recv_next=recv_next,
-                tensor_shape=tensor_shape,
-                config=config,
+            (input_tensor, output_tensor_grad) = (
+                p2p_communication.send_forward_backward_recv_forward_backward(
+                    output_tensor,
+                    input_tensor_grad,
+                    recv_prev=recv_prev,
+                    recv_next=recv_next,
+                    tensor_shape=tensor_shape,
+                    config=config,
+                )
             )
             deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
 
diff --git a/megatron/core/ssm/mamba_block.py b/megatron/core/ssm/mamba_block.py
index ef444e8d2c..0bb9acce8d 100644
--- a/megatron/core/ssm/mamba_block.py
+++ b/megatron/core/ssm/mamba_block.py
@@ -146,12 +146,7 @@ def __init__(
                 eps=self.config.layernorm_epsilon,
             )
 
-        self.apply(
-            partial(
-                _init_weights,
-                n_layer=self.config.num_layers,
-            )
-        )
+        self.apply(partial(_init_weights, n_layer=self.config.num_layers))
 
     def _select_layers_for_pipeline_parallel(self, layer_type_list):
         pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
diff --git a/megatron/core/tensor_parallel/cross_entropy.py b/megatron/core/tensor_parallel/cross_entropy.py
index 45fa07515d..0066d126fd 100644
--- a/megatron/core/tensor_parallel/cross_entropy.py
+++ b/megatron/core/tensor_parallel/cross_entropy.py
@@ -80,8 +80,7 @@ def calculate_cross_entropy_loss(
 
     @staticmethod
     def prepare_gradient_calculation_operands(
-        softmax: torch.Tensor,
-        target_mask: torch.Tensor,
+        softmax: torch.Tensor, target_mask: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
 
         # All the inputs have softmax as thier gradient.
@@ -133,14 +132,10 @@ def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0):
         world_size = get_tensor_model_parallel_world_size()
         vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size)
 
-        (
-            target_mask,
-            masked_target_1d,
-            predicted_logits,
-            sum_exp_logits,
-            exp_logits,
-        ) = VocabParallelCrossEntropy.calculate_predicted_logits(
-            vocab_parallel_logits, target, logits_max, vocab_start_index, vocab_end_index
+        (target_mask, masked_target_1d, predicted_logits, sum_exp_logits, exp_logits) = (
+            VocabParallelCrossEntropy.calculate_predicted_logits(
+                vocab_parallel_logits, target, logits_max, vocab_start_index, vocab_end_index
+            )
         )
 
         # All reduce is needed to get the chunks from other GPUs.
@@ -193,12 +188,9 @@ def backward(ctx, grad_output):
         softmax, target_mask, masked_target_1d = ctx.saved_tensors
         label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size
 
-        (
-            grad_2d,
-            arange_1d,
-            softmax_update,
-            grad_input,
-        ) = VocabParallelCrossEntropy.prepare_gradient_calculation_operands(softmax, target_mask)
+        (grad_2d, arange_1d, softmax_update, grad_input) = (
+            VocabParallelCrossEntropy.prepare_gradient_calculation_operands(softmax, target_mask)
+        )
 
         if label_smoothing > 0:
             smoothing = label_smoothing * vocab_size / (vocab_size - 1)
diff --git a/megatron/core/tensor_parallel/data.py b/megatron/core/tensor_parallel/data.py
index 01dd90de51..c549f74d73 100644
--- a/megatron/core/tensor_parallel/data.py
+++ b/megatron/core/tensor_parallel/data.py
@@ -14,9 +14,10 @@
 def _check_data_types(keys, data, target_dtype):
     """Check that all the keys have the same target data type."""
     for key in keys:
-        assert data[key].dtype == target_dtype, (
-            '{} has data type {} which '
-            'is different than {}'.format(key, data[key].dtype, target_dtype)
+        assert (
+            data[key].dtype == target_dtype
+        ), '{} has data type {} which ' 'is different than {}'.format(
+            key, data[key].dtype, target_dtype
         )
 
 
diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py
index d644eb89ef..5707a0b529 100644
--- a/megatron/core/tensor_parallel/layers.py
+++ b/megatron/core/tensor_parallel/layers.py
@@ -179,11 +179,12 @@ def __init__(
         self.reduce_scatter_embeddings = reduce_scatter_embeddings
         self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
         # Divide the weight matrix along the vocaburaly dimension.
-        (
-            self.vocab_start_index,
-            self.vocab_end_index,
-        ) = VocabUtility.vocab_range_from_global_vocab_size(
-            self.num_embeddings, get_tensor_model_parallel_rank(), self.tensor_model_parallel_size
+        (self.vocab_start_index, self.vocab_end_index) = (
+            VocabUtility.vocab_range_from_global_vocab_size(
+                self.num_embeddings,
+                get_tensor_model_parallel_rank(),
+                self.tensor_model_parallel_size,
+            )
         )
         self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
         self.deterministic_mode = config.deterministic_mode
@@ -276,13 +277,7 @@ class LinearWithFrozenWeight(torch.autograd.Function):
 
     @staticmethod
     @custom_fwd
-    def forward(
-        ctx,
-        input,
-        weight,
-        bias,
-        allreduce_dgrad,
-    ):
+    def forward(ctx, input, weight, bias, allreduce_dgrad):
         ctx.save_for_backward(weight)
         ctx.allreduce_dgrad = allreduce_dgrad
         output = torch.matmul(input, weight.t())
@@ -372,12 +367,7 @@ def linear_with_frozen_weight(
         )
         allreduce_dgrad = async_grad_allreduce
 
-    args = [
-        input,
-        weight,
-        bias,
-        allreduce_dgrad,
-    ]
+    args = [input, weight, bias, allreduce_dgrad]
 
     return LinearWithFrozenWeight.apply(*args)
 
diff --git a/megatron/core/tensor_parallel/mappings.py b/megatron/core/tensor_parallel/mappings.py
index 88e77541d1..3eed700ceb 100644
--- a/megatron/core/tensor_parallel/mappings.py
+++ b/megatron/core/tensor_parallel/mappings.py
@@ -368,7 +368,7 @@ def symbolic(graph, input_):
 
     @staticmethod
     def forward(ctx, input_):
-        return _gather_along_last_dim(input_,)
+        return _gather_along_last_dim(input_)
 
     @staticmethod
     def backward(ctx, grad_output):
@@ -384,7 +384,7 @@ def symbolic(graph, input_):
 
     @staticmethod
     def forward(ctx, input_):
-        return _reduce_scatter_along_last_dim(input_,)
+        return _reduce_scatter_along_last_dim(input_)
 
     @staticmethod
     def backward(ctx, grad_output):
@@ -514,7 +514,7 @@ def all_to_all_hp2sp(input_):
 
     Args:
         input_ (torch.Tensor): The input tensor which has been distributed along the hidden dimension.
-        
+
     Returns:
         torch.Tensor: The output tensor with shape [num_tokens/TP, H].
     """
diff --git a/megatron/core/tensor_parallel/utils.py b/megatron/core/tensor_parallel/utils.py
index 53f0d60de0..d7c191b411 100644
--- a/megatron/core/tensor_parallel/utils.py
+++ b/megatron/core/tensor_parallel/utils.py
@@ -14,18 +14,18 @@
 
 
 def split_tensor_along_last_dim(
-    tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False,
+    tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False
 ) -> List[torch.Tensor]:
-    """ Split a tensor along its last dimension.
+    """Split a tensor along its last dimension.
 
-        Args:
-            tensor: input tensor.
-            num_partitions: number of partitions to split the tensor
-            contiguous_split_chunks: If True, make each chunk contiguous
-                                     in memory.
+    Args:
+        tensor: input tensor.
+        num_partitions: number of partitions to split the tensor
+        contiguous_split_chunks: If True, make each chunk contiguous
+                                 in memory.
 
-        Returns:
-            A list of Tensors
+    Returns:
+        A list of Tensors
     """
     # Get the size and dimension.
     last_dim = tensor.dim() - 1
@@ -40,17 +40,17 @@ def split_tensor_along_last_dim(
 
 
 def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
-    """ Break a tensor into equal 1D chunks across tensor parallel ranks.
+    """Break a tensor into equal 1D chunks across tensor parallel ranks.
 
-        Returns a Tensor or View with this rank's portion of the data.
+    Returns a Tensor or View with this rank's portion of the data.
 
-        Args:
-            tensor: The tensor to split
+    Args:
+        tensor: The tensor to split
 
-        Keyword Args:
-            new_buffer (bool): If True, returns a new Tensor.
-                               If False, returns a view into the existing Tensor.
-                               Default is False
+    Keyword Args:
+        new_buffer (bool): If True, returns a new Tensor.
+                           If False, returns a view into the existing Tensor.
+                           Default is False
 
     """
     partition_size = torch.numel(tensor) // parallel_state.get_tensor_model_parallel_world_size()
@@ -70,13 +70,13 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
 
 
 def gather_split_1d_tensor(tensor):
-    """ Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor
-        model parallel ranks.
+    """Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor
+    model parallel ranks.
 
-        Returns a new Tensor with the gathered data.
+    Returns a new Tensor with the gathered data.
 
-        Args:
-            tensor: A Tensor or view of this rank's portion of the data.
+    Args:
+        tensor: A Tensor or view of this rank's portion of the data.
     """
     numel_gathered = torch.numel(tensor) * parallel_state.get_tensor_model_parallel_world_size()
     gathered = torch.empty(
@@ -94,9 +94,9 @@ def gather_split_1d_tensor(tensor):
 
 
 class VocabUtility:
-    """ Split the vocabulary into `world_size` chunks and return the first
-        and last index of the vocabulary belonging to the `rank`
-        partition: Note that indices in [fist, last)
+    """Split the vocabulary into `world_size` chunks and return the first
+    and last index of the vocabulary belonging to the `rank`
+    partition: Note that indices in [fist, last)
 
     """
 
diff --git a/megatron/core/timers.py b/megatron/core/timers.py
index b61eb4ed22..e7070e37d8 100644
--- a/megatron/core/timers.py
+++ b/megatron/core/timers.py
@@ -110,8 +110,7 @@ def stop(self, barrier=False):
         self._started = False
 
     def reset(self):
-        """Reset timer.
-        """
+        """Reset timer."""
         # Don't reset _active_time
         self._elapsed = 0.0
         self._started = False
@@ -145,14 +144,13 @@ def active_time(self):
 
 
 class Timers:
-    """Class for a group of Timers.
-    """
+    """Class for a group of Timers."""
 
     def __init__(self, log_level, log_option):
         """Initialize group of timers.
 
         Args:
-            log_level (int): Log level to control what timers are enabled.            
+            log_level (int): Log level to control what timers are enabled.
             log_option (str): Setting for logging statistics over ranks for all the timers. Allowed: ['max', 'minmax', 'all'].
         """
         self._log_level = log_level
@@ -351,7 +349,7 @@ def log(
         barrier: bool = False,
     ):
         """logs the timers passed in names to stdout. Example usage is to log average per step value for timer 'foo',
-          this function can be called with normalizer factor set to logging interval. 
+          this function can be called with normalizer factor set to logging interval.
 
         Args:
             names (List[str]): Names of the timers to log.
diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py
index 96c19d0fca..43eacf03f9 100644
--- a/megatron/core/transformer/attention.py
+++ b/megatron/core/transformer/attention.py
@@ -149,14 +149,7 @@ def custom_forward(*inputs):
             attn_mask_type = self.attn_mask_type
         attn_mask_type = torch.tensor([attn_mask_type.value], dtype=torch.int)
         hidden_states = tensor_parallel.checkpoint(
-            custom_forward,
-            False,
-            query,
-            key,
-            value,
-            attention_mask,
-            rotary_pos_emb,
-            attn_mask_type,
+            custom_forward, False, query, key, value, attention_mask, rotary_pos_emb, attn_mask_type
         )
 
         return hidden_states
@@ -289,17 +282,9 @@ def forward(
             else:
                 cu_seqlens_q = cu_seqlens_kv = None
             query = apply_rotary_pos_emb(
-                query,
-                q_pos_emb,
-                config=self.config,
-                cu_seqlens=cu_seqlens_q,
-            )
-            key = apply_rotary_pos_emb(
-                key,
-                k_pos_emb,
-                config=self.config,
-                cu_seqlens=cu_seqlens_kv,
+                query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q
             )
+            key = apply_rotary_pos_emb(key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv)
 
             # TODO, can apply positional embedding to value_layer so it has
             # absolute positional embedding.
@@ -499,19 +484,11 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
         if SplitAlongDim is not None:
 
             # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
-            (query, key, value) = SplitAlongDim(
-                mixed_qkv,
-                3,
-                split_arg_list,
-            )
+            (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list)
         else:
 
             # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
-            (query, key, value) = torch.split(
-                mixed_qkv,
-                split_arg_list,
-                dim=3,
-            )
+            (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3)
 
         # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
         query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)
diff --git a/megatron/core/transformer/custom_layers/transformer_engine.py b/megatron/core/transformer/custom_layers/transformer_engine.py
index 879547fc1b..4d73995bbd 100644
--- a/megatron/core/transformer/custom_layers/transformer_engine.py
+++ b/megatron/core/transformer/custom_layers/transformer_engine.py
@@ -39,9 +39,7 @@ def get_te_version_str():
 
 
 def _get_extra_te_kwargs(config: TransformerConfig):
-    extra_transformer_engine_kwargs = {
-        "params_dtype": config.params_dtype,
-    }
+    extra_transformer_engine_kwargs = {"params_dtype": config.params_dtype}
 
     if _te_version >= packaging.version.Version("0.12.0"):
         if config.use_cpu_initialization:
@@ -62,12 +60,7 @@ class TENorm:
     """
 
     # TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm?
-    def __new__(
-        cls,
-        config: TransformerConfig,
-        hidden_size: int,
-        eps: float = 1e-5,
-    ):
+    def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5):
         if config.normalization == "LayerNorm":
             instance = te.pytorch.LayerNorm(
                 hidden_size=hidden_size,
@@ -559,13 +552,7 @@ def forward(
                 **packed_seq_kwargs,
             )
         else:
-            core_attn_out = super().forward(
-                query,
-                key,
-                value,
-                attention_mask,
-                **packed_seq_kwargs,
-            )
+            core_attn_out = super().forward(query, key, value, attention_mask, **packed_seq_kwargs)
 
         if self.config.apply_rope_fusion and qkv_format == 'bshd':
             return core_attn_out.transpose(0, 1)
@@ -767,12 +754,7 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
             """
             tp_axis_map = {}
             for gemm_idx in range(self.num_gemms):
-                tp_axis_map.update(
-                    {
-                        f'{gemm_idx}.weight': 0,
-                        f'{gemm_idx}.bias': 0,
-                    }
-                )
+                tp_axis_map.update({f'{gemm_idx}.weight': 0, f'{gemm_idx}.bias': 0})
             return super()._sharded_state_dict_grouped(
                 tp_axis_map, prefix, sharded_offsets, metadata
             )
diff --git a/megatron/core/transformer/dot_product_attention.py b/megatron/core/transformer/dot_product_attention.py
index 967d0ce8d8..7c28c153bc 100644
--- a/megatron/core/transformer/dot_product_attention.py
+++ b/megatron/core/transformer/dot_product_attention.py
@@ -120,12 +120,7 @@ def forward(
             )
 
         # [b, np, sq, sk]
-        output_size = (
-            query.size(1),
-            query.size(2),
-            query.size(0),
-            key.size(0),
-        )
+        output_size = (query.size(1), query.size(2), query.size(0), key.size(0))
 
         # [sq, b, np, hn] -> [sq, b * np, hn]
         # This will be a simple view when doing normal attention, but in group query attention
@@ -137,7 +132,7 @@ def forward(
 
         # preallocting input tensor: [b * np, sq, sk]
         matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor(
-            (output_size[0] * output_size[1], output_size[2], output_size[3]), query.dtype, "mpu",
+            (output_size[0] * output_size[1], output_size[2], output_size[3]), query.dtype, "mpu"
         )
 
         # Raw attention scores. [b * np, sq, sk]
@@ -176,12 +171,7 @@ def forward(
         # [sk, b, np, hn] --> [b, np, sq, hn]
 
         # context layer shape: [b, np, sq, hn]
-        output_size = (
-            value.size(1),
-            value.size(2),
-            query.size(0),
-            value.size(3),
-        )
+        output_size = (value.size(1), value.size(2), query.size(0), value.size(3))
 
         # change view [sk, b * np, hn]
         value = value.view(value.size(0), output_size[0] * output_size[1], -1)
diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py
index e11adf9447..d19ff6a234 100644
--- a/megatron/core/transformer/moe/experts.py
+++ b/megatron/core/transformer/moe/experts.py
@@ -94,9 +94,7 @@ def glu(x):
             )
             self.weight2 = Parameter(
                 torch.empty(
-                    fc2_input_size_per_partition,
-                    self.config.hidden_size,
-                    dtype=config.params_dtype,
+                    fc2_input_size_per_partition, self.config.hidden_size, dtype=config.params_dtype
                 )
             )
             if config.perform_initialization:
diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py
index c0c10a2c58..da3bde82f5 100644
--- a/megatron/core/transformer/moe/moe_utils.py
+++ b/megatron/core/transformer/moe/moe_utils.py
@@ -270,9 +270,7 @@ def unpermute_with_padded_tokens(
 
     # Prepare a tensor of zeros with the desired output shape
     empty_tokens = torch.zeros(
-        restore_shape,
-        dtype=combined_output.dtype,
-        device=combined_output.device,
+        restore_shape, dtype=combined_output.dtype, device=combined_output.device
     )
 
     # Scatter the combined tokens back to their original positions
@@ -325,9 +323,7 @@ def topk_softmax_with_capacity(
     else:
         # TopK with capacity
         expert_capacity = get_capacity(
-            num_tokens=num_tokens * topk,
-            num_experts=num_experts,
-            capacity_factor=capacity_factor,
+            num_tokens=num_tokens * topk, num_experts=num_experts, capacity_factor=capacity_factor
         )
         # TopK selection, Maskout unused experts
         topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
@@ -418,9 +414,7 @@ def reduce_aux_losses_tracker_across_ranks():
             torch.distributed.all_reduce(values, group=tracker[name].get('reduce_group'))
         if tracker[name].get('avg_group') is not None:
             torch.distributed.all_reduce(
-                values,
-                group=tracker[name]['avg_group'],
-                op=torch.distributed.ReduceOp.AVG,
+                values, group=tracker[name]['avg_group'], op=torch.distributed.ReduceOp.AVG
             )
 
 
diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py
index a98959b710..817bfc0bdb 100644
--- a/megatron/core/transformer/moe/router.py
+++ b/megatron/core/transformer/moe/router.py
@@ -40,10 +40,7 @@ def __init__(self, config: TransformerConfig) -> None:
 
         # Initialize the gate weights.
         self.weight = torch.nn.Parameter(
-            torch.empty(
-                (self.config.num_moe_experts, self.config.hidden_size),
-                dtype=torch.float32,
-            )
+            torch.empty((self.config.num_moe_experts, self.config.hidden_size), dtype=torch.float32)
         )
         if config.perform_initialization:
             if get_cuda_rng_tracker().is_initialized():
@@ -99,10 +96,7 @@ def set_layer_number(self, layer_number: int):
 class TopKRouter(Router):
     """Route each token to the top-k experts."""
 
-    def __init__(
-        self,
-        config: TransformerConfig,
-    ) -> None:
+    def __init__(self, config: TransformerConfig) -> None:
         """Initialize the zero token dropping router.
 
         Args:
@@ -228,10 +222,7 @@ def apply_z_loss(self, logits):
             z_loss = z_loss_func(logits, moe_z_loss_coeff)
             logits = MoEAuxLossAutoScaler.apply(logits, z_loss)
             save_to_aux_losses_tracker(
-                "z_loss",
-                z_loss / moe_z_loss_coeff,
-                self.layer_number,
-                self.config.num_layers,
+                "z_loss", z_loss / moe_z_loss_coeff, self.layer_number, self.config.num_layers
             )
         return logits
 
diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py
index 377403a5d7..c76ca6541e 100644
--- a/megatron/core/transformer/moe/token_dispatcher.py
+++ b/megatron/core/transformer/moe/token_dispatcher.py
@@ -23,11 +23,7 @@ def __init__(self, config: TransformerConfig) -> None:
         self.config = config
 
     @abstractmethod
-    def token_permutation(
-        self,
-        tokens: torch.Tensor,
-        indices: torch.Tensor,
-    ):
+    def token_permutation(self, tokens: torch.Tensor, indices: torch.Tensor):
         """Dispatch tokens to experts.
 
         Args:
@@ -41,10 +37,7 @@ def token_permutation(
 
     @abstractmethod
     def token_unpermutation(
-        self,
-        expert_output: torch.Tensor,
-        probs: torch.Tensor,
-        indices: torch.Tensor,
+        self, expert_output: torch.Tensor, probs: torch.Tensor, indices: torch.Tensor
     ):
         """Restores the expert output to its original ordering.
 
@@ -65,10 +58,7 @@ class MoEAllGatherTokenDispatcher(MoETokenDispatcher):
     """
 
     def __init__(
-        self,
-        num_local_experts: int,
-        local_expert_indices: List[int],
-        config: TransformerConfig,
+        self, num_local_experts: int, local_expert_indices: List[int], config: TransformerConfig
     ) -> None:
         """
         Initialize the zero token dropping router.
@@ -163,8 +153,7 @@ def token_permutation(
             # The indices of local_indices that give its sorted order along dim 0.
             self.indices = torch.argsort(local_indices, dim=0)
             tokens_per_expert = torch.bincount(
-                local_indices.view(-1),
-                minlength=self.config.num_moe_experts,
+                local_indices.view(-1), minlength=self.config.num_moe_experts
             )
             if self.num_local_experts < self.config.num_moe_experts:
                 tokens_per_expert = tokens_per_expert[
@@ -179,16 +168,9 @@ def token_permutation(
             permuted_local_hidden_states = moe_gather.apply(local_hidden_states, self.indices)
         else:
             permuted_local_hidden_states = local_hidden_states
-        return (
-            permuted_local_hidden_states,
-            tokens_per_expert,
-        )
+        return (permuted_local_hidden_states, tokens_per_expert)
 
-    def token_unpermutation(
-        self,
-        hidden_states: torch.Tensor,
-        bias: torch.Tensor = None,
-    ):
+    def token_unpermutation(self, hidden_states: torch.Tensor, bias: torch.Tensor = None):
         """
         Reverse process of `dispatch()` which permutes the ouput of local
         experts locallay and across expert parallel rank into the original order to
@@ -299,10 +281,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
     """
 
     def __init__(
-        self,
-        num_local_experts: int,
-        local_expert_indices: List[int],
-        config: TransformerConfig,
+        self, num_local_experts: int, local_expert_indices: List[int], config: TransformerConfig
     ) -> None:
         """
         Initialize the AlltoAll token dispatcher.
@@ -442,10 +421,7 @@ def preprocess(self, indices: torch.Tensor) -> torch.Tensor:
         return num_tokens_per_local_expert
 
     def token_permutation(
-        self,
-        hidden_states: torch.Tensor,
-        probs: torch.Tensor,
-        indices: torch.Tensor,
+        self, hidden_states: torch.Tensor, probs: torch.Tensor, indices: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """
         Dispatch tokens to local experts using AlltoAll communication.
@@ -522,9 +498,7 @@ def token_permutation(
         return global_input_tokens, tokens_per_expert
 
     def token_unpermutation(
-        self,
-        hidden_states: torch.Tensor,
-        bias: torch.Tensor = None,
+        self, hidden_states: torch.Tensor, bias: torch.Tensor = None
     ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
         """
         Reverse the token permutation to restore the original order.
@@ -551,8 +525,7 @@ def token_unpermutation(
         if self.num_local_experts > 1:
             if not self.drop_and_pad:
                 hidden_states = unpermute(
-                    hidden_states,
-                    self.reversed_global_input_permutation_mapping,
+                    hidden_states, self.reversed_global_input_permutation_mapping
                 )
             else:
                 hidden_states = hidden_states.reshape(
diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py
index 8904e4b86f..1e90099a21 100755
--- a/megatron/core/transformer/transformer_block.py
+++ b/megatron/core/transformer/transformer_block.py
@@ -90,8 +90,7 @@ class TransformerBlockSubmodules:
 
 
 def _get_block_submodules(
-    config: TransformerConfig,
-    spec: Union[TransformerBlockSubmodules, ModuleSpec],
+    config: TransformerConfig, spec: Union[TransformerBlockSubmodules, ModuleSpec]
 ) -> TransformerBlockSubmodules:
 
     # Transformer block submodules.
@@ -107,8 +106,7 @@ def _get_block_submodules(
         elif issubclass(spec.module, BaseTransformerLayer):
             num_layers = get_num_layers_to_build(config)
             return TransformerBlockSubmodules(
-                layer_specs=[spec] * num_layers,
-                layer_norm=LayerNormImpl,
+                layer_specs=[spec] * num_layers, layer_norm=LayerNormImpl
             )
         else:
             raise Exception(f"specialize for {spec.module.__name__}.")
@@ -146,15 +144,14 @@ def __init__(
         self.checkpoint_core_attention = self.config.recompute_granularity == 'selective'
 
         if get_cpu_offload_context is not None:
-            (
-                self.offload_context,
-                self.group_prefetch_offload_commit_async,
-            ) = get_cpu_offload_context(
-                self.config.cpu_offloading,
-                self.config.cpu_offloading_num_layers,
-                self.config.num_layers,
-                self.config.cpu_offloading_activations,
-                self.config.cpu_offloading_weights,
+            (self.offload_context, self.group_prefetch_offload_commit_async) = (
+                get_cpu_offload_context(
+                    self.config.cpu_offloading,
+                    self.config.cpu_offloading_num_layers,
+                    self.config.num_layers,
+                    self.config.cpu_offloading_activations,
+                    self.config.cpu_offloading_weights,
+                )
             )
             self.config._cpu_offloading_context = (
                 self.offload_context if self.config.cpu_offloading else None
@@ -178,11 +175,7 @@ def _build_layers(self):
         #     coeff = self.layer_number
         #     self.norm_factor *= coeff
         def build_layer(layer_spec, layer_number):
-            return build_module(
-                layer_spec,
-                config=self.config,
-                layer_number=layer_number,
-            )
+            return build_module(layer_spec, config=self.config, layer_number=layer_number)
 
         # offset is implicit in TransformerLayer
         self.layers = torch.nn.ModuleList(
@@ -235,11 +228,7 @@ def _checkpointed_forward(
 
         def custom(start: int, end: int):
             def custom_forward(
-                hidden_states,
-                attention_mask,
-                context,
-                context_mask,
-                rotary_pos_emb,
+                hidden_states, attention_mask, context, context_mask, rotary_pos_emb
             ):
                 for index in range(start, end):
                     layer = self._get_layer(index)
@@ -310,11 +299,7 @@ def checkpoint_handler(forward_func):
                     hidden_states, context = checkpoint_handler(custom(l, l + 1))
                 else:
                     hidden_states, context = custom(l, l + 1)(
-                        hidden_states,
-                        attention_mask,
-                        context,
-                        context_mask,
-                        rotary_pos_emb,
+                        hidden_states, attention_mask, context, context_mask, rotary_pos_emb
                     )
         else:
             raise ValueError("Invalid activation recompute method.")
@@ -363,11 +348,7 @@ def forward(
         #   likely redundant, since p2p_communication.py (likely originator)
         #   already creates viewless tensors. That said, make_viewless_tensor()
         #   is called here to be future-proof and corner-case-proof.
-        hidden_states = make_viewless_tensor(
-            inp=hidden_states,
-            requires_grad=True,
-            keep_graph=True,
-        )
+        hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)
 
         if self.config.sequence_parallel:
             rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
@@ -437,8 +418,7 @@ def forward(
                                 self.current_microbatch < len(self.cuda_graphs[l_no])
                             )
                             hidden_states = self.cuda_graphs[l_no][self.current_microbatch](
-                                hidden_states,
-                                is_first_microbatch=(self.current_microbatch == 0),
+                                hidden_states, is_first_microbatch=(self.current_microbatch == 0)
                             )
 
                     if (
@@ -455,9 +435,7 @@ def forward(
             # deallocate_output_tensor() throwing an error, so a viewless tensor is
             # created to prevent this.
             hidden_states = make_viewless_tensor(
-                inp=hidden_states,
-                requires_grad=True,
-                keep_graph=True,
+                inp=hidden_states, requires_grad=True, keep_graph=True
             )
 
         return hidden_states
diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py
index 631179ed08..703a291e83 100644
--- a/megatron/core/transformer/transformer_layer.py
+++ b/megatron/core/transformer/transformer_layer.py
@@ -36,7 +36,7 @@ class TransformerLayerSubmodules:
 
 
 class BaseTransformerLayer(ABC):
-    """ A common parent class for `TransformerLayer` like implementations.
+    """A common parent class for `TransformerLayer` like implementations.
 
     A dummy class that is subclassed by similar `TransformerLayer`s e.g. the
     `TransformerLayer` in this file and possibly other `TransformerLayer`
@@ -82,7 +82,7 @@ def __init__(
 
         ## [Module 2: SelfAttention]
         self.self_attention = build_module(
-            submodules.self_attention, config=self.config, layer_number=layer_number,
+            submodules.self_attention, config=self.config, layer_number=layer_number
         )
 
         ## [Module 3: BiasDropoutFusion]
@@ -98,11 +98,11 @@ def __init__(
 
         ## [Module 5: CrossAttention]
         self.cross_attention = build_module(
-            submodules.cross_attention, config=self.config, layer_number=layer_number,
+            submodules.cross_attention, config=self.config, layer_number=layer_number
         )
 
         ## [Module 6: BiasDropoutFusion]
-        self.cross_attn_bda = build_module(submodules.cross_attn_bda, config=self.config,)
+        self.cross_attn_bda = build_module(submodules.cross_attn_bda, config=self.config)
 
         ## [Module 7: Pre MLP] Optional Layernorm before MLP
         self.pre_mlp_layernorm = build_module(
diff --git a/megatron/core/transformer/utils.py b/megatron/core/transformer/utils.py
index 025f7c2b1e..4781b68d2a 100644
--- a/megatron/core/transformer/utils.py
+++ b/megatron/core/transformer/utils.py
@@ -97,12 +97,12 @@ def make_sharded_tensors_for_checkpoint(
         elif layer_name in tensor_parallel_layers_axis_map:
             tp_axis = tensor_parallel_layers_axis_map[layer_name]
             sharded_state_dict[layer_key] = make_tp_sharded_tensor_for_checkpoint(
-                tensor, layer_key, tp_axis, prepend_offsets=sharded_offsets,
+                tensor, layer_key, tp_axis, prepend_offsets=sharded_offsets
             )
 
         else:
             sharded_state_dict[layer_key] = make_sharded_tensor_for_checkpoint(
-                tensor, layer_key, prepend_offsets=sharded_offsets,
+                tensor, layer_key, prepend_offsets=sharded_offsets
             )
 
     return sharded_state_dict
@@ -115,7 +115,7 @@ def make_sharded_object_for_checkpoint(
     replica_id: Union[None, int, Tuple[int, ...]] = None,
     **kwargs,
 ):
-    """ Helper for instantiating a non-sharded ShardedObject (replicated across TP and DP group).
+    """Helper for instantiating a non-sharded ShardedObject (replicated across TP and DP group).
 
     Args:
         obj (object): any object to be sharded
@@ -138,7 +138,7 @@ def make_sharded_object_for_checkpoint(
 def _get_extra_state_offsets(
     sharded_offsets: Iterable[Tuple[int, int, int]]
 ) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
-    """ Turns ShardedTensor offsets into offsets suitable for ShardedObject. """
+    """Turns ShardedTensor offsets into offsets suitable for ShardedObject."""
     if sharded_offsets:
         sharded_offsets = sorted(sharded_offsets, key=itemgetter(0))  # sort by axis
         axis, extra_state_offset, extra_state_shape = zip(*sharded_offsets)
@@ -183,6 +183,6 @@ def sharded_state_dict_default(
     else:
         module_sd = module.state_dict(prefix='', keep_vars=True)
         module_sharded_sd = make_sharded_tensors_for_checkpoint(
-            module_sd, prefix, {}, sharded_offsets,
+            module_sd, prefix, {}, sharded_offsets
         )
     return module_sharded_sd
diff --git a/megatron/core/utils.py b/megatron/core/utils.py
index a777770617..062372d97d 100644
--- a/megatron/core/utils.py
+++ b/megatron/core/utils.py
@@ -111,12 +111,7 @@ def _kernel_make_viewless_tensor(inp, requires_grad):
     data, without linking the viewed tensor, referenced via the '._base'
     field.
     '''
-    out = torch.empty(
-        (1,),
-        dtype=inp.dtype,
-        device=inp.device,
-        requires_grad=requires_grad,
-    )
+    out = torch.empty((1,), dtype=inp.dtype, device=inp.device, requires_grad=requires_grad)
     out.data = inp.data
     return out
 
@@ -908,13 +903,7 @@ def report(self, total_flops: float = 0.0, log_interval: int = 0) -> bool:
             et_flops = apir_flops / self.amp  # Estimated TFLOPs, not tracing backward
 
             o_dt = self._min_max(
-                ptime,
-                btime,
-                float(temp),
-                float(power),
-                float(util),
-                float(clock),
-                et_flops,
+                ptime, btime, float(temp), float(power), float(util), float(clock), et_flops
             )
             if self.rank == 0 and o_dt is not None and o_dt.aflops is not None:
                 now = f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]"
diff --git a/pyproject.toml b/pyproject.toml
index 934745ec68..c707686a83 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -9,7 +9,7 @@ requires = [
 [tool.isort]
 profile = "black"  # black-compatible
 line_length = 100  # should match black parameters
-py_version = 38  # python 3.8 as a target version
+py_version = 310  # python 3.8 as a target version
 known_first_party = ["megatron"]  # FIRSTPARTY section
 known_third_party = ["transformer_engine"]  # THIRDPARTY section
 sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"]
diff --git a/tests/functional_tests/python_test_utils/common.py b/tests/functional_tests/python_test_utils/common.py
index 3ce43f095f..3a9fd359a6 100644
--- a/tests/functional_tests/python_test_utils/common.py
+++ b/tests/functional_tests/python_test_utils/common.py
@@ -10,10 +10,7 @@
 # Since we expect every step to be there when we do our comparisons, we explicitly
 # set the size guidance to 0 so that we load everything. It's okay given our tests
 # are small/short.
-SIZE_GUIDANCE = {
-    event_accumulator.TENSORS: 0,
-    event_accumulator.SCALARS: 0,
-}
+SIZE_GUIDANCE = {event_accumulator.TENSORS: 0, event_accumulator.SCALARS: 0}
 
 logger = logging.getLogger()
 
diff --git a/tests/functional_tests/python_test_utils/get_test_results_from_tensorboard_logs.py b/tests/functional_tests/python_test_utils/get_test_results_from_tensorboard_logs.py
index ba3d43f9c5..e93fd2046e 100644
--- a/tests/functional_tests/python_test_utils/get_test_results_from_tensorboard_logs.py
+++ b/tests/functional_tests/python_test_utils/get_test_results_from_tensorboard_logs.py
@@ -9,12 +9,7 @@
 
 
 @click.command()
-@click.option(
-    "--logs-dir",
-    required=True,
-    type=str,
-    help="Path to Tensorboard logs",
-)
+@click.option("--logs-dir", required=True, type=str, help="Path to Tensorboard logs")
 @click.option(
     "--output-path",
     required=False,
diff --git a/tests/functional_tests/python_test_utils/test_resume_checkpoint_pipeline.py b/tests/functional_tests/python_test_utils/test_resume_checkpoint_pipeline.py
index bf14f8ef75..f0375dfb3d 100644
--- a/tests/functional_tests/python_test_utils/test_resume_checkpoint_pipeline.py
+++ b/tests/functional_tests/python_test_utils/test_resume_checkpoint_pipeline.py
@@ -16,9 +16,7 @@
 def collect_train_test_metrics(logs_dir, index):
     train_loss_list = read_tb_logs_as_list(logs_dir, index)["lm loss"]
     train_loss_list = [round(elem, 3) for elem in train_loss_list]
-    train_metrics = {
-        "lm loss": train_loss_list[0 : len(train_loss_list) : STEP_INTERVAL],
-    }
+    train_metrics = {"lm loss": train_loss_list[0 : len(train_loss_list) : STEP_INTERVAL]}
     str_train_metrics = str(train_metrics).replace("'", '"')
     print("\n ----------- The following are the metrics for ----------")
     print(f"\n {str_train_metrics}", flush=True)
diff --git a/tests/unit_tests/__init__.py b/tests/unit_tests/__init__.py
index 1d3c586a5d..38a9977640 100644
--- a/tests/unit_tests/__init__.py
+++ b/tests/unit_tests/__init__.py
@@ -1,2 +1,3 @@
 import torch._dynamo
-torch._dynamo.config.suppress_errors = True
\ No newline at end of file
+
+torch._dynamo.config.suppress_errors = True
diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py
index fb5cfc3ba4..787dd48c7a 100644
--- a/tests/unit_tests/conftest.py
+++ b/tests/unit_tests/conftest.py
@@ -13,9 +13,10 @@
 
 @pytest.fixture(scope="session")
 def tmp_path_dist_ckpt(tmp_path_factory) -> Path:
-    """ Common directory for saving the checkpoint.
+    """Common directory for saving the checkpoint.
 
-    Can't use pytest `tmp_path_factory` directly because directory must be shared between processes. """
+    Can't use pytest `tmp_path_factory` directly because directory must be shared between processes.
+    """
 
     tmp_dir = tmp_path_factory.mktemp('ignored', numbered=False)
     tmp_dir = tmp_dir.parent.parent / 'tmp_dist_ckpt'
diff --git a/tests/unit_tests/data/test_builder.py b/tests/unit_tests/data/test_builder.py
index 8f149dcffb..7f4caaa0f6 100644
--- a/tests/unit_tests/data/test_builder.py
+++ b/tests/unit_tests/data/test_builder.py
@@ -110,11 +110,7 @@ def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]:
         config = BlendedMegatronDatasetConfig(
             random_seed=1234,
             sequence_length=_SEQUENCE_LENGTH,
-            blend_per_split=[
-                blends[Split.train],
-                None,
-                None,
-            ],
+            blend_per_split=[blends[Split.train], None, None],
         )
         try:
             datasets = BlendedMegatronDatasetBuilder(
@@ -127,11 +123,7 @@ def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]:
         config = BlendedMegatronDatasetConfig(
             random_seed=1234,
             sequence_length=_SEQUENCE_LENGTH,
-            blend_per_split=[
-                get_blend_from_list([paths[Split.train][0]]),
-                None,
-                None,
-            ],
+            blend_per_split=[get_blend_from_list([paths[Split.train][0]]), None, None],
         )
         datasets = BlendedMegatronDatasetBuilder(
             TestDataset, [1000, None, None], lambda: True, config
@@ -187,11 +179,7 @@ def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]:
         config = BlendedMegatronDatasetConfig(
             random_seed=1234,
             sequence_length=_SEQUENCE_LENGTH,
-            blend_per_split=[
-                blends_unweighted[Split.train],
-                None,
-                None,
-            ],
+            blend_per_split=[blends_unweighted[Split.train], None, None],
         )
         datasets = BlendedMegatronDatasetBuilder(
             TestDataset, [1000, None, None], lambda: True, config
@@ -245,11 +233,7 @@ def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]:
             config = BlendedMegatronDatasetConfig(
                 random_seed=1234,
                 sequence_length=_SEQUENCE_LENGTH,
-                blend_per_split=[
-                    blends[Split.train],
-                    blends[Split.valid],
-                    blends[Split.test],
-                ],
+                blend_per_split=[blends[Split.train], blends[Split.valid], blends[Split.test]],
             )
             datasets = BlendedMegatronDatasetBuilder(
                 TestDataset, [100, 100, 100], lambda: True, config
diff --git a/tests/unit_tests/data/test_gpt_dataset.py b/tests/unit_tests/data/test_gpt_dataset.py
index 906a5728de..f10be883bf 100644
--- a/tests/unit_tests/data/test_gpt_dataset.py
+++ b/tests/unit_tests/data/test_gpt_dataset.py
@@ -96,7 +96,7 @@ def test_mock_gpt_dataset():
     assert torch.all(sample['labels'][argmax + 1 :] == 0)
     assert not torch.any(
         sample['loss_mask'][
-            torch.logical_and(sample['labels'] == tokenizer.eod, sample['labels'] == 0,)
+            torch.logical_and(sample['labels'] == tokenizer.eod, sample['labels'] == 0)
         ]
     )
 
diff --git a/tests/unit_tests/data/test_multimodal_dataset.py b/tests/unit_tests/data/test_multimodal_dataset.py
index ef5430c2da..a9a30c02ec 100644
--- a/tests/unit_tests/data/test_multimodal_dataset.py
+++ b/tests/unit_tests/data/test_multimodal_dataset.py
@@ -25,7 +25,7 @@ def test_mock_multimodal_dataset():
         torch.distributed.barrier()
     else:
         compile_helpers()
-        
+
     config = MultimodalDatasetConfig(
         random_seed=1234,
         sequence_length=1024,
diff --git a/tests/unit_tests/data/test_preprocess_data.py b/tests/unit_tests/data/test_preprocess_data.py
index 8d35e4c5c0..0b460f51a9 100644
--- a/tests/unit_tests/data/test_preprocess_data.py
+++ b/tests/unit_tests/data/test_preprocess_data.py
@@ -82,14 +82,12 @@ def do_test_preprocess_data(temp_dir, extra_args=[]):
     dummy_jsonl(path_to_raws)
 
     # build the datasets
-    build_datasets(
-        path_to_raws, path_to_data, extra_args=extra_args,
-    )
+    build_datasets(path_to_raws, path_to_data, extra_args=extra_args)
 
     # merge the datasets
     merge_datasets(path_to_data)
 
-    sys.argv = [sys.argv[0], "--input", None, "--output-prefix", None,] + extra_args
+    sys.argv = [sys.argv[0], "--input", None, "--output-prefix", None] + extra_args
     encoder = Encoder(build_args())
     encoder.initializer()
 
@@ -184,6 +182,7 @@ def gpt2_merge(odir):
         writer.write(requests.get(PRETRAINED_MERGES_ARCHIVE_MAP['gpt2']).content)
     return path
 
+
 @pytest.mark.skip(reason="Tests are flaky and need to be debugged")
 def test_preprocess_data_gpt():
     with tempfile.TemporaryDirectory() as temp_dir:
@@ -214,6 +213,7 @@ def bert_vocab(odir):
         writer.write(requests.get(__HUGGINGFACE_BERT_BASE_UNCASED_VOCAB).content)
     return path
 
+
 @pytest.mark.skip(reason="Tests are flaky and need to be debugged")
 def test_preprocess_data_bert():
     with tempfile.TemporaryDirectory() as temp_dir:
@@ -239,4 +239,4 @@ def test_preprocess_data_bert():
 
 if __name__ == "__main__":
     test_preprocess_data_gpt()
-    test_preprocess_data_bert()
\ No newline at end of file
+    test_preprocess_data_bert()
diff --git a/tests/unit_tests/data/test_preprocess_mmdata.py b/tests/unit_tests/data/test_preprocess_mmdata.py
index 8aab96e64a..d6ad4eddc7 100644
--- a/tests/unit_tests/data/test_preprocess_mmdata.py
+++ b/tests/unit_tests/data/test_preprocess_mmdata.py
@@ -74,9 +74,7 @@ def do_test_preprocess_mmdata(temp_dir, extra_args=[]):
     dummy_img(path_to_raws_txt, path_to_raws_img)
 
     # build the datasets
-    build_datasets(
-        path_to_raws_txt, path_to_raws_img, path_to_data, extra_args=extra_args,
-    )
+    build_datasets(path_to_raws_txt, path_to_raws_img, path_to_data, extra_args=extra_args)
 
     # merge the datasets
     merge_datasets(path_to_data)
diff --git a/tests/unit_tests/dist_checkpointing/__init__.py b/tests/unit_tests/dist_checkpointing/__init__.py
index 3b4a7896d7..d6c2701891 100644
--- a/tests/unit_tests/dist_checkpointing/__init__.py
+++ b/tests/unit_tests/dist_checkpointing/__init__.py
@@ -3,15 +3,15 @@
 from pathlib import Path
 from shutil import rmtree
 from tempfile import TemporaryDirectory
-from typing import Union, Optional
+from typing import Optional, Union
 
-from tests.unit_tests.test_utilities import Utils
 from tests.unit_tests.dist_checkpointing.utils import (
-    setup_model_and_optimizer,
     init_basic_mock_args,
     init_checkpointing_mock_args,
     initialize_gpt_model,
+    setup_model_and_optimizer,
 )
+from tests.unit_tests.test_utilities import Utils
 
 
 def empty_dir(path: Path):
@@ -25,23 +25,23 @@ def empty_dir(path: Path):
 
 
 class TempNamedDir(TemporaryDirectory):
-    """ TemporaryDirectory with a fully named directory. Empties the dir if not empty. """
-    def __init__(self, name: Union[str, Path], sync=True,
-                 ignore_cleanup_errors=False) -> None:
+    """TemporaryDirectory with a fully named directory. Empties the dir if not empty."""
+
+    def __init__(self, name: Union[str, Path], sync=True, ignore_cleanup_errors=False) -> None:
         self.name = str(name)
         if Utils.rank == 0:
             os.makedirs(name, exist_ok=True)
             empty_dir(Path(name))
         if sync:
             import torch
+
             torch.distributed.barrier()
         else:
             os.makedirs(name, exist_ok=True)
 
         self._ignore_cleanup_errors = ignore_cleanup_errors
         self._finalizer = weakref.finalize(
-            self, self._cleanup, self.name,
-            warn_message="Implicitly cleaning up {!r}".format(self)
+            self, self._cleanup, self.name, warn_message="Implicitly cleaning up {!r}".format(self)
         )
         self.sync = sync
 
@@ -49,6 +49,7 @@ def cleanup(self, override_sync: Optional[bool] = None) -> None:
         sync = self.sync if override_sync is None else override_sync
         if sync:
             import torch
+
             torch.distributed.barrier()
 
         if Utils.rank == 0:
@@ -58,6 +59,7 @@ def __enter__(self):
         path = Path(super().__enter__())
         if self.sync:
             import torch
+
             torch.distributed.barrier()
         return path
 
diff --git a/tests/unit_tests/dist_checkpointing/conftest.py b/tests/unit_tests/dist_checkpointing/conftest.py
index 655550d632..fed9cdb482 100644
--- a/tests/unit_tests/dist_checkpointing/conftest.py
+++ b/tests/unit_tests/dist_checkpointing/conftest.py
@@ -18,4 +18,3 @@ def get_pyt_dist_save_sharded_strategy():
         new=get_pyt_dist_save_sharded_strategy,
     ) as _fixture:
         yield _fixture
-
diff --git a/tests/unit_tests/dist_checkpointing/models/common.py b/tests/unit_tests/dist_checkpointing/models/common.py
index 4159a2a90c..4b908ba3fc 100644
--- a/tests/unit_tests/dist_checkpointing/models/common.py
+++ b/tests/unit_tests/dist_checkpointing/models/common.py
@@ -3,34 +3,45 @@
 
 import torch
 
-from megatron.core.dist_checkpointing import save, load, load_plain_tensors
 from megatron.core import parallel_state
+from megatron.core.dist_checkpointing import load, load_plain_tensors, save
 from megatron.core.dist_checkpointing.dict_utils import diff
-from megatron.core.dist_checkpointing.serialization import \
-    get_default_save_sharded_strategy, get_default_load_sharded_strategy
-from megatron.core.dist_checkpointing.strategies.fully_parallel import \
-    FullyParallelSaveStrategyWrapper, FullyParallelLoadStrategyWrapper
+from megatron.core.dist_checkpointing.serialization import (
+    get_default_load_sharded_strategy,
+    get_default_save_sharded_strategy,
+)
+from megatron.core.dist_checkpointing.strategies.fully_parallel import (
+    FullyParallelLoadStrategyWrapper,
+    FullyParallelSaveStrategyWrapper,
+)
 from megatron.core.dist_checkpointing.validation import StrictHandling
 from tests.unit_tests.dist_checkpointing import TempNamedDir
 from tests.unit_tests.test_utilities import Utils
 
 
 def common_test_simple_sharded_state_dict_save_load(
-    initialize_model_fn, tmp_path_dist_ckpt, src_layer_spec_fn, dst_layer_spec_fn):
-    """ Simple save and load sanity check, without any equality tests. """
+    initialize_model_fn, tmp_path_dist_ckpt, src_layer_spec_fn, dst_layer_spec_fn
+):
+    """Simple save and load sanity check, without any equality tests."""
     tp = 2
     pp = 4
     Utils.initialize_model_parallel(tp, pp)
-    gpt_model = initialize_model_fn(1, src_layer_spec_fn, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp)
+    gpt_model = initialize_model_fn(
+        1, src_layer_spec_fn, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp
+    )
     with TempNamedDir(tmp_path_dist_ckpt / 'test_gpt_model') as ckpt_dir:
         # Save
         sharded_state_dict = gpt_model.sharded_state_dict()
         save(sharded_state_dict, ckpt_dir)
 
         # Load
-        gpt_model = initialize_model_fn(2, dst_layer_spec_fn, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp)
+        gpt_model = initialize_model_fn(
+            2, dst_layer_spec_fn, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp
+        )
         sharded_state_dict = gpt_model.sharded_state_dict()
-        state_dict, missing_keys, unexpected_keys = load(sharded_state_dict, ckpt_dir, strict=StrictHandling.RETURN_ALL)
+        state_dict, missing_keys, unexpected_keys = load(
+            sharded_state_dict, ckpt_dir, strict=StrictHandling.RETURN_ALL
+        )
         # Potential mismatch is because of extra states which is ok
         assert all('_extra_state' in k for k in missing_keys)
         assert all('_extra_state' in k for k in unexpected_keys)
@@ -38,21 +49,37 @@ def common_test_simple_sharded_state_dict_save_load(
     Utils.destroy_model_parallel()
 
 
-def common_test_parallel_reconfiguration_e2e(initialize_model_fn, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp,
-                                      src_layer_spec_fn, dst_layer_spec_fn, use_fpsl,
-                                      load_order="tp-dp-pp", store_order="tp-dp-pp"):
-    """ Test model saving and loading with different TP/PP """
-    with TempNamedDir(tmp_path_dist_ckpt / 'test_gpt_model_reconfiguration_model_A') as ckpt_dir_A, \
-         TempNamedDir(tmp_path_dist_ckpt / 'test_gpt_model_reconfiguration_model_B') as ckpt_dir_B:
+def common_test_parallel_reconfiguration_e2e(
+    initialize_model_fn,
+    tmp_path_dist_ckpt,
+    src_tp_pp,
+    dest_tp_pp,
+    src_layer_spec_fn,
+    dst_layer_spec_fn,
+    use_fpsl,
+    load_order="tp-dp-pp",
+    store_order="tp-dp-pp",
+):
+    """Test model saving and loading with different TP/PP"""
+    with TempNamedDir(
+        tmp_path_dist_ckpt / 'test_gpt_model_reconfiguration_model_A'
+    ) as ckpt_dir_A, TempNamedDir(
+        tmp_path_dist_ckpt / 'test_gpt_model_reconfiguration_model_B'
+    ) as ckpt_dir_B:
         # Save checkpoint A
         Utils.initialize_model_parallel(*src_tp_pp, order=load_order)
-        gpt_model_A = initialize_model_fn(1, src_layer_spec_fn, tensor_model_parallel_size=src_tp_pp[0], pipeline_model_parallel_size=src_tp_pp[1])
+        gpt_model_A = initialize_model_fn(
+            1,
+            src_layer_spec_fn,
+            tensor_model_parallel_size=src_tp_pp[0],
+            pipeline_model_parallel_size=src_tp_pp[1],
+        )
         save_strategy = get_default_save_sharded_strategy()
         if use_fpsl:
             save_strategy = FullyParallelSaveStrategyWrapper(
                 save_strategy,
                 parallel_state.get_data_parallel_group(with_context_parallel=True),
-                True
+                True,
             )
         save(gpt_model_A.sharded_state_dict(), ckpt_dir_A, save_strategy)
         regular_state_dict_A = gpt_model_A.state_dict()
@@ -61,13 +88,23 @@ def common_test_parallel_reconfiguration_e2e(initialize_model_fn, tmp_path_dist_
         # Load checkpoint A with different TP/PP and save as checkpoint B
         # No FPS this time, only FPL
         Utils.initialize_model_parallel(*dest_tp_pp, order=store_order)
-        gpt_model_B = initialize_model_fn(2, dst_layer_spec_fn, tensor_model_parallel_size=dest_tp_pp[0], pipeline_model_parallel_size=dest_tp_pp[1])
+        gpt_model_B = initialize_model_fn(
+            2,
+            dst_layer_spec_fn,
+            tensor_model_parallel_size=dest_tp_pp[0],
+            pipeline_model_parallel_size=dest_tp_pp[1],
+        )
         if use_fpsl:
             load_strategy = get_default_load_sharded_strategy(ckpt_dir_A)
             load_strategy = FullyParallelLoadStrategyWrapper(load_strategy)
         else:
             load_strategy = None
-        state_dict, missing_keys, unexpected_keys = load(gpt_model_B.sharded_state_dict(), ckpt_dir_A, load_strategy, strict=StrictHandling.RETURN_ALL)
+        state_dict, missing_keys, unexpected_keys = load(
+            gpt_model_B.sharded_state_dict(),
+            ckpt_dir_A,
+            load_strategy,
+            strict=StrictHandling.RETURN_ALL,
+        )
         # Potential mismatch is because of extra states which is ok
         assert all('_extra_state' in k for k in missing_keys)
         assert all('_extra_state' in k for k in unexpected_keys)
@@ -84,10 +121,12 @@ def common_test_parallel_reconfiguration_e2e(initialize_model_fn, tmp_path_dist_
         assert not any(map(bool, diffs)), diffs
 
         # Test both regular state dicts are equal, turning FP8 states to bytes first
-        regular_state_dict_A = {k: v for k, v in regular_state_dict_A.items()
-                                if not k.endswith('_extra_state')}
-        regular_state_dict_B = {k: v for k, v in regular_state_dict_B.items()
-                                if not k.endswith('_extra_state')}
+        regular_state_dict_A = {
+            k: v for k, v in regular_state_dict_A.items() if not k.endswith('_extra_state')
+        }
+        regular_state_dict_B = {
+            k: v for k, v in regular_state_dict_B.items() if not k.endswith('_extra_state')
+        }
         diffs = diff(regular_state_dict_A, regular_state_dict_B)
         assert not any(map(bool, diffs)), diffs
         Utils.destroy_model_parallel()
@@ -97,11 +136,18 @@ def common_test_state_dict_comparison(initialize_model_fn, tmp_path_dist_ckpt):
     tp = 2
     pp = 4
     Utils.initialize_model_parallel(tp, pp)
-    with TempNamedDir(tmp_path_dist_ckpt / 'test_state_dict_comparison_A') as ckpt_dir_A, \
-         TempNamedDir(tmp_path_dist_ckpt / 'test_state_dict_comparison_B') as ckpt_dir_B:
-        gpt_model_A = initialize_model_fn(1, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp)
+    with TempNamedDir(
+        tmp_path_dist_ckpt / 'test_state_dict_comparison_A'
+    ) as ckpt_dir_A, TempNamedDir(
+        tmp_path_dist_ckpt / 'test_state_dict_comparison_B'
+    ) as ckpt_dir_B:
+        gpt_model_A = initialize_model_fn(
+            1, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp
+        )
         save(gpt_model_A.sharded_state_dict(), ckpt_dir_A)
-        gpt_model_B = initialize_model_fn(2, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp)
+        gpt_model_B = initialize_model_fn(
+            2, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp
+        )
         save(gpt_model_B.sharded_state_dict(), ckpt_dir_B)
 
         state_dict_A = load_plain_tensors(ckpt_dir_A)
@@ -114,13 +160,16 @@ def common_test_state_dict_comparison(initialize_model_fn, tmp_path_dist_ckpt):
 
         # Test that A *keys* match B *keys*, but the tensors content is different
         only_left, only_right, mismatch = diff(state_dict_A, state_dict_B)
-        assert (not only_left and not only_right), (only_left, only_right)
+        assert not only_left and not only_right, (only_left, only_right)
         assert len(mismatch) == len(state_dict_A), (len(mismatch), (len(state_dict_A)))
     Utils.destroy_model_parallel()
 
 
-def common_test_vocab_size_padding_change(initialize_model_fn, tmp_path_dist_ckpt, vocab_size_base, src_tp_pp, dest_tp_pp):
-    """ Test model loading with different vocab size (caused by TP padding). """
+def common_test_vocab_size_padding_change(
+    initialize_model_fn, tmp_path_dist_ckpt, vocab_size_base, src_tp_pp, dest_tp_pp
+):
+    """Test model loading with different vocab size (caused by TP padding)."""
+
     def get_test_vocab_size(make_divisible_by=128):
         divisor = make_divisible_by * parallel_state.get_tensor_model_parallel_world_size()
         return int(math.ceil(vocab_size_base / divisor)) * divisor
@@ -131,17 +180,30 @@ def get_test_vocab_size(make_divisible_by=128):
         'embedding.word_embeddings.weight',
     }
 
-    with TempNamedDir(tmp_path_dist_ckpt / 'test_vocab_size_padding_change_A') as ckpt_dir_A, \
-         TempNamedDir(tmp_path_dist_ckpt / 'test_vocab_size_padding_change_B') as ckpt_dir_B:
+    with TempNamedDir(
+        tmp_path_dist_ckpt / 'test_vocab_size_padding_change_A'
+    ) as ckpt_dir_A, TempNamedDir(
+        tmp_path_dist_ckpt / 'test_vocab_size_padding_change_B'
+    ) as ckpt_dir_B:
         # Save checkpoint A
         Utils.initialize_model_parallel(*src_tp_pp)
-        gpt_model_A = initialize_model_fn(1, tensor_model_parallel_size=src_tp_pp[0], pipeline_model_parallel_size=src_tp_pp[1], vocab_size=get_test_vocab_size())
+        gpt_model_A = initialize_model_fn(
+            1,
+            tensor_model_parallel_size=src_tp_pp[0],
+            pipeline_model_parallel_size=src_tp_pp[1],
+            vocab_size=get_test_vocab_size(),
+        )
         save(gpt_model_A.sharded_state_dict(), ckpt_dir_A)
         Utils.destroy_model_parallel()
 
         # Load checkpoint A with different TP/PP and save as checkpoint B
         Utils.initialize_model_parallel(*dest_tp_pp)
-        gpt_model_B = initialize_model_fn(2, tensor_model_parallel_size=dest_tp_pp[0], pipeline_model_parallel_size=dest_tp_pp[1], vocab_size=get_test_vocab_size())
+        gpt_model_B = initialize_model_fn(
+            2,
+            tensor_model_parallel_size=dest_tp_pp[0],
+            pipeline_model_parallel_size=dest_tp_pp[1],
+            vocab_size=get_test_vocab_size(),
+        )
         state_dict = load(gpt_model_B.sharded_state_dict(), ckpt_dir_A)
         gpt_model_B.load_state_dict(state_dict)
         save(gpt_model_B.sharded_state_dict(), ckpt_dir_B)
@@ -156,7 +218,9 @@ def get_test_vocab_size(make_divisible_by=128):
             if vocab_layer_key in plain_state_dict_A:
                 ten_A = plain_state_dict_A.pop(vocab_layer_key)
                 ten_B = plain_state_dict_B.pop(vocab_layer_key)
-                assert torch.all(ten_A[:vocab_size_base] == ten_B[:vocab_size_base]), vocab_layer_key
+                assert torch.all(
+                    ten_A[:vocab_size_base] == ten_B[:vocab_size_base]
+                ), vocab_layer_key
 
         # Test other tensors are equal
         diffs = diff(plain_state_dict_A, plain_state_dict_B)
diff --git a/tests/unit_tests/dist_checkpointing/models/test_bert_model.py b/tests/unit_tests/dist_checkpointing/models/test_bert_model.py
index 74af0bc674..e4838faa3d 100644
--- a/tests/unit_tests/dist_checkpointing/models/test_bert_model.py
+++ b/tests/unit_tests/dist_checkpointing/models/test_bert_model.py
@@ -22,20 +22,35 @@
 from tests.unit_tests.test_utilities import Utils
 
 
-def initialize_bert_model(seed, layer_spec_fn=bert_layer_with_transformer_engine_spec, vocab_size=128, **config_kwargs):
+def initialize_bert_model(
+    seed, layer_spec_fn=bert_layer_with_transformer_engine_spec, vocab_size=128, **config_kwargs
+):
     os.environ['NVTE_ALLOW_NONDETERMINISTIC_ALGO'] = '0'
     torch.manual_seed(seed)
     model_parallel_cuda_manual_seed(seed)
 
     layer_spec = layer_spec_fn() if callable(layer_spec_fn) else layer_spec_fn
 
-    default_config_kwargs=dict(num_layers=8, hidden_size=16, num_attention_heads=8, use_cpu_initialization=True, pipeline_dtype=torch.bfloat16)
+    default_config_kwargs = dict(
+        num_layers=8,
+        hidden_size=16,
+        num_attention_heads=8,
+        use_cpu_initialization=True,
+        pipeline_dtype=torch.bfloat16,
+    )
     default_config_kwargs.update(**config_kwargs)
     transformer_config = TransformerConfig(**default_config_kwargs)
     pre_process = ps.is_pipeline_first_stage()
     post_process = ps.is_pipeline_last_stage()
-    model = BertModel(config=transformer_config, transformer_layer_spec=layer_spec, vocab_size=vocab_size, max_sequence_length=4,
-                     pre_process=pre_process, post_process=post_process, num_tokentypes=0)
+    model = BertModel(
+        config=transformer_config,
+        transformer_layer_spec=layer_spec,
+        vocab_size=vocab_size,
+        max_sequence_length=4,
+        pre_process=pre_process,
+        post_process=post_process,
+        num_tokentypes=0,
+    )
 
     with torch.no_grad():
         for p in model.parameters():
@@ -44,53 +59,95 @@ def initialize_bert_model(seed, layer_spec_fn=bert_layer_with_transformer_engine
 
 
 class TestBertModel:
-    @pytest.mark.parametrize('src_layer_spec', [bert_layer_with_transformer_engine_spec, bert_layer_local_spec])
-    @pytest.mark.parametrize('dst_layer_spec', [bert_layer_with_transformer_engine_spec, bert_layer_local_spec])
-    def test_sharded_state_dict_save_load(self, tmp_path_dist_ckpt,
-                                          src_layer_spec, dst_layer_spec):
-        common_test_simple_sharded_state_dict_save_load(initialize_bert_model, tmp_path_dist_ckpt,
-                                                        src_layer_spec, dst_layer_spec)
+    @pytest.mark.parametrize(
+        'src_layer_spec', [bert_layer_with_transformer_engine_spec, bert_layer_local_spec]
+    )
+    @pytest.mark.parametrize(
+        'dst_layer_spec', [bert_layer_with_transformer_engine_spec, bert_layer_local_spec]
+    )
+    def test_sharded_state_dict_save_load(self, tmp_path_dist_ckpt, src_layer_spec, dst_layer_spec):
+        common_test_simple_sharded_state_dict_save_load(
+            initialize_bert_model, tmp_path_dist_ckpt, src_layer_spec, dst_layer_spec
+        )
 
 
 class TestBERTModelReconfiguration:
     def setup_method(self, method):
         pass
-    
+
     def teardown_method(self, method):
         Utils.destroy_model_parallel()
-        
+
     @pytest.mark.parametrize(
         ('use_fpsl', 'src_tp_pp', 'dest_tp_pp', 'src_layer_spec', 'dst_layer_spec'),
         [
-            (False, (2, 4), (4, 2), bert_layer_with_transformer_engine_spec, bert_layer_with_transformer_engine_spec),
-            (False, (1, 8), (8, 1), bert_layer_with_transformer_engine_spec, bert_layer_with_transformer_engine_spec),
-            (True, (2, 1), (1, 8), bert_layer_with_transformer_engine_spec, bert_layer_with_transformer_engine_spec),
-            (False, (1, 1), (2, 2), bert_layer_with_transformer_engine_spec, bert_layer_with_transformer_engine_spec),
+            (
+                False,
+                (2, 4),
+                (4, 2),
+                bert_layer_with_transformer_engine_spec,
+                bert_layer_with_transformer_engine_spec,
+            ),
+            (
+                False,
+                (1, 8),
+                (8, 1),
+                bert_layer_with_transformer_engine_spec,
+                bert_layer_with_transformer_engine_spec,
+            ),
+            (
+                True,
+                (2, 1),
+                (1, 8),
+                bert_layer_with_transformer_engine_spec,
+                bert_layer_with_transformer_engine_spec,
+            ),
+            (
+                False,
+                (1, 1),
+                (2, 2),
+                bert_layer_with_transformer_engine_spec,
+                bert_layer_with_transformer_engine_spec,
+            ),
             (True, (2, 1), (1, 8), bert_layer_local_spec, bert_layer_local_spec),
             (True, (1, 1), (2, 4), bert_layer_with_transformer_engine_spec, bert_layer_local_spec),
             (False, (1, 8), (2, 1), bert_layer_local_spec, bert_layer_with_transformer_engine_spec),
-        ]
+        ],
     )
-    def test_parallel_reconfiguration_e2e(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp,
-                                          src_layer_spec, dst_layer_spec, use_fpsl):
-        """ Test model saving and loading with different TP/PP """
+    def test_parallel_reconfiguration_e2e(
+        self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, src_layer_spec, dst_layer_spec, use_fpsl
+    ):
+        """Test model saving and loading with different TP/PP"""
         Utils.initialize_model_parallel(src_tp_pp[0], src_tp_pp[1])
-                                        
-        common_test_parallel_reconfiguration_e2e(initialize_bert_model, tmp_path_dist_ckpt, src_tp_pp,
-                                                 dest_tp_pp, src_layer_spec, dst_layer_spec, use_fpsl)
+
+        common_test_parallel_reconfiguration_e2e(
+            initialize_bert_model,
+            tmp_path_dist_ckpt,
+            src_tp_pp,
+            dest_tp_pp,
+            src_layer_spec,
+            dst_layer_spec,
+            use_fpsl,
+        )
 
     def test_state_dict_comparison(self, tmp_path_dist_ckpt):
         common_test_state_dict_comparison(initialize_bert_model, tmp_path_dist_ckpt)
 
-    @pytest.mark.parametrize("vocab_size_base,src_tp_pp,dest_tp_pp", [
-        (128, (2, 4), (4, 2)),
-        (17, (1, 8), (8, 1)),
-        (127, (1, 8), (8, 1)),
-        (31123, (1, 1), (1, 8)),
-        (17, (1, 1), (1, 8)),
-    ])
-    def test_vocab_size_padding_change(self, tmp_path_dist_ckpt, vocab_size_base, src_tp_pp, dest_tp_pp):
-        """ Test model loading with different vocab size (caused by TP padding). """
+    @pytest.mark.parametrize(
+        "vocab_size_base,src_tp_pp,dest_tp_pp",
+        [
+            (128, (2, 4), (4, 2)),
+            (17, (1, 8), (8, 1)),
+            (127, (1, 8), (8, 1)),
+            (31123, (1, 1), (1, 8)),
+            (17, (1, 1), (1, 8)),
+        ],
+    )
+    def test_vocab_size_padding_change(
+        self, tmp_path_dist_ckpt, vocab_size_base, src_tp_pp, dest_tp_pp
+    ):
+        """Test model loading with different vocab size (caused by TP padding)."""
         Utils.initialize_model_parallel(src_tp_pp[0], src_tp_pp[1])
-        common_test_vocab_size_padding_change(initialize_bert_model, tmp_path_dist_ckpt, vocab_size_base,
-                                              src_tp_pp, dest_tp_pp)
+        common_test_vocab_size_padding_change(
+            initialize_bert_model, tmp_path_dist_ckpt, vocab_size_base, src_tp_pp, dest_tp_pp
+        )
diff --git a/tests/unit_tests/dist_checkpointing/models/test_gpt_model.py b/tests/unit_tests/dist_checkpointing/models/test_gpt_model.py
index b044ff15c7..20699d4500 100644
--- a/tests/unit_tests/dist_checkpointing/models/test_gpt_model.py
+++ b/tests/unit_tests/dist_checkpointing/models/test_gpt_model.py
@@ -23,13 +23,25 @@ def initialize_gpt_model(seed, layer_spec_fn=gpt_te_spec, vocab_size=128, **conf
     torch.manual_seed(seed)
     model_parallel_cuda_manual_seed(seed)
 
-    default_config_kwargs=dict(num_layers=8, hidden_size=16, num_attention_heads=8, use_cpu_initialization=True, pipeline_dtype=torch.bfloat16)
+    default_config_kwargs = dict(
+        num_layers=8,
+        hidden_size=16,
+        num_attention_heads=8,
+        use_cpu_initialization=True,
+        pipeline_dtype=torch.bfloat16,
+    )
     default_config_kwargs.update(**config_kwargs)
     transformer_config = TransformerConfig(**default_config_kwargs)
     pre_process = ps.is_pipeline_first_stage()
     post_process = ps.is_pipeline_last_stage()
-    model = GPTModel(config=transformer_config, transformer_layer_spec=layer_spec_fn(), vocab_size=vocab_size, max_sequence_length=4,
-                     pre_process=pre_process, post_process=post_process)
+    model = GPTModel(
+        config=transformer_config,
+        transformer_layer_spec=layer_spec_fn(),
+        vocab_size=vocab_size,
+        max_sequence_length=4,
+        pre_process=pre_process,
+        post_process=post_process,
+    )
 
     with torch.no_grad():
         for p in model.parameters():
@@ -40,53 +52,86 @@ def initialize_gpt_model(seed, layer_spec_fn=gpt_te_spec, vocab_size=128, **conf
 class TestGPTModel:
     @pytest.mark.parametrize('src_layer_spec_fn', [gpt_te_spec, gpt_local_spec])
     @pytest.mark.parametrize('dst_layer_spec_fn', [gpt_te_spec, gpt_local_spec])
-    def test_sharded_state_dict_save_load(self, tmp_path_dist_ckpt,
-                                          src_layer_spec_fn, dst_layer_spec_fn):
-        common_test_simple_sharded_state_dict_save_load(initialize_gpt_model, tmp_path_dist_ckpt,
-                                                        src_layer_spec_fn, dst_layer_spec_fn)
+    def test_sharded_state_dict_save_load(
+        self, tmp_path_dist_ckpt, src_layer_spec_fn, dst_layer_spec_fn
+    ):
+        common_test_simple_sharded_state_dict_save_load(
+            initialize_gpt_model, tmp_path_dist_ckpt, src_layer_spec_fn, dst_layer_spec_fn
+        )
 
 
 class TestGPTModelReconfiguration:
     def setup_method(self, method):
         pass
-    
+
     def teardown_method(self, method):
         Utils.destroy_model_parallel()
 
     @pytest.mark.parametrize(
-        ('use_fpsl', 'load_order', 'store_order', 'src_tp_pp', 'dest_tp_pp', 'src_layer_spec_fn', 'dst_layer_spec_fn'),
+        (
+            'use_fpsl',
+            'load_order',
+            'store_order',
+            'src_tp_pp',
+            'dest_tp_pp',
+            'src_layer_spec_fn',
+            'dst_layer_spec_fn',
+        ),
         [
             (False, 'tp-dp-pp', 'tp-dp-pp', (2, 4), (4, 2), gpt_te_spec, gpt_te_spec),
             (False, 'tp-pp-dp', 'tp-pp-dp', (1, 8), (8, 1), gpt_te_spec, gpt_te_spec),
-            (True,  'tp-dp-pp', 'tp-pp-dp', (2, 1), (1, 8), gpt_te_spec, gpt_te_spec),
+            (True, 'tp-dp-pp', 'tp-pp-dp', (2, 1), (1, 8), gpt_te_spec, gpt_te_spec),
             (False, 'tp-dp-pp', 'tp-dp-pp', (1, 1), (2, 2), gpt_te_spec, gpt_te_spec),
-            (True,  'tp-pp-dp', 'tp-pp-dp', (2, 1), (1, 8), gpt_local_spec, gpt_local_spec),
+            (True, 'tp-pp-dp', 'tp-pp-dp', (2, 1), (1, 8), gpt_local_spec, gpt_local_spec),
             (False, 'tp-dp-pp', 'tp-pp-dp', (1, 1), (2, 4), gpt_te_spec, gpt_local_spec),
-            (True,  'tp-dp-pp', 'tp-dp-pp', (2, 4), (4, 2), gpt_local_spec, gpt_te_spec),
+            (True, 'tp-dp-pp', 'tp-dp-pp', (2, 4), (4, 2), gpt_local_spec, gpt_te_spec),
             (False, 'tp-pp-dp', 'tp-pp-dp', (2, 1), (1, 8), gpt_te_spec, gpt_local_spec),
             (False, 'tp-dp-pp', 'tp-pp-dp', (2, 4), (2, 4), gpt_local_spec, gpt_local_spec),
-        ]
+        ],
     )
-    def test_parallel_reconfiguration_e2e(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp,
-                                          src_layer_spec_fn, dst_layer_spec_fn, use_fpsl, load_order, store_order):
-        """ Test model saving and loading with different TP/PP """
+    def test_parallel_reconfiguration_e2e(
+        self,
+        tmp_path_dist_ckpt,
+        src_tp_pp,
+        dest_tp_pp,
+        src_layer_spec_fn,
+        dst_layer_spec_fn,
+        use_fpsl,
+        load_order,
+        store_order,
+    ):
+        """Test model saving and loading with different TP/PP"""
         Utils.initialize_model_parallel(src_tp_pp[0], src_tp_pp[1])
-        common_test_parallel_reconfiguration_e2e(initialize_gpt_model, tmp_path_dist_ckpt, src_tp_pp,
-                                                 dest_tp_pp, src_layer_spec_fn, dst_layer_spec_fn, use_fpsl, load_order, store_order)
-
+        common_test_parallel_reconfiguration_e2e(
+            initialize_gpt_model,
+            tmp_path_dist_ckpt,
+            src_tp_pp,
+            dest_tp_pp,
+            src_layer_spec_fn,
+            dst_layer_spec_fn,
+            use_fpsl,
+            load_order,
+            store_order,
+        )
 
     def test_state_dict_comparison(self, tmp_path_dist_ckpt):
         common_test_state_dict_comparison(initialize_gpt_model, tmp_path_dist_ckpt)
 
-    @pytest.mark.parametrize("vocab_size_base,src_tp_pp,dest_tp_pp", [
-        (128, (2, 4), (4, 2)),
-        (17, (1, 8), (8, 1)),
-        (127, (1, 8), (8, 1)),
-        (31123, (1, 1), (1, 8)),
-        (17, (1, 1), (1, 8)),
-    ])
-    def test_vocab_size_padding_change(self, tmp_path_dist_ckpt, vocab_size_base, src_tp_pp, dest_tp_pp):
-        """ Test model loading with different vocab size (caused by TP padding). """
+    @pytest.mark.parametrize(
+        "vocab_size_base,src_tp_pp,dest_tp_pp",
+        [
+            (128, (2, 4), (4, 2)),
+            (17, (1, 8), (8, 1)),
+            (127, (1, 8), (8, 1)),
+            (31123, (1, 1), (1, 8)),
+            (17, (1, 1), (1, 8)),
+        ],
+    )
+    def test_vocab_size_padding_change(
+        self, tmp_path_dist_ckpt, vocab_size_base, src_tp_pp, dest_tp_pp
+    ):
+        """Test model loading with different vocab size (caused by TP padding)."""
         Utils.initialize_model_parallel(src_tp_pp[0], src_tp_pp[1])
-        common_test_vocab_size_padding_change(initialize_gpt_model, tmp_path_dist_ckpt, vocab_size_base,
-                                              src_tp_pp, dest_tp_pp)
+        common_test_vocab_size_padding_change(
+            initialize_gpt_model, tmp_path_dist_ckpt, vocab_size_base, src_tp_pp, dest_tp_pp
+        )
diff --git a/tests/unit_tests/dist_checkpointing/models/test_grouped_mlp.py b/tests/unit_tests/dist_checkpointing/models/test_grouped_mlp.py
index df0005e1a3..1bab7ce54b 100644
--- a/tests/unit_tests/dist_checkpointing/models/test_grouped_mlp.py
+++ b/tests/unit_tests/dist_checkpointing/models/test_grouped_mlp.py
@@ -30,8 +30,15 @@ def initialize_grouped_mlp(seed, glu=True, **config_kwargs):
     pp_size = parallel_state.get_pipeline_model_parallel_world_size()
     num_moe_experts = 8
     num_local_experts = num_moe_experts // parallel_state.get_expert_model_parallel_world_size()
-    default_config_kwargs = dict(num_layers=pp_size, hidden_size=12, num_attention_heads=4, num_moe_experts=num_moe_experts, use_cpu_initialization=True,
-                                 gated_linear_unit=glu, add_bias_linear=False)
+    default_config_kwargs = dict(
+        num_layers=pp_size,
+        hidden_size=12,
+        num_attention_heads=4,
+        num_moe_experts=num_moe_experts,
+        use_cpu_initialization=True,
+        gated_linear_unit=glu,
+        add_bias_linear=False,
+    )
     default_config_kwargs.update(**config_kwargs)
     transformer_config = TransformerConfig(**default_config_kwargs)
     model = GroupedMLP(num_local_experts, transformer_config)
@@ -47,36 +54,44 @@ def get_pp_offsets():
 class TestGroupedMLPReconfiguration:
     def setup_method(self, method):
         pass
-    
+
     def teardown_method(self, method):
         Utils.destroy_model_parallel()
 
-    @pytest.mark.parametrize("use_fpsl,src_tp_pp_exp,dest_tp_pp_exp,use_glu", [
-        # changing PP is impossible because the number of layers must be the same
-        (False, (2, 4, 1), (2, 4, 1), False),
-        (True,  (2, 4, 1), (2, 4, 1), False),
-        (False, (1, 1, 1), (1, 1, 1), False),
-        (True,  (1, 1, 1), (1, 1, 4), False),
-        (False, (1, 1, 8), (1, 1, 2), False),
-        (False, (2, 2, 2), (4, 2, 1), False),
-        (True,  (1, 1, 4), (8, 1, 1), False),
-        (False, (1, 8, 1), (1, 8, 1), False),
-        (False, (1, 1, 4), (2, 1, 1), False),
-        (False, (1, 1, 1), (1, 1, 1), True),
-        (False, (1, 1, 1), (1, 1, 4), True),
-        (True,  (1, 1, 1), (2, 1, 1), True),
-        (False, (1, 1, 4), (8, 1, 1), True),
-        (True,  (2, 1, 4), (1, 1, 8), True),
-        (False, (2, 1, 4), (1, 1, 8), True),
-    ])
-    def test_parallel_reconfiguration_e2e(self, tmp_path_dist_ckpt, src_tp_pp_exp, dest_tp_pp_exp, use_glu, use_fpsl):
-        """ Test model saving and loading with different TP/PP/expert parallelism """
+    @pytest.mark.parametrize(
+        "use_fpsl,src_tp_pp_exp,dest_tp_pp_exp,use_glu",
+        [
+            # changing PP is impossible because the number of layers must be the same
+            (False, (2, 4, 1), (2, 4, 1), False),
+            (True, (2, 4, 1), (2, 4, 1), False),
+            (False, (1, 1, 1), (1, 1, 1), False),
+            (True, (1, 1, 1), (1, 1, 4), False),
+            (False, (1, 1, 8), (1, 1, 2), False),
+            (False, (2, 2, 2), (4, 2, 1), False),
+            (True, (1, 1, 4), (8, 1, 1), False),
+            (False, (1, 8, 1), (1, 8, 1), False),
+            (False, (1, 1, 4), (2, 1, 1), False),
+            (False, (1, 1, 1), (1, 1, 1), True),
+            (False, (1, 1, 1), (1, 1, 4), True),
+            (True, (1, 1, 1), (2, 1, 1), True),
+            (False, (1, 1, 4), (8, 1, 1), True),
+            (True, (2, 1, 4), (1, 1, 8), True),
+            (False, (2, 1, 4), (1, 1, 8), True),
+        ],
+    )
+    def test_parallel_reconfiguration_e2e(
+        self, tmp_path_dist_ckpt, src_tp_pp_exp, dest_tp_pp_exp, use_glu, use_fpsl
+    ):
+        """Test model saving and loading with different TP/PP/expert parallelism"""
         src_tp, src_pp, src_exp = src_tp_pp_exp
         dest_tp, dest_pp, dest_exp = dest_tp_pp_exp
         Utils.initialize_model_parallel(src_tp, src_pp, expert_model_parallel_size=src_exp)
-        
-        with TempNamedDir(tmp_path_dist_ckpt / 'test_grouped_mlp_reconfiguration_model_A') as ckpt_dir_A, \
-             TempNamedDir(tmp_path_dist_ckpt / 'test_grouped_mlp_reconfiguration_model_B') as ckpt_dir_B:
+
+        with TempNamedDir(
+            tmp_path_dist_ckpt / 'test_grouped_mlp_reconfiguration_model_A'
+        ) as ckpt_dir_A, TempNamedDir(
+            tmp_path_dist_ckpt / 'test_grouped_mlp_reconfiguration_model_B'
+        ) as ckpt_dir_B:
             # Save checkpoint A
             model_A = initialize_grouped_mlp(1, use_glu)
             sharded_state_dict = model_A.sharded_state_dict(sharded_offsets=get_pp_offsets())
@@ -86,7 +101,7 @@ def test_parallel_reconfiguration_e2e(self, tmp_path_dist_ckpt, src_tp_pp_exp, d
                 save_strategy = FullyParallelSaveStrategyWrapper(
                     save_strategy,
                     parallel_state.get_data_parallel_group(with_context_parallel=True),
-                    True
+                    True,
                 )
             save(sharded_state_dict, ckpt_dir_A, save_strategy)
             Utils.destroy_model_parallel()
@@ -97,11 +112,17 @@ def test_parallel_reconfiguration_e2e(self, tmp_path_dist_ckpt, src_tp_pp_exp, d
             model_B = initialize_grouped_mlp(2, use_glu)
             if use_fpsl:
                 load_strategy = get_default_load_sharded_strategy(ckpt_dir_A)
-                load_strategy = FullyParallelLoadStrategyWrapper(load_strategy,
-                                                                 parallel_state.get_data_parallel_group(with_context_parallel=True))
+                load_strategy = FullyParallelLoadStrategyWrapper(
+                    load_strategy,
+                    parallel_state.get_data_parallel_group(with_context_parallel=True),
+                )
             else:
                 load_strategy = None
-            state_dict = load(model_B.sharded_state_dict(sharded_offsets=get_pp_offsets()), ckpt_dir_A, load_strategy)
+            state_dict = load(
+                model_B.sharded_state_dict(sharded_offsets=get_pp_offsets()),
+                ckpt_dir_A,
+                load_strategy,
+            )
             model_B.load_state_dict(state_dict)
             save(model_B.sharded_state_dict(sharded_offsets=get_pp_offsets()), ckpt_dir_B)
             Utils.destroy_model_parallel()
@@ -114,41 +135,51 @@ def test_parallel_reconfiguration_e2e(self, tmp_path_dist_ckpt, src_tp_pp_exp, d
             assert not any(map(bool, diffs)), diffs
             Utils.destroy_model_parallel()
 
-    @pytest.mark.parametrize("src_module,src_tp_pp_exp,dest_tp_pp_exp,use_glu", [
-        # changing PP is impossible because the number of layers must be the same
-        ('sequential', (2, 4, 1), (2, 4, 1), False),
-        ('sequential', (1, 1, 1), (1, 1, 4), False),
-        ('sequential', (2, 2, 2), (4, 2, 1), False),
-        ('sequential', (1, 1, 4), (8, 1, 1), False),
-        ('sequential', (2, 1, 4), (1, 1, 8), False),
-        ('sequential', (2, 4, 1), (2, 4, 1), True),
-        ('sequential', (1, 1, 1), (1, 1, 4), True),
-        ('sequential', (2, 2, 2), (4, 2, 1), True),
-        ('sequential', (1, 1, 4), (8, 1, 1), True),
-        ('sequential', (2, 1, 4), (1, 1, 8), True),
-        ('grouped', (2, 4, 1), (2, 4, 1), False),
-        ('grouped', (1, 1, 1), (1, 1, 4), False),
-        ('grouped', (2, 2, 2), (4, 2, 1), False),
-        ('grouped', (1, 1, 4), (8, 1, 1), False),
-        ('grouped', (2, 1, 4), (1, 1, 8), False),
-        ('grouped', (2, 4, 1), (2, 4, 1), True),
-        ('grouped', (1, 1, 1), (1, 1, 4), True),
-        ('grouped', (2, 2, 2), (4, 2, 1), True),
-        ('grouped', (1, 1, 4), (8, 1, 1), True),
-        ('grouped', (2, 1, 4), (1, 1, 8), True),
-    ])
-    def test_sequential_grouped_mlp_interchangeable(self, tmp_path_dist_ckpt, src_tp_pp_exp, dest_tp_pp_exp, use_glu, src_module):
-        """ Test model saving and loading with different TP/PP/expert parallelism """
+    @pytest.mark.parametrize(
+        "src_module,src_tp_pp_exp,dest_tp_pp_exp,use_glu",
+        [
+            # changing PP is impossible because the number of layers must be the same
+            ('sequential', (2, 4, 1), (2, 4, 1), False),
+            ('sequential', (1, 1, 1), (1, 1, 4), False),
+            ('sequential', (2, 2, 2), (4, 2, 1), False),
+            ('sequential', (1, 1, 4), (8, 1, 1), False),
+            ('sequential', (2, 1, 4), (1, 1, 8), False),
+            ('sequential', (2, 4, 1), (2, 4, 1), True),
+            ('sequential', (1, 1, 1), (1, 1, 4), True),
+            ('sequential', (2, 2, 2), (4, 2, 1), True),
+            ('sequential', (1, 1, 4), (8, 1, 1), True),
+            ('sequential', (2, 1, 4), (1, 1, 8), True),
+            ('grouped', (2, 4, 1), (2, 4, 1), False),
+            ('grouped', (1, 1, 1), (1, 1, 4), False),
+            ('grouped', (2, 2, 2), (4, 2, 1), False),
+            ('grouped', (1, 1, 4), (8, 1, 1), False),
+            ('grouped', (2, 1, 4), (1, 1, 8), False),
+            ('grouped', (2, 4, 1), (2, 4, 1), True),
+            ('grouped', (1, 1, 1), (1, 1, 4), True),
+            ('grouped', (2, 2, 2), (4, 2, 1), True),
+            ('grouped', (1, 1, 4), (8, 1, 1), True),
+            ('grouped', (2, 1, 4), (1, 1, 8), True),
+        ],
+    )
+    def test_sequential_grouped_mlp_interchangeable(
+        self, tmp_path_dist_ckpt, src_tp_pp_exp, dest_tp_pp_exp, use_glu, src_module
+    ):
+        """Test model saving and loading with different TP/PP/expert parallelism"""
         src_tp, src_pp, src_exp = src_tp_pp_exp
         dest_tp, dest_pp, dest_exp = dest_tp_pp_exp
         Utils.initialize_model_parallel(src_tp, src_pp, expert_model_parallel_size=src_exp)
 
-        with TempNamedDir(tmp_path_dist_ckpt / 'test_sequential_grouped_mlp_interchangeable_model_A') as ckpt_dir_A, \
-             TempNamedDir(tmp_path_dist_ckpt / 'test_sequential_grouped_mlp_interchangeable_model_B') as ckpt_dir_B:
+        with TempNamedDir(
+            tmp_path_dist_ckpt / 'test_sequential_grouped_mlp_interchangeable_model_A'
+        ) as ckpt_dir_A, TempNamedDir(
+            tmp_path_dist_ckpt / 'test_sequential_grouped_mlp_interchangeable_model_B'
+        ) as ckpt_dir_B:
             # Save checkpoint A
-            
+
             if src_module == 'sequential':
-                model_A = initialize_expert_layer(1, use_glu, add_bias_linear=False, moe_grouped_gemm=False)
+                model_A = initialize_expert_layer(
+                    1, use_glu, add_bias_linear=False, moe_grouped_gemm=False
+                )
             else:
                 model_A = initialize_grouped_mlp(1, use_glu)
             sharded_state_dict = model_A.sharded_state_dict(sharded_offsets=get_pp_offsets())
@@ -161,9 +192,15 @@ def test_sequential_grouped_mlp_interchangeable(self, tmp_path_dist_ckpt, src_tp
             if src_module == 'sequential':
                 model_B = initialize_grouped_mlp(1, use_glu)
             else:
-                model_B = initialize_expert_layer(1, use_glu, add_bias_linear=False, moe_grouped_gemm=False)
+                model_B = initialize_expert_layer(
+                    1, use_glu, add_bias_linear=False, moe_grouped_gemm=False
+                )
             load_strategy = None
-            state_dict = load(model_B.sharded_state_dict(sharded_offsets=get_pp_offsets()), ckpt_dir_A, load_strategy)
+            state_dict = load(
+                model_B.sharded_state_dict(sharded_offsets=get_pp_offsets()),
+                ckpt_dir_A,
+                load_strategy,
+            )
             model_B.load_state_dict(state_dict)
             save(model_B.sharded_state_dict(sharded_offsets=get_pp_offsets()), ckpt_dir_B)
             Utils.destroy_model_parallel()
@@ -174,4 +211,4 @@ def test_sequential_grouped_mlp_interchangeable(self, tmp_path_dist_ckpt, src_tp
             state_dict_B = load_plain_tensors(ckpt_dir_B)
             diffs = diff(state_dict_A, state_dict_B)
             assert not any(map(bool, diffs)), diffs
-            Utils.destroy_model_parallel()
\ No newline at end of file
+            Utils.destroy_model_parallel()
diff --git a/tests/unit_tests/dist_checkpointing/models/test_mlp_glu.py b/tests/unit_tests/dist_checkpointing/models/test_mlp_glu.py
index 04148a44d4..1a0851039a 100644
--- a/tests/unit_tests/dist_checkpointing/models/test_mlp_glu.py
+++ b/tests/unit_tests/dist_checkpointing/models/test_mlp_glu.py
@@ -22,9 +22,16 @@
 def initialize_mlp(glu=True):
     model_parallel_cuda_manual_seed(123)
     pp_size = parallel_state.get_pipeline_model_parallel_world_size()
-    transformer_config = TransformerConfig(num_layers=pp_size, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True,
-                                           gated_linear_unit=glu)
-    return MLP(transformer_config, get_gpt_layer_with_transformer_engine_spec().submodules.mlp.submodules)
+    transformer_config = TransformerConfig(
+        num_layers=pp_size,
+        hidden_size=12,
+        num_attention_heads=4,
+        use_cpu_initialization=True,
+        gated_linear_unit=glu,
+    )
+    return MLP(
+        transformer_config, get_gpt_layer_with_transformer_engine_spec().submodules.mlp.submodules
+    )
 
 
 def get_pp_offsets():
@@ -36,23 +43,29 @@ def get_pp_offsets():
 class TestParallelMLPWithGLU:
     def setup_method(self, method):
         pass
-    
+
     def teardown_method(self, method):
         Utils.destroy_model_parallel()
-        
-    @pytest.mark.parametrize("src_tp_pp,dest_tp_pp", [
-        # changing PP is impossible because the number of layers must be the same
-        ((2, 2), (4, 2)),
-        ((1, 1), (8, 1)),
-        ((1, 8), (1, 8)),
-        ((1, 1), (2, 1)),
-    ])
+
+    @pytest.mark.parametrize(
+        "src_tp_pp,dest_tp_pp",
+        [
+            # changing PP is impossible because the number of layers must be the same
+            ((2, 2), (4, 2)),
+            ((1, 1), (8, 1)),
+            ((1, 8), (1, 8)),
+            ((1, 1), (2, 1)),
+        ],
+    )
     def test_parallel_reconfiguration_e2e(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp):
-        """ Test module saving and loading with different TP/PP """
+        """Test module saving and loading with different TP/PP"""
         Utils.initialize_model_parallel(*src_tp_pp)
-        
-        with TempNamedDir(tmp_path_dist_ckpt / 'test_mlp_glu_reconfiguration_model_A') as ckpt_dir_A, \
-             TempNamedDir(tmp_path_dist_ckpt / 'test_mlp_glu_reconfiguration_model_B') as ckpt_dir_B:
+
+        with TempNamedDir(
+            tmp_path_dist_ckpt / 'test_mlp_glu_reconfiguration_model_A'
+        ) as ckpt_dir_A, TempNamedDir(
+            tmp_path_dist_ckpt / 'test_mlp_glu_reconfiguration_model_B'
+        ) as ckpt_dir_B:
             # Save checkpoint A
             mlp_A = initialize_mlp()
             save(mlp_A.sharded_state_dict(sharded_offsets=get_pp_offsets()), ckpt_dir_A)
@@ -61,7 +74,9 @@ def test_parallel_reconfiguration_e2e(self, tmp_path_dist_ckpt, src_tp_pp, dest_
             # Load checkpoint A with different TP/PP and save as checkpoint B
             Utils.initialize_model_parallel(*dest_tp_pp)
             mlp_B = initialize_mlp()
-            state_dict = load(mlp_B.sharded_state_dict(sharded_offsets=get_pp_offsets()), ckpt_dir_A)
+            state_dict = load(
+                mlp_B.sharded_state_dict(sharded_offsets=get_pp_offsets()), ckpt_dir_A
+            )
             mlp_B.load_state_dict(state_dict)
             save(mlp_B.sharded_state_dict(sharded_offsets=get_pp_offsets()), ckpt_dir_B)
             Utils.destroy_model_parallel()
diff --git a/tests/unit_tests/dist_checkpointing/models/test_retro_model.py b/tests/unit_tests/dist_checkpointing/models/test_retro_model.py
index 013543def2..cf972f0c53 100644
--- a/tests/unit_tests/dist_checkpointing/models/test_retro_model.py
+++ b/tests/unit_tests/dist_checkpointing/models/test_retro_model.py
@@ -18,7 +18,7 @@ def initialize_retro_model(seed, decoder_spec_fn, spec_type, num_layers=9, **con
     torch.manual_seed(seed)
     model_parallel_cuda_manual_seed(seed)
 
-    default_config_kwargs=dict(
+    default_config_kwargs = dict(
         num_layers=num_layers,
         hidden_size=16,
         num_attention_heads=12,
@@ -35,11 +35,17 @@ def initialize_retro_model(seed, decoder_spec_fn, spec_type, num_layers=9, **con
     pre_process = ps.is_pipeline_first_stage()
     post_process = ps.is_pipeline_last_stage()
 
-
-    de_block_spec = decoder_spec_fn(retro_config, use_transformer_engine=True if spec_type=="te" else False)
-    model = RetroModel(config=retro_config, transformer_layer_spec=de_block_spec,
-                       pre_process=pre_process, post_process=post_process,
-                       vocab_size=29184, max_sequence_length=4)
+    de_block_spec = decoder_spec_fn(
+        retro_config, use_transformer_engine=True if spec_type == "te" else False
+    )
+    model = RetroModel(
+        config=retro_config,
+        transformer_layer_spec=de_block_spec,
+        pre_process=pre_process,
+        post_process=post_process,
+        vocab_size=29184,
+        max_sequence_length=4,
+    )
 
     with torch.no_grad():
         for p in model.parameters():
@@ -50,14 +56,16 @@ def initialize_retro_model(seed, decoder_spec_fn, spec_type, num_layers=9, **con
 class TestRetroModel:
     def setup_method(self, method):
         pass
-    
+
     def teardown_method(self, method):
         Utils.destroy_model_parallel()
-        
+
     @pytest.mark.parametrize('src_spec_type', ['te', 'local'])
     @pytest.mark.parametrize('dst_spec_type', ['te', 'local'])
     @pytest.mark.parametrize('model_type', ['retro'])
-    def test_sharded_state_dict_save_load(self, tmp_path_dist_ckpt, src_spec_type, dst_spec_type, model_type):
+    def test_sharded_state_dict_save_load(
+        self, tmp_path_dist_ckpt, src_spec_type, dst_spec_type, model_type
+    ):
         decoder_spec_fn = get_retro_decoder_block_spec
 
         Utils.initialize_model_parallel(1, 1)
@@ -71,7 +79,9 @@ def test_sharded_state_dict_save_load(self, tmp_path_dist_ckpt, src_spec_type, d
             gpt_model = initialize_retro_model(2, decoder_spec_fn, dst_spec_type)
             sharded_state_dict = gpt_model.sharded_state_dict()
 
-            state_dict, missing_keys, unexpected_keys = load(sharded_state_dict, ckpt_dir, strict=StrictHandling.RETURN_ALL)
+            state_dict, missing_keys, unexpected_keys = load(
+                sharded_state_dict, ckpt_dir, strict=StrictHandling.RETURN_ALL
+            )
             # Potential mismatch is because of extra states which is ok
             assert all('_extra_state' in k for k in missing_keys)
             assert all('_extra_state' in k for k in unexpected_keys)
diff --git a/tests/unit_tests/dist_checkpointing/models/test_sequential_mlp.py b/tests/unit_tests/dist_checkpointing/models/test_sequential_mlp.py
index 0bc07298a4..111e982a35 100644
--- a/tests/unit_tests/dist_checkpointing/models/test_sequential_mlp.py
+++ b/tests/unit_tests/dist_checkpointing/models/test_sequential_mlp.py
@@ -26,6 +26,7 @@
 
 _te_version = packaging.version.Version(version("transformer-engine"))
 
+
 def initialize_expert_layer(seed, glu=True, moe_grouped_gemm=False, **config_kwargs):
     torch.manual_seed(seed)
     model_parallel_cuda_manual_seed(seed)
@@ -62,17 +63,19 @@ def get_pp_offsets():
     pp_size = parallel_state.get_pipeline_model_parallel_world_size()
     return ((0, pp_rank, pp_size),)
 
+
 moe_grouped_gemm_options = [False]
 if _te_version >= packaging.version.Version("1.9.0.dev0"):
     moe_grouped_gemm_options.append(True)
 
+
 class TestExpertLayerReconfiguration:
     def setup_method(self, method):
         pass
-    
+
     def teardown_method(self, method):
         Utils.destroy_model_parallel()
-        
+
     @pytest.mark.parametrize(
         "use_fpsl,src_tp_pp_exp,dest_tp_pp_exp,use_glu",
         [
@@ -96,7 +99,7 @@ def teardown_method(self, method):
     def test_parallel_reconfiguration_e2e(
         self, tmp_path_dist_ckpt, src_tp_pp_exp, dest_tp_pp_exp, use_glu, use_fpsl, moe_grouped_gemm
     ):
-        """ Test model saving and loading with different TP/PP/expert parallelism """
+        """Test model saving and loading with different TP/PP/expert parallelism"""
         src_tp, src_pp, src_exp = src_tp_pp_exp
         dest_tp, dest_pp, dest_exp = dest_tp_pp_exp
         # Save checkpoint A
@@ -180,7 +183,7 @@ def test_parallel_reconfiguration_e2e(
     def test_sequential_grouped_mlp_interchangeable(
         self, tmp_path_dist_ckpt, src_tp_pp_exp, dest_tp_pp_exp, use_glu, src_module
     ):
-        """ Test model saving and loading with different TP/PP/expert parallelism """
+        """Test model saving and loading with different TP/PP/expert parallelism"""
         src_tp, src_pp, src_exp = src_tp_pp_exp
         dest_tp, dest_pp, dest_exp = dest_tp_pp_exp
         # Save checkpoint A
@@ -190,7 +193,7 @@ def test_sequential_grouped_mlp_interchangeable(
         ) as ckpt_dir_A, TempNamedDir(
             tmp_path_dist_ckpt / 'test_sequential_grouped_mlp_interchangeable_model_B'
         ) as ckpt_dir_B:
-            
+
             model_A = initialize_expert_layer(
                 1, use_glu, moe_grouped_gemm=src_module != 'sequential'
             )
diff --git a/tests/unit_tests/dist_checkpointing/models/test_t5_model.py b/tests/unit_tests/dist_checkpointing/models/test_t5_model.py
index da1ae4b093..07c9f8676a 100644
--- a/tests/unit_tests/dist_checkpointing/models/test_t5_model.py
+++ b/tests/unit_tests/dist_checkpointing/models/test_t5_model.py
@@ -34,9 +34,14 @@ def initialize_t5_model(seed, encoder_spec_fn, decoder_spec_fn, num_layers=2, **
     torch.manual_seed(seed)
     model_parallel_cuda_manual_seed(seed)
 
-    default_config_kwargs=dict(
-        num_layers=num_layers, hidden_size=16, num_attention_heads=12, kv_channels=64, ffn_hidden_size=64,
-        use_cpu_initialization=True, pipeline_dtype=torch.bfloat16
+    default_config_kwargs = dict(
+        num_layers=num_layers,
+        hidden_size=16,
+        num_attention_heads=12,
+        kv_channels=64,
+        ffn_hidden_size=64,
+        use_cpu_initialization=True,
+        pipeline_dtype=torch.bfloat16,
     )
     default_config_kwargs.update(**config_kwargs)
     transformer_config = TransformerConfig(**default_config_kwargs)
@@ -45,10 +50,16 @@ def initialize_t5_model(seed, encoder_spec_fn, decoder_spec_fn, num_layers=2, **
 
     en_block_spec = TransformerBlockSubmodules([encoder_spec_fn()] * num_layers)
     de_block_spec = TransformerBlockSubmodules([decoder_spec_fn()] * num_layers)
-    model = T5Model(encoder_config=transformer_config, config=transformer_config,
-                    transformer_encoder_layer_spec=en_block_spec, transformer_decoder_layer_spec=de_block_spec,
-                    pre_process=False, post_process=False,
-                    vocab_size=29184, max_sequence_length=4)
+    model = T5Model(
+        encoder_config=transformer_config,
+        config=transformer_config,
+        transformer_encoder_layer_spec=en_block_spec,
+        transformer_decoder_layer_spec=de_block_spec,
+        pre_process=False,
+        post_process=False,
+        vocab_size=29184,
+        max_sequence_length=4,
+    )
 
     with torch.no_grad():
         for p in model.parameters():
@@ -59,14 +70,16 @@ def initialize_t5_model(seed, encoder_spec_fn, decoder_spec_fn, num_layers=2, **
 class TestT5Model:
     def setup_method(self, method):
         pass
-    
+
     def teardown_method(self, method):
         Utils.destroy_model_parallel()
-        
+
     @pytest.mark.parametrize('src_spec_type', ['te', 'local'])
     @pytest.mark.parametrize('dst_spec_type', ['te', 'local'])
     @pytest.mark.parametrize('model_type', ['t5'])
-    def test_sharded_state_dict_save_load(self, tmp_path_dist_ckpt, src_spec_type, dst_spec_type, model_type):
+    def test_sharded_state_dict_save_load(
+        self, tmp_path_dist_ckpt, src_spec_type, dst_spec_type, model_type
+    ):
         enc_dec_spec_fn = {
             'te': {
                 't5': (t5_encoder_te_spec, t5_decoder_te_spec),
@@ -75,7 +88,7 @@ def test_sharded_state_dict_save_load(self, tmp_path_dist_ckpt, src_spec_type, d
             'local': {
                 't5': (t5_encoder_local_spec, t5_decoder_local_spec),
                 'retro': (get_retro_encoder_layer_local_spec, get_retro_decoder_layer_local_spec),
-            }
+            },
         }
         src_encoder_spec_fn, src_decoder_spec_fn = enc_dec_spec_fn[src_spec_type][model_type]
         dst_encoder_spec_fn, dst_decoder_spec_fn = enc_dec_spec_fn[dst_spec_type][model_type]
@@ -91,7 +104,9 @@ def test_sharded_state_dict_save_load(self, tmp_path_dist_ckpt, src_spec_type, d
             gpt_model = initialize_t5_model(2, dst_encoder_spec_fn, dst_decoder_spec_fn)
             sharded_state_dict = gpt_model.sharded_state_dict()
 
-            state_dict, missing_keys, unexpected_keys = load(sharded_state_dict, ckpt_dir, strict=StrictHandling.RETURN_ALL)
+            state_dict, missing_keys, unexpected_keys = load(
+                sharded_state_dict, ckpt_dir, strict=StrictHandling.RETURN_ALL
+            )
             # Potential mismatch is because of extra states which is ok
             assert all('_extra_state' in k for k in missing_keys)
             assert all('_extra_state' in k for k in unexpected_keys)
diff --git a/tests/unit_tests/dist_checkpointing/test_async_save.py b/tests/unit_tests/dist_checkpointing/test_async_save.py
index 9b8fe0044c..d6aa879982 100644
--- a/tests/unit_tests/dist_checkpointing/test_async_save.py
+++ b/tests/unit_tests/dist_checkpointing/test_async_save.py
@@ -13,7 +13,6 @@
 from tests.unit_tests.test_utilities import Utils
 
 
-
 def write_data_os_err_mock_fn(local_proc_idx, write_bucket, results_queue, count_queue, use_fsync):
     """Raises an error on worker #2 during storage save"""
     try:
@@ -32,8 +31,8 @@ def setup_method(self, method):
         pass
 
     def teardown_method(self, method):
-        Utils.destroy_model_parallel()   
-        
+        Utils.destroy_model_parallel()
+
     def test_async_is_equivalent_to_sync(self, tmp_path_dist_ckpt):
         Utils.initialize_model_parallel(2, 4)
 
diff --git a/tests/unit_tests/dist_checkpointing/test_cached_metadata.py b/tests/unit_tests/dist_checkpointing/test_cached_metadata.py
index b1286f01f1..2733ea7a1b 100644
--- a/tests/unit_tests/dist_checkpointing/test_cached_metadata.py
+++ b/tests/unit_tests/dist_checkpointing/test_cached_metadata.py
@@ -2,7 +2,6 @@
 
 import pickle
 from copy import deepcopy
-
 from dataclasses import fields
 
 import torch
@@ -20,8 +19,8 @@ def setup_method(self, method):
         pass
 
     def teardown_method(self, method):
-        Utils.destroy_model_parallel()   
-        
+        Utils.destroy_model_parallel()
+
     def test_cached_metadata(self, tmp_path_dist_ckpt):
         Utils.initialize_model_parallel(2, 4)
 
diff --git a/tests/unit_tests/dist_checkpointing/test_flattened_resharding.py b/tests/unit_tests/dist_checkpointing/test_flattened_resharding.py
index 0b64f36e64..fa00a20cad 100644
--- a/tests/unit_tests/dist_checkpointing/test_flattened_resharding.py
+++ b/tests/unit_tests/dist_checkpointing/test_flattened_resharding.py
@@ -27,21 +27,18 @@ def setup_method(self, method):
         pass
 
     def teardown_method(self, method):
-        Utils.destroy_model_parallel()   
-        
+        Utils.destroy_model_parallel()
+
     @pytest.mark.parametrize(
-        ('src_tp_pp', 'dest_tp_pp',),
-        [
-            ((2, 4), (2, 4)),
-            ((2, 4), (2, 2)),
-            ((2, 4), (4, 2)),
-            ((8, 1), (1, 2)),
-        ]
+        ('src_tp_pp', 'dest_tp_pp'),
+        [((2, 4), (2, 4)), ((2, 4), (2, 2)), ((2, 4), (4, 2)), ((8, 1), (1, 2))],
     )
     def test_partition_change_save_load(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp):
         Utils.initialize_model_parallel(*src_tp_pp)
-        with TempNamedDir(tmp_path_dist_ckpt / 'test_flattened_partition_change_save_load') as ckpt_dir:
-            
+        with TempNamedDir(
+            tmp_path_dist_ckpt / 'test_flattened_partition_change_save_load'
+        ) as ckpt_dir:
+
             state_dict = self._build_state_dict()
 
             save(state_dict, ckpt_dir)
@@ -57,30 +54,32 @@ def test_partition_change_save_load(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp
 
         Utils.destroy_model_parallel()
 
-
     @pytest.mark.parametrize(
         ('src_tp_pp', 'dest_tp_pp', 'expected_ckpt_offsets_by_rank'),
         [
-            ((2, 4), (2, 2), {
-                0: [(0, 0, 0), (0, 0, 10)],  # TP 0, DP 0, PP 0
-                1: [(4, 0, 0), (4, 0, 10)],  # TP 1, DP 0, PP 0
-                2: [(0, 0, 0), (0, 0, 10)],  # TP 0, DP 1, PP 0
-                3: [(4, 0, 0), (4, 0, 10)],  # TP 1, DP 1, PP 0
-                4: [(0, 0, 20), (0, 0, 30)],  # TP 0, DP 0, PP 1
-                5: [(4, 0, 20), (4, 0, 30)],  # TP 1, DP 0, PP 1
-                6: [(0, 0, 20), (0, 0, 30)],  # TP 0, DP 1, PP 1
-                7: [(4, 0, 20), (4, 0, 30)],  # TP 1, DP 1, PP 1
-            }),
-            ((8, 1), (1, 2), {
-                rank: [(tp, 0, 0) for tp in range(8)]
-                for rank in range(8)
-            })
-        ]
+            (
+                (2, 4),
+                (2, 2),
+                {
+                    0: [(0, 0, 0), (0, 0, 10)],  # TP 0, DP 0, PP 0
+                    1: [(4, 0, 0), (4, 0, 10)],  # TP 1, DP 0, PP 0
+                    2: [(0, 0, 0), (0, 0, 10)],  # TP 0, DP 1, PP 0
+                    3: [(4, 0, 0), (4, 0, 10)],  # TP 1, DP 1, PP 0
+                    4: [(0, 0, 20), (0, 0, 30)],  # TP 0, DP 0, PP 1
+                    5: [(4, 0, 20), (4, 0, 30)],  # TP 1, DP 0, PP 1
+                    6: [(0, 0, 20), (0, 0, 30)],  # TP 0, DP 1, PP 1
+                    7: [(4, 0, 20), (4, 0, 30)],  # TP 1, DP 1, PP 1
+                },
+            ),
+            ((8, 1), (1, 2), {rank: [(tp, 0, 0) for tp in range(8)] for rank in range(8)}),
+        ],
     )
-    def test_reformulate_nd_flattened_tensors(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, expected_ckpt_offsets_by_rank):
+    def test_reformulate_nd_flattened_tensors(
+        self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, expected_ckpt_offsets_by_rank
+    ):
         Utils.initialize_model_parallel(*src_tp_pp, order='tp-dp-pp')
         with TempNamedDir(tmp_path_dist_ckpt / 'test_reformulate_nd_flattened_tensors') as ckpt_dir:
-            
+
             state_dict = self._build_state_dict()
 
             ckpt_local_shape = state_dict['sd_key_flat'].local_shape
@@ -93,36 +92,38 @@ def test_reformulate_nd_flattened_tensors(self, tmp_path_dist_ckpt, src_tp_pp, d
             load_state_dict = self._build_state_dict(random=True)
 
             reformulation_metadata = get_reformulation_metadata(load_state_dict, ckpt_dir)
-            reformulated_state_dict, formulation_restore_data = apply_nd_flattened_tensors_reformulation(load_state_dict, reformulation_metadata)
+            reformulated_state_dict, formulation_restore_data = (
+                apply_nd_flattened_tensors_reformulation(load_state_dict, reformulation_metadata)
+            )
             assert isinstance(reformulated_state_dict['sd_key_unflat'], ShardedTensor)
             assert isinstance(reformulated_state_dict['sd_key_flat'], dict)
 
-            assert reformulated_state_dict['sd_key_flat'].keys() == set((offset, ckpt_local_shape) for offset in expected_ckpt_offsets_by_rank[Utils.rank]), \
-                (reformulated_state_dict['sd_key_flat'].keys(), ckpt_local_shape, expected_ckpt_offsets_by_rank[Utils.rank])
+            assert reformulated_state_dict['sd_key_flat'].keys() == set(
+                (offset, ckpt_local_shape) for offset in expected_ckpt_offsets_by_rank[Utils.rank]
+            ), (
+                reformulated_state_dict['sd_key_flat'].keys(),
+                ckpt_local_shape,
+                expected_ckpt_offsets_by_rank[Utils.rank],
+            )
 
             # We can even load the reformulated state dict with a high-level API
-            loaded_state_dict = load(reformulated_state_dict, ckpt_dir, validate_access_integrity=False)
-            loaded_state_dict = restore_nd_flattened_tensors_formulation(loaded_state_dict, formulation_restore_data)
+            loaded_state_dict = load(
+                reformulated_state_dict, ckpt_dir, validate_access_integrity=False
+            )
+            loaded_state_dict = restore_nd_flattened_tensors_formulation(
+                loaded_state_dict, formulation_restore_data
+            )
             expected_state_dict = {k: v.data for k, v in self._build_state_dict().items()}
             diffs = diff(expected_state_dict, loaded_state_dict)
             assert not any(diffs), diffs
 
         Utils.destroy_model_parallel()
 
-
-    @pytest.mark.parametrize(
-        ('src_tp_pp',),
-        [
-            ((2, 4),),
-            ((8, 1),),
-            ((1, 1),),
-            ((1, 4),),
-        ]
-    )
+    @pytest.mark.parametrize(('src_tp_pp',), [((2, 4),), ((8, 1),), ((1, 1),), ((1, 4),)])
     def test_load_tensor_metadata(self, tmp_path_dist_ckpt, src_tp_pp):
         Utils.initialize_model_parallel(*src_tp_pp, order='tp-dp-pp')
         with TempNamedDir(tmp_path_dist_ckpt / 'test_reformulate_nd_flattened_tensors') as ckpt_dir:
-            
+
             state_dict = self._build_state_dict()
 
             save(state_dict, ckpt_dir)
@@ -141,7 +142,9 @@ def test_load_tensor_metadata(self, tmp_path_dist_ckpt, src_tp_pp):
             for sh_ten in sharded_metadata.values():
                 sh_ten.replica_id = Utils.rank
             loaded_state_dict = load(sharded_metadata, ckpt_dir)
-            assert torch.all(loaded_state_dict['unflat'] == torch.arange(8 * 5 * 40).reshape(8, 5, 40))
+            assert torch.all(
+                loaded_state_dict['unflat'] == torch.arange(8 * 5 * 40).reshape(8, 5, 40)
+            )
             assert torch.all(loaded_state_dict['flat'] == torch.arange(8 * 5 * 40))
 
         Utils.destroy_model_parallel()
@@ -169,7 +172,7 @@ def _build_state_dict(self, random=False):
         end_jitter = dp_rank + 1 if dp_rank + 1 < dp_size else 0
         local_dp_slice = slice(
             local_ten_size_by_dp * dp_rank + start_jitter,
-            local_ten_size_by_dp * (dp_rank + 1) + end_jitter
+            local_ten_size_by_dp * (dp_rank + 1) + end_jitter,
         )
         local_flat_ten = local_ten.flatten()[local_dp_slice]
         if dp_rank == dp_size - 1:
@@ -191,7 +194,7 @@ def _build_state_dict(self, random=False):
                 local_ten.shape,
                 (0, tp_rank, tp_size),
                 (2, pp_rank, pp_size),
-                flattened_range=local_dp_slice
+                flattened_range=local_dp_slice,
             ),
         }
         return state_dict
diff --git a/tests/unit_tests/dist_checkpointing/test_fully_parallel.py b/tests/unit_tests/dist_checkpointing/test_fully_parallel.py
index f357f1b57d..42eda5d549 100644
--- a/tests/unit_tests/dist_checkpointing/test_fully_parallel.py
+++ b/tests/unit_tests/dist_checkpointing/test_fully_parallel.py
@@ -34,8 +34,11 @@ def __init__(self):
         self.save_keys = set()
 
     def save(self, sharded_state_dict, ckpt_dir):
-        self.save_keys = {sh_ten.key for sh_ten in nested_values(sharded_state_dict)
-                          if is_main_replica(sh_ten.replica_id)}
+        self.save_keys = {
+            sh_ten.key
+            for sh_ten in nested_values(sharded_state_dict)
+            if is_main_replica(sh_ten.replica_id)
+        }
 
 
 class MockLoadStrategy(LoadShardedStrategy):
@@ -45,8 +48,11 @@ def __init__(self, device='cpu'):
         self.load_keys = set()
 
     def load(self, sharded_state_dict, ckpt_dir):
-        self.load_keys = {sh_ten.key for sh_ten in nested_values(sharded_state_dict)
-                          if is_main_replica(sh_ten.replica_id)}
+        self.load_keys = {
+            sh_ten.key
+            for sh_ten in nested_values(sharded_state_dict)
+            if is_main_replica(sh_ten.replica_id)
+        }
 
         def load_rand(x):
             assert isinstance(x, ShardedTensor)
@@ -71,21 +77,43 @@ def setup_method(self, method):
         pass
 
     def teardown_method(self, method):
-        Utils.destroy_model_parallel()   
-        
+        Utils.destroy_model_parallel()
+
     @staticmethod
     def get_sharded_state_dict():
         return {
-            'sd_key_tp_repl1': ShardedTensor.from_rank_offsets('key_TP_repl1', torch.ones(10),
-                                                               (0, parallel_state.get_tensor_model_parallel_rank(), parallel_state.get_tensor_model_parallel_world_size()),
-                                                               replica_id=parallel_state.get_data_parallel_rank(with_context_parallel=True)),
-            'sd_key_tp_repl2': ShardedTensor.from_rank_offsets('key_TP_repl2', torch.ones(10),
-                                                               (0, parallel_state.get_tensor_model_parallel_rank(), parallel_state.get_tensor_model_parallel_world_size()),
-                                                               replica_id=parallel_state.get_data_parallel_rank(with_context_parallel=True)),
-            'sd_keyB': ShardedTensor.from_rank_offsets('keyB', torch.ones(20), (0, Utils.rank, Utils.world_size)),
-            'sd_keyE_no_C': ShardedTensor.from_rank_offsets('keyC', torch.ones(100), replica_id=Utils.rank),
-            'sd_keyX_no_D': ShardedTensor.from_rank_offsets('keyD', torch.ones(1000), replica_id=Utils.rank),
-            'sd_keyC_no_E': ShardedTensor.from_rank_offsets('keyE', torch.ones(100), replica_id=Utils.rank),
+            'sd_key_tp_repl1': ShardedTensor.from_rank_offsets(
+                'key_TP_repl1',
+                torch.ones(10),
+                (
+                    0,
+                    parallel_state.get_tensor_model_parallel_rank(),
+                    parallel_state.get_tensor_model_parallel_world_size(),
+                ),
+                replica_id=parallel_state.get_data_parallel_rank(with_context_parallel=True),
+            ),
+            'sd_key_tp_repl2': ShardedTensor.from_rank_offsets(
+                'key_TP_repl2',
+                torch.ones(10),
+                (
+                    0,
+                    parallel_state.get_tensor_model_parallel_rank(),
+                    parallel_state.get_tensor_model_parallel_world_size(),
+                ),
+                replica_id=parallel_state.get_data_parallel_rank(with_context_parallel=True),
+            ),
+            'sd_keyB': ShardedTensor.from_rank_offsets(
+                'keyB', torch.ones(20), (0, Utils.rank, Utils.world_size)
+            ),
+            'sd_keyE_no_C': ShardedTensor.from_rank_offsets(
+                'keyC', torch.ones(100), replica_id=Utils.rank
+            ),
+            'sd_keyX_no_D': ShardedTensor.from_rank_offsets(
+                'keyD', torch.ones(1000), replica_id=Utils.rank
+            ),
+            'sd_keyC_no_E': ShardedTensor.from_rank_offsets(
+                'keyE', torch.ones(100), replica_id=Utils.rank
+            ),
         }
 
     @pytest.mark.parametrize("parallelization_along_dp", [False, True])
@@ -99,7 +127,9 @@ def test_save_distribution(self, parallelization_along_dp, tmp_path_dist_ckpt):
         # 3. Shard id (key)
         if not parallelization_along_dp:
             expected_key_to_saving_ranks = {
-                'keyB': list(range(Utils.world_size)), # everyone must save (disjoint shards, coverage == 1)
+                'keyB': list(
+                    range(Utils.world_size)
+                ),  # everyone must save (disjoint shards, coverage == 1)
                 'key_TP_repl1': [0, 1],  # lowest coverage (4), first TP domain
                 'key_TP_repl2': [2, 3],  # lowest coverage (4), second TP domain
                 'keyD': [4],  # largest tensor
@@ -110,7 +140,11 @@ def test_save_distribution(self, parallelization_along_dp, tmp_path_dist_ckpt):
             if parallel_state.get_tensor_model_parallel_rank() == 0:
                 expected_key_to_saving_ranks = {
                     # everyone must save (disjoint shards, coverage == 1):
-                    'keyB': list(range(parallel_state.get_data_parallel_world_size(with_context_parallel=True))),
+                    'keyB': list(
+                        range(
+                            parallel_state.get_data_parallel_world_size(with_context_parallel=True)
+                        )
+                    ),
                     # this time, TP sharded tensors have the same coverage as fully replicated!
                     'keyD': [0],  # largest tensor
                     'keyC': [1],  # second largest tensor
@@ -121,32 +155,59 @@ def test_save_distribution(self, parallelization_along_dp, tmp_path_dist_ckpt):
             else:
                 expected_key_to_saving_ranks = {
                     # everyone must save (disjoint shards, coverage == 1):
-                    'keyB': list(range(parallel_state.get_data_parallel_world_size(with_context_parallel=True))),
+                    'keyB': list(
+                        range(
+                            parallel_state.get_data_parallel_world_size(with_context_parallel=True)
+                        )
+                    ),
                     # tensors C, D, E are absent in this DP group
                     'key_TP_repl1': [0],  # smallest tensor
                     'key_TP_repl2': [1],  # smallest tensor, last rank is the least occupied
                 }
 
-        parallelization_group = parallel_state.get_data_parallel_group(with_context_parallel=True) if parallelization_along_dp else None
+        parallelization_group = (
+            parallel_state.get_data_parallel_group(with_context_parallel=True)
+            if parallelization_along_dp
+            else None
+        )
         dp_rank = torch.distributed.get_rank(parallelization_group)
-        expected_keys_saved_by_current_rank = {k for k, v in expected_key_to_saving_ranks.items() if dp_rank in v}
+        expected_keys_saved_by_current_rank = {
+            k for k, v in expected_key_to_saving_ranks.items() if dp_rank in v
+        }
 
         # Run save and tests
         mock_strategy = MockSaveStrategy()
-        save_strategy = FullyParallelSaveStrategyWrapper(mock_strategy,
-                                                         parallelization_group,
-                                                         do_cache_distribution=True)
+        save_strategy = FullyParallelSaveStrategyWrapper(
+            mock_strategy, parallelization_group, do_cache_distribution=True
+        )
         with TempNamedDir(tmp_path_dist_ckpt / 'mock_dir') as ckpt_dir_A:
             save_strategy.save(state_dict, ckpt_dir_A)
-        key_to_saving_rank = dict(map_reduce(save_strategy.cached_distribution.main_rank_for_shard.items(), lambda shard_rank: shard_rank[0][0], lambda shard_rank: shard_rank[1]))
+        key_to_saving_rank = dict(
+            map_reduce(
+                save_strategy.cached_distribution.main_rank_for_shard.items(),
+                lambda shard_rank: shard_rank[0][0],
+                lambda shard_rank: shard_rank[1],
+            )
+        )
         assert expected_key_to_saving_ranks == key_to_saving_rank
 
         for k, sh_ten in state_dict.items():
-            if _sharded_tensor_shard_id(sh_ten) in save_strategy.cached_distribution.shards_in_this_group:
-                is_expected_to_be_saved_by_this_rank = dp_rank in expected_key_to_saving_ranks.get(sh_ten.key, [])
-                assert sh_ten.replica_id == int(not is_expected_to_be_saved_by_this_rank), expected_key_to_saving_ranks
-
-        assert mock_strategy.save_keys == expected_keys_saved_by_current_rank, (Utils.rank, mock_strategy.save_keys, expected_keys_saved_by_current_rank)
+            if (
+                _sharded_tensor_shard_id(sh_ten)
+                in save_strategy.cached_distribution.shards_in_this_group
+            ):
+                is_expected_to_be_saved_by_this_rank = dp_rank in expected_key_to_saving_ranks.get(
+                    sh_ten.key, []
+                )
+                assert sh_ten.replica_id == int(
+                    not is_expected_to_be_saved_by_this_rank
+                ), expected_key_to_saving_ranks
+
+        assert mock_strategy.save_keys == expected_keys_saved_by_current_rank, (
+            Utils.rank,
+            mock_strategy.save_keys,
+            expected_keys_saved_by_current_rank,
+        )
 
     @pytest.mark.parametrize("parallelization_along_dp", [False, True])
     def test_load_distribution(self, parallelization_along_dp, tmp_path_dist_ckpt):
@@ -160,7 +221,9 @@ def test_load_distribution(self, parallelization_along_dp, tmp_path_dist_ckpt):
         # 3. Shard id (key)
         if not parallelization_along_dp:
             expected_key_to_saving_ranks = {
-                'keyB': list(range(Utils.world_size)), # everyone must save (disjoint shards, coverage == 1)
+                'keyB': list(
+                    range(Utils.world_size)
+                ),  # everyone must save (disjoint shards, coverage == 1)
                 'key_TP_repl1': [0, 1],  # lowest coverage (4), first TP domain
                 'key_TP_repl2': [2, 3],  # lowest coverage (4), second TP domain
                 'keyD': [4],  # largest tensor
@@ -171,7 +234,9 @@ def test_load_distribution(self, parallelization_along_dp, tmp_path_dist_ckpt):
             # When loading, expected key distribution is the same across TP, because every replica needs to be loaded
             expected_key_to_saving_ranks = {
                 # everyone must load (disjoint shards, coverage == 1):
-                'keyB': list(range(parallel_state.get_data_parallel_world_size(with_context_parallel=True))),
+                'keyB': list(
+                    range(parallel_state.get_data_parallel_world_size(with_context_parallel=True))
+                ),
                 # this time, TP sharded tensors have the same coverage as fully replicated!
                 'keyD': [0],  # largest tensor
                 'keyC': [1],  # second largest tensor
@@ -180,21 +245,37 @@ def test_load_distribution(self, parallelization_along_dp, tmp_path_dist_ckpt):
                 'key_TP_repl2': [3],  # smallest tensor, last rank is the least occupied
             }
 
-        parallelization_group = parallel_state.get_data_parallel_group(with_context_parallel=True) if parallelization_along_dp else None
+        parallelization_group = (
+            parallel_state.get_data_parallel_group(with_context_parallel=True)
+            if parallelization_along_dp
+            else None
+        )
         dp_rank = torch.distributed.get_rank(parallelization_group)
-        expected_keys_saved_by_current_rank = {k for k, v in expected_key_to_saving_ranks.items() if dp_rank in v}
+        expected_keys_saved_by_current_rank = {
+            k for k, v in expected_key_to_saving_ranks.items() if dp_rank in v
+        }
 
         # Run save and tests
         mock_strategy = MockLoadStrategy()
-        load_strategy = FullyParallelLoadStrategyWrapper(mock_strategy,
-                                                         parallelization_group,
-                                                         do_cache_distribution=True)
+        load_strategy = FullyParallelLoadStrategyWrapper(
+            mock_strategy, parallelization_group, do_cache_distribution=True
+        )
         with TempNamedDir(tmp_path_dist_ckpt / 'mock_dir') as ckpt_dir_A:
             loaded_state_dict = load_strategy.load(state_dict, ckpt_dir_A)
-        key_to_saving_rank = dict(map_reduce(load_strategy.cached_distribution.main_rank_for_shard.items(), lambda shard_rank: shard_rank[0][0], lambda shard_rank: shard_rank[1]))
+        key_to_saving_rank = dict(
+            map_reduce(
+                load_strategy.cached_distribution.main_rank_for_shard.items(),
+                lambda shard_rank: shard_rank[0][0],
+                lambda shard_rank: shard_rank[1],
+            )
+        )
         assert expected_key_to_saving_ranks == key_to_saving_rank
 
-        assert mock_strategy.load_keys == expected_keys_saved_by_current_rank, (Utils.rank, mock_strategy.load_keys, expected_keys_saved_by_current_rank)
+        assert mock_strategy.load_keys == expected_keys_saved_by_current_rank, (
+            Utils.rank,
+            mock_strategy.load_keys,
+            expected_keys_saved_by_current_rank,
+        )
 
         assert loaded_state_dict.keys() == state_dict.keys()
 
@@ -220,8 +301,11 @@ def _get_empty_tensor_for_exchange(self, *args, **kwargs) -> torch.Tensor:
         # Each tensor is 4MB, 40MB in total.
         # We expect extra memory usage peak at ~32MB, not 1GB
         sharded_state_dict = {
-            f'ten_{i}': ShardedTensor.from_rank_offsets(f'ten_{i}', torch.rand(megabytes, dtype=torch.float, device=state_dict_device),
-                                                        (0, Utils.rank, Utils.world_size))
+            f'ten_{i}': ShardedTensor.from_rank_offsets(
+                f'ten_{i}',
+                torch.rand(megabytes, dtype=torch.float, device=state_dict_device),
+                (0, Utils.rank, Utils.world_size),
+            )
             for i in range(10)
         }
 
@@ -233,6 +317,9 @@ def _get_empty_tensor_for_exchange(self, *args, **kwargs) -> torch.Tensor:
         # Each rank is expected to do 7 * 10 empty allocations
         assert len(mem_alloc) == 7 * 10
         # Peak mem usage should be within 4MB (single tensor)
-        assert max(mem_alloc) - mem_alloc_start < 4.01 * megabytes, (max(mem_alloc), mem_alloc_start)
+        assert max(mem_alloc) - mem_alloc_start < 4.01 * megabytes, (
+            max(mem_alloc),
+            mem_alloc_start,
+        )
 
-        Utils.destroy_model_parallel()
\ No newline at end of file
+        Utils.destroy_model_parallel()
diff --git a/tests/unit_tests/dist_checkpointing/test_mapping.py b/tests/unit_tests/dist_checkpointing/test_mapping.py
index ebd0d1ed15..2f986ec1c2 100644
--- a/tests/unit_tests/dist_checkpointing/test_mapping.py
+++ b/tests/unit_tests/dist_checkpointing/test_mapping.py
@@ -1,16 +1,21 @@
 # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
 
 import pytest
-
 import torch
 
 from megatron.core.dist_checkpointing import ShardedTensor
 from megatron.core.dist_checkpointing.core import CheckpointingException
-from megatron.core.dist_checkpointing.mapping import is_main_replica, \
-    ShardedTensorFactory, ShardedObject, apply_factories, apply_factory_merges
+from megatron.core.dist_checkpointing.mapping import (
+    ShardedObject,
+    ShardedTensorFactory,
+    apply_factories,
+    apply_factory_merges,
+    is_main_replica,
+)
 from megatron.core.transformer.transformer_config import TransformerConfig
 from tests.unit_tests.test_utilities import Utils
 
+
 class TestShardedTensor:
 
     # def setup_method(self, method):
@@ -20,14 +25,11 @@ class TestShardedTensor:
     #
     # def teardown_method(self, method):
     #     Utils.destroy_model_parallel()
-    
+
     def test_from_rank_offsets_constructor(self, dtype=torch.float, device='cuda'):
         data = torch.ones((1, 3, 7, 9), dtype=dtype, device=device)
         shape = data.shape
-        rank_offsets = [
-            (0, 0, 10),
-            (2, 3, 6)
-        ]
+        rank_offsets = [(0, 0, 10), (2, 3, 6)]
         sh_ten = ShardedTensor.from_rank_offsets('keyA', data, *rank_offsets)
 
         assert isinstance(sh_ten, ShardedTensor)
@@ -40,13 +42,12 @@ def test_from_rank_offsets_constructor(self, dtype=torch.float, device='cuda'):
     def test_from_rank_offsets_flat_constructor(self, dtype=torch.float, device='cuda'):
         data = torch.arange(28, dtype=dtype, device=device).reshape((1, 4, 7))
         shape = data.shape
-        rank_offsets = [
-            (1, 0, 2),
-            (2, 3, 5)
-        ]
+        rank_offsets = [(1, 0, 2), (2, 3, 5)]
         flattened_range = slice(4, 9)
         flat_data = data.flatten()[flattened_range]
-        sh_ten = ShardedTensor.from_rank_offsets_flat('keyA', flat_data, data.shape, *rank_offsets, flattened_range=flattened_range)
+        sh_ten = ShardedTensor.from_rank_offsets_flat(
+            'keyA', flat_data, data.shape, *rank_offsets, flattened_range=flattened_range
+        )
 
         # The main attributes properties are unchanged
         assert isinstance(sh_ten, ShardedTensor)
@@ -60,10 +61,7 @@ def test_from_rank_offsets_flat_constructor(self, dtype=torch.float, device='cud
 
     def test_metadata_integrity_violation(self):
         data = torch.ones((1, 3, 7, 9), device='meta')
-        rank_offsets = [
-            (0, 0, 10),
-            (2, 3, 6)
-        ]
+        rank_offsets = [(0, 0, 10), (2, 3, 6)]
         sh_ten = ShardedTensor.from_rank_offsets('keyA', data, *rank_offsets)
         sh_ten.validate_metadata_integrity()
         with pytest.raises(CheckpointingException):
@@ -76,32 +74,40 @@ def test_metadata_integrity_violation(self):
             sh_ten.validate_metadata_integrity()
 
         with pytest.raises(CheckpointingException):
-            sh_ten = ShardedTensor.from_rank_offsets_flat('keyA', data, data.shape, *rank_offsets,
-                                                          flattened_range=slice(4, 9))
+            sh_ten = ShardedTensor.from_rank_offsets_flat(
+                'keyA', data, data.shape, *rank_offsets, flattened_range=slice(4, 9)
+            )
 
-        sh_ten = ShardedTensor.from_rank_offsets_flat('keyA', data.flatten()[4:9], data.shape, *rank_offsets,
-                                                      flattened_range=slice(4, 9))
+        sh_ten = ShardedTensor.from_rank_offsets_flat(
+            'keyA', data.flatten()[4:9], data.shape, *rank_offsets, flattened_range=slice(4, 9)
+        )
         assert sh_ten.local_shape == (1, 3, 7, 9)
         with pytest.raises(CheckpointingException):
             sh_ten.local_shape = (5,)
             sh_ten.validate_metadata_integrity()
 
 
-
 class TestShardedTensorFactory:
     def test_build_and_merge(self):
         def build_fn(key, tensor, replica_id, flattened_range):
             assert flattened_range is None
             return {
-                'level2_a': ShardedTensor.from_rank_offsets(key + 'part1', tensor + 1, replica_id=replica_id),
-                'level2_b': ShardedTensor.from_rank_offsets(key + 'part2', tensor + 2, replica_id=replica_id)
+                'level2_a': ShardedTensor.from_rank_offsets(
+                    key + 'part1', tensor + 1, replica_id=replica_id
+                ),
+                'level2_b': ShardedTensor.from_rank_offsets(
+                    key + 'part2', tensor + 2, replica_id=replica_id
+                ),
             }
 
         # state_dict will be modified in-place
         def get_state_dict():
             return {
-                'level1': ShardedTensorFactory('a', torch.arange(3), build_fn, lambda x: x['level2_b'])
+                'level1': ShardedTensorFactory(
+                    'a', torch.arange(3), build_fn, lambda x: x['level2_b']
+                )
             }
+
         state_dict = get_state_dict()
         apply_factories(state_dict)
         assert torch.allclose(state_dict['level1']['level2_a'].data, torch.tensor([1, 2, 3]))
diff --git a/tests/unit_tests/dist_checkpointing/test_nonpersistent.py b/tests/unit_tests/dist_checkpointing/test_nonpersistent.py
index 667efddff4..d7907ead1f 100644
--- a/tests/unit_tests/dist_checkpointing/test_nonpersistent.py
+++ b/tests/unit_tests/dist_checkpointing/test_nonpersistent.py
@@ -2,36 +2,33 @@
 
 import filecmp
 import os
-import pytest
 from types import SimpleNamespace
 from unittest import mock
 
+import pytest
+
 from megatron.training.checkpointing import (
     _NON_PERSISTENT_CKPT_SUBDIR,
     load_checkpoint,
     save_checkpoint,
 )
 from tests.unit_tests.dist_checkpointing import (
+    TempNamedDir,
     init_basic_mock_args,
     init_checkpointing_mock_args,
-    TempNamedDir,
     setup_model_and_optimizer,
 )
 from tests.unit_tests.test_utilities import Utils
 
+
 class TestNonPersistentSaveAndLoad:
     def setup_method(self, method):
         pass
 
     def teardown_method(self, method):
-        Utils.destroy_model_parallel()   
-        
-    @pytest.mark.parametrize(
-        ('tp,pp'),
-        [
-            (2, 4),
-        ]
-    )
+        Utils.destroy_model_parallel()
+
+    @pytest.mark.parametrize(('tp,pp'), [(2, 4)])
     def test_basic_save_load_scenarios(self, tmp_path_dist_ckpt, tp, pp):
         Utils.initialize_model_parallel(tp, pp)
         num_floating_point_operations_so_far = 0
@@ -60,7 +57,7 @@ def test_basic_save_load_scenarios(self, tmp_path_dist_ckpt, tp, pp):
                 non_persistent_ckpt=True,
             )
             save_checkpoint(
-                3, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far, {},
+                3, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far, {}
             )
             save_checkpoint(
                 4,
@@ -74,7 +71,7 @@ def test_basic_save_load_scenarios(self, tmp_path_dist_ckpt, tp, pp):
             iteration, _ = load_checkpoint(model, optimizer, opt_param_scheduler)
             assert iteration == 4
             save_checkpoint(
-                6, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far, {},
+                6, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far, {}
             )
             iteration, _ = load_checkpoint(model, optimizer, opt_param_scheduler)
             assert iteration == 6
@@ -119,12 +116,7 @@ def test_basic_save_load_scenarios(self, tmp_path_dist_ckpt, tp, pp):
 
 
 class TestLegacySaveAndLoad:
-    @pytest.mark.parametrize(
-        ('tp,pp'),
-        [
-            (2, 4),
-        ]
-    )
+    @pytest.mark.parametrize(('tp,pp'), [(2, 4)])
     def test_basic_save_load_scenario(self, tmp_path_dist_ckpt, tp, pp):
         Utils.initialize_model_parallel(tp, pp)
         num_floating_point_operations_so_far = 0
@@ -139,7 +131,7 @@ def test_basic_save_load_scenario(self, tmp_path_dist_ckpt, tp, pp):
             init_checkpointing_mock_args(mock_args, legacy_ckpt_dir)
 
             save_checkpoint(
-                2, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far, {},
+                2, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far, {}
             )
             iteration, _ = load_checkpoint(model, optimizer, opt_param_scheduler)
             assert iteration == 2
diff --git a/tests/unit_tests/dist_checkpointing/test_optimizer.py b/tests/unit_tests/dist_checkpointing/test_optimizer.py
index 87047b92b4..59577c73fa 100644
--- a/tests/unit_tests/dist_checkpointing/test_optimizer.py
+++ b/tests/unit_tests/dist_checkpointing/test_optimizer.py
@@ -62,20 +62,25 @@ def sharded_state_dict(self):
         sharded_state_dict = self.state_dict(keep_vars=True)
         # conv
         sharded_state_dict['conv.weight'] = ShardedTensor.from_rank_offsets(
-            'conv.weight', sharded_state_dict['conv.weight'],
-            (1, parallel_state.get_tensor_model_parallel_rank(), parallel_state.get_tensor_model_parallel_world_size())
+            'conv.weight',
+            sharded_state_dict['conv.weight'],
+            (
+                1,
+                parallel_state.get_tensor_model_parallel_rank(),
+                parallel_state.get_tensor_model_parallel_world_size(),
+            ),
         )
         # bias is non-sharded
-        sharded_state_dict['conv.bias'] = ShardedTensor.from_rank_offsets('conv.bias', sharded_state_dict['conv.bias'])
+        sharded_state_dict['conv.bias'] = ShardedTensor.from_rank_offsets(
+            'conv.bias', sharded_state_dict['conv.bias']
+        )
 
         # proj
         sharded_state_dict['proj.weight'] = ShardedTensor.from_rank_offsets(
-            'proj.weight', sharded_state_dict['proj.weight'],
-            (0, Utils.rank, Utils.world_size)
+            'proj.weight', sharded_state_dict['proj.weight'], (0, Utils.rank, Utils.world_size)
         )
         sharded_state_dict['proj.bias'] = ShardedTensor.from_rank_offsets(
-            'proj.bias', sharded_state_dict['proj.bias'],
-            (0, Utils.rank, Utils.world_size)
+            'proj.bias', sharded_state_dict['proj.bias'], (0, Utils.rank, Utils.world_size)
         )
         return sharded_state_dict
 
@@ -83,34 +88,68 @@ def sharded_state_dict(self):
 class SwigluFactoryModel(torch.nn.Module):
     def __init__(self):
         super().__init__()
-        self.linear = torch.nn.Linear(5, 64 // parallel_state.get_tensor_model_parallel_world_size(), bias=False)
+        self.linear = torch.nn.Linear(
+            5, 64 // parallel_state.get_tensor_model_parallel_world_size(), bias=False
+        )
         self.config = TransformerConfig(hidden_size=8, num_attention_heads=1, num_layers=1)
 
     def sharded_state_dict(self):
         sharded_state_dict = self.state_dict(keep_vars=True)
         sharded_state_dict['linear.weight'] = ShardedTensor.from_rank_offsets(
-            'linear.weight', sharded_state_dict['linear.weight'],
-            ((0, parallel_state.get_tensor_model_parallel_rank(), parallel_state.get_tensor_model_parallel_world_size())),
-            replica_id=((parallel_state.get_pipeline_model_parallel_rank(), 0, parallel_state.get_data_parallel_rank(with_context_parallel=True)))
+            'linear.weight',
+            sharded_state_dict['linear.weight'],
+            (
+                (
+                    0,
+                    parallel_state.get_tensor_model_parallel_rank(),
+                    parallel_state.get_tensor_model_parallel_world_size(),
+                )
+            ),
+            replica_id=(
+                (
+                    parallel_state.get_pipeline_model_parallel_rank(),
+                    0,
+                    parallel_state.get_data_parallel_rank(with_context_parallel=True),
+                )
+            ),
+        )
+        sharded_state_dict['linear.weight'] = apply_swiglu_sharded_factory(
+            sharded_state_dict['linear.weight'], ()
         )
-        sharded_state_dict['linear.weight'] = apply_swiglu_sharded_factory(sharded_state_dict['linear.weight'], ())
         return sharded_state_dict
 
 
 class SwigluFactoryModel(torch.nn.Module):
     def __init__(self):
         super().__init__()
-        self.linear = torch.nn.Linear(5, 64 // parallel_state.get_tensor_model_parallel_world_size(), bias=False)
+        self.linear = torch.nn.Linear(
+            5, 64 // parallel_state.get_tensor_model_parallel_world_size(), bias=False
+        )
         self.config = TransformerConfig(hidden_size=8, num_attention_heads=1, num_layers=1)
 
     def sharded_state_dict(self):
         sharded_state_dict = self.state_dict(keep_vars=True)
         sharded_state_dict['linear.weight'] = ShardedTensor.from_rank_offsets(
-            'linear.weight', sharded_state_dict['linear.weight'],
-            ((0, parallel_state.get_tensor_model_parallel_rank(), parallel_state.get_tensor_model_parallel_world_size())),
-            replica_id=((parallel_state.get_pipeline_model_parallel_rank(), 0, parallel_state.get_data_parallel_rank(with_context_parallel=True)))
+            'linear.weight',
+            sharded_state_dict['linear.weight'],
+            (
+                (
+                    0,
+                    parallel_state.get_tensor_model_parallel_rank(),
+                    parallel_state.get_tensor_model_parallel_world_size(),
+                )
+            ),
+            replica_id=(
+                (
+                    parallel_state.get_pipeline_model_parallel_rank(),
+                    0,
+                    parallel_state.get_data_parallel_rank(with_context_parallel=True),
+                )
+            ),
+        )
+        sharded_state_dict['linear.weight'] = apply_swiglu_sharded_factory(
+            sharded_state_dict['linear.weight'], ()
         )
-        sharded_state_dict['linear.weight'] = apply_swiglu_sharded_factory(sharded_state_dict['linear.weight'], ())
         return sharded_state_dict
 
 
@@ -119,10 +158,10 @@ def setup_method(self, method):
         pass
 
     def teardown_method(self, method):
-        Utils.destroy_model_parallel()   
+        Utils.destroy_model_parallel()
 
     def test_optimizer_params(self, tmp_path_dist_ckpt):
-        Utils.initialize_model_parallel(1,1)
+        Utils.initialize_model_parallel(1, 1)
         model = Model()
         # Force optimizer state initialization
         for p in model.parameters():
@@ -131,18 +170,22 @@ def test_optimizer_params(self, tmp_path_dist_ckpt):
         optim.step()
 
         model_state_dict = model.sharded_state_dict()
-        param_map = get_param_id_to_sharded_param_map(model_state_dict, optim.param_groups[0]['params'])
+        param_map = get_param_id_to_sharded_param_map(
+            model_state_dict, optim.param_groups[0]['params']
+        )
         optim_state_dict = optim.state_dict()
         optim_state_to_sharding_state(optim_state_dict, param_map, exclude_keys=('step',))
 
         optim_sharded_tensors = nested_values(extract_sharded_tensors(optim_state_dict)[0])
         optim_sharded_keys = {sh_ten.key for sh_ten in optim_sharded_tensors}
         assert len(optim_sharded_keys) == 2 * len(model_state_dict)
-        assert optim_sharded_keys == set([
-            f'optimizer.state.{state_key}.{layer_name}'
-            for state_key in ['exp_avg', 'exp_avg_sq']
-            for layer_name in model_state_dict
-        ])
+        assert optim_sharded_keys == set(
+            [
+                f'optimizer.state.{state_key}.{layer_name}'
+                for state_key in ['exp_avg', 'exp_avg_sq']
+                for layer_name in model_state_dict
+            ]
+        )
 
 
 def initialize_small_model(pre_process=True, post_process=True, seed=0, **config_kwargs):
@@ -163,17 +206,20 @@ def setup_method(self, method):
         pass
 
     def teardown_method(self, method):
-        Utils.destroy_model_parallel()   
+        Utils.destroy_model_parallel()
 
     @pytest.mark.parametrize("initialize_fn", [initialize_small_model, initialize_gpt_model])
     @pytest.mark.parametrize("use_fpsl", [False, True])
-    @pytest.mark.parametrize("tp_pp,src_dp,dest_dp", [
-        ((4, 1), 2, 2),
-        # ((1, 1), 8, 1),  # TODO: changing DP doesn't work in unit tests because of NCCL crashes
-        # ((1, 1), 1, 8),
-        # ((2, 1), 2, 1),
-        # ((2, 1), 2, 2),
-    ])
+    @pytest.mark.parametrize(
+        "tp_pp,src_dp,dest_dp",
+        [
+            ((4, 1), 2, 2),
+            # ((1, 1), 8, 1),  # TODO: changing DP doesn't work in unit tests because of NCCL crashes
+            # ((1, 1), 1, 8),
+            # ((2, 1), 2, 1),
+            # ((2, 1), 2, 2),
+        ],
+    )
     def test_dp_sharding(self, tmp_path_dist_ckpt, tp_pp, src_dp, dest_dp, use_fpsl, initialize_fn):
         src_world_size = tp_pp[0] * tp_pp[1] * src_dp
         dest_world_size = tp_pp[0] * tp_pp[1] * dest_dp
@@ -190,16 +236,24 @@ def test_dp_sharding(self, tmp_path_dist_ckpt, tp_pp, src_dp, dest_dp, use_fpsl,
                 Utils.set_world_size(src_world_size)
                 if Utils.rank >= 0:
                     # Save checkpoint A
-                    model, optimizer_A = setup_model_and_optimizer(seed=2, tp=tp_pp[0], pp=tp_pp[1], initialize_fn=initialize_fn)
+                    model, optimizer_A = setup_model_and_optimizer(
+                        seed=2, tp=tp_pp[0], pp=tp_pp[1], initialize_fn=initialize_fn
+                    )
 
                     save_strategy = get_default_save_sharded_strategy()
                     if use_fpsl:
                         save_strategy = FullyParallelSaveStrategyWrapper(
                             save_strategy,
                             parallel_state.get_data_parallel_group(with_context_parallel=True),
-                            True
+                            True,
                         )
-                    save(optimizer_A.sharded_state_dict(model[0].sharded_state_dict(), sharding_type=sharding_type), ckpt_dir, save_strategy)
+                    save(
+                        optimizer_A.sharded_state_dict(
+                            model[0].sharded_state_dict(), sharding_type=sharding_type
+                        ),
+                        ckpt_dir,
+                        save_strategy,
+                    )
                     optim_param_state_A = optimizer_A.get_parameter_state_dp_zero()
                     Utils.destroy_model_parallel()
                 else:
@@ -213,7 +267,9 @@ def test_dp_sharding(self, tmp_path_dist_ckpt, tp_pp, src_dp, dest_dp, use_fpsl,
                 if Utils.rank >= 0:
                     Utils.initialize_model_parallel(*tp_pp)
 
-                    model, optimizer_B = setup_model_and_optimizer(seed=3, tp=tp_pp[0], pp=tp_pp[1], initialize_fn=initialize_fn)
+                    model, optimizer_B = setup_model_and_optimizer(
+                        seed=3, tp=tp_pp[0], pp=tp_pp[1], initialize_fn=initialize_fn
+                    )
                     optim_param_state_B = optimizer_B.get_parameter_state_dp_zero()
                     diffs = diff(optim_param_state_A, optim_param_state_B)
                     # Expect a mismatch in values - diffs[2] nonempty
@@ -221,9 +277,7 @@ def test_dp_sharding(self, tmp_path_dist_ckpt, tp_pp, src_dp, dest_dp, use_fpsl,
                         assert not diffs[0] and not diffs[1] and diffs[2], diffs
 
                     sharded_state_dict = optimizer_B.sharded_state_dict(
-                        model[0].sharded_state_dict(),
-                        is_loading=True,
-                        sharding_type=sharding_type,
+                        model[0].sharded_state_dict(), is_loading=True, sharding_type=sharding_type
                     )
                     optim_state_dict = load(sharded_state_dict, ckpt_dir)
                     optimizer_B.load_state_dict(optim_state_dict)
@@ -241,23 +295,26 @@ def test_dp_sharding(self, tmp_path_dist_ckpt, tp_pp, src_dp, dest_dp, use_fpsl,
 
     @pytest.mark.parametrize(
         ('src_tp_pp', 'dest_tp_pp', 'use_glu'),
-        [
-            ((2, 2), (2, 4), False,),
-            ((1, 8), (4, 1), True),
-            ((2, 4), (4, 2), False),
-        ]
+        [((2, 2), (2, 4), False), ((1, 8), (4, 1), True), ((2, 4), (4, 2), False)],
     )
-    def test_finetune_doesnt_load_optimizer(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, use_glu):
+    def test_finetune_doesnt_load_optimizer(
+        self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, use_glu
+    ):
         # sync=True to make sure other ranks wait for rank 0 to finish creating directory.
         Utils.initialize_model_parallel(*src_tp_pp)
-        with TempNamedDir(tmp_path_dist_ckpt / 'test_finetune_doesnt_load_optimizer', sync=True) as ckpt_dir:
+        with TempNamedDir(
+            tmp_path_dist_ckpt / 'test_finetune_doesnt_load_optimizer', sync=True
+        ) as ckpt_dir:
             mock_args = SimpleNamespace()
             with mock.patch('megatron.training.checkpointing.get_args', new=lambda: mock_args):
                 init_basic_mock_args(mock_args, tp=src_tp_pp[0], pp=src_tp_pp[1])
                 init_checkpointing_mock_args(mock_args, ckpt_dir, False)
 
                 model, optimizer = setup_model_and_optimizer(
-                    seed=2, tp=src_tp_pp[0], pp=src_tp_pp[1], initialize_fn=partial(initialize_gpt_model, use_glu=use_glu)
+                    seed=2,
+                    tp=src_tp_pp[0],
+                    pp=src_tp_pp[1],
+                    initialize_fn=partial(initialize_gpt_model, use_glu=use_glu),
                 )
 
                 save_checkpoint(10, model, optimizer, None, 0)
@@ -265,7 +322,10 @@ def test_finetune_doesnt_load_optimizer(self, tmp_path_dist_ckpt, src_tp_pp, des
 
                 Utils.initialize_model_parallel(*dest_tp_pp)
                 model, optimizer = setup_model_and_optimizer(
-                    seed=3, tp=dest_tp_pp[0], pp=dest_tp_pp[1], initialize_fn=partial(initialize_gpt_model, use_glu=use_glu)
+                    seed=3,
+                    tp=dest_tp_pp[0],
+                    pp=dest_tp_pp[1],
+                    initialize_fn=partial(initialize_gpt_model, use_glu=use_glu),
                 )
                 model_unloaded_state_dict = deepcopy(model[0].state_dict())
                 optim_unloaded_state_dict = deepcopy(optimizer.state_dict())
@@ -291,7 +351,10 @@ def test_finetune_doesnt_load_optimizer(self, tmp_path_dist_ckpt, src_tp_pp, des
 
                 # ... or `no_load_optim` flag
                 model, optimizer = setup_model_and_optimizer(
-                    seed=3, tp=dest_tp_pp[0], pp=dest_tp_pp[1], initialize_fn=partial(initialize_gpt_model, use_glu=use_glu)
+                    seed=3,
+                    tp=dest_tp_pp[0],
+                    pp=dest_tp_pp[1],
+                    initialize_fn=partial(initialize_gpt_model, use_glu=use_glu),
                 )
                 mock_args.finetune = False
                 mock_args.no_load_optim = True
@@ -299,33 +362,43 @@ def test_finetune_doesnt_load_optimizer(self, tmp_path_dist_ckpt, src_tp_pp, des
                 load_checkpoint_no_arg_checks(model, optimizer, None)
 
                 ## Model weights should be different, but optimizer state is unchanged
-                diffs = (diff(model[0].state_dict(), model_unloaded_state_dict))
+                diffs = diff(model[0].state_dict(), model_unloaded_state_dict)
                 # diffs[0] and diffs[1] is structural diff, diffs[2] is values diff - we expect only values diff
                 assert not diffs[0] and not diffs[1] and diffs[2]
                 assert not any(diff(optimizer.state_dict(), optim_unloaded_state_dict))
 
-
     def test_can_load_deprecated_bucket_space_format(self, tmp_path_dist_ckpt):
         # sync=True to make sure other ranks wait for rank 0 to finish creating directory.
         tp = 4
         pp = 2
 
         Utils.initialize_model_parallel(tp, pp)
-        with TempNamedDir(tmp_path_dist_ckpt / 'test_can_load_deprecated_bucket_space_format', sync=True) as ckpt_dir:
+        with TempNamedDir(
+            tmp_path_dist_ckpt / 'test_can_load_deprecated_bucket_space_format', sync=True
+        ) as ckpt_dir:
             mock_args = SimpleNamespace()
             with mock.patch('megatron.training.checkpointing.get_args', new=lambda: mock_args):
-                
+
                 init_basic_mock_args(mock_args, tp=tp, pp=pp)
                 init_checkpointing_mock_args(mock_args, ckpt_dir, True)
-                
-                model, optimizer = setup_model_and_optimizer(seed=2, tp=tp, pp=pp, initialize_fn=initialize_gpt_model)
+
+                model, optimizer = setup_model_and_optimizer(
+                    seed=2, tp=tp, pp=pp, initialize_fn=initialize_gpt_model
+                )
 
                 # Mock optimizer sharded_state_dict so that it ignores the externally passed sharding_type and uses 'fully_sharded_bucket_space' instead
                 orig_optim_sharded_state_dict_fn = optimizer.sharded_state_dict
-                def sharded_state_dict_bucket_space(self, *args, sharding_type: str = 'fully_sharded_model_space', **kwargs):
-                    return orig_optim_sharded_state_dict_fn(*args, sharding_type='fully_sharded_bucket_space', **kwargs)
 
-                optimizer.sharded_state_dict = MethodType(sharded_state_dict_bucket_space, optimizer)
+                def sharded_state_dict_bucket_space(
+                    self, *args, sharding_type: str = 'fully_sharded_model_space', **kwargs
+                ):
+                    return orig_optim_sharded_state_dict_fn(
+                        *args, sharding_type='fully_sharded_bucket_space', **kwargs
+                    )
+
+                optimizer.sharded_state_dict = MethodType(
+                    sharded_state_dict_bucket_space, optimizer
+                )
                 save_checkpoint(10, model, optimizer, None, 0)
 
                 flag = 0
@@ -348,30 +421,32 @@ def sharded_state_dict_bucket_space(self, *args, sharding_type: str = 'fully_sha
                 load_checkpoint_no_arg_checks(model, optimizer, None)
 
 
-
 class TestFP32Optimizer:
     def setup_method(self, method):
         pass
 
     def teardown_method(self, method):
-        Utils.destroy_model_parallel()   
+        Utils.destroy_model_parallel()
 
     @pytest.mark.parametrize(
-        ('src_tp_pp', 'dest_tp_pp'),
-        [
-            ((2, 4), (2, 4)),
-            ((2, 4), (4, 2)),
-            ((8, 1), (1, 2)),
-        ]
+        ('src_tp_pp', 'dest_tp_pp'), [((2, 4), (2, 4)), ((2, 4), (4, 2)), ((8, 1), (1, 2))]
     )
     def test_fp32_optimizer_resharding(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp):
         # sync=True to make sure other ranks wait for rank 0 to finish creating directory.
         Utils.initialize_model_parallel(*src_tp_pp)
-        with TempNamedDir(tmp_path_dist_ckpt / 'test_fp32_optimizer_state_dict_A', sync=True) as ckpt_dir_A:
-            with TempNamedDir(tmp_path_dist_ckpt / 'test_fp32_optimizer_state_dict_B', sync=True) as ckpt_dir_B:
-                
+        with TempNamedDir(
+            tmp_path_dist_ckpt / 'test_fp32_optimizer_state_dict_A', sync=True
+        ) as ckpt_dir_A:
+            with TempNamedDir(
+                tmp_path_dist_ckpt / 'test_fp32_optimizer_state_dict_B', sync=True
+            ) as ckpt_dir_B:
+
                 model_A, optimizer_A = setup_model_and_optimizer(
-                    seed=2, tp=src_tp_pp[0], pp=src_tp_pp[1], initialize_fn=initialize_small_model, bf16=False
+                    seed=2,
+                    tp=src_tp_pp[0],
+                    pp=src_tp_pp[1],
+                    initialize_fn=initialize_small_model,
+                    bf16=False,
                 )
 
                 save(optimizer_A.sharded_state_dict(model_A[0].sharded_state_dict()), ckpt_dir_A)
@@ -380,9 +455,15 @@ def test_fp32_optimizer_resharding(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_
                 # Load checkpoint A with different TP/PP and save as checkpoint B
                 Utils.initialize_model_parallel(*dest_tp_pp)
                 model_B, optimizer_B = setup_model_and_optimizer(
-                    seed=3, tp=dest_tp_pp[0], pp=dest_tp_pp[1], initialize_fn=initialize_small_model, bf16=False
+                    seed=3,
+                    tp=dest_tp_pp[0],
+                    pp=dest_tp_pp[1],
+                    initialize_fn=initialize_small_model,
+                    bf16=False,
+                )
+                load_sharded_state_dict = optimizer_B.sharded_state_dict(
+                    model_B[0].sharded_state_dict()
                 )
-                load_sharded_state_dict = optimizer_B.sharded_state_dict(model_B[0].sharded_state_dict())
                 state_dict = load(load_sharded_state_dict, ckpt_dir_A)
 
                 optimizer_B.load_state_dict(state_dict)
@@ -402,40 +483,47 @@ def setup_method(self, method):
         pass
 
     def teardown_method(self, method):
-        Utils.destroy_model_parallel()   
-    
+        Utils.destroy_model_parallel()
+
     @pytest.mark.parametrize(
         ('use_dist_opt', 'bf16'),
         (
             (False, True),  # regular BF16
-            (True, True),   # DistOpt BF16
+            (True, True),  # DistOpt BF16
             # (False, False), # FP32
-        )
+        ),
     )
     @pytest.mark.parametrize(
-        ('src_tp_pp', 'dest_tp_pp',),
-        [
-            ((2, 4), (2, 4)),
-            ((2, 4), (2, 2)),
-            ((2, 4), (4, 2)),
-            ((8, 1), (1, 2)),
-        ]
+        ('src_tp_pp', 'dest_tp_pp'),
+        [((2, 4), (2, 4)), ((2, 4), (2, 2)), ((2, 4), (4, 2)), ((8, 1), (1, 2))],
     )
     @pytest.mark.skip(reason="Tests are flaky and need to be debugged")
-    def test_optimizer_resharding(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, use_dist_opt, bf16):
+    def test_optimizer_resharding(
+        self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, use_dist_opt, bf16
+    ):
         Utils.initialize_model_parallel(*src_tp_pp)
-        with TempNamedDir(tmp_path_dist_ckpt / 'test_fp32_optimizer_state_dict_A', sync=False) as ckpt_dir_A:
-            with TempNamedDir(tmp_path_dist_ckpt / 'test_fp32_optimizer_state_dict_B', sync=False) as ckpt_dir_B:
-                
-                model_A, optimizer_A = setup_model_and_optimizer(seed=2, tp=src_tp_pp[0], pp=src_tp_pp[1], bf16=bf16, dist_opt=use_dist_opt)
+        with TempNamedDir(
+            tmp_path_dist_ckpt / 'test_fp32_optimizer_state_dict_A', sync=False
+        ) as ckpt_dir_A:
+            with TempNamedDir(
+                tmp_path_dist_ckpt / 'test_fp32_optimizer_state_dict_B', sync=False
+            ) as ckpt_dir_B:
+
+                model_A, optimizer_A = setup_model_and_optimizer(
+                    seed=2, tp=src_tp_pp[0], pp=src_tp_pp[1], bf16=bf16, dist_opt=use_dist_opt
+                )
 
                 save(optimizer_A.sharded_state_dict(model_A[0].sharded_state_dict()), ckpt_dir_A)
                 Utils.destroy_model_parallel()
 
                 # Load checkpoint A with different TP/PP and save as checkpoint B
                 Utils.initialize_model_parallel(*dest_tp_pp)
-                model_B, optimizer_B = setup_model_and_optimizer(seed=3, tp=dest_tp_pp[0], pp=dest_tp_pp[1], bf16=bf16, dist_opt=use_dist_opt)
-                load_sharded_state_dict = optimizer_B.sharded_state_dict(model_B[0].sharded_state_dict())
+                model_B, optimizer_B = setup_model_and_optimizer(
+                    seed=3, tp=dest_tp_pp[0], pp=dest_tp_pp[1], bf16=bf16, dist_opt=use_dist_opt
+                )
+                load_sharded_state_dict = optimizer_B.sharded_state_dict(
+                    model_B[0].sharded_state_dict()
+                )
                 state_dict = load(load_sharded_state_dict, ckpt_dir_A)
 
                 optimizer_B.load_state_dict(state_dict)
diff --git a/tests/unit_tests/dist_checkpointing/test_serialization.py b/tests/unit_tests/dist_checkpointing/test_serialization.py
index 6c625f11d3..19e99de553 100644
--- a/tests/unit_tests/dist_checkpointing/test_serialization.py
+++ b/tests/unit_tests/dist_checkpointing/test_serialization.py
@@ -9,18 +9,16 @@
 from torch.distributed.checkpoint import CheckpointException as PyTCheckpointingException
 
 from megatron.core import parallel_state
-from megatron.core.dist_checkpointing import ShardedTensor, save, load
-from megatron.core.dist_checkpointing.core import CheckpointingException, \
-    maybe_load_config
+from megatron.core.dist_checkpointing import ShardedTensor, load, save
+from megatron.core.dist_checkpointing.core import CheckpointingException, maybe_load_config
 from megatron.core.dist_checkpointing.dict_utils import diff
-from megatron.core.dist_checkpointing.mapping import ShardedTensorFactory, \
-    ShardedObject
-from megatron.core.dist_checkpointing.serialization import \
-    load_tensors_metadata, load_sharded_metadata
-from megatron.core.dist_checkpointing.strategies.base import StrategyAction, \
-    get_default_strategy
+from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedTensorFactory
+from megatron.core.dist_checkpointing.serialization import (
+    load_sharded_metadata,
+    load_tensors_metadata,
+)
+from megatron.core.dist_checkpointing.strategies.base import StrategyAction, get_default_strategy
 from megatron.core.dist_checkpointing.validation import StrictHandling
-
 from tests.unit_tests.dist_checkpointing import TempNamedDir
 from tests.unit_tests.test_utilities import Utils
 
@@ -30,18 +28,24 @@ def setup_method(self, method):
         pass
 
     def teardown_method(self, method):
-        Utils.destroy_model_parallel()  
+        Utils.destroy_model_parallel()
 
     def test_single_process_save_load(self, tmp_path_dist_ckpt):
-        Utils.initialize_model_parallel(1,1)
+        Utils.initialize_model_parallel(1, 1)
 
         sharded_state_dict = {
-            'sd_keyA': ShardedTensor.from_rank_offsets('keyA', torch.ones(2, 4), replica_id=Utils.rank),
-            'sd_keyB': ShardedTensor.from_rank_offsets('keyB', torch.ones(3, 5, 7), replica_id=Utils.rank),
+            'sd_keyA': ShardedTensor.from_rank_offsets(
+                'keyA', torch.ones(2, 4), replica_id=Utils.rank
+            ),
+            'sd_keyB': ShardedTensor.from_rank_offsets(
+                'keyB', torch.ones(3, 5, 7), replica_id=Utils.rank
+            ),
         }
 
         # sync=True to make sure other ranks wait for rank 0 to finish creating directory.
-        with TempNamedDir(tmp_path_dist_ckpt / 'test_single_process_save_load', sync=True) as ckpt_dir:
+        with TempNamedDir(
+            tmp_path_dist_ckpt / 'test_single_process_save_load', sync=True
+        ) as ckpt_dir:
             save(sharded_state_dict, ckpt_dir)
             torch.distributed.barrier()
 
@@ -53,23 +57,28 @@ def test_single_process_save_load(self, tmp_path_dist_ckpt):
                 assert not (ckpt_dir / 'sd_keyA').is_dir()
 
             load_ssd = {
-                'load_sd_keyA': ShardedTensor.from_rank_offsets('keyA', torch.ones(2, 4), replica_id=Utils.rank),
+                'load_sd_keyA': ShardedTensor.from_rank_offsets(
+                    'keyA', torch.ones(2, 4), replica_id=Utils.rank
+                )
             }
             loaded_state_dict = load(load_ssd, ckpt_dir)
-            
+
             assert set(loaded_state_dict.keys()) == {'load_sd_keyA'}
             assert isinstance(loaded_state_dict['load_sd_keyA'], torch.Tensor)
             assert loaded_state_dict['load_sd_keyA'].shape == (2, 4)
 
         Utils.destroy_model_parallel()
 
-
     def test_multi_process_save(self, tmp_path_dist_ckpt):
-        Utils.initialize_model_parallel(2,4)
+        Utils.initialize_model_parallel(2, 4)
 
         state_dict = {
-            'sd_keyA': ShardedTensor.from_rank_offsets('keyA', torch.ones(2, 4), (0, Utils.rank, Utils.world_size)),
-            'sd_keyB': ShardedTensor.from_rank_offsets('keyB', torch.ones(3, 5, 7), (2, Utils.rank, Utils.world_size)),
+            'sd_keyA': ShardedTensor.from_rank_offsets(
+                'keyA', torch.ones(2, 4), (0, Utils.rank, Utils.world_size)
+            ),
+            'sd_keyB': ShardedTensor.from_rank_offsets(
+                'keyB', torch.ones(3, 5, 7), (2, Utils.rank, Utils.world_size)
+            ),
         }
 
         # sync=True to make sure other ranks wait for rank 0 to finish creating directory.
@@ -85,13 +94,16 @@ def test_multi_process_save(self, tmp_path_dist_ckpt):
 
         Utils.destroy_model_parallel()
 
-
     def test_partition_change_save_load(self, tmp_path_dist_ckpt, strategy=None):
-        Utils.initialize_model_parallel(2,4)
+        Utils.initialize_model_parallel(2, 4)
 
         # ten_a: global shape (2, 4):
         ten_a_global = torch.tensor([[0, 1, 2, 3], [10, 11, 12, 13]])
-        ten_a = torch.zeros(1, 1) + 10 * parallel_state.get_tensor_model_parallel_rank() + parallel_state.get_pipeline_model_parallel_rank()
+        ten_a = (
+            torch.zeros(1, 1)
+            + 10 * parallel_state.get_tensor_model_parallel_rank()
+            + parallel_state.get_pipeline_model_parallel_rank()
+        )
         assert ten_a.shape == (1, 1)
 
         # ten_b: global shape (4, 5, 80), where (x, y, z) is (100x + z)
@@ -100,11 +112,24 @@ def test_partition_change_save_load(self, tmp_path_dist_ckpt, strategy=None):
         assert ten_b.shape == (4, 5, 10)
 
         state_dict = {
-            'sd_keyA': ShardedTensor.from_rank_offsets('keyA', ten_a,
-                                                       (0, parallel_state.get_tensor_model_parallel_rank(), parallel_state.get_tensor_model_parallel_world_size()),
-                                                       (1, parallel_state.get_pipeline_model_parallel_rank(), parallel_state.get_pipeline_model_parallel_world_size()),
-                                                       replica_id=0),
-            'sd_keyB': ShardedTensor.from_rank_offsets('keyB', ten_b, (2, Utils.rank, Utils.world_size)),
+            'sd_keyA': ShardedTensor.from_rank_offsets(
+                'keyA',
+                ten_a,
+                (
+                    0,
+                    parallel_state.get_tensor_model_parallel_rank(),
+                    parallel_state.get_tensor_model_parallel_world_size(),
+                ),
+                (
+                    1,
+                    parallel_state.get_pipeline_model_parallel_rank(),
+                    parallel_state.get_pipeline_model_parallel_world_size(),
+                ),
+                replica_id=0,
+            ),
+            'sd_keyB': ShardedTensor.from_rank_offsets(
+                'keyB', ten_b, (2, Utils.rank, Utils.world_size)
+            ),
         }
 
         ten_a_global_shape = ten_a_global.shape
@@ -115,19 +140,21 @@ def test_partition_change_save_load(self, tmp_path_dist_ckpt, strategy=None):
         assert state_dict['sd_keyB'].global_shape == ten_b_global_shape
 
         # sync=True to make sure other ranks wait for rank 0 to finish creating directory.
-        with TempNamedDir(tmp_path_dist_ckpt / 'test_partition_change_save_load', sync=True) as ckpt_dir:
+        with TempNamedDir(
+            tmp_path_dist_ckpt / 'test_partition_change_save_load', sync=True
+        ) as ckpt_dir:
             save(state_dict, ckpt_dir, strategy)
 
             del ten_a, ten_b
 
             # without changing TPxPP, load tensors without any sharding
             load_sd = {
-                'sd_keyA': ShardedTensor.from_rank_offsets('keyA',
-                                                           torch.empty(ten_a_global_shape),
-                                                           replica_id=Utils.rank),
-                'sd_keyB': ShardedTensor.from_rank_offsets('keyB',
-                                                           torch.empty(ten_b_global_shape),
-                                                           replica_id=Utils.rank),
+                'sd_keyA': ShardedTensor.from_rank_offsets(
+                    'keyA', torch.empty(ten_a_global_shape), replica_id=Utils.rank
+                ),
+                'sd_keyB': ShardedTensor.from_rank_offsets(
+                    'keyB', torch.empty(ten_b_global_shape), replica_id=Utils.rank
+                ),
             }
             loaded_state_dict = load(load_sd, ckpt_dir)
 
@@ -139,27 +166,39 @@ def test_partition_change_save_load(self, tmp_path_dist_ckpt, strategy=None):
 
             assert isinstance(ten_b, torch.Tensor)
             assert ten_b.shape == ten_b_global_shape
-            assert np.all([
-                val == 100 * x + z
-                for x, x_row in enumerate(ten_b)
-                for y, y_row in enumerate(x_row)
-                for z, val in enumerate(y_row)
-            ])
+            assert np.all(
+                [
+                    val == 100 * x + z
+                    for x, x_row in enumerate(ten_b)
+                    for y, y_row in enumerate(x_row)
+                    for z, val in enumerate(y_row)
+                ]
+            )
 
             del ten_a, ten_b
 
             # change TPxPP
             Utils.destroy_model_parallel()
-            Utils.initialize_model_parallel(1,2)
+            Utils.initialize_model_parallel(1, 2)
 
             load_sd = {
-                'sd_keyA': ShardedTensor.from_rank_offsets('keyA', torch.empty(2, 1),
-                                                           (1, parallel_state.get_data_parallel_rank(), parallel_state.get_data_parallel_world_size()),
-                                                           replica_id=parallel_state.get_pipeline_model_parallel_rank()),
-                'sd_keyB': ShardedTensor.from_rank_offsets('keyB', torch.empty(5, 80),
-                                                           (0, Utils.rank // 2, 4),
-                                                           prepend_axis_num=1,
-                                                           replica_id=Utils.rank % 2),
+                'sd_keyA': ShardedTensor.from_rank_offsets(
+                    'keyA',
+                    torch.empty(2, 1),
+                    (
+                        1,
+                        parallel_state.get_data_parallel_rank(),
+                        parallel_state.get_data_parallel_world_size(),
+                    ),
+                    replica_id=parallel_state.get_pipeline_model_parallel_rank(),
+                ),
+                'sd_keyB': ShardedTensor.from_rank_offsets(
+                    'keyB',
+                    torch.empty(5, 80),
+                    (0, Utils.rank // 2, 4),
+                    prepend_axis_num=1,
+                    replica_id=Utils.rank % 2,
+                ),
             }
 
             loaded_state_dict = load(load_sd, ckpt_dir)
@@ -168,18 +207,26 @@ def test_partition_change_save_load(self, tmp_path_dist_ckpt, strategy=None):
 
             assert isinstance(ten_a, torch.Tensor)
             assert ten_a.shape == (2, 1)
-            assert torch.all(ten_a[:, 0] == ten_a_global[:, parallel_state.get_data_parallel_rank()])
+            assert torch.all(
+                ten_a[:, 0] == ten_a_global[:, parallel_state.get_data_parallel_rank()]
+            )
 
             assert isinstance(ten_b, torch.Tensor)
             assert ten_b.shape == (5, 10 * 8)
-            assert torch.all(ten_b == torch.arange(80).unsqueeze(0).expand(5, 80) + Utils.rank // 2 * 100)
+            assert torch.all(
+                ten_b == torch.arange(80).unsqueeze(0).expand(5, 80) + Utils.rank // 2 * 100
+            )
 
     def test_load_tensors_metadata(self, tmp_path_dist_ckpt):
-        Utils.initialize_model_parallel(2,4)
+        Utils.initialize_model_parallel(2, 4)
 
         state_dict = {
-            'sd_keyA': ShardedTensor.from_rank_offsets('keyA', torch.arange(10) + Utils.rank * 10, (0, Utils.rank, Utils.world_size)),
-            'sd_keyB': ShardedTensor.from_rank_offsets('keyB', torch.ones(3, 5, 7), (2, Utils.rank, Utils.world_size)),
+            'sd_keyA': ShardedTensor.from_rank_offsets(
+                'keyA', torch.arange(10) + Utils.rank * 10, (0, Utils.rank, Utils.world_size)
+            ),
+            'sd_keyB': ShardedTensor.from_rank_offsets(
+                'keyB', torch.ones(3, 5, 7), (2, Utils.rank, Utils.world_size)
+            ),
         }
 
         # sync=True to make sure other ranks wait for rank 0 to finish creating directory.
@@ -223,15 +270,27 @@ def _build_fn(key, tensor, replica_id, flattened_range):
 
         # state dict can be modified by dist_checkpointing.save, so two copies
         def get_sharded_state_dict(base=0):
-            return {'all': [
-                ShardedTensor.from_rank_offsets('A', torch.arange(2) + base, replica_id=Utils.rank),
-                ShardedTensor.from_rank_offsets('B', torch.arange(3) + base, replica_id=Utils.rank),
-                ShardedTensor.from_rank_offsets('C', torch.arange(4) + base, replica_id=Utils.rank),
-                ShardedTensorFactory('D', torch.arange(5) + base, _build_fn, sum, replica_id=Utils.rank),
-            ]}
+            return {
+                'all': [
+                    ShardedTensor.from_rank_offsets(
+                        'A', torch.arange(2) + base, replica_id=Utils.rank
+                    ),
+                    ShardedTensor.from_rank_offsets(
+                        'B', torch.arange(3) + base, replica_id=Utils.rank
+                    ),
+                    ShardedTensor.from_rank_offsets(
+                        'C', torch.arange(4) + base, replica_id=Utils.rank
+                    ),
+                    ShardedTensorFactory(
+                        'D', torch.arange(5) + base, _build_fn, sum, replica_id=Utils.rank
+                    ),
+                ]
+            }
 
         # sync=True to make sure other ranks wait for rank 0 to finish creating directory.
-        with TempNamedDir(tmp_path_dist_ckpt / 'test_can_mix_sharded_tensors_and_factories', sync=True) as ckpt_dir:
+        with TempNamedDir(
+            tmp_path_dist_ckpt / 'test_can_mix_sharded_tensors_and_factories', sync=True
+        ) as ckpt_dir:
             save(get_sharded_state_dict(0), ckpt_dir)
             loaded_state_dict = load(get_sharded_state_dict(10), ckpt_dir)
 
@@ -282,16 +341,22 @@ def test_sharded_object_serialization(self, tmp_path_dist_ckpt):
             state = {'some': 'dict'}
             state_serialized = io.BytesIO()
             torch.save(state, state_serialized)
-            state_dict = {'some_key': ShardedObject('sh_obj_A', state_serialized, (1,), (0,),
-                                                    replica_id=Utils.rank)}
+            state_dict = {
+                'some_key': ShardedObject(
+                    'sh_obj_A', state_serialized, (1,), (0,), replica_id=Utils.rank
+                )
+            }
 
             save(state_dict, ckpt_dir)
             del state, state_serialized, state_dict
             other_state = {'other': 'dictionary'}
             other_serialized = io.BytesIO()
             torch.save(other_state, other_serialized)
-            state_dict = {'other_key': ShardedObject('sh_obj_A', other_serialized, (1,), (0,),
-                                                     replica_id=Utils.rank)}
+            state_dict = {
+                'other_key': ShardedObject(
+                    'sh_obj_A', other_serialized, (1,), (0,), replica_id=Utils.rank
+                )
+            }
             load_state_dict = load(state_dict, ckpt_dir)
             assert 'other_key' in load_state_dict
             load_state_dict['other_key'].seek(0)
@@ -302,15 +367,18 @@ def test_sharded_object_serialization(self, tmp_path_dist_ckpt):
         Utils.destroy_model_parallel()
 
     def test_tensor_shape_mismatch(self, tmp_path_dist_ckpt):
-        Utils.initialize_model_parallel(2,4)
+        Utils.initialize_model_parallel(2, 4)
 
         # Global tensor is just a range(32) repeated twice over the first dimension
         local_tensor = torch.arange(4).unsqueeze(0).expand(2, 4) + Utils.rank * 4
 
         state_dict = {
-            'rigid': ShardedTensor.from_rank_offsets('keyA', local_tensor, (1, Utils.rank, Utils.world_size)),
-            'flexible': ShardedTensor.from_rank_offsets('keyB', local_tensor, (1, Utils.rank, Utils.world_size),
-                                                        allow_shape_mismatch=True),
+            'rigid': ShardedTensor.from_rank_offsets(
+                'keyA', local_tensor, (1, Utils.rank, Utils.world_size)
+            ),
+            'flexible': ShardedTensor.from_rank_offsets(
+                'keyB', local_tensor, (1, Utils.rank, Utils.world_size), allow_shape_mismatch=True
+            ),
         }
         assert state_dict['rigid'].global_shape == (2, 32)
         assert state_dict['flexible'].global_shape == (2, 32)
@@ -325,28 +393,45 @@ def test_tensor_shape_mismatch(self, tmp_path_dist_ckpt):
 
             # Smaller coverage than expected (28 < 32)
             state_dict = {
-                'rigid': ShardedTensor.from_rank_offsets('keyA', torch.ones(2, 7), (1, pp_rank, pp_size), replica_id=tp_rank),
+                'rigid': ShardedTensor.from_rank_offsets(
+                    'keyA', torch.ones(2, 7), (1, pp_rank, pp_size), replica_id=tp_rank
+                )
             }
             with pytest.raises((CheckpointingException, PyTCheckpointingException)):
                 load(state_dict, ckpt_dir)
 
             state_dict = {
-                'flexible': ShardedTensor.from_rank_offsets('keyB', torch.ones(2, 7), (1, pp_rank, pp_size), replica_id=tp_rank,
-                                                            allow_shape_mismatch=True),
+                'flexible': ShardedTensor.from_rank_offsets(
+                    'keyB',
+                    torch.ones(2, 7),
+                    (1, pp_rank, pp_size),
+                    replica_id=tp_rank,
+                    allow_shape_mismatch=True,
+                )
             }
             loaded_state_dict = load(state_dict, ckpt_dir)
-            assert torch.all(loaded_state_dict['flexible'] == torch.arange(7).unsqueeze(0).expand(2, 7) + pp_rank * 7)
+            assert torch.all(
+                loaded_state_dict['flexible']
+                == torch.arange(7).unsqueeze(0).expand(2, 7) + pp_rank * 7
+            )
 
             # Larger coverage than expected (36 > 32)
             state_dict = {
-                'rigid': ShardedTensor.from_rank_offsets('keyA', torch.ones(2, 9), (1, pp_rank, pp_size), replica_id=tp_rank),
+                'rigid': ShardedTensor.from_rank_offsets(
+                    'keyA', torch.ones(2, 9), (1, pp_rank, pp_size), replica_id=tp_rank
+                )
             }
             with pytest.raises((CheckpointingException, PyTCheckpointingException)):
                 load(state_dict, ckpt_dir)
 
             state_dict = {
-                'flexible': ShardedTensor.from_rank_offsets('keyB', torch.ones(2, 9), (1, pp_rank, pp_size), replica_id=tp_rank,
-                                                            allow_shape_mismatch=True),
+                'flexible': ShardedTensor.from_rank_offsets(
+                    'keyB',
+                    torch.ones(2, 9),
+                    (1, pp_rank, pp_size),
+                    replica_id=tp_rank,
+                    allow_shape_mismatch=True,
+                )
             }
             loaded_state_dict = load(state_dict, ckpt_dir)
             expected_tensor = torch.arange(9).unsqueeze(0).expand(2, 9) + pp_rank * 9
@@ -369,25 +454,44 @@ def teardown_method(self, method):
     def _get_base_state_dict(self):
         return {
             'TenA': ShardedTensor.from_rank_offsets('TenA', torch.arange(2), replica_id=Utils.rank),
-            'TenB': ShardedTensor.from_rank_offsets('TenB', torch.arange(3), (0, Utils.rank, Utils.world_size), replica_id=0),
-            'TenC': ShardedTensor.from_rank_offsets('TenC', torch.arange(3), replica_id=Utils.world_size - Utils.rank - 1),
+            'TenB': ShardedTensor.from_rank_offsets(
+                'TenB', torch.arange(3), (0, Utils.rank, Utils.world_size), replica_id=0
+            ),
+            'TenC': ShardedTensor.from_rank_offsets(
+                'TenC', torch.arange(3), replica_id=Utils.world_size - Utils.rank - 1
+            ),
             'ObjA': ShardedObject('ObjA', list(range(10)), (1,), (0,), replica_id=Utils.rank),
-            'ObjB': ShardedObject('ObjB', {Utils.rank + 7}, (1, Utils.world_size), (0, Utils.rank), replica_id=0),
+            'ObjB': ShardedObject(
+                'ObjB', {Utils.rank + 7}, (1, Utils.world_size), (0, Utils.rank), replica_id=0
+            ),
         }
 
     @pytest.mark.parametrize('save_format', ['zarr', 'torch_dist'])
     @pytest.mark.parametrize('validate_integrity', [True, False])
-    def test_unexpected_keys_handling_during_validation(self, caplog, tmp_path_dist_ckpt, validate_integrity, save_format):
+    def test_unexpected_keys_handling_during_validation(
+        self, caplog, tmp_path_dist_ckpt, validate_integrity, save_format
+    ):
         sharded_state_dict = self._get_base_state_dict()
-        with TempNamedDir(tmp_path_dist_ckpt / 'test_unexpected_keys_raises_error_during_validation') as ckpt_dir:
+        with TempNamedDir(
+            tmp_path_dist_ckpt / 'test_unexpected_keys_raises_error_during_validation'
+        ) as ckpt_dir:
             save_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, save_format, 1)
             save(sharded_state_dict, ckpt_dir, save_strategy)
 
             def load_with_flag(strict):
                 sharded_state_dict = self._get_base_state_dict()
-                sharded_state_dict['TenD'] = ShardedTensor.from_rank_offsets('UnexpectedTenD', torch.arange(3), replica_id=Utils.rank)
-                sharded_state_dict['ObjD'] = ShardedObject('UnexpectedObjD', None, (1,), (0,), replica_id=Utils.rank)
-                return load(sharded_state_dict, ckpt_dir, validate_access_integrity=validate_integrity, strict=strict)
+                sharded_state_dict['TenD'] = ShardedTensor.from_rank_offsets(
+                    'UnexpectedTenD', torch.arange(3), replica_id=Utils.rank
+                )
+                sharded_state_dict['ObjD'] = ShardedObject(
+                    'UnexpectedObjD', None, (1,), (0,), replica_id=Utils.rank
+                )
+                return load(
+                    sharded_state_dict,
+                    ckpt_dir,
+                    validate_access_integrity=validate_integrity,
+                    strict=strict,
+                )
 
             def test_error(error_msg):
                 assert 'Unexpected keys' in error_msg
@@ -396,7 +500,9 @@ def test_error(error_msg):
                 assert 'Missing keys' not in error_msg
 
             # ASSUME_OK_UNEXPECTED results in an exception raised by the underlying strategy
-            with pytest.raises(PyTCheckpointingException if save_format == 'torch_dist' else CheckpointingException) as exc_info:
+            with pytest.raises(
+                PyTCheckpointingException if save_format == 'torch_dist' else CheckpointingException
+            ) as exc_info:
                 load_with_flag(StrictHandling.ASSUME_OK_UNEXPECTED)
             # Informative exceptions with `RAISE_*` options:
             with pytest.raises(CheckpointingException) as exc_info:
@@ -417,11 +523,15 @@ def test_error(error_msg):
             test_error(caplog.text)
 
             # Returned mismatches
-            loaded_state_dict, missing_keys, unexpected_keys = load_with_flag(StrictHandling.RETURN_UNEXPECTED)
+            loaded_state_dict, missing_keys, unexpected_keys = load_with_flag(
+                StrictHandling.RETURN_UNEXPECTED
+            )
             assert 'TenA' in loaded_state_dict
             assert unexpected_keys == {'UnexpectedTenD', 'UnexpectedObjD'}
             assert missing_keys == set()
-            loaded_state_dict, missing_keys, unexpected_keys = load_with_flag(StrictHandling.RETURN_ALL)
+            loaded_state_dict, missing_keys, unexpected_keys = load_with_flag(
+                StrictHandling.RETURN_ALL
+            )
             assert 'TenA' in loaded_state_dict
             assert unexpected_keys == {'UnexpectedTenD', 'UnexpectedObjD'}
             assert missing_keys == set()
@@ -432,9 +542,13 @@ def test_error(error_msg):
 
     @pytest.mark.parametrize('save_format', ['zarr', 'torch_dist'])
     @pytest.mark.parametrize('validate_integrity', [True, False])
-    def test_missing_keys_raises_error_during_validation(self, caplog, tmp_path_dist_ckpt, validate_integrity, save_format):
+    def test_missing_keys_raises_error_during_validation(
+        self, caplog, tmp_path_dist_ckpt, validate_integrity, save_format
+    ):
         sharded_state_dict = self._get_base_state_dict()
-        with TempNamedDir(tmp_path_dist_ckpt / 'test_missing_keys_raises_error_during_validation') as ckpt_dir:
+        with TempNamedDir(
+            tmp_path_dist_ckpt / 'test_missing_keys_raises_error_during_validation'
+        ) as ckpt_dir:
             save_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, save_format, 1)
             save(sharded_state_dict, ckpt_dir, save_strategy)
 
@@ -442,7 +556,12 @@ def load_with_flag(strict):
                 sharded_state_dict = self._get_base_state_dict()
                 del sharded_state_dict['TenA']
                 del sharded_state_dict['ObjB']
-                return load(sharded_state_dict, ckpt_dir, validate_access_integrity=validate_integrity, strict=strict)
+                return load(
+                    sharded_state_dict,
+                    ckpt_dir,
+                    validate_access_integrity=validate_integrity,
+                    strict=strict,
+                )
 
             def test_error(error_msg):
                 assert 'Unexpected keys' not in error_msg
@@ -459,10 +578,15 @@ def test_error(error_msg):
 
             with caplog.at_level(logging.WARNING):
                 loaded_state_dict = load_with_flag(StrictHandling.LOG_UNEXPECTED)
-            assert caplog.text == '' or '`zarr` distributed checkpoint backend is deprecated' in caplog.text
+            assert (
+                caplog.text == ''
+                or '`zarr` distributed checkpoint backend is deprecated' in caplog.text
+            )
             assert 'TenB' in loaded_state_dict
 
-            loaded_state_dict, missing_keys, unexpected_keys = load_with_flag(StrictHandling.RETURN_UNEXPECTED)
+            loaded_state_dict, missing_keys, unexpected_keys = load_with_flag(
+                StrictHandling.RETURN_UNEXPECTED
+            )
             assert 'TenB' in loaded_state_dict
             assert missing_keys == set()
             assert unexpected_keys == set()
@@ -482,7 +606,9 @@ def test_error(error_msg):
             test_error(caplog.text)
 
             # Returned mismatches
-            loaded_state_dict, missing_keys, unexpected_keys = load_with_flag(StrictHandling.RETURN_ALL)
+            loaded_state_dict, missing_keys, unexpected_keys = load_with_flag(
+                StrictHandling.RETURN_ALL
+            )
             assert 'TenB' in loaded_state_dict
             assert unexpected_keys == set()
             assert missing_keys == {'TenA', 'ObjB'}
@@ -497,7 +623,12 @@ def test_exact_load_handling(self, caplog, tmp_path_dist_ckpt, validate_integrit
 
             def load_with_flag(strict):
                 sharded_state_dict = self._get_base_state_dict()
-                return load(sharded_state_dict, ckpt_dir, validate_access_integrity=validate_integrity, strict=strict)
+                return load(
+                    sharded_state_dict,
+                    ckpt_dir,
+                    validate_access_integrity=validate_integrity,
+                    strict=strict,
+                )
 
             for strict in (
                 StrictHandling.ASSUME_OK_UNEXPECTED,
@@ -509,17 +640,20 @@ def load_with_flag(strict):
             ):
                 with caplog.at_level(logging.WARNING):
                     loaded_state_dict = load_with_flag(strict)
-                assert caplog.text == '' or '`zarr` distributed checkpoint backend is deprecated' in caplog.text
+                assert (
+                    caplog.text == ''
+                    or '`zarr` distributed checkpoint backend is deprecated' in caplog.text
+                )
                 assert 'TenB' in loaded_state_dict
                 assert 'ObjB' in loaded_state_dict
 
-            for strict in (
-                StrictHandling.RETURN_UNEXPECTED,
-                StrictHandling.RETURN_ALL,
-            ):
+            for strict in (StrictHandling.RETURN_UNEXPECTED, StrictHandling.RETURN_ALL):
                 with caplog.at_level(logging.WARNING):
                     loaded_state_dict, missing_keys, unexpected_keys = load_with_flag(strict)
-                assert caplog.text == '' or '`zarr` distributed checkpoint backend is deprecated' in caplog.text
+                assert (
+                    caplog.text == ''
+                    or '`zarr` distributed checkpoint backend is deprecated' in caplog.text
+                )
                 assert 'TenB' in loaded_state_dict
                 assert 'ObjB' in loaded_state_dict
                 assert missing_keys == set()
@@ -534,9 +668,17 @@ def test_sharded_metadata(self, tmp_path_dist_ckpt, save_format):
             save(sharded_state_dict, ckpt_dir, save_strategy)
             torch.distributed.barrier()
             sharded_metadata = load_sharded_metadata(ckpt_dir)
-            assert set(sh_base.key for sh_base in sharded_metadata.values()) == {'TenA', 'TenB', 'TenC', 'ObjA', 'ObjB'}
+            assert set(sh_base.key for sh_base in sharded_metadata.values()) == {
+                'TenA',
+                'TenB',
+                'TenC',
+                'ObjA',
+                'ObjB',
+            }
             assert set(sharded_metadata.keys()) == {
-                'TenA', 'TenB', 'TenC',
+                'TenA',
+                'TenB',
+                'TenC',
                 'ObjA/shard_0_1',
                 *(f'ObjB/shard_0.{i}_1.8' for i in range(8)),
             }
diff --git a/tests/unit_tests/dist_checkpointing/utils.py b/tests/unit_tests/dist_checkpointing/utils.py
index 5b2b4aa3eb..c4532b7f4a 100644
--- a/tests/unit_tests/dist_checkpointing/utils.py
+++ b/tests/unit_tests/dist_checkpointing/utils.py
@@ -3,6 +3,7 @@
 from unittest import mock
 
 import torch
+
 from megatron.core.models.gpt import GPTModel
 from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
 from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer
@@ -16,7 +17,9 @@
 NUM_ATTENTION_HEADS = 8
 
 
-def initialize_gpt_model(pre_process=True, post_process=True, seed=0, use_glu=True, **config_kwargs):
+def initialize_gpt_model(
+    pre_process=True, post_process=True, seed=0, use_glu=True, **config_kwargs
+):
     torch.manual_seed(seed)
     model_parallel_cuda_manual_seed(seed)
 
@@ -59,6 +62,7 @@ def init_basic_mock_args(args, tp, pp, bf16=True):
     args.pipeline_model_parallel_size = pp
     return args
 
+
 def init_checkpointing_mock_args(args, ckpt_dir, fully_parallel=False):
     args.non_persistent_global_ckpt_dir = None
     args.non_persistent_ckpt_type = None
@@ -90,15 +94,28 @@ def init_checkpointing_mock_args(args, ckpt_dir, fully_parallel=False):
     args.hidden_size = HIDDEN_SIZE
     args.num_attention_heads = NUM_ATTENTION_HEADS
 
-def setup_model_and_optimizer(seed, tp, pp, initialize_fn=initialize_gpt_model, bf16=True, dist_opt=True):
+
+def setup_model_and_optimizer(
+    seed, tp, pp, initialize_fn=initialize_gpt_model, bf16=True, dist_opt=True
+):
     mock_args = SimpleNamespace()
     with mock.patch('megatron.training.training.get_args', new=lambda: mock_args):
         init_basic_mock_args(mock_args, tp, pp, bf16=bf16)
-        model = get_model(partial(
-            initialize_fn, seed=seed, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp, pipeline_dtype=torch.bfloat16
-        ))
+        model = get_model(
+            partial(
+                initialize_fn,
+                seed=seed,
+                tensor_model_parallel_size=tp,
+                pipeline_model_parallel_size=pp,
+                pipeline_dtype=torch.bfloat16,
+            )
+        )
 
-    config = OptimizerConfig(bf16=bf16, params_dtype=torch.bfloat16 if bf16 else torch.float, use_distributed_optimizer=dist_opt)
+    config = OptimizerConfig(
+        bf16=bf16,
+        params_dtype=torch.bfloat16 if bf16 else torch.float,
+        use_distributed_optimizer=dist_opt,
+    )
     optimizer = get_megatron_optimizer(config, model)
 
     torch.manual_seed(seed + 1)
diff --git a/tests/unit_tests/distributed/test_param_and_grad_buffer.py b/tests/unit_tests/distributed/test_param_and_grad_buffer.py
index 14d3be7071..f070303177 100644
--- a/tests/unit_tests/distributed/test_param_and_grad_buffer.py
+++ b/tests/unit_tests/distributed/test_param_and_grad_buffer.py
@@ -1,11 +1,12 @@
 import contextlib
 import math
+
 import pytest
 import torch
 
 from megatron.core import parallel_state
 from megatron.core.distributed import DistributedDataParallelConfig, ParamAndGradBuffer
-from tests.unit_tests.test_utilities import Utils, TestModel
+from tests.unit_tests.test_utilities import TestModel, Utils
 
 
 def get_model_and_buffers(
diff --git a/tests/unit_tests/fusions/test_torch_softmax.py b/tests/unit_tests/fusions/test_torch_softmax.py
index 504bb0b48d..63b0bc7b5d 100644
--- a/tests/unit_tests/fusions/test_torch_softmax.py
+++ b/tests/unit_tests/fusions/test_torch_softmax.py
@@ -19,10 +19,10 @@ def setup_method(self, method):
             softmax_in_fp32=True,
             scale=None,
         )
-    
+
     def teardown_method(self):
-        get_default_causal_mask.cache_clear() 
-    
+        get_default_causal_mask.cache_clear()
+
     def test_output_shape(self):
         x = torch.randn(8, 2, 4, 4, device="cuda")
         y = self.softmax(x, None)
diff --git a/tests/unit_tests/inference/engines/test_mcore_engine.py b/tests/unit_tests/inference/engines/test_mcore_engine.py
index 1c8568feea..161284ceeb 100644
--- a/tests/unit_tests/inference/engines/test_mcore_engine.py
+++ b/tests/unit_tests/inference/engines/test_mcore_engine.py
@@ -1,52 +1,72 @@
+import random
+import string
 from typing import List
-from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig
+from unittest import mock
+
 import torch
-import random 
-import string
 
 from megatron.core.inference.common_inference_params import CommonInferenceParams
 from megatron.core.inference.engines.mcore_engine import MCoreEngine
-from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper
 from megatron.core.inference.inference_request import InferenceRequest, Status
-from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import SimpleTextGenerationController
+from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
+    GPTInferenceWrapper,
+)
+from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
+    InferenceWrapperConfig,
+)
+from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import (
+    SimpleTextGenerationController,
+)
 from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
 from megatron.core.models.gpt.gpt_model import GPTModel
 from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
 from megatron.core.transformer.transformer_config import TransformerConfig
 from tests.unit_tests.test_utilities import Utils
-from unittest import mock
+
 
 class TestMCoreEngine:
     def setup_method(self, method):
-        Utils.initialize_model_parallel(tensor_model_parallel_size=1,pipeline_model_parallel_size=1)
-        model_parallel_cuda_manual_seed(123)          
+        Utils.initialize_model_parallel(
+            tensor_model_parallel_size=1, pipeline_model_parallel_size=1
+        )
+        model_parallel_cuda_manual_seed(123)
         self.batch_size = 4
         self.hidden_size = 12
         self.vocab_size = 100
         self.sequence_length = 64
-        transformer_config = TransformerConfig(num_layers=4, hidden_size=self.hidden_size, num_attention_heads=4, use_cpu_initialization=True)
-                                                    
+        transformer_config = TransformerConfig(
+            num_layers=4,
+            hidden_size=self.hidden_size,
+            num_attention_heads=4,
+            use_cpu_initialization=True,
+        )
+
         gpt_model = GPTModel(
-            config=transformer_config, 
-            transformer_layer_spec=get_gpt_layer_local_spec(), 
-            vocab_size=self.vocab_size, 
-            max_sequence_length=self.sequence_length, 
-            parallel_output = True).cuda()
+            config=transformer_config,
+            transformer_layer_spec=get_gpt_layer_local_spec(),
+            vocab_size=self.vocab_size,
+            max_sequence_length=self.sequence_length,
+            parallel_output=True,
+        ).cuda()
 
         inference_wrapper_config = InferenceWrapperConfig(
             hidden_size=self.hidden_size,
             inference_batch_times_seqlen_threshold=400,
             fp32_residual_connection=False,
             params_dtype=torch.float,
-            padded_vocab_size=self.vocab_size
+            padded_vocab_size=self.vocab_size,
         )
 
         inference_wrapped_model = GPTInferenceWrapper(gpt_model, inference_wrapper_config)
         self.mock_tokenizer = mock.Mock()
-        text_generation_controller = SimpleTextGenerationController(inference_wrapped_model=inference_wrapped_model, tokenizer=self.mock_tokenizer)       
+        text_generation_controller = SimpleTextGenerationController(
+            inference_wrapped_model=inference_wrapped_model, tokenizer=self.mock_tokenizer
+        )
+
+        self.mcore_engine = MCoreEngine(
+            text_generation_controller=text_generation_controller, max_batch_size=4
+        )
 
-        self.mcore_engine = MCoreEngine(text_generation_controller=text_generation_controller, max_batch_size=4)
-        
     def teardown_method(self, method):
         Utils.destroy_model_parallel()
 
@@ -54,14 +74,22 @@ def test_generate(self):
         self.mock_tokenizer.vocab_size = self.vocab_size
         self.mock_tokenizer.eod = self.vocab_size - 1
         # Generating random length integer prompts
-        self.mock_tokenizer.tokenize.return_value = [random.randint(0, self.vocab_size -1) for _ in range(random.randint(5,10))]
+        self.mock_tokenizer.tokenize.return_value = [
+            random.randint(0, self.vocab_size - 1) for _ in range(random.randint(5, 10))
+        ]
         # Generates some random string
-        self.mock_tokenizer.detokenize.return_value = ''.join(random.choices(string.ascii_letters, k=random.randint(4,10)))
+        self.mock_tokenizer.detokenize.return_value = ''.join(
+            random.choices(string.ascii_letters, k=random.randint(4, 10))
+        )
 
-        prompts = ["sample"*(i+1) for i in range(self.batch_size)]
-        results : List[InferenceRequest] = self.mcore_engine.generate(prompts, common_inference_params=CommonInferenceParams(num_tokens_to_generate=10))
+        prompts = ["sample" * (i + 1) for i in range(self.batch_size)]
+        results: List[InferenceRequest] = self.mcore_engine.generate(
+            prompts, common_inference_params=CommonInferenceParams(num_tokens_to_generate=10)
+        )
 
         for result in results:
-            assert result.status == Status.COMPLETED, f"Status should be completed but its {result.status}"
-            assert result.generated_length > 0 , f"Generated length should be greater than zero"
-            assert result.generated_text is not None , f'Generated text should not be None'
+            assert (
+                result.status == Status.COMPLETED
+            ), f"Status should be completed but its {result.status}"
+            assert result.generated_length > 0, f"Generated length should be greater than zero"
+            assert result.generated_text is not None, f'Generated text should not be None'
diff --git a/tests/unit_tests/inference/model_inference_wrappers/gpt/test_gpt_inference_wrapper.py b/tests/unit_tests/inference/model_inference_wrappers/gpt/test_gpt_inference_wrapper.py
index 1f7fb478a3..e01c3f4d17 100644
--- a/tests/unit_tests/inference/model_inference_wrappers/gpt/test_gpt_inference_wrapper.py
+++ b/tests/unit_tests/inference/model_inference_wrappers/gpt/test_gpt_inference_wrapper.py
@@ -1,83 +1,124 @@
 from argparse import Namespace
-from megatron.core import parallel_state
-from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig
+
 import torch
-from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper
-from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec, get_gpt_layer_with_transformer_engine_spec
-from megatron.core.transformer.transformer_config import TransformerConfig
+
+from megatron.core import parallel_state
+from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
+    GPTInferenceWrapper,
+)
+from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
+    InferenceWrapperConfig,
+)
+from megatron.core.models.gpt.gpt_layer_specs import (
+    get_gpt_layer_local_spec,
+    get_gpt_layer_with_transformer_engine_spec,
+)
 from megatron.core.models.gpt.gpt_model import GPTModel
-from tests.unit_tests.test_utilities import Utils
 from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
+from megatron.core.transformer.transformer_config import TransformerConfig
+from tests.unit_tests.test_utilities import Utils
+
 
 class TestGPTInferenceWrapper:
 
     def setup_model(self, tensor_parallel_size, pipeline_parallel_size):
-        Utils.initialize_model_parallel(tensor_model_parallel_size=tensor_parallel_size,pipeline_model_parallel_size=pipeline_parallel_size)
+        Utils.initialize_model_parallel(
+            tensor_model_parallel_size=tensor_parallel_size,
+            pipeline_model_parallel_size=pipeline_parallel_size,
+        )
         model_parallel_cuda_manual_seed(123)
         self.vocab_size = 100
         self.batch_size = 4
         self.sequence_length = 32
         hidden_size = 12
 
-        transformer_config = TransformerConfig(num_layers=4, hidden_size=hidden_size, num_attention_heads=4, use_cpu_initialization=True)
-                                                    
+        transformer_config = TransformerConfig(
+            num_layers=4,
+            hidden_size=hidden_size,
+            num_attention_heads=4,
+            use_cpu_initialization=True,
+        )
+
         gpt_model = GPTModel(
-            config=transformer_config, 
-            transformer_layer_spec=get_gpt_layer_local_spec(), 
-            vocab_size=self.vocab_size, 
-            max_sequence_length=self.sequence_length, 
-            parallel_output = True).cuda()
+            config=transformer_config,
+            transformer_layer_spec=get_gpt_layer_local_spec(),
+            vocab_size=self.vocab_size,
+            max_sequence_length=self.sequence_length,
+            parallel_output=True,
+        ).cuda()
 
         inference_wrapper_config = InferenceWrapperConfig(
             hidden_size=hidden_size,
             inference_batch_times_seqlen_threshold=20,
             fp32_residual_connection=False,
             params_dtype=torch.float,
-            padded_vocab_size=self.vocab_size
+            padded_vocab_size=self.vocab_size,
         )
 
         self.inference_wrapped_model = GPTInferenceWrapper(gpt_model, inference_wrapper_config)
+
     def teardown_method(self, method):
         Utils.destroy_model_parallel()
-        
-    # This will call the inference_wrapped_model.forward_pass_with_pipeline_parallel_small_input_batch()    
+
+    # This will call the inference_wrapped_model.forward_pass_with_pipeline_parallel_small_input_batch()
     def test_inference_pipeline_parallel_small_size(self):
         self.setup_model(tensor_parallel_size=2, pipeline_parallel_size=2)
-        
-        batch_prompt_tokens = torch.randint(low = 0, high = self.vocab_size, size=(self.batch_size, self.sequence_length)).int().cuda()
+
+        batch_prompt_tokens = (
+            torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.sequence_length))
+            .int()
+            .cuda()
+        )
         self.inference_wrapped_model.prep_model_for_inference(prompts_tokens=batch_prompt_tokens)
- 
+
         inference_input = self.inference_wrapped_model.get_batch_for_context_window(0, 5)
-        
+
         logits = self.inference_wrapped_model.run_one_forward_step(inference_input)
         # Logits are not returned in all ranks in PP
         if parallel_state.is_pipeline_last_stage():
-            assert logits.shape == (self.batch_size, 5, self.vocab_size), f"Shape mismatch . Expected {(self.batch_size, 5, self.vocab_size)}, but got {logits.shape}"
- 
+            assert logits.shape == (
+                self.batch_size,
+                5,
+                self.vocab_size,
+            ), f"Shape mismatch . Expected {(self.batch_size, 5, self.vocab_size)}, but got {logits.shape}"
 
     # This will call the inference_wrapped_model.forward_pass_with_pipeline_parallel_large_input_batch()
     def test_inference_pipeline_parallel_large__size(self):
         self.setup_model(tensor_parallel_size=2, pipeline_parallel_size=2)
-        
-        batch_prompt_tokens = torch.randint(low = 0, high = self.vocab_size, size=(self.batch_size, self.sequence_length)).int().cuda()
+
+        batch_prompt_tokens = (
+            torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.sequence_length))
+            .int()
+            .cuda()
+        )
         self.inference_wrapped_model.prep_model_for_inference(prompts_tokens=batch_prompt_tokens)
 
         inference_input = self.inference_wrapped_model.get_batch_for_context_window(0, 10)
-        
+
         logits = self.inference_wrapped_model.run_one_forward_step(inference_input)
 
         if parallel_state.is_pipeline_last_stage():
-            assert logits.shape == (self.batch_size, 10, self.vocab_size), f"Shape mismatch . Expected {(self.batch_size,10, self.vocab_size)}, but got {logits.shape}"
-   
+            assert logits.shape == (
+                self.batch_size,
+                10,
+                self.vocab_size,
+            ), f"Shape mismatch . Expected {(self.batch_size,10, self.vocab_size)}, but got {logits.shape}"
 
     def test_inference_only_tensor_parallel(self):
         self.setup_model(tensor_parallel_size=4, pipeline_parallel_size=1)
-    
-        batch_prompt_tokens = torch.randint(low = 0, high = self.vocab_size, size=(self.batch_size, self.sequence_length)).int().cuda()
+
+        batch_prompt_tokens = (
+            torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.sequence_length))
+            .int()
+            .cuda()
+        )
         self.inference_wrapped_model.prep_model_for_inference(prompts_tokens=batch_prompt_tokens)
 
         inference_input = self.inference_wrapped_model.get_batch_for_context_window(0, 5)
         logits = self.inference_wrapped_model.run_one_forward_step(inference_input)
-        
-        assert logits.shape == (self.batch_size, 5, self.vocab_size), f"Shape mismatch . Expected {(self.batch_size, 5, self.vocab_size)}, but got {logits.shape}"
 
+        assert logits.shape == (
+            self.batch_size,
+            5,
+            self.vocab_size,
+        ), f"Shape mismatch . Expected {(self.batch_size, 5, self.vocab_size)}, but got {logits.shape}"
diff --git a/tests/unit_tests/inference/model_inference_wrappers/test_model_inference_wrapper_config.py b/tests/unit_tests/inference/model_inference_wrappers/test_model_inference_wrapper_config.py
index 5c6f4229c0..e3da997cd4 100644
--- a/tests/unit_tests/inference/model_inference_wrappers/test_model_inference_wrapper_config.py
+++ b/tests/unit_tests/inference/model_inference_wrappers/test_model_inference_wrapper_config.py
@@ -1,5 +1,9 @@
 import torch
-from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig
+
+from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
+    InferenceWrapperConfig,
+)
+
 
 class TestModelInferenceWrapperConfig:
 
@@ -9,7 +13,9 @@ def test_inference_params(self):
             inference_batch_times_seqlen_threshold=10,
             padded_vocab_size=10,
             params_dtype=torch.float,
-            fp32_residual_connection=False
+            fp32_residual_connection=False,
         )
         inference_parameters.add_attributes({"abc": 45})
-        assert inference_parameters.abc == 45, f"min tokens not set correctly. it is {inference_parameters.min_tokens}"
\ No newline at end of file
+        assert (
+            inference_parameters.abc == 45
+        ), f"min tokens not set correctly. it is {inference_parameters.min_tokens}"
diff --git a/tests/unit_tests/inference/test_common_inference_params.py b/tests/unit_tests/inference/test_common_inference_params.py
index c22a72d326..af51e433df 100644
--- a/tests/unit_tests/inference/test_common_inference_params.py
+++ b/tests/unit_tests/inference/test_common_inference_params.py
@@ -1,8 +1,11 @@
 from megatron.core.inference.common_inference_params import CommonInferenceParams
 
+
 class TestCommonInferenceParams:
 
     def test_inference_params(self):
         inference_parameters = CommonInferenceParams()
         inference_parameters.add_attributes({"min_tokens": 45})
-        assert inference_parameters.min_tokens == 45, f"min tokens not set correctly. it is {inference_parameters.min_tokens}"
\ No newline at end of file
+        assert (
+            inference_parameters.min_tokens == 45
+        ), f"min tokens not set correctly. it is {inference_parameters.min_tokens}"
diff --git a/tests/unit_tests/inference/test_inference_utils.py b/tests/unit_tests/inference/test_inference_utils.py
index 7f0061963e..fc4e69018d 100644
--- a/tests/unit_tests/inference/test_inference_utils.py
+++ b/tests/unit_tests/inference/test_inference_utils.py
@@ -1,5 +1,6 @@
 from megatron.core.inference.utils import Counter
 
+
 class TestInferenceUtils:
 
     def test_counter(self):
diff --git a/tests/unit_tests/inference/test_modelopt_gpt_model.py b/tests/unit_tests/inference/test_modelopt_gpt_model.py
index 953052c732..380ac7fa16 100644
--- a/tests/unit_tests/inference/test_modelopt_gpt_model.py
+++ b/tests/unit_tests/inference/test_modelopt_gpt_model.py
@@ -7,7 +7,6 @@
 from megatron.core.models.gpt.gpt_model import GPTModel
 from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
 from megatron.core.transformer.transformer_config import TransformerConfig
-
 from tests.unit_tests.test_utilities import Utils
 
 
@@ -17,10 +16,7 @@ def setup_method(self, method):
         Utils.initialize_model_parallel(1, 1)
         model_parallel_cuda_manual_seed(123)
         transformer_config = TransformerConfig(
-            num_layers=2,
-            hidden_size=12,
-            num_attention_heads=4,
-            use_cpu_initialization=True,
+            num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True
         )
         self.gpt_model = GPTModel(
             config=transformer_config,
diff --git a/tests/unit_tests/inference/test_scheduler.py b/tests/unit_tests/inference/test_scheduler.py
index 57e08106d3..b1f0ea184e 100644
--- a/tests/unit_tests/inference/test_scheduler.py
+++ b/tests/unit_tests/inference/test_scheduler.py
@@ -1,17 +1,26 @@
 from typing import Dict
+
 import torch
+
 from megatron.core.inference.common_inference_params import CommonInferenceParams
 from megatron.core.inference.inference_request import InferenceRequest, Status
 from megatron.core.inference.scheduler import Scheduler
 
+
 class TestScheduler:
 
     def setup_method(self, method):
         self.max_batch_size = 4
         self.scheduler = Scheduler(max_batch_size=self.max_batch_size)
-        assert len(self.scheduler.active_request_pool) == 0, "Active request pool should be empty on initalization"
-        assert len(self.scheduler.waiting_request_pool) == 0, "Waiting request pool should be empty on initalization"
-        assert len(self.scheduler.completed_request_pool) == 0, "Completed request pool should be empty on initalization"
+        assert (
+            len(self.scheduler.active_request_pool) == 0
+        ), "Active request pool should be empty on initalization"
+        assert (
+            len(self.scheduler.waiting_request_pool) == 0
+        ), "Waiting request pool should be empty on initalization"
+        assert (
+            len(self.scheduler.completed_request_pool) == 0
+        ), "Completed request pool should be empty on initalization"
 
     def test_scheduler(self):
         prompt = "sample prompt"
@@ -20,15 +29,23 @@ def test_scheduler(self):
 
         for i in range(self.max_batch_size):
             self.scheduler.add_request(prompt, prompt_tokens, inference_parameters)
-            assert len(self.scheduler.active_request_pool) == i + 1, f"Active request pool should have {i+1} requests, but it has only {len(self.scheduler.active_request_pool)}"
+            assert (
+                len(self.scheduler.active_request_pool) == i + 1
+            ), f"Active request pool should have {i+1} requests, but it has only {len(self.scheduler.active_request_pool)}"
 
         self.scheduler.add_request(prompt, prompt_tokens, inference_parameters)
-        assert len(self.scheduler.waiting_request_pool) == 1, f"Waiting request pool should have 1 request but it has {len(self.scheduler.waiting_request_pool)} requests"
-        
+        assert (
+            len(self.scheduler.waiting_request_pool) == 1
+        ), f"Waiting request pool should have 1 request but it has {len(self.scheduler.waiting_request_pool)} requests"
+
         waiting_request: InferenceRequest = list(self.scheduler.waiting_request_pool.values())[0]
-        assert waiting_request.status == Status.WAITING_IN_QUEUE, f"Status should be WAITING_IN_QUEUE, but its {waiting_request.status} for the waiting request"
+        assert (
+            waiting_request.status == Status.WAITING_IN_QUEUE
+        ), f"Status should be WAITING_IN_QUEUE, but its {waiting_request.status} for the waiting request"
 
-        assert self.scheduler.have_requests_pending(), "Scheduler should have requests pending, but it seems to be having no requests"
+        assert (
+            self.scheduler.have_requests_pending()
+        ), "Scheduler should have requests pending, but it seems to be having no requests"
 
         active_request_dict: Dict[int, InferenceRequest] = self.scheduler.active_request_pool
         for request_id, request in active_request_dict.items():
@@ -37,11 +54,17 @@ def test_scheduler(self):
                 request.status = Status.COMPLETED
 
         self.scheduler.update_requests_pools(active_request_dict)
-        assert len(self.scheduler.active_request_pool) == 3, f"Active request pool should have 3 requests, but it has {len(self.scheduler.active_request_pool)}"
+        assert (
+            len(self.scheduler.active_request_pool) == 3
+        ), f"Active request pool should have 3 requests, but it has {len(self.scheduler.active_request_pool)}"
 
-        assert len(self.scheduler.waiting_request_pool) == 0, f"Waiting request pool should be empty but it has {len(self.scheduler.waiting_request_pool)} requests"
+        assert (
+            len(self.scheduler.waiting_request_pool) == 0
+        ), f"Waiting request pool should be empty but it has {len(self.scheduler.waiting_request_pool)} requests"
 
-        assert len(self.scheduler.completed_request_pool) == 2, f"Completed request pool should have 2 requests but it has {len(self.scheduler.completed_request_pool)} requests "
+        assert (
+            len(self.scheduler.completed_request_pool) == 2
+        ), f"Completed request pool should have 2 requests but it has {len(self.scheduler.completed_request_pool)} requests "
 
         active_request_dict: Dict[int, InferenceRequest] = self.scheduler.active_request_pool
         for request_id, request in active_request_dict.items():
@@ -49,15 +72,18 @@ def test_scheduler(self):
             request.status = Status.COMPLETED
 
         self.scheduler.update_requests_pools(active_request_dict)
-        assert len(self.scheduler.active_request_pool) == 0, f"Active request pool should be empty, but it has {len(self.scheduler.active_request_pool)}"
-
-        assert len(self.scheduler.waiting_request_pool) == 0, f"Waiting request pool should be empty but it has {len(self.scheduler.waiting_request_pool)} requests"
-
-        assert len(self.scheduler.completed_request_pool) == 5, f"Completed request pool should have 5 requests but it has {len(self.scheduler.completed_request_pool)} requests "
-
-        assert self.scheduler.have_requests_pending() == False, "Scheduler should not have any requests pending"
+        assert (
+            len(self.scheduler.active_request_pool) == 0
+        ), f"Active request pool should be empty, but it has {len(self.scheduler.active_request_pool)}"
 
+        assert (
+            len(self.scheduler.waiting_request_pool) == 0
+        ), f"Waiting request pool should be empty but it has {len(self.scheduler.waiting_request_pool)} requests"
 
+        assert (
+            len(self.scheduler.completed_request_pool) == 5
+        ), f"Completed request pool should have 5 requests but it has {len(self.scheduler.completed_request_pool)} requests "
 
-        
-    
\ No newline at end of file
+        assert (
+            self.scheduler.have_requests_pending() == False
+        ), "Scheduler should not have any requests pending"
diff --git a/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py
index 35b820edd6..a9f15faf80 100644
--- a/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py
+++ b/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py
@@ -1,118 +1,172 @@
-
+import random
+import string
+import time
 from collections import OrderedDict
 from typing import Dict
-from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig
+from unittest import mock
+
+import pytest
 import torch
-import random
-import string 
+
 from megatron.core.inference.common_inference_params import CommonInferenceParams
-from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper
 from megatron.core.inference.inference_request import InferenceRequest, Status
-from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import SimpleTextGenerationController
+from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
+    GPTInferenceWrapper,
+)
+from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
+    InferenceWrapperConfig,
+)
+from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import (
+    SimpleTextGenerationController,
+)
 from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
 from megatron.core.models.gpt.gpt_model import GPTModel
 from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
 from megatron.core.transformer.transformer_config import TransformerConfig
-from unittest import mock
-import pytest
-import time
+from tests.unit_tests.test_utilities import Utils
 
-from tests.unit_tests.test_utilities import Utils 
 
 class TestTextGenerationController:
 
     def setup_method(self, method):
-        Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=2)
-        model_parallel_cuda_manual_seed(123)        
+        Utils.initialize_model_parallel(
+            tensor_model_parallel_size=2, pipeline_model_parallel_size=2
+        )
+        model_parallel_cuda_manual_seed(123)
         self.batch_size = 4
         self.hidden_size = 12
         self.vocab_size = 100
         self.sequence_length = 64
-        transformer_config = TransformerConfig(num_layers=4, hidden_size=self.hidden_size, num_attention_heads=4, use_cpu_initialization=True)
-                                                    
+        transformer_config = TransformerConfig(
+            num_layers=4,
+            hidden_size=self.hidden_size,
+            num_attention_heads=4,
+            use_cpu_initialization=True,
+        )
+
         gpt_model = GPTModel(
-            config=transformer_config, 
-            transformer_layer_spec=get_gpt_layer_local_spec(), 
-            vocab_size=self.vocab_size, 
-            max_sequence_length=self.sequence_length, 
-            parallel_output = True).cuda()
-        
+            config=transformer_config,
+            transformer_layer_spec=get_gpt_layer_local_spec(),
+            vocab_size=self.vocab_size,
+            max_sequence_length=self.sequence_length,
+            parallel_output=True,
+        ).cuda()
+
         inference_wrapper_config = InferenceWrapperConfig(
             hidden_size=self.hidden_size,
             inference_batch_times_seqlen_threshold=20,
             fp32_residual_connection=False,
             params_dtype=torch.float,
-            padded_vocab_size=self.vocab_size
+            padded_vocab_size=self.vocab_size,
         )
 
         inference_wrapped_model = GPTInferenceWrapper(gpt_model, inference_wrapper_config)
 
         self.mock_tokenizer = mock.Mock()
 
-        self.text_generation_controller = SimpleTextGenerationController(inference_wrapped_model=inference_wrapped_model, tokenizer=self.mock_tokenizer)
-    
+        self.text_generation_controller = SimpleTextGenerationController(
+            inference_wrapped_model=inference_wrapped_model, tokenizer=self.mock_tokenizer
+        )
+
     def teardown_method(self, method):
         Utils.destroy_model_parallel()
 
     def test_sample_from_logits(self):
         with pytest.raises(AssertionError) as aerror:
-            self.text_generation_controller.sample_from_logits(last_token_logits=None, common_inference_params=CommonInferenceParams(top_k=2, top_p=0.4), vocab_size=self.vocab_size )
+            self.text_generation_controller.sample_from_logits(
+                last_token_logits=None,
+                common_inference_params=CommonInferenceParams(top_k=2, top_p=0.4),
+                vocab_size=self.vocab_size,
+            )
         assert str(aerror.value) == 'Cannot have top-p and top-k both greater than zero'
 
         with pytest.raises(AssertionError) as aerror:
-            self.text_generation_controller.sample_from_logits(last_token_logits=None, common_inference_params=CommonInferenceParams(top_p=1.4, top_k=0), vocab_size=self.vocab_size )
+            self.text_generation_controller.sample_from_logits(
+                last_token_logits=None,
+                common_inference_params=CommonInferenceParams(top_p=1.4, top_k=0),
+                vocab_size=self.vocab_size,
+            )
         assert str(aerror.value) == 'top-p should be in (0,1]'
 
         with pytest.raises(AssertionError) as aerror:
-            self.text_generation_controller.sample_from_logits(last_token_logits=torch.randn(self.batch_size, 1), common_inference_params=CommonInferenceParams(top_k = self.vocab_size + 10), vocab_size=self.vocab_size)
+            self.text_generation_controller.sample_from_logits(
+                last_token_logits=torch.randn(self.batch_size, 1),
+                common_inference_params=CommonInferenceParams(top_k=self.vocab_size + 10),
+                vocab_size=self.vocab_size,
+            )
         assert str(aerror.value) == 'top-k is larger than logit size.'
 
-    
-        last_token_logits = torch.arange(0, self.vocab_size).repeat(self.batch_size,1).float().cuda()
-        sampled_logits = self.text_generation_controller.sample_from_logits(last_token_logits, CommonInferenceParams(top_k=1), self.vocab_size)
-        assert torch.all(sampled_logits.cpu() == torch.ones(self.batch_size) * self.vocab_size - 1), f"The sampled logits should all be {self.vocab_size} but its {sampled_logits}"
+        last_token_logits = (
+            torch.arange(0, self.vocab_size).repeat(self.batch_size, 1).float().cuda()
+        )
+        sampled_logits = self.text_generation_controller.sample_from_logits(
+            last_token_logits, CommonInferenceParams(top_k=1), self.vocab_size
+        )
+        assert torch.all(
+            sampled_logits.cpu() == torch.ones(self.batch_size) * self.vocab_size - 1
+        ), f"The sampled logits should all be {self.vocab_size} but its {sampled_logits}"
 
-        sampled_logits = self.text_generation_controller.sample_from_logits(last_token_logits, CommonInferenceParams(top_k=2), self.vocab_size)
-        assert torch.all(sampled_logits >= self.vocab_size - 2), f"The sampled logits should all be greater than {self.vocab_size-2} but its {sampled_logits}"
+        sampled_logits = self.text_generation_controller.sample_from_logits(
+            last_token_logits, CommonInferenceParams(top_k=2), self.vocab_size
+        )
+        assert torch.all(
+            sampled_logits >= self.vocab_size - 2
+        ), f"The sampled logits should all be greater than {self.vocab_size-2} but its {sampled_logits}"
 
         l = last_token_logits[0]
         top_p = 0.3
         expected_min_value = l[l.softmax(dim=-1).cumsum(dim=-1) > top_p][0].item()
-        sampled_logits = self.text_generation_controller.sample_from_logits(last_token_logits, CommonInferenceParams(top_p=top_p, top_k=0), self.vocab_size)
-        assert torch.all(sampled_logits >= expected_min_value), f"The sampled logits should all be greater than {expected_min_value} but its {sampled_logits}"
+        sampled_logits = self.text_generation_controller.sample_from_logits(
+            last_token_logits, CommonInferenceParams(top_p=top_p, top_k=0), self.vocab_size
+        )
+        assert torch.all(
+            sampled_logits >= expected_min_value
+        ), f"The sampled logits should all be greater than {expected_min_value} but its {sampled_logits}"
 
         top_p = 0.95
-        temperature=2
+        temperature = 2
         expected_min_value = l[l.div_(temperature).softmax(dim=-1).cumsum(dim=-1) > top_p][0].item()
-        sampled_logits = self.text_generation_controller.sample_from_logits(last_token_logits, CommonInferenceParams(top_p=top_p, temperature=temperature, top_k=0), self.vocab_size)
-        assert torch.all(sampled_logits >= expected_min_value), f"The sampled logits should all be greater than {expected_min_value} but its {sampled_logits}"
- 
+        sampled_logits = self.text_generation_controller.sample_from_logits(
+            last_token_logits,
+            CommonInferenceParams(top_p=top_p, temperature=temperature, top_k=0),
+            self.vocab_size,
+        )
+        assert torch.all(
+            sampled_logits >= expected_min_value
+        ), f"The sampled logits should all be greater than {expected_min_value} but its {sampled_logits}"
+
     def test_generate_all_output_tokens_static_batch(self):
         self.mock_tokenizer.vocab_size = self.vocab_size
         self.mock_tokenizer.eod = self.vocab_size - 1
-        self.mock_tokenizer.detokenize.return_value = ''.join(random.choices(string.ascii_letters, k=random.randint(4,10)))
+        self.mock_tokenizer.detokenize.return_value = ''.join(
+            random.choices(string.ascii_letters, k=random.randint(4, 10))
+        )
 
         active_requests: Dict[int, InferenceRequest] = OrderedDict()
         for i in range(self.batch_size):
-            prompt = "sample" * (i+1)
-            self.mock_tokenizer.tokenize.return_value = torch.randn(self.batch_size, self.vocab_size).cuda()   
+            prompt = "sample" * (i + 1)
+            self.mock_tokenizer.tokenize.return_value = torch.randn(
+                self.batch_size, self.vocab_size
+            ).cuda()
             inference_request = InferenceRequest(
                 request_id=i,
                 prompt=prompt,
                 inference_parameters=CommonInferenceParams(num_tokens_to_generate=10),
                 arrival_time=time.time(),
-                prompt_tokens=torch.randint(low=0, high=self.vocab_size - 1, size=(len(prompt),)).tolist(),
-                status=Status.ACTIVE_BUT_NOT_GENERATING_TOKENS
+                prompt_tokens=torch.randint(
+                    low=0, high=self.vocab_size - 1, size=(len(prompt),)
+                ).tolist(),
+                status=Status.ACTIVE_BUT_NOT_GENERATING_TOKENS,
             )
             active_requests[i] = inference_request
 
-        requests = self.text_generation_controller.generate_all_output_tokens_static_batch(active_requests)
-        
+        requests = self.text_generation_controller.generate_all_output_tokens_static_batch(
+            active_requests
+        )
+
         for request_id, request in requests.items():
-            assert request.status == Status.COMPLETED, f"Status should be completed but its {request.status}"
-            assert request.generated_length > 0 , f"Generated length should be greater than zero"
+            assert (
+                request.status == Status.COMPLETED
+            ), f"Status should be completed but its {request.status}"
+            assert request.generated_length > 0, f"Generated length should be greater than zero"
             assert request.generated_text is not None, "Generated text should not be None"
-
-
-        
-    
\ No newline at end of file
diff --git a/tests/unit_tests/models/test_base_embedding.py b/tests/unit_tests/models/test_base_embedding.py
index 511b0262fa..0ce18b3843 100644
--- a/tests/unit_tests/models/test_base_embedding.py
+++ b/tests/unit_tests/models/test_base_embedding.py
@@ -1,11 +1,10 @@
 # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
 
 import pytest
-
 import torch
 
-from megatron.core.transformer.transformer_config import TransformerConfig
 from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
+from megatron.core.transformer.transformer_config import TransformerConfig
 from tests.unit_tests.test_utilities import Utils
 
 
@@ -14,17 +13,21 @@ class TestBaseEmbedding:
     def setup_method(self, method):
         Utils.initialize_model_parallel(1, 1)
         transformer_config = TransformerConfig(
-            num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True)
+            num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True
+        )
         self.base_embedding = LanguageModelEmbedding(
-            config=transformer_config, vocab_size=100, max_sequence_length=4, position_embedding_type='learned_absolute')
+            config=transformer_config,
+            vocab_size=100,
+            max_sequence_length=4,
+            position_embedding_type='learned_absolute',
+        )
 
     def teardown_method(self, method):
         Utils.destroy_model_parallel()
 
     def test_constructor(self):
         assert isinstance(self.base_embedding, LanguageModelEmbedding)
-        num_weights = sum([p.numel()
-                          for p in self.base_embedding.parameters()])
+        num_weights = sum([p.numel() for p in self.base_embedding.parameters()])
         assert num_weights == 1248
 
     def test_zero_parameters(self):
@@ -35,10 +38,8 @@ def test_zero_parameters(self):
         assert sum_weights == 0
 
     def test_cpu_forward(self):
-        input_ids = torch.tensor(
-            [0, 1, 2, 3], dtype=torch.int64).repeat((2, 1))
-        position_ids = torch.tensor(
-            [0, 1, 2, 3], dtype=torch.int64).repeat((2, 1))
+        input_ids = torch.tensor([0, 1, 2, 3], dtype=torch.int64).repeat((2, 1))
+        position_ids = torch.tensor([0, 1, 2, 3], dtype=torch.int64).repeat((2, 1))
         embeddings = self.base_embedding(input_ids, position_ids)
         assert embeddings.device.type == 'cpu'
         assert embeddings.shape[0] == self.base_embedding.max_sequence_length
@@ -47,10 +48,8 @@ def test_cpu_forward(self):
 
     def test_gpu_forward(self):
         self.base_embedding.cuda()
-        input_ids = torch.tensor(
-            [0, 1, 2, 3], dtype=torch.int64).repeat((2, 1)).cuda()
-        position_ids = torch.tensor(
-            [0, 1, 2, 3], dtype=torch.int64).repeat((2, 1)).cuda()
+        input_ids = torch.tensor([0, 1, 2, 3], dtype=torch.int64).repeat((2, 1)).cuda()
+        position_ids = torch.tensor([0, 1, 2, 3], dtype=torch.int64).repeat((2, 1)).cuda()
         embeddings = self.base_embedding(input_ids, position_ids)
         assert embeddings.device.type == 'cuda'
         assert embeddings.shape[0] == self.base_embedding.max_sequence_length
diff --git a/tests/unit_tests/models/test_bert_model.py b/tests/unit_tests/models/test_bert_model.py
index f6722f66a3..b1b544698b 100644
--- a/tests/unit_tests/models/test_bert_model.py
+++ b/tests/unit_tests/models/test_bert_model.py
@@ -1,33 +1,45 @@
 # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
 
-import pytest
+import os
 
+import pytest
 import torch
-import os
 from pkg_resources import packaging
-from megatron.core.transformer.transformer_config import TransformerConfig
+from pytest_mock import mocker
+
+from megatron.core.models.bert.bert_layer_specs import bert_layer_with_transformer_engine_spec
 from megatron.core.models.bert.bert_model import BertModel
-from tests.unit_tests.test_utilities import Utils
 from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
-from megatron.core.models.bert.bert_layer_specs import bert_layer_with_transformer_engine_spec
-from pytest_mock import mocker 
+from megatron.core.transformer.transformer_config import TransformerConfig
+from tests.unit_tests.test_utilities import Utils
+
 
 class TestBertModel:
 
     def setup_method(self, method):
-        os.environ['NVTE_ALLOW_NONDETERMINISTIC_ALGO'] = '0' #Bert does not support flash attention
+        os.environ['NVTE_ALLOW_NONDETERMINISTIC_ALGO'] = (
+            '0'  # Bert does not support flash attention
+        )
         tp = 1
         pp = 1
         Utils.initialize_model_parallel(tp, pp)
         model_parallel_cuda_manual_seed(123)
         transformer_config = TransformerConfig(
-            num_layers=2, hidden_size=12, num_attention_heads=4,
-            use_cpu_initialization=True, perform_initialization=True,
-            tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp, pipeline_dtype=torch.bfloat16
+            num_layers=2,
+            hidden_size=12,
+            num_attention_heads=4,
+            use_cpu_initialization=True,
+            perform_initialization=True,
+            tensor_model_parallel_size=tp,
+            pipeline_model_parallel_size=pp,
+            pipeline_dtype=torch.bfloat16,
         )
         self.bert_model = BertModel(
-            config=transformer_config, num_tokentypes=0,
-            transformer_layer_spec=bert_layer_with_transformer_engine_spec, vocab_size=100, max_sequence_length=4
+            config=transformer_config,
+            num_tokentypes=0,
+            transformer_layer_spec=bert_layer_with_transformer_engine_spec,
+            vocab_size=100,
+            max_sequence_length=4,
         )
 
     def teardown_method(self, method):
@@ -77,66 +89,105 @@ def test_post_process_forward(self):
 class TestBertModelAssertions:
 
     def test_te_assertions_te_less_than_1_7(self, mocker):
-        os.environ.pop('NVTE_ALLOW_NONDETERMINISTIC_ALGO',None)
-        os.environ.pop('NVTE_FLASH_ATTN',None)
-        os.environ.pop('NVTE_FUSED_ATTN',None)
+        os.environ.pop('NVTE_ALLOW_NONDETERMINISTIC_ALGO', None)
+        os.environ.pop('NVTE_FLASH_ATTN', None)
+        os.environ.pop('NVTE_FUSED_ATTN', None)
         tp = 1
         pp = 1
-        Utils.initialize_model_parallel(tp, pp)        
+        Utils.initialize_model_parallel(tp, pp)
         model_parallel_cuda_manual_seed(123)
         transformer_config = TransformerConfig(
-            num_layers=2, hidden_size=12, num_attention_heads=4,
-            use_cpu_initialization=True, perform_initialization=True,
-            tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp, pipeline_dtype=torch.bfloat16
+            num_layers=2,
+            hidden_size=12,
+            num_attention_heads=4,
+            use_cpu_initialization=True,
+            perform_initialization=True,
+            tensor_model_parallel_size=tp,
+            pipeline_model_parallel_size=pp,
+            pipeline_dtype=torch.bfloat16,
         )
 
         with pytest.raises(Exception) as exc_info:
-            mocker.patch("megatron.core.models.bert.bert_model.get_te_version", return_value = packaging.version.Version("1.4"))
+            mocker.patch(
+                "megatron.core.models.bert.bert_model.get_te_version",
+                return_value=packaging.version.Version("1.4"),
+            )
             self.bert_model = BertModel(
-                config=transformer_config, num_tokentypes=0,
-                transformer_layer_spec=bert_layer_with_transformer_engine_spec, vocab_size=100, max_sequence_length=4
+                config=transformer_config,
+                num_tokentypes=0,
+                transformer_layer_spec=bert_layer_with_transformer_engine_spec,
+                vocab_size=100,
+                max_sequence_length=4,
             )
-        assert str(exc_info.value) == "Flash and fused attention is not supported with transformer engine version < 1.7. Set NVTE_FLASH_ATTN=0 and NVTE_FUSED_ATTN=0 or upgrade transformer engine >= 1.7 or set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0"
+        assert (
+            str(exc_info.value)
+            == "Flash and fused attention is not supported with transformer engine version < 1.7. Set NVTE_FLASH_ATTN=0 and NVTE_FUSED_ATTN=0 or upgrade transformer engine >= 1.7 or set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0"
+        )
 
     def test_te_assertions_te_equal_to_1_7_exception(self, mocker):
-        os.environ.pop('NVTE_ALLOW_NONDETERMINISTIC_ALGO',None)
+        os.environ.pop('NVTE_ALLOW_NONDETERMINISTIC_ALGO', None)
         os.environ['NVTE_FLASH_ATTN'] = '0'
         os.environ['NVTE_FUSED_ATTN'] = '0'
         tp = 1
         pp = 1
-        Utils.initialize_model_parallel(tp, pp)        
+        Utils.initialize_model_parallel(tp, pp)
         model_parallel_cuda_manual_seed(123)
         transformer_config = TransformerConfig(
-            num_layers=2, hidden_size=12, num_attention_heads=4,
-            use_cpu_initialization=True, perform_initialization=True,
-            tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp, pipeline_dtype=torch.bfloat16
+            num_layers=2,
+            hidden_size=12,
+            num_attention_heads=4,
+            use_cpu_initialization=True,
+            perform_initialization=True,
+            tensor_model_parallel_size=tp,
+            pipeline_model_parallel_size=pp,
+            pipeline_dtype=torch.bfloat16,
         )
 
         with pytest.raises(Exception) as exc_info:
-            mocker.patch("megatron.core.models.bert.bert_model.get_te_version", return_value = packaging.version.Version("1.7"))
+            mocker.patch(
+                "megatron.core.models.bert.bert_model.get_te_version",
+                return_value=packaging.version.Version("1.7"),
+            )
             self.bert_model = BertModel(
-                config=transformer_config, num_tokentypes=0,
-                transformer_layer_spec=bert_layer_with_transformer_engine_spec, vocab_size=100, max_sequence_length=4
+                config=transformer_config,
+                num_tokentypes=0,
+                transformer_layer_spec=bert_layer_with_transformer_engine_spec,
+                vocab_size=100,
+                max_sequence_length=4,
             )
-        assert str(exc_info.value) == "Set env variable NVTE_FLASH_ATTN to 1 or NVTE_FUSED_ATTN to 1 to use a more optimized attention kernal. Currently using unfused attention path. If you want to proceed with this path set AttnMaskType in module spec to be arbitrary"
+        assert (
+            str(exc_info.value)
+            == "Set env variable NVTE_FLASH_ATTN to 1 or NVTE_FUSED_ATTN to 1 to use a more optimized attention kernal. Currently using unfused attention path. If you want to proceed with this path set AttnMaskType in module spec to be arbitrary"
+        )
 
     def test_te_assertions_te_equal_to_1_7_no_exception(self, mocker):
-        os.environ.pop('NVTE_ALLOW_NONDETERMINISTIC_ALGO',None)
-        os.environ.pop('NVTE_FLASH_ATTN',None)
-        os.environ.pop('NVTE_FUSED_ATTN',None)
+        os.environ.pop('NVTE_ALLOW_NONDETERMINISTIC_ALGO', None)
+        os.environ.pop('NVTE_FLASH_ATTN', None)
+        os.environ.pop('NVTE_FUSED_ATTN', None)
         tp = 1
         pp = 1
-        Utils.initialize_model_parallel(tp, pp)        
+        Utils.initialize_model_parallel(tp, pp)
         model_parallel_cuda_manual_seed(123)
         transformer_config = TransformerConfig(
-            num_layers=2, hidden_size=12, num_attention_heads=4,
-            use_cpu_initialization=True, perform_initialization=True,
-            tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp, pipeline_dtype=torch.bfloat16
+            num_layers=2,
+            hidden_size=12,
+            num_attention_heads=4,
+            use_cpu_initialization=True,
+            perform_initialization=True,
+            tensor_model_parallel_size=tp,
+            pipeline_model_parallel_size=pp,
+            pipeline_dtype=torch.bfloat16,
         )
 
-        mocker.patch("megatron.core.models.bert.bert_model.get_te_version", return_value = packaging.version.Version("1.7"))
+        mocker.patch(
+            "megatron.core.models.bert.bert_model.get_te_version",
+            return_value=packaging.version.Version("1.7"),
+        )
         self.bert_model = BertModel(
-            config=transformer_config, num_tokentypes=0,
-            transformer_layer_spec=bert_layer_with_transformer_engine_spec, vocab_size=100, max_sequence_length=4
+            config=transformer_config,
+            num_tokentypes=0,
+            transformer_layer_spec=bert_layer_with_transformer_engine_spec,
+            vocab_size=100,
+            max_sequence_length=4,
         )
-        Utils.destroy_model_parallel()
\ No newline at end of file
+        Utils.destroy_model_parallel()
diff --git a/tests/unit_tests/models/test_clip_vit_model.py b/tests/unit_tests/models/test_clip_vit_model.py
index bc29f943af..fcbf2ad440 100644
--- a/tests/unit_tests/models/test_clip_vit_model.py
+++ b/tests/unit_tests/models/test_clip_vit_model.py
@@ -16,12 +16,11 @@ def setup_method(self, method):
         Utils.initialize_model_parallel(1, 1)
         model_parallel_cuda_manual_seed(123)
         transformer_config = TransformerConfig(
-            num_layers=2, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True,
+            num_layers=2, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True
         )
         transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec()
         self.model = CLIPViTModel(
-            transformer_config, transformer_layer_spec,
-            img_h=336, img_w=336, patch_dim=14,
+            transformer_config, transformer_layer_spec, img_h=336, img_w=336, patch_dim=14
         )
 
     def teardown_method(self, method):
diff --git a/tests/unit_tests/models/test_llava_model.py b/tests/unit_tests/models/test_llava_model.py
index f5681fc154..c65f2d3b87 100644
--- a/tests/unit_tests/models/test_llava_model.py
+++ b/tests/unit_tests/models/test_llava_model.py
@@ -21,7 +21,7 @@ def setup_method(self, method):
             num_layers=3, hidden_size=128, num_attention_heads=8, use_cpu_initialization=True
         )
         vision_config = TransformerConfig(
-            num_layers=2, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True,
+            num_layers=2, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True
         )
         vision_projection_config = TransformerConfig(
             num_layers=2,
@@ -101,7 +101,7 @@ def test_forward(self):
         kv_dict = inference_params.key_value_memory_dict
 
         assert kv_dict["image_tokens_count"] == 577
-        for layer_no in range(1, 4):    # 3 layers in the model.
+        for layer_no in range(1, 4):  # 3 layers in the model.
             layer_kv = kv_dict[layer_no]
             # Expected shape is [sequence_len, batch_size, num_heads, hidden_size_per_head]
             assert layer_kv[0].shape == layer_kv[1].shape == torch.Size((1601, 2, 8, 16))
diff --git a/tests/unit_tests/models/test_mamba_model.py b/tests/unit_tests/models/test_mamba_model.py
index db9277f028..913adb538c 100644
--- a/tests/unit_tests/models/test_mamba_model.py
+++ b/tests/unit_tests/models/test_mamba_model.py
@@ -71,9 +71,7 @@ def test_forward(self):
         ).cuda()
 
         logits = self.model.forward(
-            input_ids=input_ids,
-            position_ids=position_ids,
-            attention_mask=attention_mask,
+            input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask
         )
 
         assert logits.shape[0] == micro_batch_size
diff --git a/tests/unit_tests/models/test_multimodal_projector.py b/tests/unit_tests/models/test_multimodal_projector.py
index f5ef29c6e8..976dc489da 100644
--- a/tests/unit_tests/models/test_multimodal_projector.py
+++ b/tests/unit_tests/models/test_multimodal_projector.py
@@ -1,32 +1,40 @@
 # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
 
 import pytest
-
 import torch
 
-from megatron.core.transformer.transformer_config import TransformerConfig
+from megatron.core.models.gpt.gpt_layer_specs import _get_mlp_module_spec
 from megatron.core.models.vision.multimodal_projector import MultimodalProjector
-from tests.unit_tests.test_utilities import Utils
+from megatron.core.tensor_parallel.layers import ColumnParallelLinear
 from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
-from megatron.core.models.gpt.gpt_layer_specs import _get_mlp_module_spec
 from megatron.core.transformer.mlp import MLPSubmodules
-from megatron.core.tensor_parallel.layers import ColumnParallelLinear
+from megatron.core.transformer.transformer_config import TransformerConfig
+from tests.unit_tests.test_utilities import Utils
 
 
 class TestMultimodalProjector:
 
     def setup_method(self, method):
-        Utils.initialize_model_parallel(1,1)
+        Utils.initialize_model_parallel(1, 1)
         model_parallel_cuda_manual_seed(123)
-        transformer_config = TransformerConfig(num_layers=1, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True)
+        transformer_config = TransformerConfig(
+            num_layers=1, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True
+        )
         mlp_layer_spec = _get_mlp_module_spec().submodules
-        
-        affine_layer_spec = MLPSubmodules(
-                linear_fc1=ColumnParallelLinear,
-                linear_fc2=None,
-            )
-        self.mlp = MultimodalProjector(config = transformer_config, submodules = mlp_layer_spec, projector_type = "mlp", input_size = 1024)
-        self.affine = MultimodalProjector(config = transformer_config, submodules = affine_layer_spec, projector_type = "affine", input_size = 1024)
+
+        affine_layer_spec = MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=None)
+        self.mlp = MultimodalProjector(
+            config=transformer_config,
+            submodules=mlp_layer_spec,
+            projector_type="mlp",
+            input_size=1024,
+        )
+        self.affine = MultimodalProjector(
+            config=transformer_config,
+            submodules=affine_layer_spec,
+            projector_type="affine",
+            input_size=1024,
+        )
 
     def teardown_method(self, method):
         Utils.destroy_model_parallel()
@@ -65,4 +73,3 @@ def test_save_load(self, tmp_path):
         torch.save(self.affine.state_dict(), path)
 
         self.affine.load_state_dict(torch.load(path))
-
diff --git a/tests/unit_tests/models/test_t5_model.py b/tests/unit_tests/models/test_t5_model.py
index 75d2286960..efe12b78f4 100644
--- a/tests/unit_tests/models/test_t5_model.py
+++ b/tests/unit_tests/models/test_t5_model.py
@@ -1,19 +1,22 @@
 # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
 
 from copy import deepcopy
-import pytest
 
+import pytest
 import torch
-import megatron.core.parallel_state as ps
 
-from megatron.core.transformer.transformer_config import TransformerConfig
+import megatron.core.parallel_state as ps
 from megatron.core.models.T5.t5_model import T5Model
-from tests.unit_tests.test_utilities import Utils
+from megatron.core.models.T5.t5_spec import (
+    get_t5_decoder_with_local_block_spec,
+    get_t5_decoder_with_transformer_engine_block_spec,
+    get_t5_encoder_with_local_block_spec,
+    get_t5_encoder_with_transformer_engine_block_spec,
+)
 from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
-from megatron.core.models.T5.t5_spec import (get_t5_encoder_with_transformer_engine_block_spec,
-                                            get_t5_decoder_with_transformer_engine_block_spec,
-                                            get_t5_encoder_with_local_block_spec,
-                                            get_t5_decoder_with_local_block_spec)
+from megatron.core.transformer.transformer_config import TransformerConfig
+from tests.unit_tests.test_utilities import Utils
+
 
 class TestT5Model:
 
@@ -27,9 +30,15 @@ def setup_method(self, method):
         )
         model_parallel_cuda_manual_seed(123)
         transformer_config = TransformerConfig(
-            num_layers=12, hidden_size=768, num_attention_heads=12, kv_channels=64, ffn_hidden_size=3072,
-            use_cpu_initialization=True, pipeline_dtype=torch.bfloat16,
-            tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp,
+            num_layers=12,
+            hidden_size=768,
+            num_attention_heads=12,
+            kv_channels=64,
+            ffn_hidden_size=3072,
+            use_cpu_initialization=True,
+            pipeline_dtype=torch.bfloat16,
+            tensor_model_parallel_size=tp,
+            pipeline_model_parallel_size=pp,
         )
         rank = ps.get_pipeline_model_parallel_rank()
         world_size = ps.get_pipeline_model_parallel_world_size()
@@ -38,15 +47,21 @@ def setup_method(self, method):
 
         first_decoder_rank = pp
         pre_process = rank == 0 or rank == first_decoder_rank
-        post_process = (rank == (first_decoder_rank - 1)) or (rank == (world_size-1))
+        post_process = (rank == (first_decoder_rank - 1)) or (rank == (world_size - 1))
         add_encoder = ps.is_inside_encoder(rank)
         add_decoder = ps.is_inside_decoder(rank)
 
         self.t5_model = T5Model(
-            encoder_config=transformer_config, config=transformer_config, transformer_encoder_layer_spec=en_block_spec,
-            transformer_decoder_layer_spec=de_block_spec,  vocab_size=29184, max_sequence_length=4,
-            pre_process=pre_process, post_process=post_process,
-            add_encoder=add_encoder, add_decoder=add_decoder,
+            encoder_config=transformer_config,
+            config=transformer_config,
+            transformer_encoder_layer_spec=en_block_spec,
+            transformer_decoder_layer_spec=de_block_spec,
+            vocab_size=29184,
+            max_sequence_length=4,
+            pre_process=pre_process,
+            post_process=post_process,
+            add_encoder=add_encoder,
+            add_decoder=add_decoder,
         )
 
     def teardown_method(self, method):
@@ -96,14 +111,22 @@ def test_post_process_forward(self):
         self.t5_model.cuda()
 
         data = list(range(sequence_length))
-        encoder_input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda()
-        decoder_input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda()
+        encoder_input_ids = (
+            torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda()
+        )
+        decoder_input_ids = (
+            torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda()
+        )
         encoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda()
         decoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda()
-        encoder_decoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda()
+        encoder_decoder_attn_mask = torch.ones(
+            (1, sequence_length, sequence_length), dtype=bool
+        ).cuda()
 
         if self.t5_model.add_decoder:
-            encoder_hidden_states = torch.zeros((sequence_length, micro_batch_size, config.hidden_size), dtype=torch.float32).cuda()
+            encoder_hidden_states = torch.zeros(
+                (sequence_length, micro_batch_size, config.hidden_size), dtype=torch.float32
+            ).cuda()
         else:
             encoder_hidden_states = None
 
@@ -113,20 +136,22 @@ def test_post_process_forward(self):
             encoder_attn_mask=encoder_attn_mask,
             decoder_attn_mask=decoder_attn_mask,
             encoder_decoder_attn_mask=encoder_decoder_attn_mask,
-            encoder_hidden_states=encoder_hidden_states
+            encoder_hidden_states=encoder_hidden_states,
         )
         if self.t5_model.add_decoder:
             logits = output
             assert logits.shape[0] == micro_batch_size
             assert logits.shape[1] == sequence_length
-            assert logits.shape[2] == self.t5_model.vocab_size // ps.get_tensor_model_parallel_world_size()
+            assert (
+                logits.shape[2]
+                == self.t5_model.vocab_size // ps.get_tensor_model_parallel_world_size()
+            )
         else:
             encoder_hidden_states = output
             assert encoder_hidden_states.shape[0] == sequence_length
             assert encoder_hidden_states.shape[1] == micro_batch_size
             assert encoder_hidden_states.shape[2] == config.hidden_size
 
-
     def test_forward_output_encoder_hidden_only(self):
         config: TransformerConfig = self.t5_model.config
         sequence_length = self.t5_model.max_sequence_length
@@ -135,11 +160,17 @@ def test_forward_output_encoder_hidden_only(self):
         self.t5_model.cuda()
 
         data = list(range(sequence_length))
-        encoder_input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda()
-        decoder_input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda()
+        encoder_input_ids = (
+            torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda()
+        )
+        decoder_input_ids = (
+            torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda()
+        )
         encoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda()
         decoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda()
-        encoder_decoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda()
+        encoder_decoder_attn_mask = torch.ones(
+            (1, sequence_length, sequence_length), dtype=bool
+        ).cuda()
 
         encoder_hidden_states = self.t5_model.forward(
             encoder_input_ids=encoder_input_ids,
@@ -147,7 +178,7 @@ def test_forward_output_encoder_hidden_only(self):
             encoder_attn_mask=encoder_attn_mask,
             decoder_attn_mask=decoder_attn_mask,
             encoder_decoder_attn_mask=encoder_decoder_attn_mask,
-            output_encoder_hidden_only=True
+            output_encoder_hidden_only=True,
         )
         if self.t5_model.add_decoder:
             assert encoder_hidden_states is None
@@ -164,12 +195,20 @@ def test_forward_with_encoder_hidden_states(self):
         self.t5_model.cuda()
 
         data = list(range(sequence_length))
-        encoder_input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda()
-        decoder_input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda()
+        encoder_input_ids = (
+            torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda()
+        )
+        decoder_input_ids = (
+            torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda()
+        )
         encoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda()
         decoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda()
-        encoder_decoder_attn_mask = torch.ones((1, sequence_length, sequence_length), dtype=bool).cuda()
-        encoder_hidden_states = torch.zeros((sequence_length, micro_batch_size, config.hidden_size), dtype=torch.float32).cuda()
+        encoder_decoder_attn_mask = torch.ones(
+            (1, sequence_length, sequence_length), dtype=bool
+        ).cuda()
+        encoder_hidden_states = torch.zeros(
+            (sequence_length, micro_batch_size, config.hidden_size), dtype=torch.float32
+        ).cuda()
 
         output = self.t5_model.forward(
             encoder_input_ids=None,
@@ -177,13 +216,16 @@ def test_forward_with_encoder_hidden_states(self):
             encoder_attn_mask=encoder_attn_mask,
             decoder_attn_mask=decoder_attn_mask,
             encoder_decoder_attn_mask=encoder_decoder_attn_mask,
-            encoder_hidden_states=encoder_hidden_states
+            encoder_hidden_states=encoder_hidden_states,
         )
         if self.t5_model.add_decoder:
             logits = output
             assert logits.shape[0] == micro_batch_size
             assert logits.shape[1] == sequence_length
-            assert logits.shape[2] == self.t5_model.vocab_size // ps.get_tensor_model_parallel_world_size()
+            assert (
+                logits.shape[2]
+                == self.t5_model.vocab_size // ps.get_tensor_model_parallel_world_size()
+            )
         else:
             encoder_hidden_states = output
             assert encoder_hidden_states.shape[0] == sequence_length
@@ -201,4 +243,3 @@ def test_state_dict_for_save_checkpoint(self):
 
     def test_load_state_dict(self):
         pass
-
diff --git a/tests/unit_tests/pipeline_parallel/test_schedules.py b/tests/unit_tests/pipeline_parallel/test_schedules.py
index 5dd6605d68..06994094fc 100644
--- a/tests/unit_tests/pipeline_parallel/test_schedules.py
+++ b/tests/unit_tests/pipeline_parallel/test_schedules.py
@@ -1,30 +1,51 @@
+import pytest
 import torch
-from tests.unit_tests.test_utilities import Utils
-from megatron.core import ModelParallelConfig
+from pytest_mock import mocker
+
 import megatron.core.pipeline_parallel.schedules as schedule
-from pytest_mock import mocker 
-import pytest
+from megatron.core import ModelParallelConfig
+from tests.unit_tests.test_utilities import Utils
 
 rank = Utils.rank
- 
+
+
 def test_get_forward_backward_func():
     Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=1)
-    assert(schedule.get_forward_backward_func() == schedule.forward_backward_no_pipelining)
+    assert schedule.get_forward_backward_func() == schedule.forward_backward_no_pipelining
     Utils.destroy_model_parallel()
     Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)
-    assert(schedule.get_forward_backward_func() == schedule.forward_backward_pipelining_without_interleaving)
+    assert (
+        schedule.get_forward_backward_func()
+        == schedule.forward_backward_pipelining_without_interleaving
+    )
     Utils.destroy_model_parallel()
-    Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4, virtual_pipeline_model_parallel_size=2)
-    assert(schedule.get_forward_backward_func() == schedule.forward_backward_pipelining_with_interleaving)
+    Utils.initialize_model_parallel(
+        tensor_model_parallel_size=2,
+        pipeline_model_parallel_size=4,
+        virtual_pipeline_model_parallel_size=2,
+    )
+    assert (
+        schedule.get_forward_backward_func()
+        == schedule.forward_backward_pipelining_with_interleaving
+    )
     Utils.destroy_model_parallel()
-    Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=2, virtual_pipeline_model_parallel_size=4)
-    assert(schedule.get_forward_backward_func() == schedule.forward_backward_pipelining_with_interleaving)
+    Utils.initialize_model_parallel(
+        tensor_model_parallel_size=2,
+        pipeline_model_parallel_size=2,
+        virtual_pipeline_model_parallel_size=4,
+    )
+    assert (
+        schedule.get_forward_backward_func()
+        == schedule.forward_backward_pipelining_with_interleaving
+    )
     Utils.destroy_model_parallel()
 
+
 def test_deallocate_output_tensor():
     out = torch.tensor([[1, 2, 3], [4, 5, 6]])
     schedule.deallocate_output_tensor(out)
-    assert(out.nelement() == 6) 
+    assert out.nelement() == 6
+
 
 def test_forward_backward_func_without_pipeline_parallel(mocker):
     from megatron.core.pipeline_parallel import get_forward_backward_func
@@ -33,43 +54,51 @@ def test_forward_backward_func_without_pipeline_parallel(mocker):
 
     def forward_step_func(data_iterator, model):
         import os
+
         rank = int(os.environ['LOCAL_RANK'])
-        dummy_data = torch.ones(1,4)
+        dummy_data = torch.ones(1, 4)
+
         def loss_func(output_tensor):
-            return rank, {'loss_reduced':rank}
+            return rank, {'loss_reduced': rank}
+
         return model(dummy_data), loss_func
 
-    model = torch.nn.Linear(4,1)
+    model = torch.nn.Linear(4, 1)
     model.model_type = 'unit-test'
+
     def set_input_tensor(input_tensor):
         return None
+
     model.set_input_tensor = set_input_tensor
 
     forward_backward_func = get_forward_backward_func()
-    assert(schedule.get_forward_backward_func() == schedule.forward_backward_no_pipelining)
+    assert schedule.get_forward_backward_func() == schedule.forward_backward_no_pipelining
 
     mocker.patch("megatron.core.pipeline_parallel.schedules.custom_backward", return_value=2)
-    config = ModelParallelConfig(
-        pipeline_model_parallel_size = 1
-    )
+    config = ModelParallelConfig(pipeline_model_parallel_size=1)
     model.config = config
 
     losses_reduced = forward_backward_func(
         forward_step_func=forward_step_func,
-        data_iterator=range(0,100),
+        data_iterator=range(0, 100),
         model=[model],
         num_microbatches=4,
         seq_length=None,
         micro_batch_size=None,
-        forward_only=True) 
-    
+        forward_only=True,
+    )
+
+    loss_reduced_expected = [
+        {'loss_reduced': rank},
+        {'loss_reduced': rank},
+        {'loss_reduced': rank},
+        {'loss_reduced': rank},
+    ]
 
-    loss_reduced_expected = [{'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}]
-    
-    for i,j in zip(losses_reduced, loss_reduced_expected):
+    for i, j in zip(losses_reduced, loss_reduced_expected):
         print(losses_reduced)
-        assert(i['loss_reduced'] == j['loss_reduced'])
-    Utils.destroy_model_parallel() 
+        assert i['loss_reduced'] == j['loss_reduced']
+    Utils.destroy_model_parallel()
 
 
 def test_forward_backward_func_with_pipeline_parallel(mocker):
@@ -79,77 +108,99 @@ def test_forward_backward_func_with_pipeline_parallel(mocker):
 
     def forward_step_func(data_iterator, model):
         import os
+
         rank = int(os.environ['LOCAL_RANK'])
+
         def loss_func(output_tensor):
-            return rank, {'loss_reduced':rank}
-        return torch.rand(512,8,256).cuda(), loss_func
+            return rank, {'loss_reduced': rank}
 
-    model = torch.nn.Linear(4,1)
+        return torch.rand(512, 8, 256).cuda(), loss_func
+
+    model = torch.nn.Linear(4, 1)
     model.model_type = 'unit-test'
+
     def set_input_tensor(input_tensor):
         return None
+
     model.set_input_tensor = set_input_tensor
 
     forward_backward_func = get_forward_backward_func()
-    assert(schedule.get_forward_backward_func() == schedule.forward_backward_pipelining_without_interleaving)
+    assert (
+        schedule.get_forward_backward_func()
+        == schedule.forward_backward_pipelining_without_interleaving
+    )
 
     sequence_length = 512
     micro_batch_size = 8
     hidden_size = 256
 
     config = ModelParallelConfig(
-        pipeline_model_parallel_size = 4,
-        sequence_parallel = False,
-        pipeline_dtype=torch.float,
+        pipeline_model_parallel_size=4, sequence_parallel=False, pipeline_dtype=torch.float
     )
     config.hidden_size = hidden_size
     model.config = config
-    
+
     losses_reduced = forward_backward_func(
         forward_step_func=forward_step_func,
         data_iterator=None,
         model=[model],
-        num_microbatches= micro_batch_size,
+        num_microbatches=micro_batch_size,
         seq_length=sequence_length,
         micro_batch_size=micro_batch_size,
-        forward_only=True) 
-    
-    loss_reduced_expected = [{'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}]
-    for i,j in zip(losses_reduced, loss_reduced_expected):
+        forward_only=True,
+    )
+
+    loss_reduced_expected = [
+        {'loss_reduced': rank},
+        {'loss_reduced': rank},
+        {'loss_reduced': rank},
+        {'loss_reduced': rank},
+    ]
+    for i, j in zip(losses_reduced, loss_reduced_expected):
         print(losses_reduced)
-        assert(i['loss_reduced'] == j['loss_reduced'])
-    Utils.destroy_model_parallel()  
+        assert i['loss_reduced'] == j['loss_reduced']
+    Utils.destroy_model_parallel()
 
 
 def test_forward_backward_func_with_interleaving(mocker):
-    from megatron.core.pipeline_parallel import get_forward_backward_func
     from megatron.core.enums import ModelType
+    from megatron.core.pipeline_parallel import get_forward_backward_func
 
-    Utils.initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=4, virtual_pipeline_model_parallel_size=2)
+    Utils.initialize_model_parallel(
+        tensor_model_parallel_size=1,
+        pipeline_model_parallel_size=4,
+        virtual_pipeline_model_parallel_size=2,
+    )
 
     def forward_step_func(data_iterator, model):
         import os
+
         rank = int(os.environ['LOCAL_RANK'])
+
         def loss_func(output_tensor):
-            return rank, {'loss_reduced':rank}
-        return torch.rand(512,8,256).cuda(), loss_func
+            return rank, {'loss_reduced': rank}
+
+        return torch.rand(512, 8, 256).cuda(), loss_func
+
+    model = torch.nn.Linear(4, 1)
 
-    model = torch.nn.Linear(4,1)
     def set_input_tensor(input_tensor):
         return None
+
     model.set_input_tensor = set_input_tensor
 
     forward_backward_func = get_forward_backward_func()
-    assert(schedule.get_forward_backward_func() == schedule.forward_backward_pipelining_with_interleaving)
+    assert (
+        schedule.get_forward_backward_func()
+        == schedule.forward_backward_pipelining_with_interleaving
+    )
 
     sequence_length = 512
     micro_batch_size = 8
     hidden_size = 256
 
     config = ModelParallelConfig(
-        pipeline_model_parallel_size = 4,
-        sequence_parallel = False,
-        pipeline_dtype=torch.float,
+        pipeline_model_parallel_size=4, sequence_parallel=False, pipeline_dtype=torch.float
     )
     config.hidden_size = hidden_size
     model.config = config
@@ -160,53 +211,61 @@ def set_input_tensor(input_tensor):
         model.model_type = ModelType.encoder_and_decoder
         forward_backward_func(
             forward_step_func=forward_step_func,
-            data_iterator=[range(0,100)],
+            data_iterator=[range(0, 100)],
             model=[model, model],
-            num_microbatches= micro_batch_size,
+            num_microbatches=micro_batch_size,
             seq_length=sequence_length,
-            micro_batch_size=micro_batch_size, 
+            micro_batch_size=micro_batch_size,
             decoder_seq_length=sequence_length,
-            forward_only=True)
- 
+            forward_only=True,
+        )
+
     with pytest.raises(RuntimeError):
         model.model_type = ModelType.encoder_or_decoder
         forward_backward_func(
             forward_step_func=forward_step_func,
-            data_iterator=[range(0,100)],
+            data_iterator=[range(0, 100)],
             model=[model, model],
-            num_microbatches= micro_batch_size,
+            num_microbatches=micro_batch_size,
             seq_length=sequence_length,
-            micro_batch_size=micro_batch_size, 
+            micro_batch_size=micro_batch_size,
             decoder_seq_length=256,
-            forward_only=True)
-     
+            forward_only=True,
+        )
+
     with pytest.raises(RuntimeError):
         model.model_type = ModelType.encoder_or_decoder
         forward_backward_func(
             forward_step_func=forward_step_func,
-            data_iterator=[range(0,100)],
+            data_iterator=[range(0, 100)],
             model=[model, model],
-            num_microbatches= 7,
+            num_microbatches=7,
             seq_length=sequence_length,
-            micro_batch_size=micro_batch_size, 
+            micro_batch_size=micro_batch_size,
             decoder_seq_length=512,
-            forward_only=True)    
+            forward_only=True,
+        )
 
-    
     model.model_type = ModelType.encoder_or_decoder
     losses_reduced = forward_backward_func(
         forward_step_func=forward_step_func,
-        data_iterator=[range(0,100), range(0,100)],
+        data_iterator=[range(0, 100), range(0, 100)],
         model=[model, model],
-        num_microbatches= micro_batch_size,
+        num_microbatches=micro_batch_size,
         seq_length=sequence_length,
-        micro_batch_size=micro_batch_size, 
+        micro_batch_size=micro_batch_size,
         decoder_seq_length=sequence_length,
-        forward_only=True) 
-    
-    loss_reduced_expected = [{'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}]
-    for i,j in zip(losses_reduced, loss_reduced_expected):
+        forward_only=True,
+    )
+
+    loss_reduced_expected = [
+        {'loss_reduced': rank},
+        {'loss_reduced': rank},
+        {'loss_reduced': rank},
+        {'loss_reduced': rank},
+    ]
+    for i, j in zip(losses_reduced, loss_reduced_expected):
         print(losses_reduced)
-        assert(i['loss_reduced'] == j['loss_reduced'])
+        assert i['loss_reduced'] == j['loss_reduced']
 
-    Utils.destroy_model_parallel()    
+    Utils.destroy_model_parallel()
diff --git a/tests/unit_tests/tensor_parallel/test_cross_entropy.py b/tests/unit_tests/tensor_parallel/test_cross_entropy.py
index a29365ee43..66982fd234 100644
--- a/tests/unit_tests/tensor_parallel/test_cross_entropy.py
+++ b/tests/unit_tests/tensor_parallel/test_cross_entropy.py
@@ -1,14 +1,34 @@
-from megatron.core.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy
+import numpy as np
 import torch
+
+from megatron.core.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy
 from tests.unit_tests.test_utilities import Utils
-import numpy as np
+
 
 def test_vocab_parallel_cross_entropy():
-    Utils.initialize_model_parallel(4,2)
-    vocab_parallel_logits = torch.range(0,7).repeat(16,4).cuda()
-    target = torch.arange(0,32,2).cuda()
+    Utils.initialize_model_parallel(4, 2)
+    vocab_parallel_logits = torch.range(0, 7).repeat(16, 4).cuda()
+    target = torch.arange(0, 32, 2).cuda()
     output = vocab_parallel_cross_entropy(vocab_parallel_logits, target)
-    expected_output = torch.tensor([10.2309,  8.2309,  6.2309,  4.2309, 10.2309,  8.2309,  6.2309,  4.2309,
-        10.2309,  8.2309,  6.2309,  4.2309, 10.2309,  8.2309,  6.2309,  4.2309]).cuda()
-    assert(torch.equal(torch.round(expected_output), torch.round(output)))
-    Utils.destroy_model_parallel()
\ No newline at end of file
+    expected_output = torch.tensor(
+        [
+            10.2309,
+            8.2309,
+            6.2309,
+            4.2309,
+            10.2309,
+            8.2309,
+            6.2309,
+            4.2309,
+            10.2309,
+            8.2309,
+            6.2309,
+            4.2309,
+            10.2309,
+            8.2309,
+            6.2309,
+            4.2309,
+        ]
+    ).cuda()
+    assert torch.equal(torch.round(expected_output), torch.round(output))
+    Utils.destroy_model_parallel()
diff --git a/tests/unit_tests/tensor_parallel/test_data.py b/tests/unit_tests/tensor_parallel/test_data.py
index 38a39ce37f..211d48b4fd 100644
--- a/tests/unit_tests/tensor_parallel/test_data.py
+++ b/tests/unit_tests/tensor_parallel/test_data.py
@@ -1,21 +1,23 @@
-from megatron.core.tensor_parallel.data import broadcast_data
 import torch
+
+from megatron.core.tensor_parallel.data import broadcast_data
 from tests.unit_tests.test_utilities import Utils
 
+
 def test_broadcast_data():
-    Utils.initialize_model_parallel(2,4)
+    Utils.initialize_model_parallel(2, 4)
     input_data = {
-        0 : torch.ones((8,8)).cuda() * 0.0,
-        1 : torch.ones((8,8)).cuda() * 1.0,
-        2 : torch.ones((8,8)).cuda() * 2.0,
-        3 : torch.ones((8,8)).cuda() * 3.0,
-        4 : torch.ones((8,8)).cuda() * 4.0,
-        5 : torch.ones((8,8)).cuda() * 5.0,
-        6 : torch.ones((8,8)).cuda() * 6.0,
-        7 : torch.ones((8,8)).cuda() * 7.0
-        }
+        0: torch.ones((8, 8)).cuda() * 0.0,
+        1: torch.ones((8, 8)).cuda() * 1.0,
+        2: torch.ones((8, 8)).cuda() * 2.0,
+        3: torch.ones((8, 8)).cuda() * 3.0,
+        4: torch.ones((8, 8)).cuda() * 4.0,
+        5: torch.ones((8, 8)).cuda() * 5.0,
+        6: torch.ones((8, 8)).cuda() * 6.0,
+        7: torch.ones((8, 8)).cuda() * 7.0,
+    }
     dtype = torch.float32
-    actual_output = broadcast_data([0,1],input_data, dtype)
-    assert(torch.equal(actual_output[0], input_data[0]))
-    assert(torch.equal(actual_output[1], input_data[1]))
-    Utils.destroy_model_parallel()
\ No newline at end of file
+    actual_output = broadcast_data([0, 1], input_data, dtype)
+    assert torch.equal(actual_output[0], input_data[0])
+    assert torch.equal(actual_output[1], input_data[1])
+    Utils.destroy_model_parallel()
diff --git a/tests/unit_tests/tensor_parallel/test_initialization.py b/tests/unit_tests/tensor_parallel/test_initialization.py
index 346ae241e0..9fcc38c259 100644
--- a/tests/unit_tests/tensor_parallel/test_initialization.py
+++ b/tests/unit_tests/tensor_parallel/test_initialization.py
@@ -1,20 +1,25 @@
 # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
 
 import pytest
-
 import torch
 
 import megatron.core.parallel_state as ps
-from megatron.core.tensor_parallel.layers import VocabParallelEmbedding, RowParallelLinear, ColumnParallelLinear
-from tests.unit_tests.test_utilities import Utils
+from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
+from megatron.core.tensor_parallel.layers import (
+    ColumnParallelLinear,
+    RowParallelLinear,
+    VocabParallelEmbedding,
+)
 from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
 from megatron.core.transformer.transformer_config import TransformerConfig
-from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
+from tests.unit_tests.test_utilities import Utils
+
 
 class Test:
 
-    transformer_config = TransformerConfig(num_layers=1, hidden_size=12,
-                                           num_attention_heads=4, use_cpu_initialization=True)
+    transformer_config = TransformerConfig(
+        num_layers=1, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True
+    )
 
     @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
     def test_embedding_init(self):
@@ -23,22 +28,27 @@ def test_embedding_init(self):
         torch.manual_seed(42)
         model_parallel_cuda_manual_seed(42)
 
-
-        tp1 = VocabParallelEmbedding(num_embeddings=16, embedding_dim=4,
-                                     init_method=self.transformer_config.init_method,
-                                     config=self.transformer_config).weight
+        tp1 = VocabParallelEmbedding(
+            num_embeddings=16,
+            embedding_dim=4,
+            init_method=self.transformer_config.init_method,
+            config=self.transformer_config,
+        ).weight
         Utils.destroy_model_parallel()
 
         Utils.initialize_model_parallel(4, 1)
         torch.manual_seed(42)
         model_parallel_cuda_manual_seed(41)  # intentionally different.
-        tp4 = VocabParallelEmbedding(num_embeddings=16, embedding_dim=4,
-                                     init_method=self.transformer_config.init_method,
-                                     config=self.transformer_config).weight
+        tp4 = VocabParallelEmbedding(
+            num_embeddings=16,
+            embedding_dim=4,
+            init_method=self.transformer_config.init_method,
+            config=self.transformer_config,
+        ).weight
 
         rank = ps.get_tensor_model_parallel_rank()
         assert tp4.shape[0] * 4 == tp1.shape[0]
-        assert torch.equal(tp1[rank*4:(rank+1)*4], tp4)
+        assert torch.equal(tp1[rank * 4 : (rank + 1) * 4], tp4)
 
     @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
     def test_row_init(self):
@@ -47,26 +57,33 @@ def test_row_init(self):
         torch.manual_seed(42)
         model_parallel_cuda_manual_seed(42)
 
-        tp1 = RowParallelLinear(input_size=16, output_size=16,
-                                init_method=self.transformer_config.init_method,
-                                bias=True, input_is_parallel=False,
-                                config=self.transformer_config,
-                                skip_bias_add=False).weight
+        tp1 = RowParallelLinear(
+            input_size=16,
+            output_size=16,
+            init_method=self.transformer_config.init_method,
+            bias=True,
+            input_is_parallel=False,
+            config=self.transformer_config,
+            skip_bias_add=False,
+        ).weight
         Utils.destroy_model_parallel()
 
         Utils.initialize_model_parallel(4, 1)
         torch.manual_seed(42)
         model_parallel_cuda_manual_seed(41)  # intentionally different.
-        tp4 = RowParallelLinear(input_size=16, output_size=16,
-                                init_method=self.transformer_config.init_method,
-                                bias=True,
-                                input_is_parallel=False,
-                                config=self.transformer_config,
-                                skip_bias_add=False).weight
+        tp4 = RowParallelLinear(
+            input_size=16,
+            output_size=16,
+            init_method=self.transformer_config.init_method,
+            bias=True,
+            input_is_parallel=False,
+            config=self.transformer_config,
+            skip_bias_add=False,
+        ).weight
 
         rank = ps.get_tensor_model_parallel_rank()
         assert tp4.shape[1] * 4 == tp1.shape[1]
-        assert torch.equal(tp1[:, rank*4:(rank+1)*4], tp4)
+        assert torch.equal(tp1[:, rank * 4 : (rank + 1) * 4], tp4)
 
     @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
     def test_col_init(self):
@@ -75,20 +92,28 @@ def test_col_init(self):
         torch.manual_seed(42)
         model_parallel_cuda_manual_seed(42)
 
-        tp1 = ColumnParallelLinear(input_size=16, output_size=16,
-                                   init_method=self.transformer_config.init_method,
-                                   bias=True, config=self.transformer_config,
-                                   skip_bias_add=False).weight
+        tp1 = ColumnParallelLinear(
+            input_size=16,
+            output_size=16,
+            init_method=self.transformer_config.init_method,
+            bias=True,
+            config=self.transformer_config,
+            skip_bias_add=False,
+        ).weight
         Utils.destroy_model_parallel()
 
         Utils.initialize_model_parallel(4, 1)
         torch.manual_seed(42)
         model_parallel_cuda_manual_seed(41)  # intentionally different.
-        tp4 = ColumnParallelLinear(input_size=16, output_size=16,
-                                   init_method=self.transformer_config.init_method,
-                                   bias=True, config=self.transformer_config,
-                                   skip_bias_add=False).weight
+        tp4 = ColumnParallelLinear(
+            input_size=16,
+            output_size=16,
+            init_method=self.transformer_config.init_method,
+            bias=True,
+            config=self.transformer_config,
+            skip_bias_add=False,
+        ).weight
 
         rank = ps.get_tensor_model_parallel_rank()
         assert tp4.shape[0] * 4 == tp1.shape[0]
-        assert torch.equal(tp1[rank*4:(rank+1)*4], tp4)
+        assert torch.equal(tp1[rank * 4 : (rank + 1) * 4], tp4)
diff --git a/tests/unit_tests/tensor_parallel/test_mappings.py b/tests/unit_tests/tensor_parallel/test_mappings.py
index 6be486ef3c..c6a789410c 100644
--- a/tests/unit_tests/tensor_parallel/test_mappings.py
+++ b/tests/unit_tests/tensor_parallel/test_mappings.py
@@ -1,135 +1,139 @@
+import torch
+
 from megatron.core.tensor_parallel import mappings
 from tests.unit_tests.test_utilities import Utils
-import torch
+
 
 def test_CopyToModelParallelRegion():
-    Utils.initialize_model_parallel(4,2)
-    input_data = torch.ones((1)).cuda()*Utils.rank
+    Utils.initialize_model_parallel(4, 2)
+    input_data = torch.ones((1)).cuda() * Utils.rank
     output_data = mappings._CopyToModelParallelRegion.backward(None, input_data)
     result = torch.ones(1).cuda()
     result = result * 22 if Utils.rank >= 4 else result * 6
-    assert(torch.equal(output_data, result))
-    assert(torch.equal(input_data, mappings.copy_to_tensor_model_parallel_region(input_data)))
-    assert(torch.equal(input_data, mappings._CopyToModelParallelRegion.symbolic(None, input_data)))
+    assert torch.equal(output_data, result)
+    assert torch.equal(input_data, mappings.copy_to_tensor_model_parallel_region(input_data))
+    assert torch.equal(input_data, mappings._CopyToModelParallelRegion.symbolic(None, input_data))
     Utils.destroy_model_parallel()
 
+
 def test_ReduceFromModelParallelRegion():
-    Utils.initialize_model_parallel(4,2)
-    input_data = torch.ones((1)).cuda()*Utils.rank
+    Utils.initialize_model_parallel(4, 2)
+    input_data = torch.ones((1)).cuda() * Utils.rank
     output_data = mappings._ReduceFromModelParallelRegion.symbolic(None, input_data)
     result = torch.ones(1).cuda()
     result = result * 22 if Utils.rank >= 4 else result * 6
-    assert(torch.equal(output_data, result))
-    input_data = torch.ones((1)).cuda()*Utils.rank
-    assert(torch.equal(mappings.reduce_from_tensor_model_parallel_region(input_data), result))
-    assert(torch.equal(input_data, mappings._ReduceFromModelParallelRegion.backward(None, input_data)))
+    assert torch.equal(output_data, result)
+    input_data = torch.ones((1)).cuda() * Utils.rank
+    assert torch.equal(mappings.reduce_from_tensor_model_parallel_region(input_data), result)
+    assert torch.equal(
+        input_data, mappings._ReduceFromModelParallelRegion.backward(None, input_data)
+    )
     Utils.destroy_model_parallel()
 
+
 def test_ScatterToModelParallelRegion():
-    Utils.initialize_model_parallel(4,2)
-    input_data = torch.rand((8,4)).cuda()
+    Utils.initialize_model_parallel(4, 2)
+    input_data = torch.rand((8, 4)).cuda()
     output_data = mappings.scatter_to_tensor_model_parallel_region(input_data)
-    req_dim = int(Utils.rank%(Utils.world_size/2))
-    assert(torch.equal(output_data, input_data[:,req_dim].reshape((8,1))))
+    req_dim = int(Utils.rank % (Utils.world_size / 2))
+    assert torch.equal(output_data, input_data[:, req_dim].reshape((8, 1)))
     output_data = mappings._ScatterToModelParallelRegion.symbolic(None, input_data)
-    assert(torch.equal(output_data, input_data[:, req_dim].reshape((8,1))))
+    assert torch.equal(output_data, input_data[:, req_dim].reshape((8, 1)))
 
     input_data = torch.ones(8).cuda() * Utils.rank
     actual_output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data)
-    expected_output = torch.cat((
-        torch.ones(8)*0,
-        torch.ones(8)*1,
-        torch.ones(8)*2,
-        torch.ones(8)*3)).cuda()
-    if (Utils.rank >= 4):
+    expected_output = torch.cat(
+        (torch.ones(8) * 0, torch.ones(8) * 1, torch.ones(8) * 2, torch.ones(8) * 3)
+    ).cuda()
+    if Utils.rank >= 4:
         expected_output = expected_output + 4
-    assert(torch.equal(actual_output_data, expected_output))
+    assert torch.equal(actual_output_data, expected_output)
     Utils.destroy_model_parallel()
 
+
 def test_GatherFromModelParallelRegion():
-    Utils.initialize_model_parallel(4,2)
-    input_data = torch.rand((8,4)).cuda()
-    req_dim = int(Utils.rank%(Utils.world_size/2))
+    Utils.initialize_model_parallel(4, 2)
+    input_data = torch.rand((8, 4)).cuda()
+    req_dim = int(Utils.rank % (Utils.world_size / 2))
     output_data = mappings._GatherFromModelParallelRegion.backward(None, input_data)
-    assert(torch.equal(output_data, input_data[:, req_dim].reshape((8,1))))
+    assert torch.equal(output_data, input_data[:, req_dim].reshape((8, 1)))
     input_data = torch.ones(8).cuda() * Utils.rank
     actual_output_data = mappings.gather_from_tensor_model_parallel_region(input_data)
-    expected_output = torch.cat((
-        torch.ones(8)*0,
-        torch.ones(8)*1,
-        torch.ones(8)*2,
-        torch.ones(8)*3)).cuda()
-    if (Utils.rank >= 4):
+    expected_output = torch.cat(
+        (torch.ones(8) * 0, torch.ones(8) * 1, torch.ones(8) * 2, torch.ones(8) * 3)
+    ).cuda()
+    if Utils.rank >= 4:
         expected_output = expected_output + 4
-    assert(torch.equal(actual_output_data, expected_output))
-    assert(torch.equal(mappings._GatherFromModelParallelRegion.symbolic(None, input_data), expected_output))
+    assert torch.equal(actual_output_data, expected_output)
+    assert torch.equal(
+        mappings._GatherFromModelParallelRegion.symbolic(None, input_data), expected_output
+    )
     Utils.destroy_model_parallel()
- 
+
+
 def test_ScatterToSequenceParallelRegion():
-    Utils.initialize_model_parallel(4,2)
-    input_data = torch.rand((8,4)).cuda()
-    req_dim = int(Utils.rank%(Utils.world_size/2))*2
+    Utils.initialize_model_parallel(4, 2)
+    input_data = torch.rand((8, 4)).cuda()
+    req_dim = int(Utils.rank % (Utils.world_size / 2)) * 2
     output_data = mappings._ScatterToSequenceParallelRegion.symbolic(None, input_data)
-    assert(torch.equal(output_data, input_data[req_dim:req_dim+2, :]))
+    assert torch.equal(output_data, input_data[req_dim : req_dim + 2, :])
     output_data = mappings.scatter_to_sequence_parallel_region(input_data)
-    assert(torch.equal(output_data, input_data[req_dim:req_dim+2, :]))
+    assert torch.equal(output_data, input_data[req_dim : req_dim + 2, :])
     input_data = torch.ones(4).cuda() * Utils.rank
     output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data)
-    expected_output = torch.concat((
-        torch.ones(4)*0,
-        torch.ones(4)*1,
-        torch.ones(4)*2,
-        torch.ones(4)*3)).cuda()
-    if (Utils.rank >= 4):
+    expected_output = torch.concat(
+        (torch.ones(4) * 0, torch.ones(4) * 1, torch.ones(4) * 2, torch.ones(4) * 3)
+    ).cuda()
+    if Utils.rank >= 4:
         expected_output = expected_output + 4
-    assert(torch.equal(output_data, expected_output))
+    assert torch.equal(output_data, expected_output)
     Utils.destroy_model_parallel()
 
+
 def test_GatherFromSequenceParallelRegion():
-    Utils.initialize_model_parallel(4,2)
+    Utils.initialize_model_parallel(4, 2)
     input_data = torch.ones(4).cuda() * Utils.rank
     output_data = mappings.gather_from_sequence_parallel_region(input_data)
-    expected_output = torch.concat((
-        torch.ones(4)*0,
-        torch.ones(4)*1,
-        torch.ones(4)*2,
-        torch.ones(4)*3)).cuda()
-    if (Utils.rank >= 4):
+    expected_output = torch.concat(
+        (torch.ones(4) * 0, torch.ones(4) * 1, torch.ones(4) * 2, torch.ones(4) * 3)
+    ).cuda()
+    if Utils.rank >= 4:
         expected_output = expected_output + 4
-    assert(torch.equal(output_data, expected_output))
-    assert(torch.equal(mappings._GatherFromSequenceParallelRegion.symbolic(None, input_data), expected_output))
-    input_data = torch.vstack((
-        torch.ones(4)*0,
-        torch.ones(4)*1,
-        torch.ones(4)*2,
-        torch.ones(4)*3)).cuda()
+    assert torch.equal(output_data, expected_output)
+    assert torch.equal(
+        mappings._GatherFromSequenceParallelRegion.symbolic(None, input_data), expected_output
+    )
+    input_data = torch.vstack(
+        (torch.ones(4) * 0, torch.ones(4) * 1, torch.ones(4) * 2, torch.ones(4) * 3)
+    ).cuda()
+
     class Ctx:
         tensor_parallel_output_grad = True
+
     output_data = mappings._GatherFromSequenceParallelRegion.backward(Ctx(), input_data)
-    expected_output = torch.ones((1,4)).cuda() * 4 * int(Utils.rank % 4)
-    assert(torch.equal(output_data[0], expected_output))
+    expected_output = torch.ones((1, 4)).cuda() * 4 * int(Utils.rank % 4)
+    assert torch.equal(output_data[0], expected_output)
     Utils.destroy_model_parallel()
 
+
 def test_ReduceScatterToSequenceParallelRegion():
-    Utils.initialize_model_parallel(4,2)
-    input_data = torch.vstack((
-        torch.ones(4)*0,
-        torch.ones(4)*1,
-        torch.ones(4)*2,
-        torch.ones(4)*3)).cuda()
+    Utils.initialize_model_parallel(4, 2)
+    input_data = torch.vstack(
+        (torch.ones(4) * 0, torch.ones(4) * 1, torch.ones(4) * 2, torch.ones(4) * 3)
+    ).cuda()
     output_data = mappings.reduce_scatter_to_sequence_parallel_region(input_data)
     expected_output = torch.ones(4).cuda() * 4 * int(Utils.rank % 4)
-    assert(torch.equal(output_data[0], expected_output))
-    assert(torch.equal(mappings._ReduceScatterToSequenceParallelRegion.symbolic(None, input_data) , expected_output.reshape((1,4))))
+    assert torch.equal(output_data[0], expected_output)
+    assert torch.equal(
+        mappings._ReduceScatterToSequenceParallelRegion.symbolic(None, input_data),
+        expected_output.reshape((1, 4)),
+    )
     input_data = torch.ones(4).cuda() * Utils.rank
-    output_data = mappings._ReduceScatterToSequenceParallelRegion.backward(None,input_data)
-    expected_output = torch.concat((
-        torch.ones(4)*0,
-        torch.ones(4)*1,
-        torch.ones(4)*2,
-        torch.ones(4)*3)).cuda()
-    if (Utils.rank >= 4):
+    output_data = mappings._ReduceScatterToSequenceParallelRegion.backward(None, input_data)
+    expected_output = torch.concat(
+        (torch.ones(4) * 0, torch.ones(4) * 1, torch.ones(4) * 2, torch.ones(4) * 3)
+    ).cuda()
+    if Utils.rank >= 4:
         expected_output = expected_output + 4
-    assert(torch.equal(output_data, expected_output))
+    assert torch.equal(output_data, expected_output)
     Utils.destroy_model_parallel()
-
diff --git a/tests/unit_tests/tensor_parallel/test_random.py b/tests/unit_tests/tensor_parallel/test_random.py
index e2f35cf341..ace500839d 100644
--- a/tests/unit_tests/tensor_parallel/test_random.py
+++ b/tests/unit_tests/tensor_parallel/test_random.py
@@ -1,44 +1,54 @@
-from megatron.core.tensor_parallel.random import CudaRNGStatesTracker
-from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed,get_cuda_rng_tracker
-from megatron.core.tensor_parallel.random import checkpoint
-from tests.unit_tests.test_utilities import Utils
 import pytest
 import torch
 
+from megatron.core.tensor_parallel.random import (
+    CudaRNGStatesTracker,
+    checkpoint,
+    get_cuda_rng_tracker,
+    model_parallel_cuda_manual_seed,
+)
+from tests.unit_tests.test_utilities import Utils
+
+
 def test_cuda_rng_states_tracker():
     rng_tracker = CudaRNGStatesTracker()
-    rng_tracker.set_states({"state1":1234})
-    assert(rng_tracker.get_states()["state1"] == 1234)
+    rng_tracker.set_states({"state1": 1234})
+    assert rng_tracker.get_states()["state1"] == 1234
     rng_tracker.reset()
-    assert(rng_tracker.get_states() == {})
+    assert rng_tracker.get_states() == {}
     seed = 1111
-    rng_tracker.add("state2",seed)
+    rng_tracker.add("state2", seed)
     with pytest.raises(Exception):
-        assert(rng_tracker.add("state3",seed))
+        assert rng_tracker.add("state3", seed)
     with pytest.raises(Exception):
-        assert(rng_tracker.add("state2",111))
-    assert(rng_tracker.get_states()['state2'] is not None)
+        assert rng_tracker.add("state2", 111)
+    assert rng_tracker.get_states()['state2'] is not None
     with pytest.raises(Exception):
-        assert()
-    
+        assert ()
+
     rng_tracker.fork("state2")
     torch.cuda.manual_seed(seed)
     rng_state = torch.cuda.get_rng_state()
     assert torch.equal(rng_tracker.get_states()['state2'], rng_state)
 
+
 def test_model_parallel_cuda_manual_seed():
-    Utils.initialize_model_parallel(4,2)
+    Utils.initialize_model_parallel(4, 2)
     model_parallel_cuda_manual_seed(0)
     rng_tracker = get_cuda_rng_tracker()
-    assert(rng_tracker.get_states()['model-parallel-rng'] is not None)
+    assert rng_tracker.get_states()['model-parallel-rng'] is not None
     Utils.destroy_model_parallel()
 
+
 def test_checkpoint():
     def test_forward(*input):
-        return input[0]+input[1]
-    assert(torch.equal(torch.ones(16)*3,checkpoint(test_forward, None, torch.ones(16), torch.ones(16)*2)))
+        return input[0] + input[1]
+
+    assert torch.equal(
+        torch.ones(16) * 3, checkpoint(test_forward, None, torch.ones(16), torch.ones(16) * 2)
+    )
     Utils.initialize_model_parallel()
-    input1 = torch.ones((4,4))
-    checkpoint(test_forward, True, input1, torch.ones((4,4))*2)
-    assert(torch.equal(torch.ones(input1.numel()).cuda(), input1))
+    input1 = torch.ones((4, 4))
+    checkpoint(test_forward, True, input1, torch.ones((4, 4)) * 2)
+    assert torch.equal(torch.ones(input1.numel()).cuda(), input1)
     Utils.destroy_model_parallel()
diff --git a/tests/unit_tests/tensor_parallel/test_tensor_parallel_utils.py b/tests/unit_tests/tensor_parallel/test_tensor_parallel_utils.py
index f82e5fa693..5df774e5ff 100644
--- a/tests/unit_tests/tensor_parallel/test_tensor_parallel_utils.py
+++ b/tests/unit_tests/tensor_parallel/test_tensor_parallel_utils.py
@@ -1,43 +1,55 @@
 import torch
-import megatron.core.tensor_parallel.utils as util
+
 import megatron.core.parallel_state as ps
+import megatron.core.tensor_parallel.utils as util
 from tests.unit_tests.test_utilities import Utils
 
 rank = Utils.rank
 
+
 def test_split_tensor_along_last_dim():
-    input_tensor = torch.rand((3,4))
-    torch.equal(input_tensor[0:2,0:2], util.split_tensor_along_last_dim(input_tensor,2)[0])
-    torch.equal(input_tensor[2:,2:], util.split_tensor_along_last_dim(input_tensor,2)[1])
+    input_tensor = torch.rand((3, 4))
+    torch.equal(input_tensor[0:2, 0:2], util.split_tensor_along_last_dim(input_tensor, 2)[0])
+    torch.equal(input_tensor[2:, 2:], util.split_tensor_along_last_dim(input_tensor, 2)[1])
+
 
 def test_split_tensor_into_1d_equal_chunks():
     Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)
-    input_tensor = torch.rand((3,4))
+    input_tensor = torch.rand((3, 4))
     output_tensor = util.split_tensor_into_1d_equal_chunks(input_tensor)
-    if rank % 2 == 0 :
+    if rank % 2 == 0:
         start = 0
-        end = int(input_tensor.numel()/2)
-    else :
-        start = int(input_tensor.numel()/2)
+        end = int(input_tensor.numel() / 2)
+    else:
+        start = int(input_tensor.numel() / 2)
         end = input_tensor.numel()
-        
+
     assert torch.equal(output_tensor, input_tensor.flatten()[start:end])
     Utils.destroy_model_parallel()
 
+
 def test_gather_split_1d_tensor():
     Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)
-    input_tensor = torch.ones((2,4)).cuda() * rank
+    input_tensor = torch.ones((2, 4)).cuda() * rank
     actual_output_tensor = util.gather_split_1d_tensor(input_tensor)
-    if rank %2 == 0:
+    if rank % 2 == 0:
         expected_output_tensor = torch.concat((input_tensor.flatten(), input_tensor.flatten() + 1))
-    else : 
+    else:
         expected_output_tensor = torch.concat((input_tensor.flatten() - 1, input_tensor.flatten()))
-    assert(torch.equal(actual_output_tensor, expected_output_tensor))
+    assert torch.equal(actual_output_tensor, expected_output_tensor)
     Utils.destroy_model_parallel()
 
+
 def test_vocab():
     global_vocab_size = 1600
     per_partition_vocab_size = 1600 / Utils.world_size
-    assert((rank * per_partition_vocab_size, (rank + 1)* per_partition_vocab_size) == (util.VocabUtility.vocab_range_from_per_partition_vocab_size(global_vocab_size // Utils.world_size, rank, Utils.world_size)))
-    assert((rank * per_partition_vocab_size, (rank + 1)* per_partition_vocab_size) == (util.VocabUtility.vocab_range_from_global_vocab_size(global_vocab_size, rank, Utils.world_size)))
-    
\ No newline at end of file
+    assert (rank * per_partition_vocab_size, (rank + 1) * per_partition_vocab_size) == (
+        util.VocabUtility.vocab_range_from_per_partition_vocab_size(
+            global_vocab_size // Utils.world_size, rank, Utils.world_size
+        )
+    )
+    assert (rank * per_partition_vocab_size, (rank + 1) * per_partition_vocab_size) == (
+        util.VocabUtility.vocab_range_from_global_vocab_size(
+            global_vocab_size, rank, Utils.world_size
+        )
+    )
diff --git a/tests/unit_tests/test_basic.py b/tests/unit_tests/test_basic.py
index 915d2c1001..d2a60f92c8 100644
--- a/tests/unit_tests/test_basic.py
+++ b/tests/unit_tests/test_basic.py
@@ -1,3 +1,2 @@
 def test_import():
     import megatron
-
diff --git a/tests/unit_tests/test_imports.py b/tests/unit_tests/test_imports.py
index 49e7c77b55..bad67cd8d5 100644
--- a/tests/unit_tests/test_imports.py
+++ b/tests/unit_tests/test_imports.py
@@ -81,8 +81,7 @@ def _test_domain_module_imports(module, subdomains: list):
 
     if error is None:
         for imp in dir(module):
-            class_, result, error = _get_class_from_path(
-                subdomains, imp)
+            class_, result, error = _get_class_from_path(subdomains, imp)
 
             if result is not None:
                 module_list.append(class_)
@@ -99,7 +98,8 @@ def _test_domain_module_imports(module, subdomains: list):
     print()
     for module in failed_list:
         print(
-            "Module did not match a valid signature of Megatron core Model (hence ignored):", module)
+            "Module did not match a valid signature of Megatron core Model (hence ignored):", module
+        )
 
     print()
     if len(error_list) > 0:
@@ -125,29 +125,21 @@ def _test_domain_module_imports(module, subdomains: list):
 def test_domain_mcore():
     import megatron.core as mcore
 
-    all_passed = _test_domain_module_imports(
-        mcore,  subdomains=['models'])
+    all_passed = _test_domain_module_imports(mcore, subdomains=['models'])
 
-    all_passed = _test_domain_module_imports(
-        mcore,  subdomains=['pipeline_parallel'])
+    all_passed = _test_domain_module_imports(mcore, subdomains=['pipeline_parallel'])
 
-    all_passed = _test_domain_module_imports(
-        mcore,  subdomains=['tensor_parallel'])
+    all_passed = _test_domain_module_imports(mcore, subdomains=['tensor_parallel'])
 
-    all_passed = _test_domain_module_imports(
-        mcore,  subdomains=['transformer'])
+    all_passed = _test_domain_module_imports(mcore, subdomains=['transformer'])
 
-    all_passed = _test_domain_module_imports(
-        mcore,  subdomains=['fusions'])
+    all_passed = _test_domain_module_imports(mcore, subdomains=['fusions'])
 
-    all_passed = _test_domain_module_imports(
-        mcore,  subdomains=['distributed'])
+    all_passed = _test_domain_module_imports(mcore, subdomains=['distributed'])
 
-    all_passed = _test_domain_module_imports(
-        mcore,  subdomains=['datasets'])
+    all_passed = _test_domain_module_imports(mcore, subdomains=['datasets'])
 
-    all_passed = _test_domain_module_imports(
-        mcore,  subdomains=['dist_checkpointing'])
+    all_passed = _test_domain_module_imports(mcore, subdomains=['dist_checkpointing'])
 
     if not all_passed:
         exit(1)
diff --git a/tests/unit_tests/test_local_multi_tensor_fns.py b/tests/unit_tests/test_local_multi_tensor_fns.py
index f47d549f98..086de6f6d0 100644
--- a/tests/unit_tests/test_local_multi_tensor_fns.py
+++ b/tests/unit_tests/test_local_multi_tensor_fns.py
@@ -1,11 +1,14 @@
 import copy
+
+import pytest
+import torch
+
 from megatron.core.utils import (
     local_multi_tensor_applier,
     local_multi_tensor_l2_norm,
-    local_multi_tensor_scale
+    local_multi_tensor_scale,
 )
-import pytest
-import torch
+
 
 def test_local_multi_tensor_l2_norm_and_scale():
     amp_C = pytest.importorskip("amp_C")
@@ -13,24 +16,55 @@ def test_local_multi_tensor_l2_norm_and_scale():
 
     torch.manual_seed(42)
 
-    tensor_list = [torch.rand(5,5).cuda() for _ in range(10)]
+    tensor_list = [torch.rand(5, 5).cuda() for _ in range(10)]
     tensor_list_copy = copy.deepcopy(tensor_list)
 
-    norm_apex, _ = multi_tensor_apply.multi_tensor_applier(amp_C.multi_tensor_l2norm, torch.tensor([0], dtype=torch.int, device='cuda'), [tensor_list], False)
-    norm_local, _ = multi_tensor_apply.multi_tensor_applier(local_multi_tensor_l2_norm, torch.tensor([0], dtype=torch.int, device='cuda'), [tensor_list_copy], False)
+    norm_apex, _ = multi_tensor_apply.multi_tensor_applier(
+        amp_C.multi_tensor_l2norm,
+        torch.tensor([0], dtype=torch.int, device='cuda'),
+        [tensor_list],
+        False,
+    )
+    norm_local, _ = multi_tensor_apply.multi_tensor_applier(
+        local_multi_tensor_l2_norm,
+        torch.tensor([0], dtype=torch.int, device='cuda'),
+        [tensor_list_copy],
+        False,
+    )
     torch.testing.assert_close(norm_apex, norm_local)
 
     clip_coeff = 0.05
-    multi_tensor_apply.multi_tensor_applier(amp_C.multi_tensor_scale, torch.tensor([0], dtype=torch.int, device='cuda'), [tensor_list, tensor_list], clip_coeff)
-    multi_tensor_apply.multi_tensor_applier(local_multi_tensor_scale, torch.tensor([0], dtype=torch.int, device='cuda'), [tensor_list_copy, tensor_list_copy], clip_coeff)
+    multi_tensor_apply.multi_tensor_applier(
+        amp_C.multi_tensor_scale,
+        torch.tensor([0], dtype=torch.int, device='cuda'),
+        [tensor_list, tensor_list],
+        clip_coeff,
+    )
+    multi_tensor_apply.multi_tensor_applier(
+        local_multi_tensor_scale,
+        torch.tensor([0], dtype=torch.int, device='cuda'),
+        [tensor_list_copy, tensor_list_copy],
+        clip_coeff,
+    )
     torch.testing.assert_close(tensor_list, tensor_list_copy)
 
+
 def test_local_multi_tensor_apply():
     amp_C = pytest.importorskip("amp_C")
     multi_tensor_apply = pytest.importorskip("apex.multi_tensor_apply")
 
-    tensor_list = [torch.rand(5,5).cuda() for _ in range(10)]
+    tensor_list = [torch.rand(5, 5).cuda() for _ in range(10)]
 
-    norm_apex, _ = multi_tensor_apply.multi_tensor_applier(amp_C.multi_tensor_l2norm, torch.tensor([0], dtype=torch.int, device='cuda'), [tensor_list], False)
-    norm_local, _ = local_multi_tensor_applier(amp_C.multi_tensor_l2norm, torch.tensor([0], dtype=torch.int, device='cuda'), [tensor_list], False)
+    norm_apex, _ = multi_tensor_apply.multi_tensor_applier(
+        amp_C.multi_tensor_l2norm,
+        torch.tensor([0], dtype=torch.int, device='cuda'),
+        [tensor_list],
+        False,
+    )
+    norm_local, _ = local_multi_tensor_applier(
+        amp_C.multi_tensor_l2norm,
+        torch.tensor([0], dtype=torch.int, device='cuda'),
+        [tensor_list],
+        False,
+    )
     torch.testing.assert_close(norm_apex, norm_local)
diff --git a/tests/unit_tests/test_optimizer.py b/tests/unit_tests/test_optimizer.py
index 247da4aeb9..732a68cfa6 100644
--- a/tests/unit_tests/test_optimizer.py
+++ b/tests/unit_tests/test_optimizer.py
@@ -28,8 +28,8 @@ def forward(self, x):
 
 def test_chained_optimizer():
     net = Net()
-    optimizer_1 = Adam(list(net.parameters())[:2], lr=0.01,)
-    optimizer_2 = SGD(list(net.parameters())[2:], lr=0.1, momentum=0.9,)
+    optimizer_1 = Adam(list(net.parameters())[:2], lr=0.01)
+    optimizer_2 = SGD(list(net.parameters())[2:], lr=0.1, momentum=0.9)
     chained_optimizer = ChainedOptimizer([optimizer_1, optimizer_2])
 
     # Test the chained optimizer's param groups is a reference of the underlying optimizers' param groups
diff --git a/tests/unit_tests/test_parallel_state.py b/tests/unit_tests/test_parallel_state.py
index af58872ac0..abe3ea3d2e 100644
--- a/tests/unit_tests/test_parallel_state.py
+++ b/tests/unit_tests/test_parallel_state.py
@@ -1,114 +1,132 @@
+import os
+
+import pytest
 import torch
+
 import megatron.core.parallel_state as ps
-import pytest
 from tests.unit_tests.test_utilities import Utils
-import os
 
 rank = Utils.rank
 world_size = Utils.world_size
 test_parallel_order = ['tp-cp-ep-dp-pp', 'tp-cp-pp-ep-dp']
 
+
 @pytest.mark.parametrize('order', test_parallel_order)
 def test_initialize_and_destroy_model_parallel(order):
     with pytest.raises(AssertionError):
-        assert(ps.initialize_model_parallel(order=order))
+        assert ps.initialize_model_parallel(order=order)
     Utils.initialize_distributed()
     with pytest.raises(RuntimeError):
-        assert(ps.initialize_model_parallel(tensor_model_parallel_size=2*world_size, order=order))
+        assert ps.initialize_model_parallel(tensor_model_parallel_size=2 * world_size, order=order)
     with pytest.raises(RuntimeError):
-        assert(ps.initialize_model_parallel(pipeline_model_parallel_size=2*world_size, order=order))
+        assert ps.initialize_model_parallel(
+            pipeline_model_parallel_size=2 * world_size, order=order
+        )
     with pytest.raises(RuntimeError):
-        assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size, order=order))
+        assert ps.initialize_model_parallel(
+            pipeline_model_parallel_size=world_size,
+            tensor_model_parallel_size=world_size,
+            order=order,
+        )
     with pytest.raises(RuntimeError):
-        assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2, order=order))
-    Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4, order=order)
-
-    assert(ps.model_parallel_is_initialized())
-    assert(ps.get_model_parallel_group() is not None)
-    assert(ps.get_tensor_model_parallel_group() is not None)
-    assert(ps.get_pipeline_model_parallel_group() is not None)
-    assert(ps.get_data_parallel_group() is not None)
+        assert ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2, order=order)
+    Utils.initialize_model_parallel(
+        tensor_model_parallel_size=2, pipeline_model_parallel_size=4, order=order
+    )
+
+    assert ps.model_parallel_is_initialized()
+    assert ps.get_model_parallel_group() is not None
+    assert ps.get_tensor_model_parallel_group() is not None
+    assert ps.get_pipeline_model_parallel_group() is not None
+    assert ps.get_data_parallel_group() is not None
     Utils.destroy_model_parallel()
-    assert(ps._MODEL_PARALLEL_GROUP is None)
+    assert ps._MODEL_PARALLEL_GROUP is None
+
 
 @pytest.mark.parametrize('order', test_parallel_order)
 def test_pipeline_parallel_initializations(order):
-    Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4, order=order)
-    assert(ps.get_pipeline_model_parallel_first_rank() == rank % 2 )
-    assert(ps.get_data_parallel_src_rank() == rank)
-    assert(ps.get_pipeline_model_parallel_next_rank() == ((rank + 2) % world_size))
-    assert(ps.get_pipeline_model_parallel_prev_rank() == ((rank - 2) % world_size))
+    Utils.initialize_model_parallel(
+        tensor_model_parallel_size=2, pipeline_model_parallel_size=4, order=order
+    )
+    assert ps.get_pipeline_model_parallel_first_rank() == rank % 2
+    assert ps.get_data_parallel_src_rank() == rank
+    assert ps.get_pipeline_model_parallel_next_rank() == ((rank + 2) % world_size)
+    assert ps.get_pipeline_model_parallel_prev_rank() == ((rank - 2) % world_size)
     Utils.destroy_model_parallel()
 
+
 @pytest.mark.parametrize('order', test_parallel_order)
 def test_data_parallel_initializations(order):
     Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order)
-    assert(ps.get_data_parallel_src_rank() == rank)
-    assert(ps.get_data_parallel_world_size() == 1)
-    assert(ps.get_data_parallel_rank() == 0)
+    assert ps.get_data_parallel_src_rank() == rank
+    assert ps.get_data_parallel_world_size() == 1
+    assert ps.get_data_parallel_rank() == 0
     Utils.destroy_model_parallel()
 
+
 @pytest.mark.parametrize('order', test_parallel_order)
 def test_tensor_model_parellel_world_size(order):
     Utils.initialize_model_parallel(tensor_model_parallel_size=world_size, order=order)
-    assert(ps.get_tensor_model_parallel_world_size() == world_size)
+    assert ps.get_tensor_model_parallel_world_size() == world_size
     ps.set_tensor_model_parallel_world_size(None)
-    assert(ps.get_tensor_model_parallel_world_size() == world_size)
+    assert ps.get_tensor_model_parallel_world_size() == world_size
     Utils.destroy_model_parallel()
 
 
 @pytest.mark.parametrize('order', test_parallel_order)
 def test_pipeline_model_parallel_world_size(order):
     Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order)
-    assert(ps.get_pipeline_model_parallel_world_size() == world_size)
+    assert ps.get_pipeline_model_parallel_world_size() == world_size
     ps.set_pipeline_model_parallel_world_size(None)
-    assert(ps.get_pipeline_model_parallel_world_size() == world_size)
+    assert ps.get_pipeline_model_parallel_world_size() == world_size
     Utils.destroy_model_parallel()
 
 
 @pytest.mark.parametrize('order', test_parallel_order)
 def test_tensor_model_parallel_rank(order):
     Utils.initialize_model_parallel(tensor_model_parallel_size=world_size, order=order)
-    assert(ps.get_tensor_model_parallel_rank() == rank)
+    assert ps.get_tensor_model_parallel_rank() == rank
     ps.set_tensor_model_parallel_rank(None)
-    assert(ps.get_tensor_model_parallel_rank() == rank)
+    assert ps.get_tensor_model_parallel_rank() == rank
     Utils.destroy_model_parallel()
 
 
 @pytest.mark.parametrize('order', test_parallel_order)
 def test_pipeline_model_parallel_rank(order):
     Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order)
-    assert(ps.get_pipeline_model_parallel_rank() == rank)
+    assert ps.get_pipeline_model_parallel_rank() == rank
     ps.set_pipeline_model_parallel_rank(None)
-    assert(ps.get_pipeline_model_parallel_rank() == rank)
+    assert ps.get_pipeline_model_parallel_rank() == rank
     Utils.destroy_model_parallel()
 
+
 def test_context_parallel_rank():
     Utils.initialize_model_parallel(context_parallel_size=world_size)
-    assert(ps.get_context_parallel_rank() == rank)
+    assert ps.get_context_parallel_rank() == rank
     Utils.destroy_model_parallel()
 
+
 def test_expert_model_parallel_rank():
     Utils.initialize_model_parallel(expert_model_parallel_size=world_size)
-    assert(ps.get_expert_model_parallel_rank() == rank)
+    assert ps.get_expert_model_parallel_rank() == rank
     ps.set_expert_model_parallel_rank(None)
-    assert(ps.get_expert_model_parallel_rank() == rank)
+    assert ps.get_expert_model_parallel_rank() == rank
     Utils.destroy_model_parallel()
 
 
 @pytest.mark.parametrize('order', test_parallel_order)
 def test_is_pipeline_first_stage(order):
     Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order)
-    assert(ps.is_pipeline_first_stage(ignore_virtual=True) == (rank == 0))
-    assert(ps.is_pipeline_first_stage() == (rank == 0))
+    assert ps.is_pipeline_first_stage(ignore_virtual=True) == (rank == 0)
+    assert ps.is_pipeline_first_stage() == (rank == 0)
     Utils.destroy_model_parallel()
 
 
 @pytest.mark.parametrize('order', test_parallel_order)
 def test_is_pipeline_last_stage(order):
     Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order)
-    assert(ps.is_pipeline_last_stage(ignore_virtual=True) == (rank == world_size-1))
-    assert(ps.is_pipeline_last_stage() == (rank == world_size-1))
+    assert ps.is_pipeline_last_stage(ignore_virtual=True) == (rank == world_size - 1)
+    assert ps.is_pipeline_last_stage() == (rank == world_size - 1)
     Utils.destroy_model_parallel()
 
 
@@ -116,14 +134,14 @@ def test_is_pipeline_last_stage(order):
 def test_virtual_pipeline_model_parallel_rank(order):
     Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order)
     ps.set_virtual_pipeline_model_parallel_rank(rank)
-    assert(ps.get_virtual_pipeline_model_parallel_rank() == rank)
+    assert ps.get_virtual_pipeline_model_parallel_rank() == rank
     Utils.destroy_model_parallel()
 
 
 @pytest.mark.parametrize('order', test_parallel_order)
 def test_get_tensor_model_parallel_src_rank(order):
     Utils.initialize_model_parallel(tensor_model_parallel_size=world_size, order=order)
-    assert(ps.get_tensor_model_parallel_src_rank() == ((rank // world_size) * world_size))
+    assert ps.get_tensor_model_parallel_src_rank() == ((rank // world_size) * world_size)
     Utils.destroy_model_parallel()
 
 
@@ -215,7 +233,7 @@ def test_different_initialize_order_consistency(src_tp_pp, ep_size):
 
 @pytest.mark.parametrize(
     'src_tp_pp, ep_size',
-    [((1, 2), 1), ((1, 4), 1), ((2, 2), 1), ((1, 2), 2), ((1, 4), 2), ((2, 2), 2),],
+    [((1, 2), 1), ((1, 4), 1), ((2, 2), 1), ((1, 2), 2), ((1, 4), 2), ((2, 2), 2)],
 )
 def test_different_initialize_order_unconsistency(src_tp_pp, ep_size):
     Utils.initialize_model_parallel(
@@ -350,7 +368,9 @@ def golden_rank_result_from_past_code(
 
         tp_dp_group = []
         tp_dp_cp_group = []
-        tensor_and_data_group_size_with_cp: int = tensor_model_parallel_size * data_parallel_size * context_parallel_size
+        tensor_and_data_group_size_with_cp: int = (
+            tensor_model_parallel_size * data_parallel_size * context_parallel_size
+        )
         num_tensor_and_data_groups_with_cp: int = world_size // tensor_and_data_group_size_with_cp
         for i in range(num_tensor_and_data_groups_with_cp):
             start_rank = i * tensor_and_data_group_size_with_cp
@@ -374,16 +394,20 @@ def golden_rank_result_from_past_code(
         dp_no_ep_group = []
         dp_no_ep_group_with_cp = []
 
-        all_ranks = torch.arange(world_size).reshape((
-            pipeline_model_parallel_size,
-            data_parallel_size // expert_model_parallel_size,
-            expert_model_parallel_size,
-            context_parallel_size,
-            tensor_model_parallel_size
-        ))
+        all_ranks = torch.arange(world_size).reshape(
+            (
+                pipeline_model_parallel_size,
+                data_parallel_size // expert_model_parallel_size,
+                expert_model_parallel_size,
+                context_parallel_size,
+                tensor_model_parallel_size,
+            )
+        )
         # 'pp edp ep cp tp -> (pp edp cp) (ep tp)'
         tp_ep_rearrange = torch.transpose(all_ranks, 2, 3)
-        tp_ep_rearrange = torch.reshape(tp_ep_rearrange, (-1, expert_model_parallel_size * tensor_model_parallel_size))
+        tp_ep_rearrange = torch.reshape(
+            tp_ep_rearrange, (-1, expert_model_parallel_size * tensor_model_parallel_size)
+        )
         tp_ep_rearrange = tp_ep_rearrange.tolist()
         tp_ep_rearrange.sort()
         for tensor_and_expert_parallel_ranks in tp_ep_rearrange:
@@ -392,7 +416,9 @@ def golden_rank_result_from_past_code(
             tp_ep_group.append(tensor_and_expert_parallel_ranks)
         # 'pp edp ep cp tp -> (pp ep cp tp) edp'
         edp_rearrange = torch.transpose(all_ranks, 1, 4)
-        edp_rearrange = torch.reshape(edp_rearrange, (-1, data_parallel_size // expert_model_parallel_size))
+        edp_rearrange = torch.reshape(
+            edp_rearrange, (-1, data_parallel_size // expert_model_parallel_size)
+        )
         edp_rearrange = edp_rearrange.tolist()
         edp_rearrange.sort()
         for expert_data_parallel_ranks in edp_rearrange:
@@ -404,7 +430,7 @@ def golden_rank_result_from_past_code(
         edp_cp_rearrange = torch.transpose(edp_cp_rearrange, 2, 4)
         edp_cp_rearrange = torch.reshape(
             edp_cp_rearrange,
-            (-1, context_parallel_size * data_parallel_size // expert_model_parallel_size)
+            (-1, context_parallel_size * data_parallel_size // expert_model_parallel_size),
         )
         edp_cp_rearrange = edp_cp_rearrange.tolist()
         edp_cp_rearrange.sort()
@@ -452,7 +478,7 @@ def golden_rank_result_from_past_code(
         context_parallel_size=cp,
         expert_model_parallel_size=ep,
     )
-    rank_generator = ps.RankGenerator(tp=tp, ep=ep, dp=dp, pp=pp, cp=cp, order="tp-cp-ep-dp-pp",)
+    rank_generator = ps.RankGenerator(tp=tp, ep=ep, dp=dp, pp=pp, cp=cp, order="tp-cp-ep-dp-pp")
     assert dp_groups == rank_generator.get_ranks(
         "dp"
     ), f"{dp_groups} != {rank_generator.get_ranks('dp')}"
diff --git a/tests/unit_tests/test_training.py b/tests/unit_tests/test_training.py
index 7ac6ff360a..a23496f981 100644
--- a/tests/unit_tests/test_training.py
+++ b/tests/unit_tests/test_training.py
@@ -1,8 +1,8 @@
 from types import SimpleNamespace
 
 from megatron.training.global_vars import set_args
-from megatron.training.training import build_train_valid_test_data_iterators
 from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding
+from megatron.training.training import build_train_valid_test_data_iterators
 from tests.unit_tests.test_utilities import Utils
 
 
@@ -40,7 +40,6 @@ def test_build_train_valid_test_data_iterators(self):
 
         assert (train_iter, valid_iter, test_iter) == (1, 2, 3)
 
-
     def test_closed_formula_vocab_size_with_padding(self):
         def old_round_impl(after, multiple):
             while (after % multiple) != 0:
@@ -54,12 +53,16 @@ def old_round_impl(after, multiple):
         for vocab in range(1, 600000, 1000):
             for mult in [1, 17, 32, 64, 128]:
                 args.make_vocab_size_divisible_by = mult
-                assert old_round_impl(vocab, mult) == _vocab_size_with_padding(vocab, args, False), (vocab, mult)
+                assert old_round_impl(vocab, mult) == _vocab_size_with_padding(
+                    vocab, args, False
+                ), (vocab, mult)
 
         for vocab in range(1, 10_000, 500):
-            for mult in range(1, 1024+1):
+            for mult in range(1, 1024 + 1):
                 args.make_vocab_size_divisible_by = mult
-                assert old_round_impl(vocab, mult) == _vocab_size_with_padding(vocab, args, False), (vocab, mult)
+                assert old_round_impl(vocab, mult) == _vocab_size_with_padding(
+                    vocab, args, False
+                ), (vocab, mult)
 
     def teardown_method(self, method):
         Utils.destroy_model_parallel()
diff --git a/tests/unit_tests/test_utilities.py b/tests/unit_tests/test_utilities.py
index 1de1fbe9f9..27e87378ba 100644
--- a/tests/unit_tests/test_utilities.py
+++ b/tests/unit_tests/test_utilities.py
@@ -47,10 +47,7 @@ def initialize_distributed():
             Utils.store = store
 
             torch.distributed.init_process_group(
-                backend='nccl',
-                world_size=Utils.world_size,
-                rank=Utils.rank,
-                store=store,
+                backend='nccl', world_size=Utils.world_size, rank=Utils.rank, store=store
             )
 
             torch.distributed.barrier()
diff --git a/tests/unit_tests/test_utils.py b/tests/unit_tests/test_utils.py
index e0a0c2d07d..b2095e3506 100644
--- a/tests/unit_tests/test_utils.py
+++ b/tests/unit_tests/test_utils.py
@@ -11,36 +11,42 @@
 
 
 def test_divide_properly():
-    assert util.divide(4,2) == 2
+    assert util.divide(4, 2) == 2
+
 
 def test_divide_improperly():
     with pytest.raises(AssertionError):
-        util.divide(4,5)
+        util.divide(4, 5)
+
 
 def test_global_memory_buffer():
     global_memory_buffer = util.GlobalMemoryBuffer()
-    obtained_tensor = global_memory_buffer.get_tensor((3,2), torch.float32, "test_tensor")
-    expected_tensor = torch.empty((3,2), dtype=torch.float32, device=torch.cuda.current_device())
+    obtained_tensor = global_memory_buffer.get_tensor((3, 2), torch.float32, "test_tensor")
+    expected_tensor = torch.empty((3, 2), dtype=torch.float32, device=torch.cuda.current_device())
     assert obtained_tensor.shape == expected_tensor.shape
 
+
 def test_make_viewless_tensor():
-    inp = torch.rand((3,4))
-    assert(torch.equal(inp, util.make_viewless_tensor(inp, True, True)))
-    assert(torch.equal(inp, util.make_viewless_tensor(inp, True, False)))
+    inp = torch.rand((3, 4))
+    assert torch.equal(inp, util.make_viewless_tensor(inp, True, True))
+    assert torch.equal(inp, util.make_viewless_tensor(inp, True, False))
+
 
 def test_safely_set_viewless_tensor_data():
-    tensor = torch.zeros((3,4))
-    new_data_tensor = torch.tensor(np.random.rand(3,4))
+    tensor = torch.zeros((3, 4))
+    new_data_tensor = torch.tensor(np.random.rand(3, 4))
     util.safely_set_viewless_tensor_data(tensor, new_data_tensor)
-    assert(torch.equal(tensor, new_data_tensor))
+    assert torch.equal(tensor, new_data_tensor)
+
 
 def test_assert_viewless_tensor():
-    tensor = torch.rand((3,4))
-    assert(torch.equal(util.assert_viewless_tensor(tensor), tensor))
-    input_tensor_list=[tensor,tensor,tensor]
+    tensor = torch.rand((3, 4))
+    assert torch.equal(util.assert_viewless_tensor(tensor), tensor)
+    input_tensor_list = [tensor, tensor, tensor]
     output_tensor_list = util.assert_viewless_tensor(input_tensor_list)
-    for inp,out in zip(input_tensor_list, output_tensor_list):
-        assert(torch.equal(inp,out))
+    for inp, out in zip(input_tensor_list, output_tensor_list):
+        assert torch.equal(inp, out)
+
 
 # Initialize torch.distributed; do not call init_process_group here, call
 # Utils.initialize_distributed() instead.
@@ -51,12 +57,14 @@ def _init_distributed(world, rank):
     assert torch.cuda.device_count() == world
     torch.distributed.barrier()
 
+
 # Deinitialization and cleanup.
 # Do not call torch.distributed.destroy_process_group, may be needed by other tests.
 def _deinit_distributed():
     assert torch.distributed.is_initialized() == True
     torch.distributed.barrier()
 
+
 def test_check_param_hashes_across_dp_replicas():
     world = int(os.getenv('WORLD_SIZE', '1'))
     rank = int(os.getenv('RANK', '0'))
@@ -74,7 +82,7 @@ def test_check_param_hashes_across_dp_replicas():
     if rank == 0:
         model.weight.data.fill_(0.0)
     param_hashes_match = util.check_param_hashes_across_dp_replicas([model])
-    expected_param_hashes_match = (rank == 0)
+    expected_param_hashes_match = rank == 0
     assert param_hashes_match == expected_param_hashes_match
 
     # Teardown.
@@ -117,7 +125,7 @@ def straggler_detector_timeit():
         # GEMM.
         with stimer:
             res = torch.matmul(mat1, mat2)
-        delta, batch_delta, _, _, _, _, = stimer.elapsed()
+        delta, batch_delta, _, _, _, _ = stimer.elapsed()
         assert delta > 0.0
         assert batch_delta >= s
 
diff --git a/tests/unit_tests/transformer/moe/test_a2a_token_dispatcher.py b/tests/unit_tests/transformer/moe/test_a2a_token_dispatcher.py
index 38eb9aa15e..68b12b36f5 100644
--- a/tests/unit_tests/transformer/moe/test_a2a_token_dispatcher.py
+++ b/tests/unit_tests/transformer/moe/test_a2a_token_dispatcher.py
@@ -7,6 +7,7 @@
 from tests.unit_tests.test_utilities import Utils
 from tests.unit_tests.transformer.moe.test_token_dispatcher import MoEModelTestContainer
 
+
 class TestAlltoAllDispatcher:
     def setup_method(self, method):
         pass
@@ -16,12 +17,7 @@ def teardown_method(self, method):
 
     @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
     @pytest.mark.timeout(120)
-    @pytest.mark.parametrize("tp_size,ep_size", [
-        (1, 8),
-        (8, 1),
-        (4, 2),
-        (1, 1),
-    ])
+    @pytest.mark.parametrize("tp_size,ep_size", [(1, 8), (8, 1), (4, 2), (1, 1)])
     def test_forward_backward(self, tp_size, ep_size):
         container = MoEModelTestContainer(
             tp_size=tp_size,
@@ -36,12 +32,7 @@ def test_forward_backward(self, tp_size, ep_size):
 
     @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
     @pytest.mark.timeout(120)
-    @pytest.mark.parametrize("tp_size,ep_size", [
-        (1, 8),
-        (8, 1),
-        (4, 2),
-        (1, 1),
-    ])
+    @pytest.mark.parametrize("tp_size,ep_size", [(1, 8), (8, 1), (4, 2), (1, 1)])
     def test_capacity_forward_backward(self, tp_size, ep_size):
         container = MoEModelTestContainer(
             tp_size=tp_size,
@@ -59,14 +50,10 @@ def test_capacity_forward_backward(self, tp_size, ep_size):
 
     @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
     @pytest.mark.timeout(120)
-    @pytest.mark.parametrize("tp_size,ep_size", [
-        (1, 8),
-        (8, 1),
-        (4, 2),
-        (1, 1)
-    ])
+    @pytest.mark.parametrize("tp_size,ep_size", [(1, 8), (8, 1), (4, 2), (1, 1)])
     def test_capacity_padding_forward_backward(self, tp_size, ep_size):
         import time
+
         time.sleep(5)
         container = MoEModelTestContainer(
             tp_size=tp_size,
@@ -81,4 +68,3 @@ def test_capacity_padding_forward_backward(self, tp_size, ep_size):
             moe_pad_expert_input_to_capacity=True,
         )
         container.dispatcher_drop_and_pad_test()
-
diff --git a/tests/unit_tests/transformer/moe/test_aux_loss.py b/tests/unit_tests/transformer/moe/test_aux_loss.py
index 217a0a2711..2e26f01551 100644
--- a/tests/unit_tests/transformer/moe/test_aux_loss.py
+++ b/tests/unit_tests/transformer/moe/test_aux_loss.py
@@ -2,15 +2,18 @@
 
 import pytest
 import torch
-from megatron.core.transformer.moe.moe_utils import clear_aux_losses_tracker
 
+from megatron.core import parallel_state
+from megatron.core.transformer.moe.moe_utils import clear_aux_losses_tracker
 from tests.unit_tests.test_utilities import Utils
 from tests.unit_tests.transformer.moe.test_token_dispatcher import MoEModelTestContainer
-from megatron.core import parallel_state
+
 
 class AuxlossTestContainer(MoEModelTestContainer):
     def partition_input(self, input):
-        partitioned_input = input.chunk(parallel_state.get_tensor_and_context_parallel_world_size(), dim=1)[parallel_state.get_tensor_and_context_parallel_rank()]
+        partitioned_input = input.chunk(
+            parallel_state.get_tensor_and_context_parallel_world_size(), dim=1
+        )[parallel_state.get_tensor_and_context_parallel_rank()]
         output = partitioned_input.clone().detach()
         output.requires_grad = True
         return output
@@ -27,6 +30,7 @@ def aux_loss_test(self, input, baseline_grad):
         loss = parallel_state.get_moe_layer_wise_logging_tracker()['load_balancing_loss']
         clear_aux_losses_tracker()
 
+
 class TestAuxLoss:
     def setup_method(self, method):
         baseline_container = AuxlossTestContainer(
@@ -44,7 +48,7 @@ def setup_method(self, method):
         self.input = torch.randn((32, 8, moe_layer.config.hidden_size)).cuda()
         self.input.requires_grad = True
         probs, indices = moe_layer.router(self.input)
-        probs.sum().mul_(0).backward() # zero out the main gradients
+        probs.sum().mul_(0).backward()  # zero out the main gradients
         self.baseline_grad = self.input.grad
         self.input.grad = None
         clear_aux_losses_tracker()
@@ -53,13 +57,9 @@ def teardown_method(self, method):
         Utils.destroy_model_parallel()
 
     @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
-    @pytest.mark.parametrize("tp_size,ep_size,cp_size", [
-        (8, 1, 1),
-        (4, 2, 1),
-        (1, 1, 8),
-        (2, 1, 4),
-        (2, 2, 2),
-    ])
+    @pytest.mark.parametrize(
+        "tp_size,ep_size,cp_size", [(8, 1, 1), (4, 2, 1), (1, 1, 8), (2, 1, 4), (2, 2, 2)]
+    )
     def test_allgather_dispatcher(self, tp_size, ep_size, cp_size):
         container = AuxlossTestContainer(
             tp_size=tp_size,
@@ -75,13 +75,9 @@ def test_allgather_dispatcher(self, tp_size, ep_size, cp_size):
         container.aux_loss_test(self.input, self.baseline_grad)
 
     @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
-    @pytest.mark.parametrize("tp_size,ep_size,cp_size", [
-        (8, 1, 1),
-        (4, 2, 1),
-        (1, 1, 8),
-        (2, 1, 4),
-        (2, 2, 2),
-    ])
+    @pytest.mark.parametrize(
+        "tp_size,ep_size,cp_size", [(8, 1, 1), (4, 2, 1), (1, 1, 8), (2, 1, 4), (2, 2, 2)]
+    )
     def test_a2a_dispatcher(self, tp_size, ep_size, cp_size):
         container = AuxlossTestContainer(
             tp_size=tp_size,
diff --git a/tests/unit_tests/transformer/moe/test_grouped_mlp.py b/tests/unit_tests/transformer/moe/test_grouped_mlp.py
index b86edde68d..757be59232 100644
--- a/tests/unit_tests/transformer/moe/test_grouped_mlp.py
+++ b/tests/unit_tests/transformer/moe/test_grouped_mlp.py
@@ -1,20 +1,20 @@
 # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
 
-import pytest
-from pkg_resources import packaging
 from importlib.metadata import version
 
+import pytest
 import torch
 import torch.nn.functional as F
+from pkg_resources import packaging
 
-from megatron.training.arguments import parse_args
 from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
 from megatron.core.transformer.moe import grouped_gemm_util as gg
-from megatron.core.transformer.moe.moe_layer import MoELayer
 from megatron.core.transformer.moe.experts import TEGroupedMLP
+from megatron.core.transformer.moe.moe_layer import MoELayer
 from megatron.core.transformer.transformer_config import TransformerConfig
-from megatron.training.initialize import _set_random_seed
 from megatron.legacy.model import Float16Module
+from megatron.training.arguments import parse_args
+from megatron.training.initialize import _set_random_seed
 from tests.unit_tests.test_utilities import Utils
 
 DEVICE_CAPABILITY = None
@@ -28,23 +28,37 @@ class TestParallelGroupedMLP:
 
     def setup_method(self, method, use_cpu_initialization=False, swiglu=True):
         print("============")
-        print("Test for use_cpu_initilization={} and swiglu={}.".format(use_cpu_initialization, swiglu))
+        print(
+            "Test for use_cpu_initilization={} and swiglu={}.".format(
+                use_cpu_initialization, swiglu
+            )
+        )
         print("============")
-        Utils.initialize_model_parallel(1,1)
-        num_layers = 1 # 2
-        self.hidden_size = 16 # must be an multiple of 16, otherwise trigger CUTLASS misaligned issue
+        Utils.initialize_model_parallel(1, 1)
+        num_layers = 1  # 2
+        self.hidden_size = (
+            16  # must be an multiple of 16, otherwise trigger CUTLASS misaligned issue
+        )
         self.num_experts = 2
         self.gated_linear_unit = swiglu
         self.activation_func = F.silu if swiglu else F.gelu
         self.use_cpu_initialization = use_cpu_initialization
 
         tf_config = TransformerConfig(
-            num_layers=num_layers, hidden_size=self.hidden_size, num_attention_heads=4,
-            num_moe_experts=self.num_experts, use_cpu_initialization=self.use_cpu_initialization,
-            add_bias_linear=False, gated_linear_unit=self.gated_linear_unit,
+            num_layers=num_layers,
+            hidden_size=self.hidden_size,
+            num_attention_heads=4,
+            num_moe_experts=self.num_experts,
+            use_cpu_initialization=self.use_cpu_initialization,
+            add_bias_linear=False,
+            gated_linear_unit=self.gated_linear_unit,
             activation_func=self.activation_func,
             bias_activation_fusion=False,
-            bf16=True, params_dtype=torch.bfloat16, moe_router_load_balancing_type="sinkhorn", moe_router_topk=1)
+            bf16=True,
+            params_dtype=torch.bfloat16,
+            moe_router_load_balancing_type="sinkhorn",
+            moe_router_topk=1,
+        )
 
         self.fc1_ffn_hidden_size = tf_config.ffn_hidden_size
         self.fc2_ffn_hidden_size = tf_config.ffn_hidden_size
@@ -56,15 +70,15 @@ def setup_method(self, method, use_cpu_initialization=False, swiglu=True):
         # Set random seed for reproducability
         _set_random_seed(seed_=123, data_parallel_random_init=False)
         transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
-            self.num_experts, moe_grouped_gemm=False)
-        self.sequential_mlp = MoELayer(tf_config,
-            transformer_layer_spec.submodules.mlp.submodules)
+            self.num_experts, moe_grouped_gemm=False
+        )
+        self.sequential_mlp = MoELayer(tf_config, transformer_layer_spec.submodules.mlp.submodules)
 
         self.args = parse_args(ignore_unknown_args=True)
-        self.args.bf16=True
+        self.args.bf16 = True
         # Bias is not supported in grouped gemm currently, thus we disable the
         # bias in the linear layer.
-        self.args.add_bias_linear=False
+        self.args.add_bias_linear = False
         self.sequential_mlp = Float16Module(self.sequential_mlp, self.args).module
         print("done intializing for sequential gemm")
 
@@ -89,9 +103,12 @@ def test_constructor(self):
         # GroupedGEMM and sequential GEMMs should hold the same number of parms.
         assert num_weights_smm == num_weights_gmm
         # expected num weights: router linear weights+bias + MLP weights(no bias) of all experts
-        expected_num_weights = \
-            self.hidden_size * self.num_experts + \
-            self.hidden_size * (self.fc1_ffn_hidden_size + self.fc2_ffn_hidden_size) * self.num_experts
+        expected_num_weights = (
+            self.hidden_size * self.num_experts
+            + self.hidden_size
+            * (self.fc1_ffn_hidden_size + self.fc2_ffn_hidden_size)
+            * self.num_experts
+        )
         assert num_weights_smm == expected_num_weights
 
         assert torch.equal(self.sequential_mlp.router.weight, self.grouped_mlp.router.weight)
@@ -99,12 +116,19 @@ def test_constructor(self):
         # weight1: [h, num_experts*4h]
         # weight2: [num_experts*4h, h]
         assert self.grouped_mlp.experts.weight1.shape[0] == self.hidden_size
-        assert self.grouped_mlp.experts.weight1.shape[1] == self.num_experts * self.fc1_ffn_hidden_size
+        assert (
+            self.grouped_mlp.experts.weight1.shape[1] == self.num_experts * self.fc1_ffn_hidden_size
+        )
         if self.gated_linear_unit:
-            assert self.grouped_mlp.experts.weight2.shape[0] == self.num_experts * self.fc2_ffn_hidden_size
+            assert (
+                self.grouped_mlp.experts.weight2.shape[0]
+                == self.num_experts * self.fc2_ffn_hidden_size
+            )
             assert self.grouped_mlp.experts.weight2.shape[1] == self.hidden_size
         else:
-            assert self.grouped_mlp.experts.weight1.shape == self.grouped_mlp.experts.weight2.t().shape
+            assert (
+                self.grouped_mlp.experts.weight1.shape == self.grouped_mlp.experts.weight2.t().shape
+            )
 
     def test_weight_init_value_the_same(self):
         gmm_w1 = self.grouped_mlp.experts.weight1.view(self.num_experts, -1, self.hidden_size)
@@ -130,17 +154,18 @@ def test_weight_init_value_the_same(self):
 
     @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
     @pytest.mark.skipif(
-        not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, reason='GroupedGEMM kernels are not supported on this device.'
+        not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8,
+        reason='GroupedGEMM kernels are not supported on this device.',
     )
     def test_gpu_forward(self):
         self.sequential_mlp.cuda()
         self.grouped_mlp.cuda()
         # [sequence length, batch size, hidden size]
-        seq_len = 3 #32
+        seq_len = 3  # 32
         batch_size = 2
         hidden_states = torch.rand(
-            (seq_len, batch_size, self.sequential_mlp.config.hidden_size),
-            dtype=torch.bfloat16)
+            (seq_len, batch_size, self.sequential_mlp.config.hidden_size), dtype=torch.bfloat16
+        )
         hidden_states = hidden_states.cuda()
         output_smm, _ = self.sequential_mlp(hidden_states)
         output_gmm, _ = self.grouped_mlp(hidden_states)
@@ -151,7 +176,8 @@ def test_gpu_forward(self):
 
     @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
     @pytest.mark.skipif(
-        not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, reason='GroupedGEMM kernels are not supported on this device.'
+        not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8,
+        reason='GroupedGEMM kernels are not supported on this device.',
     )
     def test_gpu_forward_with_no_tokens_allocated(self):
         """Test the case when no token is allocated for groupedGEMM kernels."""
@@ -168,7 +194,8 @@ def test_gpu_forward_with_no_tokens_allocated(self):
 
     @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
     @pytest.mark.skipif(
-        not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, reason='GroupedGEMM kernels are not supported on this device.'
+        not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8,
+        reason='GroupedGEMM kernels are not supported on this device.',
     )
     def test_gradient_with_no_tokens_allocated(self):
         """Test that when no token is passed in, the parameters of the grouped MLP will also have gradients."""
@@ -177,10 +204,7 @@ def test_gradient_with_no_tokens_allocated(self):
         tokens_per_expert = torch.zeros(self.num_experts)
         hidden_states = torch.rand((num_allocated_tokens, self.hidden_size), dtype=torch.bfloat16)
         hidden_states = hidden_states.cuda()
-        output_gmm, _ = self.grouped_mlp.experts(
-            hidden_states,
-            tokens_per_expert=tokens_per_expert,
-        )
+        output_gmm, _ = self.grouped_mlp.experts(hidden_states, tokens_per_expert=tokens_per_expert)
         output_gmm.mean().backward()
         assert self.grouped_mlp.experts.weight1.grad is not None
 
@@ -193,7 +217,7 @@ class TestTEGroupedMLP:
 
     def setup_method(self, method, use_cpu_initialization=False, swiglu=True):
         Utils.initialize_model_parallel(1, 1)
-        num_layers = 1 
+        num_layers = 1
         self.hidden_size = 16
         self.num_experts = 2
         self.gated_linear_unit = swiglu
@@ -348,9 +372,8 @@ def test_gpu_forward_backward_with_no_tokens_allocated(self):
         for swiglu in [True, False]:
             GMLP_test = TestParallelGroupedMLP()
             GMLP_test.setup_method(
-                method=None,
-                use_cpu_initialization=use_cpu_unitilization,
-                swiglu=swiglu)
+                method=None, use_cpu_initialization=use_cpu_unitilization, swiglu=swiglu
+            )
             GMLP_test.test_constructor()
             GMLP_test.test_weight_init_value_the_same()
             GMLP_test.test_gpu_forward()
diff --git a/tests/unit_tests/transformer/moe/test_routers.py b/tests/unit_tests/transformer/moe/test_routers.py
index fbeb744f1e..ef4c9d4aed 100644
--- a/tests/unit_tests/transformer/moe/test_routers.py
+++ b/tests/unit_tests/transformer/moe/test_routers.py
@@ -1,15 +1,14 @@
 # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
 
 import pytest
-
 import torch
 
+from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
+from megatron.core.transformer.moe.moe_layer import MoELayer
 from megatron.core.transformer.moe.router import Router
+from megatron.core.transformer.transformer_config import TransformerConfig
 from megatron.training.initialize import _set_random_seed
 from tests.unit_tests.test_utilities import Utils
-from megatron.core.transformer.transformer_config import TransformerConfig
-from megatron.core.transformer.moe.moe_layer import MoELayer
-from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
 
 
 class TestTop2Router:
@@ -46,10 +45,7 @@ def test_constructor(self):
         assert num_weights == 12 * 4, num_weights
 
     @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
-    @pytest.mark.parametrize("moe_router_pre_softmax", [
-        (True),
-        (False),
-    ])
+    @pytest.mark.parametrize("moe_router_pre_softmax", [(True), (False)])
     def test_router_forward(self, moe_router_pre_softmax):
         with torch.no_grad():
             self.router = self.router.cuda()
@@ -62,30 +58,33 @@ def test_router_forward(self, moe_router_pre_softmax):
             assert scores.shape == (64, 2)
             assert indices.shape == (64, 2)
             print(
-                (indices == 0).sum(), (indices == 1).sum(), (indices == 2).sum(), (indices == 3).sum()
+                (indices == 0).sum(),
+                (indices == 1).sum(),
+                (indices == 2).sum(),
+                (indices == 3).sum(),
             )
 
     @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
     def test_aux_loss(self):
         self.sequential_mlp = self.sequential_mlp.cuda()
-        
+
         # Without aux loss
         hidden_states = torch.randn((32, 2, self.router.config.hidden_size))
         hidden_states = hidden_states.cuda()
         out = self.sequential_mlp(hidden_states)[0]
         out.sum().mul_(0).backward()
         assert self.sequential_mlp.router.weight.grad.abs().sum() == 0
-        
+
         # With aux loss
         self.transformer_config.moe_aux_loss_coeff = 1
         out = self.sequential_mlp(hidden_states)[0]
         out.sum().mul_(0).backward()
         assert self.sequential_mlp.router.weight.grad.abs().sum() > 0
-        
+
         # With Z loss
         self.transformer_config.moe_aux_loss_coeff = 0
         self.transformer_config.moe_z_loss_coeff = 1
         self.sequential_mlp.router.weight.grad.fill_(0)
         out = self.sequential_mlp(hidden_states)[0]
         out.sum().mul_(0).backward()
-        assert self.sequential_mlp.router.weight.grad.abs().sum() > 0
\ No newline at end of file
+        assert self.sequential_mlp.router.weight.grad.abs().sum() > 0
diff --git a/tests/unit_tests/transformer/moe/test_sequential_mlp.py b/tests/unit_tests/transformer/moe/test_sequential_mlp.py
index 0ebb85333e..21fcc23ca2 100644
--- a/tests/unit_tests/transformer/moe/test_sequential_mlp.py
+++ b/tests/unit_tests/transformer/moe/test_sequential_mlp.py
@@ -1,19 +1,19 @@
 # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
 
 import pytest
-
 import torch
 
-from megatron.core.transformer.moe.moe_layer import MoELayer
-from tests.unit_tests.test_utilities import Utils
+from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
 from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
+from megatron.core.transformer.moe.moe_layer import MoELayer
 from megatron.core.transformer.transformer_config import TransformerConfig
-from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
+from tests.unit_tests.test_utilities import Utils
+
 
 class TestParallelSequentialMLP:
 
     def setup_method(self, method):
-        Utils.initialize_model_parallel(1,1)
+        Utils.initialize_model_parallel(1, 1)
         model_parallel_cuda_manual_seed(123)
         print("done intializing")
         num_moe_experts = 2
@@ -27,11 +27,14 @@ def setup_method(self, method):
             gated_linear_unit=True,
             bias_activation_fusion=True,
             moe_router_load_balancing_type="sinkhorn",
-            moe_router_topk=1
+            moe_router_topk=1,
         )
         transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
-            num_experts=num_moe_experts, moe_grouped_gemm=False)
-        self.sequential_mlp = MoELayer(transformer_config, transformer_layer_spec.submodules.mlp.submodules)
+            num_experts=num_moe_experts, moe_grouped_gemm=False
+        )
+        self.sequential_mlp = MoELayer(
+            transformer_config, transformer_layer_spec.submodules.mlp.submodules
+        )
 
     def teardown_method(self, method):
         Utils.destroy_model_parallel()
@@ -42,7 +45,6 @@ def test_constructor(self):
         num_weights = sum([p.numel() for p in self.sequential_mlp.parameters()])
         assert num_weights == 3696
 
-
     @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
     def test_gpu_forward(self):
         sequential_mlp = self.sequential_mlp
@@ -58,4 +60,3 @@ def test_gpu_forward(self):
         assert output.dtype == torch.float32
         assert output.device.type == 'cuda'
         assert output_bias.device.type == 'cuda'
-
diff --git a/tests/unit_tests/transformer/moe/test_token_dispatcher.py b/tests/unit_tests/transformer/moe/test_token_dispatcher.py
index f5384143ce..f2c6d3c307 100644
--- a/tests/unit_tests/transformer/moe/test_token_dispatcher.py
+++ b/tests/unit_tests/transformer/moe/test_token_dispatcher.py
@@ -2,8 +2,8 @@
 
 import pytest
 import torch
-from megatron.core import parallel_state
 
+from megatron.core import parallel_state
 from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
 from megatron.core.transformer.moe.moe_layer import MoELayer
 from megatron.core.transformer.moe.moe_utils import permute, unpermute
@@ -34,7 +34,7 @@ def __init__(
             tensor_model_parallel_size=tp_size,
             pipeline_model_parallel_size=pp_size,
             expert_model_parallel_size=ep_size,
-            context_parallel_size=cp_size
+            context_parallel_size=cp_size,
         )
         _set_random_seed(seed_=123, data_parallel_random_init=data_parallel_random_init)
         local_expert_indices_offset = (
@@ -74,7 +74,7 @@ def __init__(
             self.config, transformer_layer_spec.submodules.mlp.submodules
         ).cuda()
         self.moe_layer.set_layer_number(0)
-    
+
     def __del__(self):
         torch.distributed.barrier()
         torch.cuda.synchronize()
@@ -96,11 +96,8 @@ def dispatcher_dropless_test(self):
         # indices = torch.ones_like(indices) * torch.distributed.get_rank()
         # print(permuted_local_hidden_states)
 
-        (
-            permuted_local_hidden_states,
-            tokens_per_expert,
-        ) = moe_layer.token_dispatcher.token_permutation(
-            hidden_states, probs, indices
+        (permuted_local_hidden_states, tokens_per_expert) = (
+            moe_layer.token_dispatcher.token_permutation(hidden_states, probs, indices)
         )
 
         permuted_local_hidden_states /= moe_layer.config.tensor_model_parallel_size
@@ -136,11 +133,8 @@ def dispacher_capacity_test(self):
         ]
         restored_hidden_states_answer = hidden_states * local_probss.sum(dim=1).unsqueeze(1)
 
-        (
-            permuted_local_hidden_states,
-            tokens_per_expert,
-        ) = moe_layer.token_dispatcher.token_permutation(
-            hidden_states, probs, indices
+        (permuted_local_hidden_states, tokens_per_expert) = (
+            moe_layer.token_dispatcher.token_permutation(hidden_states, probs, indices)
         )
 
         print(f"Dispatched tokens per expert: {tokens_per_expert}")
@@ -181,7 +175,7 @@ def dispatcher_drop_and_pad_test(self):
         # num_local_tokens_per_expert = torch.tensor([2, 2, 2, 2, 2, 2, 2, 2]).cuda()
 
         probs_1, indices_1 = moe_layer.router(hidden_states)
-        (permuted_input_1, tokens_per_expert,) = moe_layer.token_dispatcher.token_permutation(
+        (permuted_input_1, tokens_per_expert) = moe_layer.token_dispatcher.token_permutation(
             hidden_states, probs_1, indices_1
         )
         torch.distributed.barrier()
@@ -197,7 +191,7 @@ def dispatcher_drop_and_pad_test(self):
         # End
 
         probs_2, indices_2 = moe_layer.router(hidden_states)
-        (permuted_input_2, tokens_per_expert,) = moe_layer.token_dispatcher.token_permutation(
+        (permuted_input_2, tokens_per_expert) = moe_layer.token_dispatcher.token_permutation(
             hidden_states, probs_2, indices_2
         )
         restored_hidden_states, restored_bias = moe_layer.token_dispatcher.token_unpermutation(
@@ -230,9 +224,7 @@ def teardown_method(self, method):
         Utils.destroy_model_parallel()
 
     @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
-    @pytest.mark.parametrize("tp_size,ep_size", [
-        (8, 1),
-    ])
+    @pytest.mark.parametrize("tp_size,ep_size", [(8, 1)])
     def test_forward_backward(self, tp_size, ep_size):
         container = MoEModelTestContainer(
             tp_size=tp_size,
@@ -269,13 +261,15 @@ def test_extended_tp_forward_backward(self):
         assert scores.shape == (256, moe_layer.router.topk), "Scores shape is not correct"
         assert indices.shape == (256, moe_layer.router.topk), "Indices shape is not correct"
         scores = torch.ones_like(scores) / 2
-        (
-            permuted_local_hidden_states,
-            tokens_per_expert,
-        ) = moe_layer.token_dispatcher.token_permutation(hidden_states, scores, indices)
-        permuted_local_hidden_states /= moe_layer.config.tensor_model_parallel_size * moe_layer.config.expert_model_parallel_size
+        (permuted_local_hidden_states, tokens_per_expert) = (
+            moe_layer.token_dispatcher.token_permutation(hidden_states, scores, indices)
+        )
+        permuted_local_hidden_states /= (
+            moe_layer.config.tensor_model_parallel_size
+            * moe_layer.config.expert_model_parallel_size
+        )
         restored_hidden_states, restored_bias = moe_layer.token_dispatcher.token_unpermutation(
-            permuted_local_hidden_states, bias=torch.zeros_like(permuted_local_hidden_states),
+            permuted_local_hidden_states, bias=torch.zeros_like(permuted_local_hidden_states)
         )
 
         assert torch.allclose(
diff --git a/tests/unit_tests/transformer/test_attention.py b/tests/unit_tests/transformer/test_attention.py
index 4a5680ea05..8c13ff3f8c 100644
--- a/tests/unit_tests/transformer/test_attention.py
+++ b/tests/unit_tests/transformer/test_attention.py
@@ -1,25 +1,28 @@
 # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
 
 import pytest
-
 import torch
 
-from megatron.core.transformer.attention import SelfAttention
-from tests.unit_tests.test_utilities import Utils
+from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
 from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
+from megatron.core.transformer.attention import SelfAttention
 from megatron.core.transformer.transformer_config import TransformerConfig
-from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
+from tests.unit_tests.test_utilities import Utils
+
 
 class TestParallelAttention:
 
     def setup_method(self, method):
-        Utils.initialize_model_parallel(1,1)
+        Utils.initialize_model_parallel(1, 1)
         model_parallel_cuda_manual_seed(123)
-        self.transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True)
-        self.parallel_attention = SelfAttention(self.transformer_config,
-                                                get_gpt_layer_with_transformer_engine_spec().submodules.self_attention.submodules,
-                                                layer_number=1)
-
+        self.transformer_config = TransformerConfig(
+            num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True
+        )
+        self.parallel_attention = SelfAttention(
+            self.transformer_config,
+            get_gpt_layer_with_transformer_engine_spec().submodules.self_attention.submodules,
+            layer_number=1,
+        )
 
     def teardown_method(self, method):
         Utils.destroy_model_parallel()
@@ -44,7 +47,9 @@ def test_gpu_forward(self):
         self.parallel_attention.cuda()
 
         # [sequence length, batch size, hidden size]
-        hidden_states = torch.ones((sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size))
+        hidden_states = torch.ones(
+            (sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size)
+        )
         hidden_states = hidden_states.cuda()
 
         attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda()
@@ -66,12 +71,18 @@ def test_fused_rope_gpu_forward(self):
         self.parallel_attention.cuda()
 
         # [sequence length, batch size, hidden size]
-        hidden_states = torch.ones((sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size))
+        hidden_states = torch.ones(
+            (sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size)
+        )
         hidden_states = hidden_states.cuda()
 
         attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda()
-        rotary_pos_emb = torch.ones(sequence_length, 1, 1, self.parallel_attention.config.kv_channels).cuda()
-        output, bias = self.parallel_attention(hidden_states, attention_mask, rotary_pos_emb=rotary_pos_emb)
+        rotary_pos_emb = torch.ones(
+            sequence_length, 1, 1, self.parallel_attention.config.kv_channels
+        ).cuda()
+        output, bias = self.parallel_attention(
+            hidden_states, attention_mask, rotary_pos_emb=rotary_pos_emb
+        )
 
         assert config.recompute_granularity is None
         assert output.shape[0] == sequence_length
@@ -80,13 +91,14 @@ def test_fused_rope_gpu_forward(self):
         assert bias.shape[0] == config.hidden_size
         self.parallel_attention.config.apply_rope_fusion = False
 
-
     def test_checkpointed_gpu_forward(self):
         transformer_config = self.transformer_config
-        transformer_config.recompute_granularity='selective'
-        checkpointed_parallel_attention = SelfAttention(transformer_config,
-                                                        get_gpt_layer_with_transformer_engine_spec().submodules.self_attention.submodules,
-                                                        layer_number=1)
+        transformer_config.recompute_granularity = 'selective'
+        checkpointed_parallel_attention = SelfAttention(
+            transformer_config,
+            get_gpt_layer_with_transformer_engine_spec().submodules.self_attention.submodules,
+            layer_number=1,
+        )
         config = checkpointed_parallel_attention.config
 
         sequence_length = 32
diff --git a/tests/unit_tests/transformer/test_attention_packed_seq.py b/tests/unit_tests/transformer/test_attention_packed_seq.py
index c8be7dba3d..54c8787579 100644
--- a/tests/unit_tests/transformer/test_attention_packed_seq.py
+++ b/tests/unit_tests/transformer/test_attention_packed_seq.py
@@ -1,16 +1,15 @@
 # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
 
 import pytest
-
 import torch
 
+from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
 from megatron.core.packed_seq_params import PackedSeqParams
+from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
 from megatron.core.transformer.attention import SelfAttention
 from megatron.core.transformer.enums import AttnMaskType
-from tests.unit_tests.test_utilities import Utils
-from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
 from megatron.core.transformer.transformer_config import TransformerConfig
-from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
+from tests.unit_tests.test_utilities import Utils
 
 # Note: this test requires TE >= 0.13 as well as Flash Attention to run
 # FIXME this unit test doesn't work in the current test container. to be fixed soon
@@ -128,4 +127,4 @@ def test_checkpointed_gpu_forward(self):
         assert output.shape[1] == micro_batch_size
         assert output.shape[2] == config.hidden_size
         assert bias.shape[0] == config.hidden_size
-"""
\ No newline at end of file
+"""
diff --git a/tests/unit_tests/transformer/test_core_attention.py b/tests/unit_tests/transformer/test_core_attention.py
index 2966b98f89..d8710e2242 100644
--- a/tests/unit_tests/transformer/test_core_attention.py
+++ b/tests/unit_tests/transformer/test_core_attention.py
@@ -2,10 +2,10 @@
 
 
 import pytest
-
 import torch
 
 from megatron.core.transformer.attention import CrossAttention
+
 """ 
 
 @pytest.fixture
@@ -61,4 +61,4 @@ def test_gpu_forward(self, core_attention):
         assert context_layer.device.type == 'cuda'
         assert context_layer.dtype == torch.float32
 
-"""
\ No newline at end of file
+"""
diff --git a/tests/unit_tests/transformer/test_mlp.py b/tests/unit_tests/transformer/test_mlp.py
index 8e3f14688c..d2c25e0cc5 100644
--- a/tests/unit_tests/transformer/test_mlp.py
+++ b/tests/unit_tests/transformer/test_mlp.py
@@ -1,23 +1,24 @@
 # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
 
 import pytest
-
 import torch
 
-from megatron.core.transformer.mlp import MLP
-from tests.unit_tests.test_utilities import Utils
+from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
 from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
+from megatron.core.transformer.mlp import MLP
 from megatron.core.transformer.transformer_config import TransformerConfig
-from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
+from tests.unit_tests.test_utilities import Utils
+
 
 class TestParallelMLP:
 
     def setup_method(self, method):
-        Utils.initialize_model_parallel(1,1)
+        Utils.initialize_model_parallel(1, 1)
         model_parallel_cuda_manual_seed(123)
-        transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True)
-        self.mlp = MLP(transformer_config,
-                       get_gpt_layer_local_spec().submodules.mlp.submodules)
+        transformer_config = TransformerConfig(
+            num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True
+        )
+        self.mlp = MLP(transformer_config, get_gpt_layer_local_spec().submodules.mlp.submodules)
 
     def teardown_method(self, method):
         Utils.destroy_model_parallel()
@@ -55,4 +56,3 @@ def test_gpu_forward(self):
         assert output.dtype == torch.float32
         assert output.device.type == 'cuda'
         assert output_bias.device.type == 'cuda'
-
diff --git a/tests/unit_tests/transformer/test_module.py b/tests/unit_tests/transformer/test_module.py
index b530709915..64826a0ee5 100644
--- a/tests/unit_tests/transformer/test_module.py
+++ b/tests/unit_tests/transformer/test_module.py
@@ -1,13 +1,12 @@
 # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
 
 import pytest
-
 import torch
 
+from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
 from megatron.core.transformer.module import Float16Module, MegatronModule
 from megatron.core.transformer.transformer_config import TransformerConfig
 from tests.unit_tests.test_utilities import Utils
-from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
 
 DEVICE_CAPABILITY = None
 if torch.cuda.is_available():
@@ -24,16 +23,19 @@ def __init__(self, config: TransformerConfig):
     def forward(self, x):
         return self.linear(x)
 
+
 class TestMegatronModule:
 
     def setup_method(self, method):
-        Utils.initialize_model_parallel(1,1)
+        Utils.initialize_model_parallel(1, 1)
         model_parallel_cuda_manual_seed(123)
-        transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True)
+        transformer_config = TransformerConfig(
+            num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True
+        )
         self.megatron_module = DummyModule(config=transformer_config).cuda()
 
     def teardown_method(self, method):
-        Utils.destroy_model_parallel()   
+        Utils.destroy_model_parallel()
 
     def test_megatron_module(self):
         megatron_module = self.megatron_module
@@ -54,14 +56,16 @@ def test_megatron_module(self):
 class TestFloat16Module:
 
     def setup_method(self, method):
-        Utils.initialize_model_parallel(1,1)
+        Utils.initialize_model_parallel(1, 1)
         model_parallel_cuda_manual_seed(123)
-        self.transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True)
+        self.transformer_config = TransformerConfig(
+            num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True
+        )
         self.megatron_module = DummyModule(config=self.transformer_config).cuda()
 
     def teardown_method(self, method):
-        Utils.destroy_model_parallel()   
-        
+        Utils.destroy_model_parallel()
+
     def test_fp16_module(self):
         transformer_config = self.transformer_config
         megatron_module = self.megatron_module
@@ -78,7 +82,8 @@ def test_fp16_module(self):
         assert fp16_module(x).dtype == torch.float32
 
     pytest.mark.skipif(
-        not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, reason='bfloat16 is not supported on this device'
+        not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8,
+        reason='bfloat16 is not supported on this device',
     )
 
     def test_bf16_module(self):
@@ -95,4 +100,3 @@ def test_bf16_module(self):
         x = torch.ones((2, 2)).cuda()
         # inputs are converted to bf16 then outputs are converted to fp32
         assert bf16_module(x).dtype == torch.float32
-
diff --git a/tests/unit_tests/transformer/test_retro_attention.py b/tests/unit_tests/transformer/test_retro_attention.py
index 11ec7d5faa..d7c5a5f155 100644
--- a/tests/unit_tests/transformer/test_retro_attention.py
+++ b/tests/unit_tests/transformer/test_retro_attention.py
@@ -1,16 +1,17 @@
 # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
 
-import torch
 import types
 
+import torch
+
 from megatron.core.models.retro import RetroConfig, get_retro_decoder_block_spec
 from megatron.core.models.retro.decoder_attention import (
-    RetroDecoderCrossAttention,
     RetroDecoderBiasDropoutAdd,
+    RetroDecoderCrossAttention,
 )
 from megatron.core.models.retro.encoder_attention import (
-    RetroEncoderCrossAttention,
     RetroEncoderBiasDropoutAdd,
+    RetroEncoderCrossAttention,
     RetroEncoderLayerNorm,
 )
 from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
@@ -38,33 +39,42 @@ def get_modules(cls, config, use_transformer_engine, use_gpu):
 
         # Retro decoder layer.
         decoder_block_spec = get_retro_decoder_block_spec(
-            config, use_transformer_engine=use_transformer_engine)
+            config, use_transformer_engine=use_transformer_engine
+        )
         decoder_block = TransformerBlock(config=config, spec=decoder_block_spec)
-        decoder_layers = [ layer for layer in decoder_block.layers if isinstance(layer.cross_attention, RetroDecoderCrossAttention) ]
+        decoder_layers = [
+            layer
+            for layer in decoder_block.layers
+            if isinstance(layer.cross_attention, RetroDecoderCrossAttention)
+        ]
         decoder_layer = decoder_layers[0]
 
         # Retro encoder layer.
         encoder_block = decoder_layer.cross_attention.encoder
-        encoder_layers = [ layer for layer in encoder_block.layers if isinstance(layer.cross_attention, RetroEncoderCrossAttention) ]
+        encoder_layers = [
+            layer
+            for layer in encoder_block.layers
+            if isinstance(layer.cross_attention, RetroEncoderCrossAttention)
+        ]
         encoder_layer = encoder_layers[0]
 
         # Modules.
         modules = types.SimpleNamespace(
-            decoder_attn = decoder_layer.cross_attention,
-            decoder_bda = decoder_layer.cross_attn_bda,
-            encoder_attn = encoder_layer.cross_attention,
-            encoder_bda = encoder_layer.cross_attn_bda,
-            encoder_norm = encoder_layer.pre_mlp_layernorm,
+            decoder_attn=decoder_layer.cross_attention,
+            decoder_bda=decoder_layer.cross_attn_bda,
+            encoder_attn=encoder_layer.cross_attention,
+            encoder_bda=encoder_layer.cross_attn_bda,
+            encoder_norm=encoder_layer.pre_mlp_layernorm,
         )
 
         # GPU.
         if use_gpu:
-            [ m.cuda() for m in vars(modules).values() ]
+            [m.cuda() for m in vars(modules).values()]
 
         return modules
 
     def setup_method(self, method):
-        Utils.initialize_model_parallel(1,1)
+        Utils.initialize_model_parallel(1, 1)
         model_parallel_cuda_manual_seed(123)
 
     def teardown_method(self, method):
@@ -73,11 +83,7 @@ def teardown_method(self, method):
     def test_constructor(self):
 
         config = self.get_config()
-        modules = self.get_modules(
-            config,
-            use_transformer_engine=True,
-            use_gpu=False,
-        )
+        modules = self.get_modules(config, use_transformer_engine=True, use_gpu=False)
 
         assert isinstance(modules.decoder_attn, RetroDecoderCrossAttention)
         assert isinstance(modules.decoder_bda, RetroDecoderBiasDropoutAdd)
@@ -88,7 +94,7 @@ def test_constructor(self):
         assert modules.decoder_attn.attn.layer_number == 6
         assert modules.encoder_attn.attn.layer_number == 1
 
-        get_nparams = lambda m : sum(p.numel() for p in m.parameters())
+        get_nparams = lambda m: sum(p.numel() for p in m.parameters())
         assert get_nparams(modules.decoder_attn) == 8768
         assert get_nparams(modules.decoder_bda) == 0
         assert get_nparams(modules.encoder_attn) == 1088
@@ -110,52 +116,38 @@ def run_gpu_forward(self, recompute_granularity, use_transformer_engine):
         n_chunks_per_sample = seq_length // config.retro_chunk_length
 
         # Init tensors.
-        hidden_states = torch.ones((
-            seq_length,
-            micro_batch_size,
-            config.hidden_size,
-        )).cuda()
+        hidden_states = torch.ones((seq_length, micro_batch_size, config.hidden_size)).cuda()
         attention_mask = None
-        decoder_context = torch.ones((
-            config.retro_retrieved_length,
-            config.retro_num_neighbors * micro_batch_size * n_chunks_per_sample,
-            config.hidden_size,
-        )).cuda()
-        encoder_context = torch.ones((
-            config.retro_chunk_length,
-            micro_batch_size * n_chunks_per_sample,
-            config.hidden_size,
-        )).cuda()
+        decoder_context = torch.ones(
+            (
+                config.retro_retrieved_length,
+                config.retro_num_neighbors * micro_batch_size * n_chunks_per_sample,
+                config.hidden_size,
+            )
+        ).cuda()
+        encoder_context = torch.ones(
+            (config.retro_chunk_length, micro_batch_size * n_chunks_per_sample, config.hidden_size)
+        ).cuda()
 
         # Forward decoder.
-        decoder_attn_output = modules.decoder_attn(
-            hidden_states,
-            attention_mask,
-            decoder_context,
-        )
+        decoder_attn_output = modules.decoder_attn(hidden_states, attention_mask, decoder_context)
         with torch.enable_grad():
             decoder_bda_output = modules.decoder_bda(True, True)(
-                decoder_attn_output,
-                hidden_states,
-                config.hidden_dropout,
+                decoder_attn_output, hidden_states, config.hidden_dropout
             )
 
         # Forward encoder.
-        encoder_attn_output_tuples = modules.encoder_attn(
-            decoder_context,
-            None,
-            encoder_context,
-        )
+        encoder_attn_output_tuples = modules.encoder_attn(decoder_context, None, encoder_context)
         with torch.enable_grad():
             encoder_bda_output = modules.encoder_bda(True, True)(
-                encoder_attn_output_tuples,
-                decoder_context,
-                config.retro_encoder_hidden_dropout,
+                encoder_attn_output_tuples, decoder_context, config.retro_encoder_hidden_dropout
             )
         encoder_norm_output = modules.encoder_norm(encoder_bda_output)
 
         # Verify decoder.
-        assert set(decoder_attn_output.keys()) == set([ "ns", "bs", "d", "l", "pad", "attention_output", "attention_bias", "context"])
+        assert set(decoder_attn_output.keys()) == set(
+            ["ns", "bs", "d", "l", "pad", "attention_output", "attention_bias", "context"]
+        )
         assert decoder_attn_output["ns"] == seq_length
         assert decoder_attn_output["bs"] == micro_batch_size
         assert decoder_attn_output["d"] == config.hidden_size
@@ -166,9 +158,7 @@ def run_gpu_forward(self, recompute_granularity, use_transformer_engine):
             micro_batch_size * n_chunks_per_sample,
             config.hidden_size,
         )
-        assert tuple(decoder_attn_output["attention_bias"].shape) == (
-            config.hidden_size,
-        )
+        assert tuple(decoder_attn_output["attention_bias"].shape) == (config.hidden_size,)
         assert decoder_attn_output["context"].shape == (
             config.retro_retrieved_length * config.retro_num_neighbors,
             micro_batch_size * n_chunks_per_sample,
diff --git a/tests/unit_tests/transformer/test_spec_customization.py b/tests/unit_tests/transformer/test_spec_customization.py
index f0ee9e79af..e6b1fc04b7 100755
--- a/tests/unit_tests/transformer/test_spec_customization.py
+++ b/tests/unit_tests/transformer/test_spec_customization.py
@@ -55,7 +55,7 @@ def setup_method(self, method):
 
         # specify layernorm spec with module path to test dynamic importing
         self.layernorm_spec = ModuleSpec(
-            module=("megatron.core.transformer.custom_layers.transformer_engine", "TENorm"),
+            module=("megatron.core.transformer.custom_layers.transformer_engine", "TENorm")
         )
 
         # specify bias dropout add with module path
@@ -97,7 +97,7 @@ def test_build_module(self):
         assert x == random_input
 
         # Check SelfAttention
-        self_attention = build_module(self.attention_spec, config=self.config, layer_number=1,)
+        self_attention = build_module(self.attention_spec, config=self.config, layer_number=1)
         assert isinstance(self_attention, SelfAttention)
         assert self_attention.layer_number == 1
         assert self_attention.attn_mask_type == self.attention_spec.params['attn_mask_type']
diff --git a/tests/unit_tests/transformer/test_transformer_block.py b/tests/unit_tests/transformer/test_transformer_block.py
index 6a2227b52c..02702a9ff7 100644
--- a/tests/unit_tests/transformer/test_transformer_block.py
+++ b/tests/unit_tests/transformer/test_transformer_block.py
@@ -1,26 +1,31 @@
 # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
 
 import os
-import pytest
 
+import pytest
 import torch
+
 from megatron.core import dist_checkpointing
+from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
 from megatron.core.packed_seq_params import PackedSeqParams
+from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
+from megatron.core.transformer.transformer_block import TransformerBlock
 from megatron.core.transformer.transformer_config import TransformerConfig
 from megatron.core.transformer.transformer_layer import TransformerLayer
-from megatron.core.transformer.transformer_block import TransformerBlock
 from tests.unit_tests.test_utilities import Utils
-from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
-from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
+
 
 class TestParallelTransformerBlock:
 
     def setup_method(self, method):
-        Utils.initialize_model_parallel(1,1)
+        Utils.initialize_model_parallel(1, 1)
         model_parallel_cuda_manual_seed(123)
-        self.transformer_config = TransformerConfig(num_layers=2, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True)   
-        self.parallel_transformer_block = TransformerBlock(self.transformer_config,
-                                                           get_gpt_layer_with_transformer_engine_spec())
+        self.transformer_config = TransformerConfig(
+            num_layers=2, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True
+        )
+        self.parallel_transformer_block = TransformerBlock(
+            self.transformer_config, get_gpt_layer_with_transformer_engine_spec()
+        )
 
     def teardown_method(self, method):
         Utils.destroy_model_parallel()
@@ -51,7 +56,9 @@ def test_gpu_forward(self):
 
         attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda()
 
-        hidden_states = parallel_transformer_block(hidden_states=hidden_states, attention_mask=attention_mask)
+        hidden_states = parallel_transformer_block(
+            hidden_states=hidden_states, attention_mask=attention_mask
+        )
         assert hidden_states.shape[0] == sequence_length
         assert hidden_states.shape[1] == micro_batch_size
         assert hidden_states.shape[2] == config.hidden_size
@@ -75,8 +82,9 @@ def _run_full_checkpoint_test(self, fp8):
         config.recompute_method = 'block'
         config.fp8 = fp8
         config.recompute_num_layers = config.num_layers
-        full_transformer_block = TransformerBlock(config,
-                                                  get_gpt_layer_with_transformer_engine_spec())
+        full_transformer_block = TransformerBlock(
+            config, get_gpt_layer_with_transformer_engine_spec()
+        )
         assert full_transformer_block.config.recompute_granularity == 'full'
         assert full_transformer_block.config.recompute_method == 'block'
         assert full_transformer_block.config.fp8 == fp8
@@ -91,7 +99,9 @@ def _run_full_checkpoint_test(self, fp8):
 
         attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda()
 
-        hidden_states = full_transformer_block(hidden_states=hidden_states, attention_mask=attention_mask)
+        hidden_states = full_transformer_block(
+            hidden_states=hidden_states, attention_mask=attention_mask
+        )
         assert hidden_states.shape[0] == sequence_length
         assert hidden_states.shape[1] == micro_batch_size
         assert hidden_states.shape[2] == config.hidden_size
@@ -101,8 +111,9 @@ def _run_selective_checkpoint_test(self, fp8):
         config = transformer_config
         config.recompute_granularity = 'selective'
         config.fp8 = fp8
-        selective_transformer_block = TransformerBlock(config,
-                                                       get_gpt_layer_with_transformer_engine_spec())
+        selective_transformer_block = TransformerBlock(
+            config, get_gpt_layer_with_transformer_engine_spec()
+        )
         assert selective_transformer_block.config.recompute_granularity == 'selective'
         assert selective_transformer_block.checkpoint_core_attention
         assert selective_transformer_block.config.fp8 == fp8
@@ -117,7 +128,9 @@ def _run_selective_checkpoint_test(self, fp8):
 
         attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda()
 
-        hidden_states = selective_transformer_block(hidden_states=hidden_states, attention_mask=attention_mask)
+        hidden_states = selective_transformer_block(
+            hidden_states=hidden_states, attention_mask=attention_mask
+        )
         assert hidden_states.shape[0] == sequence_length
         assert hidden_states.shape[1] == micro_batch_size
         assert hidden_states.shape[2] == config.hidden_size
diff --git a/tests/unit_tests/transformer/test_transformer_layer.py b/tests/unit_tests/transformer/test_transformer_layer.py
index 31792dbe5c..ad8d3ea0f2 100644
--- a/tests/unit_tests/transformer/test_transformer_layer.py
+++ b/tests/unit_tests/transformer/test_transformer_layer.py
@@ -2,26 +2,28 @@
 
 
 import pytest
-
 import torch
 
 from megatron.core import parallel_state
 from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedTensor
-from megatron.core.transformer.transformer_layer import TransformerLayer
+from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
 from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
 from megatron.core.transformer.transformer_config import TransformerConfig
-from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
+from megatron.core.transformer.transformer_layer import TransformerLayer
 from tests.unit_tests.test_utilities import Utils
 
 
 class TestParallelTransformerLayer:
 
     def setup_method(self, method):
-        Utils.initialize_model_parallel(1,1)
+        Utils.initialize_model_parallel(1, 1)
         model_parallel_cuda_manual_seed(123)
-        transformer_config = TransformerConfig(num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True)
-        self.parallel_transformer_layer = TransformerLayer(transformer_config,
-                                                           get_gpt_layer_with_transformer_engine_spec().submodules)
+        transformer_config = TransformerConfig(
+            num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True
+        )
+        self.parallel_transformer_layer = TransformerLayer(
+            transformer_config, get_gpt_layer_with_transformer_engine_spec().submodules
+        )
 
     def teardown_method(self, method):
         Utils.destroy_model_parallel()
@@ -47,7 +49,9 @@ def test_gpu_forward(self):
 
         attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda()
 
-        hidden_states, context = parallel_transformer_layer(hidden_states=hidden_states, attention_mask=attention_mask)
+        hidden_states, context = parallel_transformer_layer(
+            hidden_states=hidden_states, attention_mask=attention_mask
+        )
         assert hidden_states.shape[0] == sequence_length
         assert hidden_states.shape[1] == micro_batch_size
         assert hidden_states.shape[2] == config.hidden_size
@@ -59,14 +63,19 @@ def test_sharded_state_dict(self, tp_pp, order):
         Utils.initialize_model_parallel(*tp_pp, order=order)
 
         model_parallel_cuda_manual_seed(123)
-        transformer_config = TransformerConfig(num_layers=2, hidden_size=128, num_attention_heads=8, use_cpu_initialization=True)
-        parallel_transformer_layer = TransformerLayer(transformer_config,
-                                                      get_gpt_layer_with_transformer_engine_spec().submodules)
+        transformer_config = TransformerConfig(
+            num_layers=2, hidden_size=128, num_attention_heads=8, use_cpu_initialization=True
+        )
+        parallel_transformer_layer = TransformerLayer(
+            transformer_config, get_gpt_layer_with_transformer_engine_spec().submodules
+        )
 
         sharded_state_dict = parallel_transformer_layer.sharded_state_dict()
 
         extra_states = {k: v for k, v in sharded_state_dict.items() if k.endswith('extra_state')}
-        sharded_tensors = {k: v for k, v in sharded_state_dict.items() if not k.endswith('extra_state')}
+        sharded_tensors = {
+            k: v for k, v in sharded_state_dict.items() if not k.endswith('extra_state')
+        }
         assert all(isinstance(t, ShardedObject) for t in extra_states.values())
         assert all(isinstance(t, ShardedTensor) for t in sharded_tensors.values())
 
diff --git a/tools/autoformat.sh b/tools/autoformat.sh
index 784a7846e2..bb5473bcfa 100755
--- a/tools/autoformat.sh
+++ b/tools/autoformat.sh
@@ -3,7 +3,7 @@ set -euox pipefail
 
 SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
 CHECK_ONLY=${CHECK_ONLY:-false}
-CHANGED_FILES=$(git diff --name-only --diff-filter=d --merge-base origin/main megatron/core | grep '\.py$' || true)
+CHANGED_FILES=$(git diff --name-only --diff-filter=d --merge-base origin/main megatron/core tests/ | grep '\.py$' || true)
 ADDITIONAL_ARGS=""
 ADDITIONAL_BLACK_ARGS=""
 
@@ -12,9 +12,8 @@ if [[ $CHECK_ONLY == true ]]; then
     ADDITIONAL_BLACK_ARGS="--diff"
 fi
 
-# for now we just format core
 if [[ -n "$CHANGED_FILES" ]]; then
-    black $ADDITIONAL_ARGS $ADDITIONAL_BLACK_ARGS --verbose $CHANGED_FILES
+    black --skip-magic-trailing-comma $ADDITIONAL_ARGS $ADDITIONAL_BLACK_ARGS --verbose $CHANGED_FILES
     isort $ADDITIONAL_ARGS $CHANGED_FILES
 else
     echo Changeset is empty, all good.