Skip to content

Commit

Permalink
Make air public inputs deserializable (#1648)
Browse files Browse the repository at this point in the history
* Derive deserialize for AirPublicInput

* Add test

* Add more test cases

* Add Changelog entry

* Clippy

* Fix import

* Fix import

* fmt
  • Loading branch information
fmoletta authored Mar 7, 2024
1 parent 59a97c7 commit 6ae7b32
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#### Upcoming Changes

* feat: Make air public inputs deserializable [#1648](https://github.com/lambdaclass/cairo-vm/pull/1648)

* feat: Show only layout builtins in air private input [#1651](https://github.com/lambdaclass/cairo-vm/pull/1651)

* feat: Sort builtin segment info upon serialization for Cairo PIE [#1654](https://github.com/lambdaclass/cairo-vm/pull/1654)
Expand Down
95 changes: 90 additions & 5 deletions vm/src/air_public_input.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::Felt252;
use serde::Serialize;
use serde::{Deserialize, Serialize};
use thiserror_no_std::Error;

use crate::{
Expand All @@ -14,18 +14,21 @@ use crate::{
},
};

#[derive(Serialize, Debug)]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct PublicMemoryEntry {
pub address: usize,
#[serde(serialize_with = "mem_value_serde::serialize")]
#[serde(deserialize_with = "mem_value_serde::deserialize")]
pub value: Option<Felt252>,
pub page: usize,
}

mod mem_value_serde {
use core::fmt;

use super::*;

use serde::Serializer;
use serde::{de, Deserializer, Serializer};

pub(crate) fn serialize<S: Serializer>(
value: &Option<Felt252>,
Expand All @@ -37,9 +40,41 @@ mod mem_value_serde {
serializer.serialize_none()
}
}

pub(crate) fn deserialize<'de, D: Deserializer<'de>>(
d: D,
) -> Result<Option<Felt252>, D::Error> {
d.deserialize_str(Felt252OptionVisitor)
}

struct Felt252OptionVisitor;

impl<'de> de::Visitor<'de> for Felt252OptionVisitor {
type Value = Option<Felt252>;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("Could not deserialize hexadecimal string")
}

fn visit_none<E>(self) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(None)
}

fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Felt252::from_hex(value)
.map_err(de::Error::custom)
.map(Some)
}
}
}

#[derive(Serialize, Debug)]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct MemorySegmentAddresses {
pub begin_addr: usize,
pub stop_ptr: usize,
Expand All @@ -55,7 +90,7 @@ impl From<(usize, usize)> for MemorySegmentAddresses {
}
}

#[derive(Serialize, Debug)]
#[derive(Serialize, Deserialize, Debug)]
pub struct PublicInput<'a> {
pub layout: &'a str,
pub rc_min: isize,
Expand All @@ -64,6 +99,7 @@ pub struct PublicInput<'a> {
pub memory_segments: HashMap<&'a str, MemorySegmentAddresses>,
pub public_memory: Vec<PublicMemoryEntry>,
#[serde(rename = "dynamic_params")]
#[serde(skip_deserializing)] // This is set to None by default so we can skip it
layout_params: Option<&'a CairoLayout>,
}

Expand Down Expand Up @@ -139,3 +175,52 @@ pub enum PublicInputError {
#[error(transparent)]
Trace(#[from] TraceError),
}
#[cfg(test)]
mod tests {
#[cfg(feature = "std")]
use super::*;
#[cfg(feature = "std")]
use rstest::rstest;

#[cfg(feature = "std")]
#[rstest]
#[case(include_bytes!("../../cairo_programs/proof_programs/fibonacci.json"))]
#[case(include_bytes!("../../cairo_programs/proof_programs/bitwise_output.json"))]
#[case(include_bytes!("../../cairo_programs/proof_programs/keccak_builtin.json"))]
#[case(include_bytes!("../../cairo_programs/proof_programs/poseidon_builtin.json"))]
#[case(include_bytes!("../../cairo_programs/proof_programs/relocate_temporary_segment_append.json"))]
#[case(include_bytes!("../../cairo_programs/proof_programs/pedersen_test.json"))]
#[case(include_bytes!("../../cairo_programs/proof_programs/ec_op.json"))]
fn serialize_and_deserialize_air_public_input(#[case] program_content: &[u8]) {
let config = crate::cairo_run::CairoRunConfig {
proof_mode: true,
relocate_mem: true,
trace_enabled: true,
layout: "all_cairo",
..Default::default()
};
let (runner, vm) = crate::cairo_run::cairo_run(program_content, &config, &mut crate::hint_processor::builtin_hint_processor::builtin_hint_processor_definition::BuiltinHintProcessor::new_empty()).unwrap();
let public_input = runner.get_air_public_input(&vm).unwrap();
// We already know serialization works as expected due to the comparison against python VM
let serialized_public_input = public_input.serialize_json().unwrap();
let deserialized_public_input: PublicInput =
serde_json::from_str(&serialized_public_input).unwrap();
// Check that the deserialized public input is equal to the one we obtained from the vm first
assert_eq!(public_input.layout, deserialized_public_input.layout);
assert_eq!(public_input.rc_max, deserialized_public_input.rc_max);
assert_eq!(public_input.rc_min, deserialized_public_input.rc_min);
assert_eq!(public_input.n_steps, deserialized_public_input.n_steps);
assert_eq!(
public_input.memory_segments,
deserialized_public_input.memory_segments
);
assert_eq!(
public_input.public_memory,
deserialized_public_input.public_memory
);
assert!(
public_input.layout_params.is_none()
&& deserialized_public_input.layout_params.is_none()
);
}
}

0 comments on commit 6ae7b32

Please sign in to comment.