Skip to content

Commit 12508a9

Browse files
Fix bellman memory measurements
1 parent 92ec298 commit 12508a9

File tree

9 files changed

+167
-158
lines changed

9 files changed

+167
-158
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,5 @@ gnark/template
114114
.idea/
115115

116116
tags/
117+
118+
*.*.swp

_scripts/parsers/csv_parser_rust.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def parse_csv(csv_filename, memory_folder, circuit):
4141
csv_reader = csv.DictReader(file)
4242
for row in csv_reader:
4343
if row['circuit'] == circuit:
44-
memory_file = os.path.join(memory_folder, f'halo2_{circuit}_memory_{row["operation"]}.txt')
44+
files = os.listdir(memory_folder)
45+
memory_filename = next(f for f in files if row["operation"] in f)
46+
memory_file = os.path.join(memory_folder, memory_filename)
4547
ram = extract_ram_from_file(memory_file)
4648
row['ram(mb)'] = ram
4749
csv_rows.append(row)

_scripts/reader/helper.py

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# BELLMAN PATHS
99
BELLMAN = os.path.join(MAIN_DIR, "bellman_circuits")
1010
BELLMAN_BENCH = os.path.join(BENCHMARKS_DIR, "bellman")
11+
BELLMAN_BENCH_MEMORY = os.path.join(BELLMAN_BENCH, "memory")
1112
BELLMAN_BENCH_JSON = os.path.join(BELLMAN_BENCH, "jsons")
1213
# BELLMAN_CE PATHS
1314
BELLMAN_CE = os.path.join(MAIN_DIR, "bellman_ce_circuits")

_scripts/reader/process_circuit.py

