Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
jjsjann123 committed Dec 24, 2024
1 parent 1e0c338 commit 17d9e39
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions csrc/preseg_passes/consecutive_cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <ir/utils.h>
#include <ops/arith.h>
#include <ops/utils.h>
#include <transform_replay.h>

namespace nvfuser::preseg_passes {
Expand All @@ -24,7 +25,8 @@ bool isCast(Expr* expr) {

bool isMovableMeta(Expr* expr) {
return (!expr->output(0)->isFusionOutput()) &&
(expr->isOneOf<SqueezeOp, BroadcastOp, ViewOp>() || ir_utils::isSimpleTVSet(expr));
(expr->isOneOf<SqueezeOp, BroadcastOp, ViewOp>() ||
ir_utils::isSimpleTVSet(expr));
}

// replaces input to the cast op that produes cast_output, return the new
Expand Down Expand Up @@ -167,7 +169,8 @@ void moveChainedCasts(Expr* expr, std::unordered_set<Expr*>& visited) {
if (starting_anchor->getDataType().value() == output_dtype) {
// if output dtype is identical to starting_anchor dtype, we can't keep
// the last cast op and will need to re-write all uses here
ir_utils::replaceValue(expr->fusion(), {{expr->output(0), starting_anchor}});
ir_utils::replaceValue(
expr->fusion(), {{expr->output(0), starting_anchor}});
} else {
replaceInputInCast(expr->output(0), starting_anchor);
}
Expand Down Expand Up @@ -197,16 +200,18 @@ void castOptimizationPass(Fusion* fusion) {
Expr* meta = expr->input(0)->definition();

// replayed cast
Val* replayed_expr_out = castOp(expr->output(0)->dtype(), meta->input(0));
Val* replayed_expr_out =
castOp(expr->output(0)->dtype(), meta->input(0));

// replayed meta
// replay meta on new inputs
Expr* replayed_meta = nvfuser::ir_utils::replaceValInExprInputs(
meta, meta->input(0), replayed_expr_out);
// update replayed meta output
Val* replayed_meta_out = ops::newValLike(
meta->output(0), meta->output(0)->getDataType().value());
replayed_meta = transferDefinitionToNewOutputs(replayed_meta, {replayed_meta_out});
meta->output(0), meta->output(0)->getDataType().value());
replayed_meta = ir_utils::transferDefinitionToNewOutputs(
replayed_meta, {replayed_meta_out});

// replace uses of old second output
ir_utils::replaceValInAllExprInputsAndFusionOutputs(
Expand Down

0 comments on commit 17d9e39

Please sign in to comment.