Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sumcheck frontend #692

Merged
merged 74 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
b848b0e
add sumcheck frontend
mickeyasa Dec 11, 2024
92f3579
format
mickeyasa Dec 11, 2024
d6a9c59
strange compilation err
mickeyasa Dec 11, 2024
e6672ff
added test for kickoff
mickeyasa Dec 12, 2024
bb3b1d5
small fixes to sumcheck frontend compilation
yshekel Dec 15, 2024
46b4d72
fix rust warppers
yshekel Dec 15, 2024
563694d
enable sumcheck for all fields and curves
yshekel Dec 15, 2024
ee5b5de
removed Sumcheck test from field api
mickeyasa Dec 15, 2024
aa68bb5
dummy backend added
mickeyasa Dec 16, 2024
ffae6b4
compilation fix
yshekel Dec 16, 2024
f3331cc
test fix
mickeyasa Dec 19, 2024
6cdcee8
format
mickeyasa Dec 19, 2024
9e9ca5a
changinf sumcheck without hash compilatio to warning
mickeyasa Dec 19, 2024
766dbe9
removed sumcheck file duplication
mickeyasa Dec 22, 2024
588a84c
default hash on for cargo
mickeyasa Dec 22, 2024
1a83ec2
cmake fix
mickeyasa Dec 22, 2024
e7ed244
removed commented code
mickeyasa Dec 23, 2024
7248477
format
mickeyasa Dec 23, 2024
4a22355
review
mickeyasa Dec 23, 2024
a9b6256
added comment to the MLE polynomial
mickeyasa Dec 24, 2024
b989711
review fixes
mickeyasa Dec 24, 2024
6a71b95
format
mickeyasa Dec 24, 2024
44663c1
mle polynomials is a vector of pointers
mickeyasa Dec 25, 2024
85de44d
format
mickeyasa Dec 25, 2024
01b7137
cpu backend inplementation start
mickeyasa Dec 25, 2024
d61ca9e
nackend implementation
mickeyasa Dec 31, 2024
d8abe7a
backend implementation
mickeyasa Dec 31, 2024
6550860
Merge remote-tracking branch 'origin/main' into add-sumcheck-frontend
mickeyasa Jan 12, 2025
30fc2bb
compilation start
mickeyasa Jan 12, 2025
1628b52
compile, fail on verification test
mickeyasa Jan 12, 2025
3b1165e
verification failed on round 1
mickeyasa Jan 12, 2025
698fb7e
test pass
mickeyasa Jan 12, 2025
b2b5ff1
format
mickeyasa Jan 12, 2025
3928fd9
spell check
mickeyasa Jan 12, 2025
8cc6078
test fix
mickeyasa Jan 12, 2025
a00a1d1
format
mickeyasa Jan 12, 2025
17fb2d1
adjusting sumcheck test to all fields
mickeyasa Jan 13, 2025
d7673f6
format
mickeyasa Jan 13, 2025
3cb70fa
reduce alpha feet small and large fields
mickeyasa Jan 13, 2025
07c1e85
format
mickeyasa Jan 13, 2025
9e60bb9
documentation
mickeyasa Jan 13, 2025
53dd569
format
mickeyasa Jan 13, 2025
b7d5ddc
go comp issue
mickeyasa Jan 13, 2025
d94ebec
remove prints
mickeyasa Jan 13, 2025
7b06f73
return the std::
mickeyasa Jan 13, 2025
2840f91
name fix
mickeyasa Jan 13, 2025
c22d38e
include added for compilation
mickeyasa Jan 13, 2025
48f30df
review fixes
mickeyasa Jan 13, 2025
81d870f
format
mickeyasa Jan 13, 2025
5baee05
removed OR HASH
mickeyasa Jan 13, 2025
ae57939
review fixes
mickeyasa Jan 15, 2025
7956576
format
mickeyasa Jan 15, 2025
51d6e8d
spell
mickeyasa Jan 15, 2025
5bd4e56
added use_extension_field for Sumcheckconfig
mickeyasa Jan 15, 2025
6d20856
format
mickeyasa Jan 15, 2025
9ea4eea
enlarging the sumcheck test to 8k
mickeyasa Jan 15, 2025
a9da45d
fiat shamir moved to frontend
mickeyasa Jan 19, 2025
17a659e
format
mickeyasa Jan 19, 2025
12469fb
another format
mickeyasa Jan 19, 2025
bc5e639
Fix/release script (#721)
LeonHibnik Jan 13, 2025
8447e1a
Fix bug in CPU vec ops regarding nof workers (#731)
yshekel Jan 13, 2025
db78555
Support android and vulkan (#735)
yshekel Jan 14, 2025
28cab47
Parallelize-vecop-program-execution (#736)
mickeyasa Jan 14, 2025
31dbb47
Create docs for program & program execution (#722)
idanfr-ingo Jan 14, 2025
64767a6
Feat: Blake3 (#733)
aviadingo Jan 14, 2025
be8cd4c
Bump rust crates' version
Jan 14, 2025
7f98ad6
Bump docs version
Jan 14, 2025
8a02604
Update sidebars.ts (#729)
ShaniBabayoff Jan 14, 2025
752dc93
Update documentation for v3.4 (#738)
ShaniBabayoff Jan 14, 2025
87677d0
Deprecated icicle/api headers and updated examples/docs (#740)
yshekel Jan 15, 2025
31a5efb
avoid warning
mickeyasa Jan 20, 2025
5aa4982
Merge remote-tracking branch 'origin/main' into add-sumcheck-frontend
yshekel Jan 20, 2025
ff25dc3
review fixes
mickeyasa Jan 20, 2025
15073f0
format
mickeyasa Jan 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/cuda
Submodule cuda added at f373d8
4 changes: 4 additions & 0 deletions icicle/backend/cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ if (FIELD)
if (POSEIDON2)
target_sources(icicle_field PRIVATE src/hash/cpu_poseidon2.cpp)
endif()
if(SUMCHECK)
target_sources(icicle_field PRIVATE src/field/cpu_sumcheck.cpp)
endif()
target_include_directories(icicle_field PRIVATE include)
endif() # FIELD

Expand Down Expand Up @@ -76,3 +79,4 @@ if (HASH)
target_include_directories(icicle_hash PUBLIC include)
endif()


147 changes: 147 additions & 0 deletions icicle/backend/cpu/include/cpu_sumcheck.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
#pragma once

#include <vector>
#include <functional>
#include "icicle/program/symbol.h"
#include "icicle/program/program.h"
#include "cpu_sumcheck_transcript.h"
#include "cpu_program_executor.h"
#include "icicle/backend/sumcheck_backend.h"
#include "cpu_sumcheck_transcript.h"
namespace icicle {
template <typename F>
class CpuSumcheckBackend : public SumcheckBackend<F>
{
public:
CpuSumcheckBackend() : SumcheckBackend<F>() {}

// Calculate a proof for the mle polynomials
eIcicleError get_proof(
const std::vector<F*>& mle_polynomials,
const uint64_t mle_polynomial_size,
const F& claimed_sum,
const CombineFunction<F>& combine_function,
const SumcheckTranscriptConfig<F>&& transcript_config,
const SumcheckConfig& sumcheck_config,
SumcheckProof<F>& sumcheck_proof /*out*/) override
{
if (sumcheck_config.use_extension_field) {
ICICLE_LOG_ERROR << "SumcheckConfig::use_extension_field field = true is currently unsupported";
return eIcicleError::INVALID_ARGUMENT;
}
// Allocate memory for the intermediate calculation: the folded mle polynomials
const int nof_mle_poly = mle_polynomials.size();
std::vector<F*> folded_mle_polynomials(nof_mle_poly); // folded mle_polynomials with the same format as inputs
std::vector<F> folded_mle_polynomials_values(
nof_mle_poly * mle_polynomial_size / 2); // folded_mle_polynomials data itself
// init the folded_mle_polynomials pointers
for (int mle_polynomial_idx = 0; mle_polynomial_idx < nof_mle_poly; mle_polynomial_idx++) {
folded_mle_polynomials[mle_polynomial_idx] =
&(folded_mle_polynomials_values[mle_polynomial_idx * mle_polynomial_size / 2]);
}

// Check that the size of the the proof feet the size of the mle polynomials.
const uint32_t nof_rounds = std::log2(mle_polynomial_size);

// check that the combine function has a legal polynomial degree
int combine_function_poly_degree = combine_function.get_polynomial_degee();
if (combine_function_poly_degree < 0) {
ICICLE_LOG_ERROR << "Illegal polynomial degree (" << combine_function_poly_degree
<< ") for provided combine function";
return eIcicleError::INVALID_ARGUMENT;
}

// create sumcheck_transcript for the Fiat-Shamir
const uint32_t combine_function_poly_degree_u = combine_function_poly_degree;
CpuSumcheckTranscript<F> sumcheck_transcript(
claimed_sum, nof_rounds, combine_function_poly_degree_u, std::move(transcript_config));
sumcheck_proof.init(
nof_rounds,
combine_function_poly_degree_u); // reset the sumcheck proof to accumulate the round polynomials

// generate a program executor for the combine function
CpuProgramExecutor program_executor(combine_function);

// run log2(poly_size) rounds
int cur_mle_polynomial_size = mle_polynomial_size;
for (int round_idx = 0; round_idx < nof_rounds; ++round_idx) {
// For the first round work on the input mle_polynomials, else work on the folded
const std::vector<F*>& in_mle_polynomials = (round_idx == 0) ? mle_polynomials : folded_mle_polynomials;
std::vector<F>& round_polynomial = sumcheck_proof.get_round_polynomial(round_idx);

// build round polynomial and update the proof
build_round_polynomial(in_mle_polynomials, cur_mle_polynomial_size, program_executor, round_polynomial);

// if its not the last round, calculate alpha and fold the mle polynomials
if (round_idx + 1 < nof_rounds) {
F alpha = sumcheck_transcript.get_alpha(round_polynomial);
fold_mle_polynomials(alpha, cur_mle_polynomial_size, in_mle_polynomials, folded_mle_polynomials);
}
}
return eIcicleError::SUCCESS;
}

private:
void build_round_polynomial(
const std::vector<F*>& in_mle_polynomials,
const int mle_polynomial_size,
CpuProgramExecutor<F>& program_executor,
std::vector<F>& round_polynomial)
{
// init program_executor input pointers
const int nof_polynomials = in_mle_polynomials.size();
std::vector<F> combine_func_inputs(nof_polynomials);
for (int poly_idx = 0; poly_idx < nof_polynomials; ++poly_idx) {
program_executor.m_variable_ptrs[poly_idx] = &(combine_func_inputs[poly_idx]);
}
// init m_program_executor output pointer
F combine_func_result;
program_executor.m_variable_ptrs[nof_polynomials] = &combine_func_result;

const int round_poly_size = round_polynomial.size();
for (int element_idx = 0; element_idx < mle_polynomial_size / 2; ++element_idx) {
for (int poly_idx = 0; poly_idx < nof_polynomials; ++poly_idx) {
combine_func_inputs[poly_idx] = in_mle_polynomials[poly_idx][element_idx];
}
for (int k = 0; k < round_poly_size; ++k) {
// execute the combine functions and append to the round polynomial
program_executor.execute();
round_polynomial[k] = round_polynomial[k] + combine_func_result;

// if this is not the last k
if (k + 1 < round_poly_size) {
// update the combine program inputs for the next k
for (int poly_idx = 0; poly_idx < nof_polynomials; ++poly_idx) {
combine_func_inputs[poly_idx] = combine_func_inputs[poly_idx] -
in_mle_polynomials[poly_idx][element_idx] +
in_mle_polynomials[poly_idx][element_idx + mle_polynomial_size / 2];
}
}
}
}
}

// Fold the MLE polynomials based on alpha
void fold_mle_polynomials(
const F& alpha,
int& mle_polynomial_size,
const std::vector<F*>& in_mle_polynomials, // input
std::vector<F*>& folded_mle_polynomials) // output
{
const int nof_polynomials = in_mle_polynomials.size();
const F one_minus_alpha = F::one() - alpha;
mle_polynomial_size >>= 1; // update the mle_polynomial size to /2 det to folding

// run over all elements in all polynomials
for (int element_idx = 0; element_idx < mle_polynomial_size; ++element_idx) {
// init combine_func_inputs for k=0
for (int poly_idx = 0; poly_idx < nof_polynomials; ++poly_idx) {
folded_mle_polynomials[poly_idx][element_idx] =
one_minus_alpha * in_mle_polynomials[poly_idx][element_idx] +
alpha * in_mle_polynomials[poly_idx][element_idx + mle_polynomial_size];
}
}
}
};

} // namespace icicle
128 changes: 128 additions & 0 deletions icicle/backend/cpu/include/cpu_sumcheck_transcript.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#pragma once
#include "icicle/sumcheck/sumcheck_transcript_config.h"
#include <string.h>

template <typename S>
class CpuSumcheckTranscript
{
public:
CpuSumcheckTranscript(
const S& claimed_sum,
const uint32_t mle_polynomial_size,
const uint32_t combine_function_poly_degree,
const SumcheckTranscriptConfig<S>&& transcript_config)
: m_claimed_sum(claimed_sum), m_mle_polynomial_size(mle_polynomial_size),
m_combine_function_poly_degree(combine_function_poly_degree), m_transcript_config(std::move(transcript_config))
{
m_entry_0.clear();
m_round_idx = 0;
}

// add round polynomial to the transcript
S get_alpha(const std::vector<S>& round_poly)
{
// Make sure reset was called (Internal assertion)
ICICLE_ASSERT(m_mle_polynomial_size > 0) << "mle_polynomial_size must reset with value > 0";
ICICLE_ASSERT(m_combine_function_poly_degree > 0) << "combine_function_poly_degree must reset with value > 0";

const std::vector<std::byte>& round_poly_label = m_transcript_config.get_round_poly_label();
std::vector<std::byte> hash_input;
hash_input.reserve(2048);
(m_round_idx == 0) ? build_hash_input_round_0(hash_input, round_poly)
: build_hash_input_round_i(hash_input, round_poly);

// hash hash_input and return alpha
const Hash& hasher = m_transcript_config.get_hasher();
std::vector<std::byte> hash_result(hasher.output_size());
hasher.hash(hash_input.data(), hash_input.size(), m_config, hash_result.data());
m_round_idx++;
reduce_hash_result_to_field(m_prev_alpha, hash_result);
return m_prev_alpha;
}

private:
const SumcheckTranscriptConfig<S>&& m_transcript_config; // configuration how to build the transcript
HashConfig m_config; // hash config - default
uint32_t m_round_idx; //
std::vector<std::byte> m_entry_0; //
uint32_t m_mle_polynomial_size = 0;
uint32_t m_combine_function_poly_degree = 0;
const S m_claimed_sum;
S m_prev_alpha;

// append to hash_input a stream of bytes received as chars
void append_data(std::vector<std::byte>& byte_vec, const std::vector<std::byte>& label)
{
byte_vec.insert(byte_vec.end(), label.begin(), label.end());
}

// convert a vector of bytes to a field
void reduce_hash_result_to_field(S& alpha, const std::vector<std::byte>& hash_result)
{
alpha = S::zero();
const int nof_bytes_to_copy = std::min(sizeof(alpha), hash_result.size());
memcpy(&alpha, hash_result.data(), nof_bytes_to_copy);
alpha = alpha * S::one();
}

// append an integer uint32_t to hash input
void append_u32(std::vector<std::byte>& byte_vec, const uint32_t data)
{
const std::byte* data_bytes = reinterpret_cast<const std::byte*>(&data);
byte_vec.insert(byte_vec.end(), data_bytes, data_bytes + sizeof(uint32_t));
}

// append a field to hash input
void append_field(std::vector<std::byte>& byte_vec, const S& field)
{
const std::byte* data_bytes = reinterpret_cast<const std::byte*>(field.limbs_storage.limbs);
byte_vec.insert(byte_vec.end(), data_bytes, data_bytes + sizeof(S));
}

// round 0 hash input
void build_hash_input_round_0(std::vector<std::byte>& hash_input, const std::vector<S>& round_poly)
{
const std::vector<std::byte>& round_poly_label = m_transcript_config.get_round_poly_label();
// append entry_DS = [domain_separator_label || proof.mle_polynomial_size || proof.degree || public (hardcoded?) ||
// claimed_sum]
append_data(hash_input, m_transcript_config.get_domain_separator_label());
append_u32(hash_input, m_mle_polynomial_size);
append_u32(hash_input, m_combine_function_poly_degree);
append_field(hash_input, m_claimed_sum);

// append seed_rng
append_field(hash_input, m_transcript_config.get_seed_rng());

// append round_challenge_label
append_data(hash_input, m_transcript_config.get_round_challenge_label());

// build entry_0 = [round_poly_label || r_0[x].len() || k=0 || r_0[x]]
append_data(m_entry_0, round_poly_label);
append_u32(m_entry_0, round_poly.size());
append_u32(m_entry_0, m_round_idx);
for (const S& r_i : round_poly) {
append_field(hash_input, r_i);
}

// append entry_0
append_data(hash_input, m_entry_0);
}

// round !=0 hash input
void build_hash_input_round_i(std::vector<std::byte>& hash_input, const std::vector<S>& round_poly)
{
const std::vector<std::byte>& round_poly_label = m_transcript_config.get_round_poly_label();
// entry_i = [round_poly_label || r_i[x].len() || k=i || r_i[x]]
// alpha_i = Hash(entry_0 || alpha_(i-1) || round_challenge_label || entry_i).to_field()
append_data(hash_input, m_entry_0);
append_field(hash_input, m_prev_alpha);
append_data(hash_input, m_transcript_config.get_round_challenge_label());

append_data(hash_input, round_poly_label);
append_u32(hash_input, round_poly.size());
append_u32(hash_input, m_round_idx);
for (const S& r_i : round_poly) {
append_field(hash_input, r_i);
}
}
};
17 changes: 17 additions & 0 deletions icicle/backend/cpu/src/field/cpu_sumcheck.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include "icicle/backend/sumcheck_backend.h"
#include "cpu_sumcheck.h"

using namespace field_config;

namespace icicle {

template <typename F>
eIcicleError cpu_create_sumcheck_backend(const Device& device, std::shared_ptr<SumcheckBackend<F>>& backend /*OUT*/)
{
backend = std::make_shared<CpuSumcheckBackend<F>>();
return eIcicleError::SUCCESS;
}

REGISTER_SUMCHECK_FACTORY_BACKEND("CPU", cpu_create_sumcheck_backend<scalar_t>);

} // namespace icicle
16 changes: 8 additions & 8 deletions icicle/cmake/fields_and_curves.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
# Define available fields with an index and their supported features
# Format: index:field:features
set(ICICLE_FIELDS
1001:babybear:NTT,EXT_FIELD,POSEIDON,POSEIDON2
1002:stark252:NTT,POSEIDON,POSEIDON2
1003:m31:EXT_FIELD,POSEIDON,POSEIDON2
1004:koalabear:NTT,EXT_FIELD,POSEIDON,POSEIDON2
1001:babybear:NTT,EXT_FIELD,POSEIDON,POSEIDON2,SUMCHECK
1002:stark252:NTT,POSEIDON,POSEIDON2,SUMCHECK
1003:m31:EXT_FIELD,POSEIDON,POSEIDON2,SUMCHECK
1004:koalabear:NTT,EXT_FIELD,POSEIDON,POSEIDON2,SUMCHECK
)

# Define available curves with an index and their supported features
# Format: index:curve:features
set(ICICLE_CURVES
1:bn254:NTT,MSM,G2,ECNTT,POSEIDON,POSEIDON2,SUMCHECK
2:bls12_381:NTT,MSM,G2,ECNTT,POSEIDON,POSEIDON2
3:bls12_377:NTT,MSM,G2,ECNTT,POSEIDON,POSEIDON2
4:bw6_761:NTT,MSM,G2,ECNTT,POSEIDON,POSEIDON2
5:grumpkin:MSM,POSEIDON,POSEIDON2
2:bls12_381:NTT,MSM,G2,ECNTT,POSEIDON,POSEIDON2,SUMCHECK
3:bls12_377:NTT,MSM,G2,ECNTT,POSEIDON,POSEIDON2,SUMCHECK
4:bw6_761:NTT,MSM,G2,ECNTT,POSEIDON,POSEIDON2,SUMCHECK
5:grumpkin:MSM,POSEIDON,POSEIDON2,SUMCHECK
)
2 changes: 1 addition & 1 deletion icicle/cmake/target_editor.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ endfunction()
function(handle_sumcheck TARGET FEATURE_LIST)
if(SUMCHECK AND "SUMCHECK" IN_LIST FEATURE_LIST)
target_compile_definitions(${TARGET} PUBLIC SUMCHECK=${SUMCHECK})
target_sources(${TARGET} PRIVATE src/sumcheck/sumcheck_c_api.cpp)
target_sources(${TARGET} PRIVATE src/sumcheck/sumcheck.cpp src/sumcheck/sumcheck_c_api.cpp)
set(SUMCHECK ON CACHE BOOL "Enable SUMCHECK feature" FORCE)
else()
set(SUMCHECK OFF CACHE BOOL "SUMCHECK not available for this field" FORCE)
Expand Down
Loading
Loading