+28-13
Original file line numberDiff line numberDiff line change
@@ -136,34 +136,43 @@ def build_command_bellman(payload, count):
136136
for circuit, input_path in payload.circuit.items():
137137
for inp in helper.get_all_input_files(input_path):
138138
commands.append(f"cd {helper.BELLMAN}; ")
139-
output_mem_size = os.path.join(
140-
helper.BELLMAN_BENCH_JSON,
141-
circuit + "_" + os.path.basename(inp)
142-
)
143139
output_bench = os.path.join(
144140
helper.BELLMAN_BENCH_JSON,
145141
circuit + "_bench_" + os.path.basename(inp)
146142
)
147143
input_file = os.path.join("..", inp)
148-
command_mem_size: str = "RUSTFLAGS=-Awarnings cargo run --bin {binary} --release -- --input {input_file} --output {output}; ".format(
149-
binary=circuit,
150-
input_file=input_file,
151-
output=output_mem_size
152-
)
153-
commands.append(command_mem_size)
154144
command_bench: str = "RUSTFLAGS=-Awarnings INPUT_FILE={input_file} CIRCUIT={circuit} cargo criterion --message-format=json --bench {bench} 1> {output}; ".format(
155145
circuit=circuit,
156146
input_file=input_file,
157147
bench="benchmark_circuit",
158148
output=output_bench
159149
)
160150
commands.append(command_bench)
151+
# Memory commands
152+
os.makedirs(f"{helper.BELLMAN_BENCH_MEMORY}/{inp}", exist_ok=True)
153+
# Altough each operation need only a subset of the arguments we pass
154+
# all of them for simplicity
155+
os.makedirs(os.path.join(helper.BELLMAN, "tmp"), exist_ok=True)
156+
for op in payload.operation:
157+
cargo_cmd = "cargo run --bin {circuit} --release -- --input {inp} --phase {phase} --params {params} --proof {proof}".format(
158+
circuit=circuit,
159+
inp=input_file,
160+
phase=op,
161+
params=os.path.join("tmp", "params"),
162+
proof=os.path.join("tmp", "proof"),
163+
)
164+
commands.append(
165+
"RUSTFLAGS=-Awarnings {memory_cmd} -h -l {cargo} 2> {time_file} > /dev/null; ".format(
166+
memory_cmd=helper.MEMORY_CMD,
167+
cargo=cargo_cmd,
168+
time_file=f"{helper.BELLMAN_BENCH_MEMORY}/{inp}/bellman_{circuit}_memory_{op}.txt"
169+
)
170+
)
161171
commands.append("cd ..; ")
162172
out = os.path.join(
163173
helper.BELLMAN_BENCH,
164174
"bellman_bls12_381_" + circuit + ".csv"
165175
)
166-
167176
python_command = "python3"
168177
try:
169178
subprocess.run([python_command, "--version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
@@ -175,14 +184,20 @@ def build_command_bellman(payload, count):
175184
print("Neither Python nor Python3 are installed or accessible. Please install or check your path settings.")
176185
sys.exit(1)
177186

178-
transform_command = "{python} _scripts/parsers/criterion_rust_parser.py --framework bellman --category circuit --backend bellman --curve bls12_381 --input {inp} --criterion_json {bench} --mem_proof_json {mem} --output_csv {out}; ".format(
187+
transform_command = "{python} _scripts/parsers/criterion_rust_parser.py --framework bellman --category circuit --backend bellman --curve bls12_381 --input {inp} --criterion_json {bench} --proof {proof} --output_csv {out}; ".format(
179188
python=python_command,
180189
inp=inp,
181190
bench=output_bench,
182-
mem=output_mem_size,
191+
proof=os.path.join(helper.BELLMAN, "tmp", "proof"),
183192
out=out
184193
)
185194
commands.append(transform_command)
195+
time_merge = "python3 _scripts/parsers/csv_parser_rust.py --memory_folder {memory_folder} --time_filename {time_filename} --circuit {circuit}; ".format(
196+
memory_folder=os.path.join(helper.BELLMAN_BENCH_MEMORY, inp),
197+
time_filename=out,
198+
circuit=circuit
199+
)
200+
commands.append(time_merge)
186201

187202
# Join the commands into a single string
188203
command = "".join(commands)

bellman_circuits/.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
params
2+
proof
3+
tmp/

bellman_circuits/bellman_utils/Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ psutil = "3.2.2"
1414
serde = { version = "1.0", features = ["derive"] }
1515
serde_json = "1.0"
1616
bellman = "0.14.0"
17-
bls12_381 = "0.8.0"
17+
bls12_381 = "0.8.0"
18+
clap = { version = "4.2.7", features = ["derive"] }

bellman_circuits/bellman_utils/src/lib.rs

+62-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,27 @@
11
use std::{fs::File};
22
use std::fs;
3-
use bellman::groth16::Proof;
4-
use bls12_381::Bls12;
3+
use bellman::{Circuit, groth16};
4+
use bellman::groth16::{Proof, Parameters};
5+
use bls12_381::{Bls12, Scalar};
6+
use clap::Parser;
7+
use rand::rngs::OsRng;
8+
9+
10+
#[derive(Parser, Debug)]
11+
#[command(author, version, about, long_about = None)]
12+
pub struct BinaryArgs {
13+
#[arg(short, long)]
14+
pub input: String,
15+
16+
#[arg(short, long)]
17+
pub phase: String,
18+
19+
#[arg(short, long)]
20+
pub params: Option<String>,
21+
22+
#[arg(short, long)]
23+
pub proof: Option<String>,
24+
}
525

626
pub fn measure_size_in_bytes(proof: &Proof<Bls12>) -> usize {
727
// TODO: Should we serialize the proof in another format?
@@ -21,4 +41,44 @@ pub fn measure_size_in_bytes(proof: &Proof<Bls12>) -> usize {
2141
fs::remove_file(&temp_file_path).expect("Cannot remove temp file");
2242

2343
return size_in_mb;
44+
}
45+
46+
pub fn save_params(params_file: String, params: Parameters<Bls12>) {
47+
let mut file = File::create(&params_file).expect("Failed to create file");
48+
// Write the init_params to the file
49+
params.write(&mut file).expect("Failed to write params to file");
50+
}
51+
52+
pub fn load_params(params_file: String) -> Parameters<Bls12> {
53+
let mut file = File::open(&params_file).expect("Failed to open file");
54+
Parameters::read(&mut file, true).expect("Failed to read params from file")
55+
}
56+
57+
pub fn save_proof(proof_file: String, proof: Proof<Bls12>) {
58+
let mut file = File::create(&proof_file).expect("Failed to create file");
59+
// Write the proof to the file
60+
proof.write(&mut file).expect("Failed to write proof to file");
61+
}
62+
63+
pub fn load_proof(proof_file: String) -> Proof<Bls12> {
64+
let mut file = File::open(&proof_file).expect("Failed to open file");
65+
Proof::read(&mut file).expect("Failed to read proof from file")
66+
}
67+
68+
pub fn f_setup<C: Circuit<Scalar> + Clone>(circuit: C, params_file: String) {
69+
let params = groth16::generate_random_parameters::<Bls12, _, _>(circuit.clone(), &mut OsRng).unwrap();
70+
save_params(params_file, params);
71+
}
72+
73+
pub fn f_prove<C: Circuit<Scalar> + Clone>(circuit: C, params_file: String, proof_file: String) {
74+
let params = load_params(params_file);
75+
let proof = groth16::create_random_proof(circuit, &params, &mut OsRng).unwrap();
76+
save_proof(proof_file, proof);
77+
}
78+
79+
pub fn f_verify(params_file: String, proof_file: String, public_input: Vec<Scalar>) {
80+
let params = load_params(params_file);
81+
let pvk = groth16::prepare_verifying_key(&params.vk);
82+
let proof = load_proof(proof_file);
83+
assert!(groth16::verify_proof(&pvk, &proof, &public_input).is_ok());
2484
}
+38-70
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,55 @@
1-
// use bellman_circuits::benches::benchmark_circuit; // Assuming this is the path to the bench_proof function
21
use bellman_circuits::circuits::exponentiate;
32
use clap::{Parser};
43
use rust_utils::{
5-
get_memory,
64
read_file_contents,
7-
save_results,
85
};
9-
use bellman_utils::measure_size_in_bytes;
10-
use bellman::groth16;
6+
use bellman_utils::{BinaryArgs, f_setup, f_verify, f_prove};
117
use bellman::gadgets::multipack;
12-
use bls12_381::{Bls12, Scalar};
13-
use rand::rngs::OsRng;
8+
use bls12_381::Scalar;
149
use ff::PrimeField;
1510

16-
#[derive(Parser, Debug)]
17-
#[clap(
18-
name = "MemoryBenchExponentiate",
19-
about = "MemoryBenchExponentiate CLI is a CLI Application to Benchmark memory consumption of Exponentiate",
20-
version = "0.0.1"
21-
)]
22-
23-
struct Args {
24-
#[arg(short, long)]
25-
input: String,
26-
27-
#[arg(short, long)]
28-
output: String,
29-
}
30-
3111
fn main() {
3212
// Parse command line arguments
33-
let args = Args::parse();
13+
let args = BinaryArgs::parse();
3414

3515
// Read and parse input from the specified JSON file
3616
let input_str = read_file_contents(args.input);
3717

38-
// Get data from config
39-
let (x_64, e, y_64) = exponentiate::get_exponentiate_data(input_str);
40-
41-
// Create Scalar from some values
42-
let x = Scalar::from(x_64);
43-
let y = Scalar::from(y_64);
44-
45-
// Public inputs are x and y
46-
let x_bits = multipack::bytes_to_bits_le(&x.to_repr().as_ref());
47-
let y_bits = multipack::bytes_to_bits_le(&y.to_repr().as_ref());
48-
let inputs = [multipack::compute_multipacking(&x_bits), multipack::compute_multipacking(&y_bits)].concat();
49-
50-
51-
// Define the circuit
52-
let circuit = exponentiate::ExponentiationCircuit {
53-
x: Some(x),
54-
e: e,
55-
y: Some(y),
56-
};
57-
58-
// Get the initial memory usage
59-
let initial_rss = get_memory();
60-
61-
// Generate Parameters
62-
let params = groth16::generate_random_parameters::<Bls12, _, _>(circuit.clone(), &mut OsRng).unwrap();
63-
64-
// Prepare the verification key
65-
let pvk = groth16::prepare_verifying_key(&params.vk);
66-
67-
// Get the memory usage after setup
68-
let setup_rss = get_memory();
69-
70-
// Create a Groth16 proof with our parameters
71-
let proof = groth16::create_random_proof(circuit, &params, &mut OsRng).unwrap();
72-
73-
// Get the memory usage after proof generation
74-
let proof_rss = get_memory();
75-
76-
// Verify the proof
77-
let _ = groth16::verify_proof(&pvk, &proof, &inputs);
78-
79-
// Get the memory usage after proof verification
80-
let verify_rss = get_memory();
18+
// Get data from config
19+
let (x_64, e, y_64) = exponentiate::get_exponentiate_data(input_str);
20+
21+
// Create Scalar from some values
22+
let x = Scalar::from(x_64);
23+
let y = Scalar::from(y_64);
24+
25+
if args.phase == "setup" {
26+
let circuit = exponentiate::ExponentiationCircuit {
27+
x: Some(x),
28+
e: e,
29+
y: Some(y),
30+
};
31+
let params_file = args.params.expect("Missing params argument");
32+
f_setup(circuit, params_file);
33+
} else if args.phase == "prove" {
34+
let circuit = exponentiate::ExponentiationCircuit {
35+
x: Some(x),
36+
e: e,
37+
y: Some(y),
38+
};
39+
let params_file = args.params.expect("Missing params argument");
40+
let proof_file = args.proof.expect("Missing proof argument");
41+
f_prove(circuit, params_file, proof_file);
42+
} else if args.phase == "verify" {
43+
// Public inputs are x and y
44+
let x_bits = multipack::bytes_to_bits_le(&x.to_repr().as_ref());
45+
let y_bits = multipack::bytes_to_bits_le(&y.to_repr().as_ref());
46+
let inputs: Vec<Scalar> = [multipack::compute_multipacking(&x_bits), multipack::compute_multipacking(&y_bits)].concat();
47+
let params_file = args.params.expect("Missing params argument");
48+
let proof_file = args.proof.expect("Missing proof argument");
49+
f_verify(params_file, proof_file, inputs)
50+
} else {
51+
panic!("Invalid phase (should be setup, prove, or verify)");
52+
}
8153

82-
// Measure the proof size
83-
let proof_size = measure_size_in_bytes(&proof);
8454

85-
// Save the results
86-
save_results(initial_rss, setup_rss, proof_rss, verify_rss, proof_size, args.output);
8755
}

0 commit comments

Comments
 (0)