-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
initial commit for aot support #79
base: master
Are you sure you want to change the base?
Conversation
杨熙 seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account. You have signed the CLA already but the status is still pending? Let us recheck it. |
@@ -0,0 +1,7 @@ | |||
rm -rf experiments/multi_tower_din_taobao_local/export |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the script in git, and add ENABLE_AOT doc in usage/export.md
|
||
gm = gm.cuda() | ||
|
||
print(gm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is print gm
the same as write gm.code
dynamic_shapes[key] = {0: batch} | ||
elif key == "batch_size": | ||
dynamic_shapes[key] = {} | ||
elif data[key].dtype == torch.float32 and "__" not in key: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"__" might be present in regular feature names, not in sequence features only.
exported_gm = torch.export.export( | ||
gm, args=(data,), dynamic_shapes=(dynamic_shapes,) | ||
) | ||
print(exported_gm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do not print, already write to exported_gm.code
) | ||
dynamic_shapes[key] = {0: tmp_val_dim} | ||
|
||
exported_gm = torch.export.export( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exported_gm -> exported_program is better
@@ -920,6 +927,16 @@ def export( | |||
) | |||
for asset in assets: | |||
shutil.copy(asset, os.path.join(export_dir, "model")) | |||
elif is_aot(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move to line 891
InferWrapper = ExportWrapper if is_aot() else ScriptWrapper:
and use InferWrapper later
@@ -1,4 +1,4 @@ | |||
# Copyright (c) 2024, Alibaba Group; | |||
# Copyright (c) 2024-2025, Alibaba Group; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert the copyright
@@ -20,6 +20,7 @@ | |||
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim | |||
from torch_tensorrt.dynamo.utils import to_torch_device | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert
@@ -236,3 +236,24 @@ def forward( | |||
""" | |||
batch = self.get_batch(data, device) | |||
return self.model.predict(batch) | |||
|
|||
|
|||
class ScriptWrapperAOT(ScriptWrapper): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
may be ExportWrapper is better
|
||
|
||
class ScriptWrapperAOT(ScriptWrapper): | ||
"""Model inference wrapper for aot export.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Model inference wrapper for torch.export
@@ -0,0 +1,139 @@ | |||
# Copyright (c) 2024, Alibaba Group; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add an AOT test in tzrec/tests/rank_integration_test.py
|
||
exported_gm_path = os.path.join(save_dir, "debug_exported_gm.py") | ||
with open(exported_gm_path, "w") as fout: | ||
fout.write(str(exported_gm)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the exported_gm.code == debug_exported_gm,py ,they all save the str(exported_gm)
@@ -746,6 +750,9 @@ def _script_model( | |||
logger.info(f"Model Outputs: {result_info}") | |||
|
|||
export_model_trt(model, data_cuda, save_dir) | |||
elif is_aot(): | |||
data_cuda = batch.to_dict(sparse_dtype=torch.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the data type is same in cpu and gpu, it can use the same data_cuda in export/trt_export/aot_export
# pyre-ignore [14] | ||
def forward( | ||
self, | ||
data: Dict[str, torch.Tensor], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the aot model predict may not support device, it need to workaround in predict such as https://github.com/alibaba/TorchEasyRec/blob/master/tzrec/main.py#L1076
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is not fully done, currently just commit the export part, the compile part and prediction part are waited to be done
@@ -737,6 +739,8 @@ def _script_model( | |||
logger.info("quantize embeddings...") | |||
quantize_embeddings(model, dtype=torch.qint8, inplace=True) | |||
|
|||
if is_aot(): | |||
model = model.cuda() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just model.cuda() is correct?when i use gloo to load device_state_dict = state_dict_to_device(
model.state_dict(), pg=checkpoint_pg, device=torch.device("cpu")
)
model = model.to_empty(device="cpu")
model = model.to("cuda:0") may be incorrect when I run forward model(data_cuda)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have test it, there is no problem
No description provided.