From f7de6145106ede80091325aa8c32fde543ef7cc8 Mon Sep 17 00:00:00 2001 From: Ohad Agadi Date: Thu, 28 Nov 2024 14:43:35 +0200 Subject: [PATCH] relation summary --- .../constraint_framework/relation_tracker.rs | 45 ++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/crates/prover/src/constraint_framework/relation_tracker.rs b/crates/prover/src/constraint_framework/relation_tracker.rs index df3996d63..b5220130a 100644 --- a/crates/prover/src/constraint_framework/relation_tracker.rs +++ b/crates/prover/src/constraint_framework/relation_tracker.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::fmt::Debug; use itertools::Itertools; @@ -74,7 +75,7 @@ impl RelationTrackerComponent { } /// Aggregates relation entries. -// TODO(Ohad): write a summarize function, test. +// TODO(Ohad): test. pub struct RelationTrackerEvaluator<'a> { entries: Vec, pub trace_eval: @@ -187,3 +188,45 @@ impl<'a> EvalAtRow for RelationTrackerEvaluator<'a> { } } } + +type RelationInfo = (String, Vec<(Vec, M31)>); +pub struct RelationSummary(Vec); +impl RelationSummary { + /// Returns the sum of every entry's yields and uses. + /// The result is a map from relation name to a list of values(M31 vectors) and their sum. + pub fn summarize_relations(entries: &[RelationTrackerEntry]) -> Self { + let mut summary = vec![]; + let relations = entries.iter().group_by(|entry| entry.relation.clone()); + for (relation, entries) in &relations { + let mut relation_sums: HashMap, M31> = HashMap::new(); + for entry in entries { + let mult = relation_sums + .entry(entry.values.clone()) + .or_insert(M31::zero()); + *mult += entry.mult; + } + let relation_sums = relation_sums.into_iter().collect_vec(); + summary.push((relation.clone(), relation_sums)); + } + Self(summary) + } + + pub fn get_relation_info(&self, relation: &str) -> Option<&[(Vec, M31)]> { + self.0 + .iter() + .find(|(name, _)| name == relation) + .map(|(_, entries)| entries.as_slice()) + } +} +impl Debug for RelationSummary { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for (relation, entries) in &self.0 { + writeln!(f, "{}:", relation)?; + for (vector, sum) in entries { + let vector = vector.iter().map(|v| v.0).collect_vec(); + writeln!(f, " {:?} -> {}", vector, sum)?; + } + } + Ok(()) + } +}