Skip to content

Commit 9e4d289

Browse files
committed
fix program chip
1 parent 360c72d commit 9e4d289

File tree

12 files changed

+104
-51
lines changed

12 files changed

+104
-51
lines changed

alu_u32/src/add/stark.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use super::Add32Chip;
33
use core::borrow::Borrow;
44

55
use p3_air::{Air, AirBuilder, BaseAir};
6-
use p3_field::PrimeField;
6+
use p3_field::{AbstractField, PrimeField};
77
use p3_matrix::MatrixRowSlices;
88

99
impl<F> BaseAir<F> for Add32Chip {}

alu_u32/src/div/stark.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use super::Div32Chip;
22
use core::borrow::Borrow;
33

44
use p3_air::{Air, AirBuilder, BaseAir};
5-
use p3_field::PrimeField;
5+
use p3_field::{AbstractField, PrimeField};
66

77
impl<F> BaseAir<F> for Div32Chip {}
88

cpu/src/lib.rs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use core::marker::Sync;
1010
use core::mem::transmute;
1111
use valida_bus::{MachineWithGeneralBus, MachineWithMemBus, MachineWithProgramBus};
1212
use valida_machine::{
13-
instructions, Chip, Instruction, InstructionWord, Interaction, Operands, Word,
13+
instructions, Chip, Instruction, InstructionWord, Interaction, Operands, Word, OPERAND_ELEMENTS,
1414
};
1515
use valida_memory::{MachineWithMemoryChip, Operation as MemoryOperation};
1616
use valida_opcodes::{
@@ -156,16 +156,16 @@ where
156156
};
157157

158158
// Program ROM bus channel
159-
let pc = VirtualPairCol::single_preprocessed(CPU_COL_MAP.pc);
160-
let opcode = VirtualPairCol::single_preprocessed(CPU_COL_MAP.instruction.opcode);
159+
let pc = VirtualPairCol::single_main(CPU_COL_MAP.pc);
160+
let opcode = VirtualPairCol::single_main(CPU_COL_MAP.instruction.opcode);
161161
let mut fields = vec![pc, opcode];
162162
fields.extend(
163163
CPU_COL_MAP
164164
.instruction
165165
.operands
166166
.0
167167
.iter()
168-
.map(|op| VirtualPairCol::single_preprocessed(*op)),
168+
.map(|op| VirtualPairCol::single_main(*op)),
169169
);
170170
let send_program = Interaction {
171171
fields,
@@ -329,6 +329,11 @@ impl CpuChip {
329329
let fp = last_row[CPU_COL_MAP.fp];
330330
let clk = last_row[CPU_COL_MAP.clk];
331331

332+
let opcode = last_row[CPU_COL_MAP.instruction.opcode];
333+
let operands = &last_row[CPU_COL_MAP.instruction.operands.0[0]
334+
..CPU_COL_MAP.instruction.operands.0[0] + OPERAND_ELEMENTS]
335+
.to_vec();
336+
332337
values.resize(n_real_rows.next_power_of_two() * NUM_CPU_COLS, F::ZERO);
333338

334339
// Interpret values as a slice of arrays of length `NUM_CPU_COLS`
@@ -346,7 +351,15 @@ impl CpuChip {
346351
padded_row[CPU_COL_MAP.pc] = pc;
347352
padded_row[CPU_COL_MAP.fp] = fp;
348353
padded_row[CPU_COL_MAP.clk] = clk + F::from_canonical_u32(n as u32 + 1);
354+
355+
// Instruction columns
349356
padded_row[CPU_COL_MAP.opcode_flags.is_stop] = F::ONE;
357+
padded_row[CPU_COL_MAP.instruction.opcode] = opcode;
358+
for i in 0..OPERAND_ELEMENTS {
359+
padded_row[CPU_COL_MAP.instruction.operands.0[i]] = operands[i];
360+
}
361+
362+
// Memory columns
350363
padded_row[CPU_COL_MAP.mem_channels[0].is_read] = F::ONE;
351364
padded_row[CPU_COL_MAP.mem_channels[1].is_read] = F::ONE;
352365
padded_row[CPU_COL_MAP.mem_channels[2].is_read] = F::ZERO;

derive/src/lib.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,19 +123,25 @@ fn run_method(machine: &Ident, instructions: &[&Field]) -> TokenStream2 {
123123
let instruction = program.get_instruction(pc);
124124
let opcode = instruction.opcode;
125125
let ops = instruction.operands;
126-
self.read_word(pc as usize);
127126

128127
// Execute
129128
match opcode {
130129
#opcode_arms
131130
_ => panic!("Unrecognized opcode: {}", opcode),
132131
};
132+
self.read_word(pc as usize);
133133

134134
// A STOP instruction signals the end of the program
135135
if opcode == <StopInstruction as Instruction<Self>>::OPCODE {
136136
break;
137137
}
138138
}
139+
140+
// Record infinite loop cycles
141+
let n = self.cpu().clock.next_power_of_two() - self.cpu().clock;
142+
for _ in 0..n {
143+
self.read_word(self.cpu().pc as usize);
144+
}
139145
}
140146
}
141147
}

