diff --git a/csrc/preseg_passes/consecutive_cast.cpp b/csrc/preseg_passes/consecutive_cast.cpp index 564aa23b361..c292b33748f 100644 --- a/csrc/preseg_passes/consecutive_cast.cpp +++ b/csrc/preseg_passes/consecutive_cast.cpp @@ -199,21 +199,22 @@ void castOptimizationPass(Fusion* fusion) { if (isMovableMeta(expr->input(0)->definition())) { Expr* meta = expr->input(0)->definition(); - // replayed cast + // replayed cast. Val* replayed_expr_out = castOp(expr->output(0)->dtype(), meta->input(0)); - // replayed meta - // replay meta on new inputs - Expr* replayed_meta = nvfuser::ir_utils::replaceValInExprInputs( - meta, meta->input(0), replayed_expr_out); - // update replayed meta output + // preparing new meta output. Val* replayed_meta_out = ops::newValLike( meta->output(0), meta->output(0)->getDataType().value()); + + // replay meta on new inputs. + Expr* replayed_meta = nvfuser::ir_utils::replaceValInExprInputs( + meta, meta->input(0), replayed_expr_out); + // update replayed meta output. replayed_meta = ir_utils::transferDefinitionToNewOutputs( replayed_meta, {replayed_meta_out}); - // replace uses of old second output + // replace uses of old second output. ir_utils::replaceValInAllExprInputsAndFusionOutputs( expr->output(0), replayed_meta_out);