From 76af3c6c7afc5b6e582ec5d16a03e74efe5a11ee Mon Sep 17 00:00:00 2001 From: Ohad <137686240+ohad-starkware@users.noreply.github.com> Date: Wed, 4 Dec 2024 11:28:06 +0200 Subject: [PATCH] relation tracker bug fix (#921) --- .../constraint_framework/relation_tracker.rs | 38 ++++++++++++++++--- .../prover/src/examples/state_machine/mod.rs | 8 +++- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/crates/prover/src/constraint_framework/relation_tracker.rs b/crates/prover/src/constraint_framework/relation_tracker.rs index 3866df39a..b804d488e 100644 --- a/crates/prover/src/constraint_framework/relation_tracker.rs +++ b/crates/prover/src/constraint_framework/relation_tracker.rs @@ -194,14 +194,24 @@ impl RelationSummary { /// Returns the sum of every entry's yields and uses. /// The result is a map from relation name to a list of values(M31 vectors) and their sum. pub fn summarize_relations(entries: &[RelationTrackerEntry]) -> Self { + let mut entry_by_relation = HashMap::new(); + for entry in entries { + entry_by_relation + .entry(entry.relation.clone()) + .or_insert_with(Vec::new) + .push(entry); + } let mut summary = vec![]; - let relations = entries.iter().group_by(|entry| entry.relation.clone()); - for (relation, entries) in &relations { + for (relation, entries) in entry_by_relation { let mut relation_sums: HashMap, M31> = HashMap::new(); for entry in entries { - let mult = relation_sums - .entry(entry.values.clone()) - .or_insert(M31::zero()); + let mut values = entry.values.clone(); + + // Trailing zeroes do not affect the sum, remove for correct aggregation. + while values.last().is_some_and(|v| v.is_zero()) { + values.pop(); + } + let mult = relation_sums.entry(values).or_insert(M31::zero()); *mult += entry.mult; } let relation_sums = relation_sums.into_iter().collect_vec(); @@ -216,6 +226,24 @@ impl RelationSummary { .find(|(name, _)| name == relation) .map(|(_, entries)| entries.as_slice()) } + + /// Cleans up the summary by removing zero-sum entries, only keeping the non-zero ones. + /// Used for debugging. + pub fn cleaned(self) -> Self { + let mut cleaned = vec![]; + for (relation, entries) in self.0 { + let mut cleaned_entries = vec![]; + for (vector, sum) in entries { + if !sum.is_zero() { + cleaned_entries.push((vector, sum)); + } + } + if !cleaned_entries.is_empty() { + cleaned.push((relation, cleaned_entries)); + } + } + Self(cleaned) + } } impl Debug for RelationSummary { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index bdb265fff..84b64617c 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -305,8 +305,12 @@ mod tests { // Check the final state inferred from the summary. let mut curr_state = initial_state; for entry in relation_info { - let x_step = entry.0[0]; - let y_step = entry.0[1]; + let (x_step, y_step) = match entry.0.len() { + 2 => (entry.0[0], entry.0[1]), + 1 => (entry.0[0], M31::zero()), + 0 => (M31::zero(), M31::zero()), + _ => unreachable!(), + }; let mult = entry.1; let next_state = [curr_state[0] - x_step * mult, curr_state[1] - y_step * mult];