From da1608216468ffe56cd7c5addb40568721ae6ad3 Mon Sep 17 00:00:00 2001 From: Miki <100796045+mickeyasa@users.noreply.github.com> Date: Wed, 11 Dec 2024 19:30:08 +0200 Subject: [PATCH] Add-returnning-value-functiom-to-program (#688) ReturningValueProgram provide the ability to build a program based on a returning value function. --- icicle/include/icicle/program/program.h | 2 +- .../icicle/program/returning_value_program.h | 32 +++++ icicle/include/icicle/program/symbol.h | 43 +++++-- icicle/tests/test_field_api.cpp | 113 +++++++++++++++++- 4 files changed, 177 insertions(+), 13 deletions(-) create mode 100644 icicle/include/icicle/program/returning_value_program.h diff --git a/icicle/include/icicle/program/program.h b/icicle/include/icicle/program/program.h index 3cedb80d4..f8cf950d0 100644 --- a/icicle/include/icicle/program/program.h +++ b/icicle/include/icicle/program/program.h @@ -102,7 +102,7 @@ namespace icicle { int m_nof_constants = 0; int m_nof_intermidiates = 0; - const int get_nof_vars() const { return m_nof_parameters + m_nof_constants + m_nof_intermidiates; } + int get_nof_vars() const { return m_nof_parameters + m_nof_constants + m_nof_intermidiates; } static inline const int INST_OPCODE = 0; static inline const int INST_OPERAND1 = 1; diff --git a/icicle/include/icicle/program/returning_value_program.h b/icicle/include/icicle/program/returning_value_program.h new file mode 100644 index 000000000..9eff6515c --- /dev/null +++ b/icicle/include/icicle/program/returning_value_program.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include +#include "icicle/program/symbol.h" +#include "icicle/program/program.h" + +namespace icicle { + + /** + * @brief A class that convert the function with inputs and return value described by user into a program that can be + * executed. + */ + + template + class ReturningValueProgram : public Program + { + public: + // Generate a program based on a lambda function with multiple inputs and 1 output as a return value + ReturningValueProgram(std::function(std::vector>&)> program_func, int nof_inputs) + { + this->m_nof_parameters = nof_inputs + 1; + std::vector> program_parameters(this->m_nof_parameters); + this->set_as_inputs(program_parameters); + program_parameters[nof_inputs] = program_func(program_parameters); // place the output after the all inputs + this->generate_program(program_parameters); + } + + // Generate a program based on a PreDefinedPrograms + ReturningValueProgram(PreDefinedPrograms pre_def) : Program(pre_def) {} + }; +} // namespace icicle diff --git a/icicle/include/icicle/program/symbol.h b/icicle/include/icicle/program/symbol.h index 9e01329f5..47725d183 100644 --- a/icicle/include/icicle/program/symbol.h +++ b/icicle/include/icicle/program/symbol.h @@ -41,6 +41,7 @@ namespace icicle { // optional parameters: std::unique_ptr m_constant; // for OP_CONST: const value + int m_poly_degree; // number of multiplications so far // implementation: int m_variable_idx; // location at the intermediate variables vectors @@ -55,6 +56,7 @@ namespace icicle { : m_opcode(opcode), m_operand1(operand1), m_operand2(operand2), m_variable_idx(variable_idx), m_constant(std::move(constant)) { + update_poly_degree(); } bool is_visited(bool set_as_visit) @@ -76,6 +78,30 @@ namespace icicle { private: unsigned int m_visit_idx = 0; static inline unsigned int s_last_visit = 1; + + // update the current poly_degree based on the operands + void update_poly_degree() + { + // if one of the operand has undef poly_degree + if ((m_operand1 && m_operand1->m_poly_degree < 0) || (m_operand2 && m_operand2->m_poly_degree < 0)) { + m_poly_degree = -1; + return; + } + switch (m_opcode) { + case OP_ADD: + case OP_SUB: + m_poly_degree = std::max(m_operand1->m_poly_degree, m_operand2->m_poly_degree); + return; + case OP_MULT: + m_poly_degree = m_operand1->m_poly_degree + m_operand2->m_poly_degree; + return; + case OP_INV: + m_poly_degree = -1; // undefined + return; + default: + m_poly_degree = 0; + } + } }; /** @@ -117,7 +143,14 @@ namespace icicle { Symbol operator*(const S& operand) const { return multiply(Symbol(operand)); } Symbol operator*=(const Symbol& operand) { return assign(multiply(operand)); } Symbol operator*=(const S& operand) { return assign(multiply(Symbol(operand))); } - Symbol operator!() const { return inverse(); } + + // inverse + Symbol inverse() const + { + Symbol rv; + rv.m_operation = std::make_shared>(OP_INV, m_operation); + return rv; + } void set_as_input(int input_idx) { @@ -155,14 +188,6 @@ namespace icicle { return rv; } - // inverse - Symbol inverse() const - { - Symbol rv; - rv.m_operation = std::make_shared>(OP_INV, m_operation); - return rv; - } - std::shared_ptr> m_operation; }; diff --git a/icicle/tests/test_field_api.cpp b/icicle/tests/test_field_api.cpp index ae863832a..3414811a7 100644 --- a/icicle/tests/test_field_api.cpp +++ b/icicle/tests/test_field_api.cpp @@ -13,6 +13,7 @@ #include "icicle/program/symbol.h" #include "icicle/program/program.h" +#include "icicle/program/returning_value_program.h" #include "../../icicle/backend/cpu/include/cpu_program_executor.h" #include "test_base.h" @@ -906,6 +907,8 @@ TYPED_TEST(FieldApiTest, ntt) // define program using MlePoly = Symbol; +// define program +using MlePoly = Symbol; void lambda_multi_result(std::vector& vars) { const MlePoly& A = vars[0]; @@ -913,7 +916,8 @@ void lambda_multi_result(std::vector& vars) const MlePoly& C = vars[2]; const MlePoly& EQ = vars[3]; vars[4] = EQ * (A * B - C) + scalar_t::from(9); - vars[5] = A * B - !C; + vars[5] = A * B - C.inverse(); + vars[6] = vars[5]; } TEST_F(FieldApiTestBase, CpuProgramExecutorMultiRes) @@ -924,8 +928,9 @@ TEST_F(FieldApiTestBase, CpuProgramExecutorMultiRes) scalar_t eq = scalar_t::rand_host(); scalar_t res_0; scalar_t res_1; + scalar_t res_2; - Program program(lambda_multi_result, 6); + Program program(lambda_multi_result, 7); CpuProgramExecutor prog_exe(program); // init program @@ -935,6 +940,7 @@ TEST_F(FieldApiTestBase, CpuProgramExecutorMultiRes) prog_exe.m_variable_ptrs[3] = &eq; prog_exe.m_variable_ptrs[4] = &res_0; prog_exe.m_variable_ptrs[5] = &res_1; + prog_exe.m_variable_ptrs[6] = &res_2; // execute prog_exe.execute(); @@ -945,10 +951,111 @@ TEST_F(FieldApiTestBase, CpuProgramExecutorMultiRes) scalar_t expected_res_1 = a * b - scalar_t::inverse(c); ASSERT_EQ(res_1, expected_res_1); + ASSERT_EQ(res_2, res_1); +} + +MlePoly returning_value_func(const std::vector& inputs) +{ + const MlePoly& A = inputs[0]; + const MlePoly& B = inputs[1]; + const MlePoly& C = inputs[2]; + const MlePoly& EQ = inputs[3]; + return (EQ * (A * B - C)); +} + +TEST_F(FieldApiTestBase, CpuProgramExecutorReturningVal) +{ + // randomize input vectors + const int total_size = 100000; + auto in_a = std::make_unique(total_size); + scalar_t::rand_host_many(in_a.get(), total_size); + auto in_b = std::make_unique(total_size); + scalar_t::rand_host_many(in_b.get(), total_size); + auto in_c = std::make_unique(total_size); + scalar_t::rand_host_many(in_c.get(), total_size); + auto in_eq = std::make_unique(total_size); + scalar_t::rand_host_many(in_eq.get(), total_size); + + //----- element wise operation ---------------------- + auto out_element_wise = std::make_unique(total_size); + START_TIMER(element_wise_op) + for (int i = 0; i < 100000; ++i) { + out_element_wise[i] = in_eq[i] * (in_a[i] * in_b[i] - in_c[i]); + } + END_TIMER(element_wise_op, "Straight forward function (Element wise) time: ", true); + + //----- explicit program ---------------------- + ReturningValueProgram program_explicit(returning_value_func, 4); + + CpuProgramExecutor prog_exe_explicit(program_explicit); + auto out_explicit_program = std::make_unique(total_size); + + // init program + prog_exe_explicit.m_variable_ptrs[0] = in_a.get(); + prog_exe_explicit.m_variable_ptrs[1] = in_b.get(); + prog_exe_explicit.m_variable_ptrs[2] = in_c.get(); + prog_exe_explicit.m_variable_ptrs[3] = in_eq.get(); + prog_exe_explicit.m_variable_ptrs[4] = out_explicit_program.get(); + + // run on all vectors + START_TIMER(explicit_program) + for (int i = 0; i < total_size; ++i) { + prog_exe_explicit.execute(); + (prog_exe_explicit.m_variable_ptrs[0])++; + (prog_exe_explicit.m_variable_ptrs[1])++; + (prog_exe_explicit.m_variable_ptrs[2])++; + (prog_exe_explicit.m_variable_ptrs[3])++; + (prog_exe_explicit.m_variable_ptrs[4])++; + } + END_TIMER(explicit_program, "Explicit program executor time: ", true); + + // check correctness + ASSERT_EQ(0, memcmp(out_element_wise.get(), out_explicit_program.get(), total_size * sizeof(scalar_t))); + + //----- predefined program ---------------------- + Program predef_program(EQ_X_AB_MINUS_C); + + CpuProgramExecutor prog_exe_predef(predef_program); + auto out_predef_program = std::make_unique(total_size); + + // init program + prog_exe_predef.m_variable_ptrs[0] = in_a.get(); + prog_exe_predef.m_variable_ptrs[1] = in_b.get(); + prog_exe_predef.m_variable_ptrs[2] = in_c.get(); + prog_exe_predef.m_variable_ptrs[3] = in_eq.get(); + prog_exe_predef.m_variable_ptrs[4] = out_predef_program.get(); + + // run on all vectors + START_TIMER(predef_program) + for (int i = 0; i < total_size; ++i) { + prog_exe_predef.execute(); + (prog_exe_predef.m_variable_ptrs[0])++; + (prog_exe_predef.m_variable_ptrs[1])++; + (prog_exe_predef.m_variable_ptrs[2])++; + (prog_exe_predef.m_variable_ptrs[3])++; + (prog_exe_predef.m_variable_ptrs[4])++; + } + END_TIMER(predef_program, "Program predefined time: ", true); + + // check correctness + ASSERT_EQ(0, memcmp(out_element_wise.get(), out_predef_program.get(), total_size * sizeof(scalar_t))); + + //----- Vecops operation ---------------------- + auto config = default_vec_ops_config(); + auto out_vec_ops = std::make_unique(total_size); + + START_TIMER(vecop) + vector_mul(in_a.get(), in_b.get(), total_size, config, out_vec_ops.get()); // A * B + vector_sub(out_vec_ops.get(), in_c.get(), total_size, config, out_vec_ops.get()); // A * B - C + vector_mul(out_vec_ops.get(), in_eq.get(), total_size, config, out_vec_ops.get()); // EQ * (A * B - C) + END_TIMER(predef_program, "Vec ops time: ", true); + + // check correctness + ASSERT_EQ(0, memcmp(out_element_wise.get(), out_vec_ops.get(), total_size * sizeof(scalar_t))); } int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); -} +} \ No newline at end of file