Skip to content

Commit

Permalink
Seeding with greedy...
Browse files Browse the repository at this point in the history
  • Loading branch information
aryavohra committed Feb 10, 2025
1 parent ffeb003 commit d6410b6
Showing 1 changed file with 32 additions and 9 deletions.
41 changes: 32 additions & 9 deletions src/enzyme_ad/jax/deps/tensat/src/optimize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,17 @@ pub fn candidate_to_recexpr(
egraph: &EGraph<Mdl, TensorAnalysis>,
root: Id,
) -> (RecExpr<Mdl>, HashMap<Id, Id>) {
// TODO: untested
let mut node_picked: HashMap<Id, Mdl> = HashMap::new();
for eclass in egraph.classes() {
let eclass_id = egraph.find(eclass.id);
let enodes = &egraph[eclass_id].nodes;
// We treat the default state as 0. TODO: improve (maybe initialise with input graph)
let node_idx = candidate.get(&eclass_id).or(Some(&0)).unwrap();
let enode = enodes[*node_idx].clone();
assert!(node_picked.insert(eclass_id.clone(), enode).is_none());
// Check if the candidate mapping is valid; if not, use 0 as default.
let node_idx = match candidate.get(&eclass_id) {
Some(&i) if i < enodes.len() => i,
_ => 0,
};
let enode = enodes[node_idx].clone();
node_picked.insert(eclass_id.clone(), enode);
}

let mut egraph_to_recexpr: HashMap<Id, Id> = HashMap::new();
Expand All @@ -150,8 +152,8 @@ pub fn candidate_to_recexpr(
pub fn extract_by_optimization(extractor: GlobalExtractor, method: OptimizationMethod) -> Candidate {
match method {
OptimizationMethod::SimulatedAnnealing => {
let init_temp = 0.05f64;
let sa = SimulatedAnnealing::new(init_temp)
let greedy_candidate = compute_greedy_candidate(&extractor.egraph, &extractor.cost_model);
let sa = SimulatedAnnealing::new(0.05f64)
.unwrap()
.with_temp_func(SATempFunc::Boltzmann)
.with_stall_best(1000)
Expand All @@ -160,9 +162,8 @@ pub fn extract_by_optimization(extractor: GlobalExtractor, method: OptimizationM
.with_reannealing_accepted(400)
.with_reannealing_best(400);

let default: HashMap<Id, usize> = HashMap::new();
let solver = Executor::new(extractor, sa)
.configure(|state| state.param(default).max_iters(50000))
.configure(|state| state.param(greedy_candidate).max_iters(1000))
.add_observer(SlogLogger::term(), observers::ObserverMode::Every(10));
solver.run().unwrap().state.param.unwrap()
}
Expand Down Expand Up @@ -406,3 +407,25 @@ fn get_init_rec(
added_memo.insert(id);
}
}

pub fn compute_greedy_candidate(
egraph: &EGraph<Mdl, TensorAnalysis>,
cost_model: &CostModel,
) -> Candidate {
let mut candidate = Candidate::new();
for eclass in egraph.classes() {
let id = egraph.find(eclass.id);
let enodes = &egraph[id].nodes;
let (min_idx, _) = enodes
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
let cost_a = cost_model.get_self_cost(egraph, a).0;
let cost_b = cost_model.get_self_cost(egraph, b).0;
cost_a.partial_cmp(&cost_b).unwrap()
})
.unwrap();
candidate.insert(id, min_idx);
}
candidate
}

0 comments on commit d6410b6

Please sign in to comment.