Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial commit for aot support #79

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open

initial commit for aot support #79

wants to merge 4 commits into from

Conversation

chengmengli06
Copy link
Collaborator

No description provided.

@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.


杨熙 seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

@@ -0,0 +1,7 @@
rm -rf experiments/multi_tower_din_taobao_local/export
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the script in git, and add ENABLE_AOT doc in usage/export.md


gm = gm.cuda()

print(gm)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is print gm the same as write gm.code

dynamic_shapes[key] = {0: batch}
elif key == "batch_size":
dynamic_shapes[key] = {}
elif data[key].dtype == torch.float32 and "__" not in key:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"__" might be present in regular feature names, not in sequence features only.

exported_gm = torch.export.export(
gm, args=(data,), dynamic_shapes=(dynamic_shapes,)
)
print(exported_gm)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not print, already write to exported_gm.code

)
dynamic_shapes[key] = {0: tmp_val_dim}

exported_gm = torch.export.export(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exported_gm -> exported_program is better

@@ -920,6 +927,16 @@ def export(
)
for asset in assets:
shutil.copy(asset, os.path.join(export_dir, "model"))
elif is_aot():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to line 891

InferWrapper = ExportWrapper if is_aot() else ScriptWrapper:

and use InferWrapper later

@@ -1,4 +1,4 @@
# Copyright (c) 2024, Alibaba Group;
# Copyright (c) 2024-2025, Alibaba Group;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert the copyright

@@ -20,6 +20,7 @@
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
from torch_tensorrt.dynamo.utils import to_torch_device


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert

@@ -236,3 +236,24 @@ def forward(
"""
batch = self.get_batch(data, device)
return self.model.predict(batch)


class ScriptWrapperAOT(ScriptWrapper):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may be ExportWrapper is better



class ScriptWrapperAOT(ScriptWrapper):
"""Model inference wrapper for aot export."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Model inference wrapper for torch.export

@@ -0,0 +1,139 @@
# Copyright (c) 2024, Alibaba Group;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add an AOT test in tzrec/tests/rank_integration_test.py


exported_gm_path = os.path.join(save_dir, "debug_exported_gm.py")
with open(exported_gm_path, "w") as fout:
fout.write(str(exported_gm))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the exported_gm.code == debug_exported_gm,py ,they all save the str(exported_gm)

@@ -746,6 +750,9 @@ def _script_model(
logger.info(f"Model Outputs: {result_info}")

export_model_trt(model, data_cuda, save_dir)
elif is_aot():
data_cuda = batch.to_dict(sparse_dtype=torch.int64)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the data type is same in cpu and gpu, it can use the same data_cuda in export/trt_export/aot_export

# pyre-ignore [14]
def forward(
self,
data: Dict[str, torch.Tensor],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the aot model predict may not support device, it need to workaround in predict such as https://github.com/alibaba/TorchEasyRec/blob/master/tzrec/main.py#L1076

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is not fully done, currently just commit the export part, the compile part and prediction part are waited to be done

@@ -737,6 +739,8 @@ def _script_model(
logger.info("quantize embeddings...")
quantize_embeddings(model, dtype=torch.qint8, inplace=True)

if is_aot():
model = model.cuda()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just model.cuda() is correct?when i use gloo to load device_state_dict = state_dict_to_device(
model.state_dict(), pg=checkpoint_pg, device=torch.device("cpu")
)
model = model.to_empty(device="cpu")
model = model.to("cuda:0") may be incorrect when I run forward model(data_cuda)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have test it, there is no problem

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants