Skip to content

Commit

Permalink
Published jaxite_ec, a JAX based library to enable TPU to accelerate …
Browse files Browse the repository at this point in the history
…Multi-Scalar Multiplication (MSM) for enabling faster Zero Knowledge Proof.

PiperOrigin-RevId: 724030943
  • Loading branch information
The Jaxite Team authored and copybara-github committed Feb 9, 2025
1 parent d537bb2 commit d670ea8
Show file tree
Hide file tree
Showing 36 changed files with 6,250 additions and 2 deletions.
94 changes: 92 additions & 2 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ load("@rules_python//python:defs.bzl", "py_library", "py_test")

package(
default_applicable_licenses = ["@jaxite//:license"],
default_visibility = ["//visibility:public"],
default_visibility = [
"//visibility:public",
],
)

package_group(
Expand Down Expand Up @@ -37,10 +39,12 @@ py_library(
),
visibility = [":internal"],
deps = [
"@jaxite_deps_gmpy2//:pkg",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
# copybara: jax:pallas_lib
# copybara: jax:pallas_tpu
"@jaxite_deps_numpy//:pkg",
],
)

Expand All @@ -50,19 +54,23 @@ py_library(
srcs = ["jaxite/jaxite_lib/test_utils.py"],
deps = [
":jaxite",
"@jaxite_deps_gmpy2//:pkg",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
],
)

