Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Oct 16, 2024
1 parent 5d4cc68 commit b952450
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/mistral_inference/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
args: TransformerArgs,
pipeline_rank: int = 0,
num_pipeline_ranks: int = 1,
softmax_fp32: bool = True,
):
super().__init__()
self.args = args
Expand All @@ -46,6 +47,8 @@ def __init__(
assert pipeline_rank < num_pipeline_ranks, (pipeline_rank, num_pipeline_ranks)
self.pipeline_rank = pipeline_rank
self.num_pipeline_ranks = num_pipeline_ranks
self.softmax_fp32 = softmax_fp32

# Modules specific to some ranks:
self.tok_embeddings: Optional[nn.Embedding] = None
self.norm: Optional[RMSNorm] = None
Expand Down Expand Up @@ -207,7 +210,11 @@ def forward(
outs = self.output(h)
if self.num_pipeline_ranks > 1:
torch.distributed.broadcast(outs, src=self.num_pipeline_ranks - 1)
return outs.float()

if self.softmax_fp32:
return outs.float()
else:
return outs

def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False) -> None:
state_to_load = {}
Expand Down Expand Up @@ -259,6 +266,7 @@ def from_folder(
num_pipeline_ranks: int = 1,
device: Union[torch.device, str] = "cuda",
dtype: Optional[torch.dtype] = None,
softmax_fp32: bool = True,
) -> "Transformer":
with open(Path(folder) / "params.json", "r") as f:
model_args = TransformerArgs.from_dict(json.load(f))
Expand All @@ -272,6 +280,7 @@ def from_folder(
model_args,
pipeline_rank=pipeline_rank,
num_pipeline_ranks=num_pipeline_ranks,
softmax_fp32=softmax_fp32,
)

pt_model_file = Path(folder) / "consolidated.00.pth"
Expand Down

0 comments on commit b952450

Please sign in to comment.