Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/comm-group' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
JoongunPark committed Jul 26, 2024
2 parents ec7a492 + 072c562 commit 9114ae0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
6 changes: 4 additions & 2 deletions src/converter/pytorch_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,10 @@ def convert_json_to_protobuf_nodes(
protobuf_node_map (Dict[int, ChakraNode]): Dictionary where the converted Protobuf nodes will be stored.
"""
for _, json_node in json_node_map.items():
if (json_node.get_op_type() == PyTorchNodeType.CPU_OP) or (
json_node.get_op_type() == PyTorchNodeType.LABEL
if (
(json_node.get_op_type() == PyTorchNodeType.CPU_OP)
or (json_node.get_op_type() == PyTorchNodeType.LABEL)
or (json_node.get_op_type() == PyTorchNodeType.METADATA)
):
chakra_node = self.convert_json_to_protobuf_node(json_node_map, protobuf_node_map, json_node)
protobuf_node_map[chakra_node.id] = chakra_node
Expand Down
4 changes: 3 additions & 1 deletion src/converter/pytorch_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def get_op_type(self) -> PyTorchNodeType:
Returns
PyTorchNodeType: The type of the PyTorch operation.
"""
if self.is_gpu_op():
if "process_group:init" in self.name:
return PyTorchNodeType.METADATA
elif self.is_gpu_op():
return PyTorchNodeType.GPU_OP
elif hasattr(self, "op_schema") or hasattr(self, "outputs"):
return PyTorchNodeType.CPU_OP
Expand Down

0 comments on commit 9114ae0

Please sign in to comment.