machine/src/__internal/check_constraints.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ pub fn check_constraints<M, A>(
2626
return;
2727
}
2828

29-
let preprocessed = air
30-
.preprocessed_trace()
31-
.unwrap_or(RowMajorMatrix::new(vec![], 0));
29+
let preprocessed = air.preprocessed_trace();
3230

3331
let cumulative_sum = *perm.row_slice(perm.height() - 1).last().unwrap();
3432

@@ -38,8 +36,16 @@ pub fn check_constraints<M, A>(
3836

3937
let main_local = main.row_slice(i);
4038
let main_next = main.row_slice(i_next);
41-
let preprocessed_local = preprocessed.row_slice(i);
42-
let preprocessed_next = preprocessed.row_slice(i_next);
39+
let preprocessed_local = if preprocessed.is_some() {
40+
preprocessed.as_ref().unwrap().row_slice(i)
41+
} else {
42+
&[]
43+
};
44+
let preprocessed_next = if preprocessed.is_some() {
45+
preprocessed.as_ref().unwrap().row_slice(i_next)
46+
} else {
47+
&[]
48+
};
4349
let perm_local = perm.row_slice(i);
4450
let perm_next = perm.row_slice(i_next);
4551

machine/src/chip.rs

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
use crate::Machine;
2+
use crate::__internal::ConstraintFolder;
23
use alloc::vec;
34
use alloc::vec::Vec;
45
use valida_util::batch_multiplicative_inverse;
56

6-
use crate::__internal::ConstraintFolder;
77
use p3_air::{Air, AirBuilder, PairBuilder, PermutationAirBuilder, VirtualPairCol};
88
use p3_field::{AbstractExtensionField, AbstractField, ExtensionField, Field, Powers, PrimeField};
99
use p3_matrix::{dense::RowMajorMatrix, Matrix, MatrixRowSlices};
@@ -119,6 +119,8 @@ where
119119
let (alphas_local, alphas_global) = generate_rlc_elements(machine, chip, &random_elements);
120120
let betas = random_elements[2].powers();
121121

122+
let preprocessed = chip.preprocessed_trace();
123+
122124
// Compute the reciprocal columns
123125
//
124126
// Row: | q_1 | q_2 | q_3 | ... | q_n | \phi |
@@ -131,15 +133,26 @@ where
131133
let perm_width = all_interactions.len() + 1;
132134
let mut perm_values = Vec::with_capacity(main.height() * perm_width);
133135

134-
for main_row in main.rows() {
136+
for (n, main_row) in main.rows().enumerate() {
135137
let mut row = vec![M::EF::ZERO; perm_width];
136-
for (n, (interaction, _)) in all_interactions.iter().enumerate() {
137-
let alpha_i = if interaction.is_local() {
138+
for (m, (interaction, _)) in all_interactions.iter().enumerate() {
139+
let alpha_m = if interaction.is_local() {
138140
alphas_local[interaction.argument_index()]
139141
} else {
140142
alphas_global[interaction.argument_index()]
141143
};
142-
row[n] = reduce_row(main_row, &interaction.fields, alpha_i, betas.clone());
144+
let preprocessed_row = if preprocessed.is_some() {
145+
preprocessed.as_ref().unwrap().row_slice(n)
146+
} else {
147+
&[]
148+
};
149+
row[m] = reduce_row(
150+
main_row,
151+
preprocessed_row,
152+
&interaction.fields,
153+
alpha_m,
154+
betas.clone(),
155+
);
143156
}
144157
perm_values.extend(row);
145158
}
@@ -152,8 +165,15 @@ where
152165
if n > 0 {
153166
phi[n] = phi[n - 1];
154167
}
168+
let preprocessed_row = if preprocessed.is_some() {
169+
preprocessed.as_ref().unwrap().row_slice(n)
170+
} else {
171+
&[]
172+
};
155173
for (m, (interaction, interaction_type)) in all_interactions.iter().enumerate() {
156-
let mult = interaction.count.apply::<M::F, M::F>(&[], main_row);
174+
let mult = interaction
175+
.count
176+
.apply::<M::F, M::F>(preprocessed_row, main_row);
157177
match interaction_type {
158178
InteractionType::LocalSend | InteractionType::GlobalSend => {
159179
phi[n] += M::EF::from_base(mult) * perm_row[m];
@@ -187,6 +207,7 @@ where
187207

188208
let preprocessed = builder.preprocessed();
189209
let preprocessed_local = preprocessed.row_slice(0);
210+
let preprocessed_next = preprocessed.row_slice(1);
190211

191212
let perm = builder.permutation();
192213
let perm_width = perm.width();
@@ -220,8 +241,10 @@ where
220241

221242
let mult_local = interaction
222243
.count
223-
.apply::<AB::Expr, AB::Var>(&[], main_local);
224-
let mult_next = interaction.count.apply::<AB::Expr, AB::Var>(&[], main_next);
244+
.apply::<AB::Expr, AB::Var>(preprocessed_local, main_local);
245+
let mult_next = interaction
246+
.count
247+
.apply::<AB::Expr, AB::Var>(preprocessed_next, main_next);
225248

226249
// Build the RHS of the permutation constraint
227250
match interaction_type {
@@ -294,14 +317,20 @@ where
294317

295318
// TODO: Use Var and Expr type bounds in place of concrete fields so that
296319
// this function can be used in `eval_permutation_constraints`.
297-
fn reduce_row<F, EF>(row: &[F], fields: &[VirtualPairCol<F>], alpha: EF, betas: Powers<EF>) -> EF
320+
fn reduce_row<F, EF>(
321+
main_row: &[F],
322+
preprocessed_row: &[F],
323+
fields: &[VirtualPairCol<F>],
324+
alpha: EF,
325+
betas: Powers<EF>,
326+
) -> EF
298327
where
299328
F: Field,
300329
EF: ExtensionField<F>,
301330
{
302331
let mut rlc = EF::ZERO;
303332
for (columns, beta) in fields.iter().zip(betas) {
304-
rlc += beta * columns.apply::<F, F>(&[], row)
333+
rlc += beta * columns.apply::<F, F>(preprocessed_row, main_row)
305334
}
306335
rlc += alpha;
307336
rlc

machine/src/lib.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,7 @@ impl InstructionWord<i32> {
4343
pub fn flatten<F: PrimeField32>(&self) -> [F; INSTRUCTION_ELEMENTS] {
4444
let mut result = [F::default(); INSTRUCTION_ELEMENTS];
4545
result[0] = F::from_canonical_u32(self.opcode);
46-
self.operands.0.into_iter().enumerate().for_each(|(i, x)| {
47-
result[i] = if x >= 0 {
48-
F::from_canonical_u32(x as u32)
49-
} else {
50-
F::from_wrapped_u32((x as i64 + F::ORDER_U32 as i64) as u32)
51-
};
52-
});
46+
result[1..].copy_from_slice(&Operands::<F>::from_i32_slice(&self.operands.0).0);
5347
result
5448
}
5549
}

native_field/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ pub struct NativeFieldChip {
3434

3535
impl<F, M> Chip<M> for NativeFieldChip
3636
where
37-
F: PrimeField,
37+
F: PrimeField32,
3838
M: MachineWithGeneralBus<F = F> + MachineWithRangeBus8,
3939
{
4040
fn generate_trace(&self, _machine: &M) -> RowMajorMatrix<M::F> {

program/src/columns.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
use core::borrow::{Borrow, BorrowMut};
22
use core::mem::{size_of, transmute};
33
use valida_derive::AlignedBorrow;
4-
use valida_machine::Operands;
5-
use valida_machine::Word;
4+
use valida_machine::{Operands, Word};
65
use valida_util::indices_arr;
76

87
#[derive(AlignedBorrow, Default)]
@@ -11,6 +10,7 @@ pub struct ProgramCols<T> {
1110
pub multiplicity: T,
1211
}
1312

13+
#[derive(AlignedBorrow, Default)]
1414
pub struct ProgramPreprocessedCols<T> {
1515
pub opcode: T,
1616
pub operands: Operands<T>,

program/src/lib.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use alloc::vec;
77
use alloc::vec::Vec;
88
use core::iter;
99
use valida_bus::MachineWithProgramBus;
10-
use valida_machine::{Chip, Interaction, Machine, PrimeField, ProgramROM};
10+
use valida_machine::{Chip, Interaction, Machine, PrimeField32, ProgramROM};
1111

1212
use p3_air::VirtualPairCol;
1313
use p3_matrix::dense::RowMajorMatrix;
@@ -31,19 +31,21 @@ impl ProgramChip {
3131

3232
impl<F, M> Chip<M> for ProgramChip
3333
where
34-
F: PrimeField,
34+
F: PrimeField32,
3535
M: MachineWithProgramBus<F = F>,
3636
{
3737
fn generate_trace(&self, _machine: &M) -> RowMajorMatrix<M::F> {
3838
let n = self.program_rom.0.len();
39-
let col = self
39+
let cols = self
4040
.counts
4141
.iter()
42-
.map(|c| F::from_canonical_u32(*c))
42+
.enumerate()
43+
.flat_map(|(n, c)| [F::from_canonical_usize(n), F::from_canonical_u32(*c)])
4344
.chain(iter::repeat(F::ZERO))
44-
.take(n.next_power_of_two())
45+
.take(2 * n.next_power_of_two())
4546
.collect();
46-
RowMajorMatrix::new(col, 1)
47+
48+
RowMajorMatrix::new(cols, 2)
4749
}
4850

4951
fn global_receives(&self, machine: &M) -> Vec<Interaction<F>> {

0 commit comments

Comments
 (0)