diff --git a/src/generator/generator.py b/src/generator/generator.py index c8352c6f..036c16f7 100644 --- a/src/generator/generator.py +++ b/src/generator/generator.py @@ -4,7 +4,11 @@ ALL_GATHER, ALL_REDUCE, ALL_TO_ALL, + BARRIER, + BROADCAST, COMM_COLL_NODE, + COMM_RECV_NODE, + COMM_SEND_NODE, COMP_NODE, MEM_LOAD_NODE, MEM_STORE_NODE, @@ -114,9 +118,8 @@ def one_remote_mem_load_node(num_npus: int, tensor_size: int) -> None: encode_message(et, GlobalMetadata(version="0.0.4")) node = get_node("MEM_LOAD_NODE", MEM_LOAD_NODE) - node.attr.extend( - [ChakraAttr(name="is_cpu_op", bool_val=False), ChakraAttr(name="tensor_size", uint64_val=tensor_size)] - ) + node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False)) + node.attr.append(ChakraAttr(name="tensor_size", uint64_val=tensor_size)) encode_message(et, node) @@ -128,9 +131,8 @@ def one_remote_mem_store_node(num_npus: int, tensor_size: int) -> None: encode_message(et, GlobalMetadata(version="0.0.4")) node = get_node("MEM_STORE_NODE", MEM_STORE_NODE) - node.attr.extend( - [ChakraAttr(name="is_cpu_op", bool_val=False), ChakraAttr(name="tensor_size", uint64_val=tensor_size)] - ) + node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False)) + node.attr.append(ChakraAttr(name="tensor_size", uint64_val=tensor_size)) encode_message(et, node) @@ -188,13 +190,8 @@ def generate_comm_coll_node(num_npus: int, comm_size: int, comm_type: int, node_ encode_message(et, GlobalMetadata(version="0.0.4")) node = get_node(node_name, COMM_COLL_NODE) - node.attr.extend( - [ - ChakraAttr(name="is_cpu_op", bool_val=False), - get_comm_type_attr(comm_type), - ChakraAttr(name="comm_size", uint64_val=comm_size), - ] - ) + node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False)) + node.attr.extend([get_comm_type_attr(comm_type), ChakraAttr(name="comm_size", uint64_val=comm_size)]) encode_message(et, node) @@ -218,6 +215,42 @@ def one_comm_coll_node_reducescatter(num_npus: int, comm_size: int) -> None: generate_comm_coll_node(num_npus, comm_size, REDUCE_SCATTER, "REDUCE_SCATTER") +def one_comm_coll_node_broadcast(num_npus: int, comm_size: int) -> None: + """Generate one Broadcast communication collective node.""" + generate_comm_coll_node(num_npus, comm_size, BROADCAST, "BROADCAST") + + +def one_comm_coll_node_barrier(num_npus: int) -> None: + """Generate one Barrier communication collective node.""" + generate_comm_coll_node(num_npus, comm_size=0, comm_type=BARRIER, node_name="BARRIER") + + +def one_comm_send_node(num_npus: int, tensor_size: int) -> None: + """Generate communication send nodes.""" + for npu_id in range(num_npus): + output_filename = f"one_comm_send_node.{npu_id}.et" + with open(output_filename, "wb") as et: + encode_message(et, GlobalMetadata(version="0.0.4")) + + node = get_node("COMM_SEND_NODE", COMM_SEND_NODE) + node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False)) + node.attr.append(ChakraAttr(name="tensor_size", uint64_val=tensor_size)) + encode_message(et, node) + + +def one_comm_recv_node(num_npus: int, tensor_size: int) -> None: + """Generate communication receive nodes.""" + for npu_id in range(num_npus): + output_filename = f"one_comm_recv_node.{npu_id}.et" + with open(output_filename, "wb") as et: + encode_message(et, GlobalMetadata(version="0.0.4")) + + node = get_node("COMM_RECV_NODE", COMM_RECV_NODE) + node.attr.append(ChakraAttr(name="is_cpu_op", bool_val=False)) + node.attr.append(ChakraAttr(name="tensor_size", uint64_val=tensor_size)) + encode_message(et, node) + + def main() -> None: parser = argparse.ArgumentParser(description="Execution Trace Generator") parser.add_argument("--num_npus", type=int, default=64, help="Number of NPUs") @@ -238,6 +271,10 @@ def main() -> None: one_comm_coll_node_alltoall(args.num_npus, args.default_comm_size) one_comm_coll_node_allgather(args.num_npus, args.default_comm_size) one_comm_coll_node_reducescatter(args.num_npus, args.default_comm_size) + one_comm_coll_node_broadcast(args.num_npus, args.default_comm_size) + one_comm_coll_node_barrier(args.num_npus) + one_comm_send_node(args.num_npus, args.default_tensor_size) + one_comm_recv_node(args.num_npus, args.default_tensor_size) if __name__ == "__main__":