cpu_gpu_tpu_test(
# Test rules are below, though the source files are in subdirectories.
tpu_test(
name = "matrix_utils_test",
size = "small",
timeout = "moderate",
srcs = ["jaxite/jaxite_lib/matrix_utils_test.py"],
shard_count = 3,
deps = [
":jaxite",
# copybara: xprof_analysis_client # buildcleaner: keep
# copybara: xprof_session # buildcleaner: keep
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_hypothesis//:pkg",
Expand All @@ -80,6 +88,88 @@ tpu_test(
shard_count = 3,
deps = [
":jaxite",
# copybara: xprof_analysis_client # buildcleaner: keep
# copybara: xprof_session # buildcleaner: keep
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
"@jaxite_deps_numpy//:pkg",
],
)

tpu_test(
name = "finite_field_test",
size = "large",
timeout = "moderate",
srcs = ["jaxite_ec/finite_field_test.py"],
python_version = "PY3",
shard_count = 3,
srcs_version = "PY3ONLY",
deps = [
":jaxite",
# copybara: xprof_analysis_client # buildcleaner: keep
# copybara: xprof_session # buildcleaner: keep
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
"@jaxite_deps_numpy//:pkg",
],
)

tpu_test(
name = "msm_test",
size = "large",
timeout = "eternal",
srcs = [
"jaxite_ec/msm_test.py",
],
data = [
"jaxite_ec/test_case/t1/zprize_msm_curve_377_bases_dim_1_seed_0.csv",
"jaxite_ec/test_case/t1/zprize_msm_curve_377_res_dim_1_seed_0.csv",
"jaxite_ec/test_case/t1/zprize_msm_curve_377_scalars_dim_1_seed_0.csv",
"jaxite_ec/test_case/t1024/zprize_msm_curve_377_bases_dim_1024_seed_0.csv",
"jaxite_ec/test_case/t1024/zprize_msm_curve_377_res_dim_1024_seed_0.csv",
"jaxite_ec/test_case/t1024/zprize_msm_curve_377_scalars_dim_1024_seed_0.csv",
"jaxite_ec/test_case/t2/zprize_msm_curve_377_bases_dim_2_seed_0.csv",
"jaxite_ec/test_case/t2/zprize_msm_curve_377_res_dim_2_seed_0.csv",
"jaxite_ec/test_case/t2/zprize_msm_curve_377_scalars_dim_2_seed_0.csv",
"jaxite_ec/test_case/t4/zprize_msm_curve_377_bases_dim_4_seed_0.csv",
"jaxite_ec/test_case/t4/zprize_msm_curve_377_res_dim_4_seed_0.csv",
"jaxite_ec/test_case/t4/zprize_msm_curve_377_scalars_dim_4_seed_0.csv",
"jaxite_ec/test_case/t8/zprize_msm_curve_377_bases_dim_8_seed_0.csv",
"jaxite_ec/test_case/t8/zprize_msm_curve_377_res_dim_8_seed_0.csv",
"jaxite_ec/test_case/t8/zprize_msm_curve_377_scalars_dim_8_seed_0.csv",
],
python_version = "PY3",
shard_count = 3,
srcs_version = "PY3ONLY",
deps = [
":jaxite",
# copybara: xprof_analysis_client # buildcleaner: keep
# copybara: xprof_session # buildcleaner: keep
# copybara: resources
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
"@jaxite_deps_numpy//:pkg",
],
)

tpu_test(
name = "elliptic_curve_test",
size = "large",
timeout = "moderate",
srcs = ["jaxite_ec/elliptic_curve_test.py"],
python_version = "PY3",
shard_count = 3,
srcs_version = "PY3ONLY",
deps = [
":jaxite",
# copybara: xprof_analysis_client # buildcleaner: keep
# copybara: xprof_session # buildcleaner: keep
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_jax//:pkg",
Expand Down
Empty file added jaxite_ec/__init__.py
Empty file.
169 changes: 169 additions & 0 deletions jaxite_ec/algorithm/big_integer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
"""Big integer classes for jaxite_ec."""

import gmpy2


class GMPBigInteger:
"""A class representing a big integer using gmpy2.
This class provides basic arithmetic operations for big integers using gmpy2.
"""

def __init__(self, value) -> None:
if isinstance(value, (int, gmpy2.mpz)):
self.value = gmpy2.mpz(value)
elif isinstance(value, GMPBigInteger):
self.value = value.value
else:
raise TypeError("Unsupported type for GMPBigInteger initialization")

def __add__(self, other):
if isinstance(other, (GMPBigInteger, int)):
return GMPBigInteger(
self.value
+ gmpy2.mpz(
other.value if isinstance(other, GMPBigInteger) else other
)
)
return NotImplemented

def __sub__(self, other):
if isinstance(other, (GMPBigInteger, int)):
return GMPBigInteger(
self.value
- gmpy2.mpz(
other.value if isinstance(other, GMPBigInteger) else other
)
)
return NotImplemented

def __mul__(self, other):
if isinstance(other, (GMPBigInteger, int)):
return GMPBigInteger(
self.value
* gmpy2.mpz(
other.value if isinstance(other, GMPBigInteger) else other
)
)
return NotImplemented

def __truediv__(self, other):
if isinstance(other, (GMPBigInteger, int)):
if (
gmpy2.mpz(other.value if isinstance(other, GMPBigInteger) else other)
== 0
):
raise ZeroDivisionError("division by zero")
return GMPBigInteger(
self.value
// gmpy2.mpz(
other.value if isinstance(other, GMPBigInteger) else other
)
)
return NotImplemented

def __mod__(self, other):
if isinstance(other, (GMPBigInteger, int)):
return GMPBigInteger(
self.value
% gmpy2.mpz(
other.value if isinstance(other, GMPBigInteger) else other
)
)
return NotImplemented

def __eq__(self, other):
if isinstance(other, (GMPBigInteger, int)):
return self.value == gmpy2.mpz(
other.value if isinstance(other, GMPBigInteger) else other
)
return NotImplemented

def __ne__(self, other):
if isinstance(other, (GMPBigInteger, int)):
return self.value != gmpy2.mpz(
other.value if isinstance(other, GMPBigInteger) else other
)
return NotImplemented

def __lt__(self, other):
if isinstance(other, (GMPBigInteger, int)):
return self.value < gmpy2.mpz(
other.value if isinstance(other, GMPBigInteger) else other
)
return NotImplemented

def __le__(self, other):
if isinstance(other, (GMPBigInteger, int)):
return self.value <= gmpy2.mpz(
other.value if isinstance(other, GMPBigInteger) else other
)
return NotImplemented

def __gt__(self, other):
if isinstance(other, (GMPBigInteger, int)):
return self.value > gmpy2.mpz(
other.value if isinstance(other, GMPBigInteger) else other
)
return NotImplemented

def __ge__(self, other):
if isinstance(other, (GMPBigInteger, int)):
return self.value >= gmpy2.mpz(
other.value if isinstance(other, GMPBigInteger) else other
)
return NotImplemented

def __pow__(self, exponent, modulus=None):
if isinstance(exponent, (GMPBigInteger, int, gmpy2.mpz)):
if isinstance(exponent, GMPBigInteger):
exponent = gmpy2.mpz(exponent.value)
if isinstance(modulus, GMPBigInteger):
modulus = gmpy2.mpz(modulus.value)
if modulus is None:
return GMPBigInteger(self.value**exponent)
else:
return GMPBigInteger(gmpy2.powmod(self.value, exponent, modulus))
else:
print(type(exponent))
raise TypeError("Exponent must be an integer")

def __lshift__(self, shift):
"""Left shift operator (<<)."""
if isinstance(shift, GMPBigInteger):
shift = shift.value
return GMPBigInteger(self.value << shift)

def __rshift__(self, shift):
"""Right shift operator (>>)."""
if isinstance(shift, GMPBigInteger):
shift = shift.value
return GMPBigInteger(self.value >> shift)

def __and__(self, other):
if isinstance(other, (GMPBigInteger, int)):
return GMPBigInteger(
self.value
& gmpy2.mpz(
other.value if isinstance(other, GMPBigInteger) else other
)
)
return NotImplemented

def ceil_log2(self):
"""Calculate the base-2 logarithm of the GMPBigInteger."""
if self.value <= 0:
raise ValueError("log2 is only defined for positive integers")
return GMPBigInteger(gmpy2.ceil(gmpy2.log2(self.value)))

def __int__(self):
return int(self.value)

def __str__(self):
return str(self.value)

def __repr__(self):
return f"GMPBigInteger({self.value})"

def hex_value_str(self) -> str:
return hex(self.value)
13 changes: 13 additions & 0 deletions jaxite_ec/algorithm/config_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Jaxite EC algorithm configuration file."""

config_BLS12_377 = {
# A small prime field for simplicity
'prime': 0x01AE3A4617C510EAC63B05C06CA1493B1A22D9F300F5138F1EF3622FBA094800170B5D44300000008508C00000000001,
'order': 0x12AB655E9A2CA55660B44D1E5C37B00159AA76FED00000010A11800000000001,
'a': 0, # Coefficient a = 0
'b': 1, # Coefficient b = 1
'generator': [
0x008848DEFE740A67C8FC6225BF87FF5485951E2CAA9D41BB188282C8BD37CB5CD5481512FFCD394EEAB9B16EB21BE9EF,
0x01914A69C5102EFF1F674F5D30AFEEC4BD7FB348CA3E52D96D182AD44FB82305C2FE3D3634A9591AFD82DE55559C8EA6,
],
}
Loading

0 comments on commit d670ea8

Please sign in to comment.