Skip to content

Commit

Permalink
Merge branch 'mlcommons:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
JoongunPark authored Nov 13, 2024
2 parents 2833ef8 + 47c8154 commit 7617e03
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/converter/text_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def get_comm_type(self, comm_type: str) -> int:
def get_comm_coll_node(self, layer_name: str, comm_type: str, comm_size: int) -> Any:
node = self.get_node(f"COMM_COLL_NODE_{layer_name}_{comm_type}", COMM_COLL_NODE)
node.attr.append(ChakraAttr(name="comm_type", int64_val=self.get_comm_type(comm_type)))
node.attr.append(ChakraAttr(name="comm_size", uint64_val=comm_size))
node.attr.append(ChakraAttr(name="comm_size", int64_val=comm_size))
return node

def add_parent(self, child_node: Any, parent_node: Any) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/feeder/et_feeder_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ETFeederNode::ETFeederNode(std::shared_ptr<ChakraProtoMsg::Node> node) {
} else if (attr_name == "comm_priority") {
this->comm_priority_ = static_cast<uint32_t>(attr.int32_val());
} else if (attr_name == "comm_size") {
this->comm_size_ = attr.int64_val();
this->comm_size_ = static_cast<uint64_t>(attr.int64_val());
} else if (attr_name == "comm_src") {
this->comm_src_ = static_cast<uint32_t>(attr.int32_val());
} else if (attr_name == "comm_dst") {
Expand Down
2 changes: 1 addition & 1 deletion src/generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def generate_comm_coll_node(num_npus: int, comm_size: int, comm_type: int, node_

node = get_node(node_name, COMM_COLL_NODE)
node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False))
node.attr.extend([get_comm_type_attr(comm_type), ChakraAttr(name="comm_size", uint64_val=comm_size)])
node.attr.extend([get_comm_type_attr(comm_type), ChakraAttr(name="comm_size", int64_val=comm_size)])
encode_message(et, node)


Expand Down

0 comments on commit 7617e03

Please sign in to comment.