diff --git a/src/frontend/dsl/mod.rs b/src/frontend/dsl/mod.rs index 9a05fc5e..7f1dbf63 100644 --- a/src/frontend/dsl/mod.rs +++ b/src/frontend/dsl/mod.rs @@ -18,12 +18,13 @@ use self::{ pub use sc::*; #[derive(Debug)] -/// A generic structure designed to handle the context of a circuit for generic types `F`, -/// `TraceArgs` and `StepArgs`. The struct contains a `Circuit` instance and implements -/// methods to build the circuit, add various components, and manipulate the circuit. `F` is a -/// generic type representing the field of the circuit. `TraceArgs` is a generic type -/// representing the arguments passed to the trace function. `StepArgs` is a generic type -/// representing the arguments passed to the `step_type_def` function. +/// A generic structure designed to handle the context of a circuit for generic types +/// `F`, `TraceArgs` and `StepArgs`. +/// The struct contains a `Circuit` instance and implements methods to build the circuit, +/// add various components, and manipulate the circuit. +/// `F` is a generic type representing the field of the circuit. +/// `TraceArgs` is a generic type representing the arguments passed to the trace function. +/// `StepArgs` is a generic type representing the arguments passed to the `step_type_def` function. pub struct CircuitContext { circuit: Circuit, tables: LookupTableRegistry, @@ -433,4 +434,205 @@ mod tests { assert!(!context.circuit.q_enable); } + + #[test] + fn test_set_num_steps() { + let circuit: Circuit = Circuit::default(); + let mut context = CircuitContext { + circuit, + tables: Default::default(), + }; + + context.pragma_num_steps(3); + assert_eq!(context.circuit.num_steps, 3); + + context.pragma_num_steps(0); + assert_eq!(context.circuit.num_steps, 0); + } + + #[test] + fn test_forward() { + // create circuit context + let circuit: Circuit = Circuit::default(); + let mut context = CircuitContext { + circuit, + tables: Default::default(), + }; + + // set forward signals + let forward_a: Queriable = context.forward("forward_a"); + let forward_b: Queriable = context.forward("forward_b"); + + // assert forward signals are correct + assert_eq!(context.circuit.forward_signals.len(), 2); + assert_eq!(context.circuit.forward_signals[0].uuid(), forward_a.uuid()); + assert_eq!(context.circuit.forward_signals[1].uuid(), forward_b.uuid()); + } + + #[test] + fn test_forward_with_phase() { + // create circuit context + let circuit: Circuit = Circuit::default(); + let mut context = CircuitContext { + circuit, + tables: Default::default(), + }; + + // set forward signals with specified phase + context.forward_with_phase("forward_a", 1); + context.forward_with_phase("forward_b", 2); + + // assert forward signals are correct + assert_eq!(context.circuit.forward_signals.len(), 2); + assert_eq!(context.circuit.forward_signals[0].phase(), 1); + assert_eq!(context.circuit.forward_signals[1].phase(), 2); + } + + #[test] + fn test_shared() { + // create circuit context + let circuit: Circuit = Circuit::default(); + let mut context = CircuitContext { + circuit, + tables: Default::default(), + }; + + // set shared signal + let shared_a: Queriable = context.shared("shared_a"); + + // assert shared signal is correct + assert_eq!(context.circuit.shared_signals.len(), 1); + assert_eq!(context.circuit.shared_signals[0].uuid(), shared_a.uuid()); + } + + #[test] + fn test_shared_with_phase() { + // create circuit context + let circuit: Circuit = Circuit::default(); + let mut context = CircuitContext { + circuit, + tables: Default::default(), + }; + + // set shared signal with specified phase + context.shared_with_phase("shared_a", 2); + + // assert shared signal is correct + assert_eq!(context.circuit.shared_signals.len(), 1); + assert_eq!(context.circuit.shared_signals[0].phase(), 2); + } + + #[test] + fn test_fixed() { + // create circuit context + let circuit: Circuit = Circuit::default(); + let mut context = CircuitContext { + circuit, + tables: Default::default(), + }; + + // set fixed signal + context.fixed("fixed_a"); + + // assert fixed signal was added to the circuit + assert_eq!(context.circuit.fixed_signals.len(), 1); + } + + #[test] + fn test_expose() { + // create circuit context + let circuit: Circuit = Circuit::default(); + let mut context = CircuitContext { + circuit, + tables: Default::default(), + }; + + // set forward signal and step to expose + let forward_a: Queriable = context.forward("forward_a"); + let step_offset: ExposeOffset = ExposeOffset::Last; + + // expose the forward signal of the final step + context.expose(forward_a, step_offset); + + // assert the signal is exposed + assert_eq!(context.circuit.exposed[0].0, forward_a); + assert_eq!( + std::mem::discriminant(&context.circuit.exposed[0].1), + std::mem::discriminant(&step_offset) + ); + } + + #[test] + fn test_step_type() { + // create circuit context + let circuit: Circuit = Circuit::default(); + let mut context = CircuitContext { + circuit, + tables: Default::default(), + }; + + // create a step type + let handler: StepTypeHandler = context.step_type("fibo_first_step"); + + // assert that the created step type was added to the circuit annotations + assert_eq!( + context.circuit.annotations[&handler.uuid()], + "fibo_first_step" + ); + } + + #[test] + fn test_step_type_def() { + // create circuit context + let circuit: Circuit = Circuit::default(); + let mut context = CircuitContext { + circuit, + tables: Default::default(), + }; + + // create a step type including its definition + let simple_step = context.step_type_def("simple_step", |context| { + context.setup(|_| {}); + context.wg(|_, _: u32| {}) + }); + + // assert step type was created and added to the circuit + assert_eq!( + context.circuit.annotations[&simple_step.uuid()], + "simple_step" + ); + assert_eq!( + simple_step.uuid(), + context.circuit.step_types[&simple_step.uuid()].uuid() + ); + } + + #[test] + fn test_step_type_def_pass_handler() { + // create circuit context + let circuit: Circuit = Circuit::default(); + let mut context = CircuitContext { + circuit, + tables: Default::default(), + }; + + // create a step type handler + let handler: StepTypeHandler = context.step_type("simple_step"); + + // create a step type including its definition + let simple_step = context.step_type_def(handler, |context| { + context.setup(|_| {}); + context.wg(|_, _: u32| {}) + }); + + // assert step type was created and added to the circuit + assert_eq!( + context.circuit.annotations[&simple_step.uuid()], + "simple_step" + ); + assert_eq!( + simple_step.uuid(), + context.circuit.step_types[&simple_step.uuid()].uuid() + ); + } }