Skip to content

Commit

Permalink
fix code style bug
Browse files Browse the repository at this point in the history
  • Loading branch information
杨熙 committed Jan 2, 2025
1 parent 265b845 commit 67c0e39
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ repos:
hooks:
- id: insert-license
files: \.py$
args: ["--license-filepath", "data/.license_header.txt", "--use-current-year"]
args: ["--license-filepath", "data/.license_header.txt", "--allow-past-years"]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.1
hooks:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Alibaba Group;
# Copyright (c) 2024-2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
2 changes: 1 addition & 1 deletion scripts/pyre_check.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Alibaba Group;
# Copyright (c) 2024-2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, Alibaba Group;
# Copyright (c) 2024-2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
6 changes: 4 additions & 2 deletions tzrec/acc/aot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

# skip default bound check which is not allow by aot
if "ENABLE_AOT" in os.environ:
# pyre-ignore [8]
IntNBitTableBatchedEmbeddingBagsCodegen.__init__ = functools.partialmethod(
IntNBitTableBatchedEmbeddingBagsCodegen.__init__,
bounds_check_mode=BoundsCheckMode.NONE,
Expand All @@ -45,9 +46,10 @@
del decomposition_table[aten._softmax.out]


# pyre-ignore [56]
@register_decomposition(aten._softmax)
@out_wrapper()
def _softmax(x: torch.Tensor, dim: int, half_to_float: bool):
def _softmax(x: torch.Tensor, dim: int, half_to_float: bool) -> torch.Tensor:
# eager softmax returns a contiguous tensor. Ensure that decomp also returns
# a contiguous tensor.
x = x.contiguous()
Expand All @@ -67,7 +69,7 @@ def _softmax(x: torch.Tensor, dim: int, half_to_float: bool):

def export_model_aot(
model: nn.Module, data: Dict[str, torch.Tensor], save_dir: str
) -> None:
) -> torch.export.ExportedProgram:
"""Export aot model.
Args:
Expand Down
1 change: 1 addition & 0 deletions tzrec/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def forward(
class ScriptWrapperAOT(ScriptWrapper):
"""Model inference wrapper for aot export."""

# pyre-ignore [14]
def forward(
self,
data: Dict[str, torch.Tensor],
Expand Down

0 comments on commit 67c0e39

Please sign in to comment.