diff --git a/src/mistral_inference/transformer.py b/src/mistral_inference/transformer.py index a53195f..cb782dd 100644 --- a/src/mistral_inference/transformer.py +++ b/src/mistral_inference/transformer.py @@ -36,6 +36,7 @@ def __init__( args: TransformerArgs, pipeline_rank: int = 0, num_pipeline_ranks: int = 1, + softmax_fp32: bool = True, ): super().__init__() self.args = args @@ -46,6 +47,8 @@ def __init__( assert pipeline_rank < num_pipeline_ranks, (pipeline_rank, num_pipeline_ranks) self.pipeline_rank = pipeline_rank self.num_pipeline_ranks = num_pipeline_ranks + self.softmax_fp32 = softmax_fp32 + # Modules specific to some ranks: self.tok_embeddings: Optional[nn.Embedding] = None self.norm: Optional[RMSNorm] = None @@ -207,7 +210,11 @@ def forward( outs = self.output(h) if self.num_pipeline_ranks > 1: torch.distributed.broadcast(outs, src=self.num_pipeline_ranks - 1) - return outs.float() + + if self.softmax_fp32: + return outs.float() + else: + return outs def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False) -> None: state_to_load = {} @@ -259,6 +266,7 @@ def from_folder( num_pipeline_ranks: int = 1, device: Union[torch.device, str] = "cuda", dtype: Optional[torch.dtype] = None, + softmax_fp32: bool = True, ) -> "Transformer": with open(Path(folder) / "params.json", "r") as f: model_args = TransformerArgs.from_dict(json.load(f)) @@ -272,6 +280,7 @@ def from_folder( model_args, pipeline_rank=pipeline_rank, num_pipeline_ranks=num_pipeline_ranks, + softmax_fp32=softmax_fp32, ) pt_model_file = Path(folder) / "consolidated.00.pth"