Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize infer_tower_product_witness with less traverse #780

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

hero78119
Copy link
Collaborator

design rationale

a simple way by just iterate chunk size 2 at once.

benchmark

this result benefits opcode read/write product argument witness inferring.

with command

cargo bench --bench fibonacci --package ceno_zkvm -- --baseline baseline

on ceno server it shows around 3-4% fibonacci e2e performance

fibonacci_max_steps_1048576/prove_fibonacci/fibonacci_max_steps_1048576
                        time:   [3.9635 s 3.9860 s 4.0097 s]
                        change: [-5.4827% -4.7524% -3.9562%] (p = 0.00 < 0.05)
                        Performance has improved.

.with_min_len(MIN_PAR_SIZE)
.map(|(v, evaluations)| *evaluations *= *v)
.collect()
next_layer.chunks_exact(2).for_each(|f| {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest using .tuples so that the compiler can help you more.

Something a bit like this:

diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs
index 26a67c3e..7bac7044 100644
--- a/ceno_zkvm/src/scheme/utils.rs
+++ b/ceno_zkvm/src/scheme/utils.rs
@@ -211,21 +211,21 @@ pub(crate) fn infer_tower_product_witness<E: ExtensionField>(
             let cur_layer: Vec<ArcMultilinearExtension<E>> = (0..num_product_fanin)
                 .map(|index| {
                     let mut evaluations = vec![E::ONE; cur_len];
-                    next_layer.chunks_exact(2).for_each(|f| {
-                        match (f[0].evaluations(), f[1].evaluations()) {
+                    next_layer
+                        .iter()
+                        .map(|f| f.evaluations())
+                        .tuples()
+                        .for_each(|(f1, f2)| match (f1, f2) {
                             (FieldType::Ext(f1), FieldType::Ext(f2)) => {
                                 let start: usize = index * cur_len;
-                                (start..(start + cur_len))
+                                (start..start + cur_len)
                                     .into_par_iter()
-                                    .zip(evaluations.par_iter_mut())
+                                    .zip(&mut evaluations)
                                     .with_min_len(MIN_PAR_SIZE)
-                                    .map(|(index, evaluations)| {
-                                        *evaluations *= f1[index] * f2[index]
-                                    })
+                                    .map(|(index, evaluation)| *evaluation *= f1[index] * f2[index])
                                     .collect()
                             }
                             _ => unreachable!("must be extension field"),
-                        }
                         });
                     evaluations.into_mle().into()
                 })

Copy link
Collaborator Author

@hero78119 hero78119 Dec 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the suggestion, but tuples will silently skip the remainder for non-even result without error hints. I would prefer stick to current way of chunks_exact(2) so it's self-explanantion, and terminate with error precisely in runtime

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants