Skip to content

Commit

Permalink
Merge pull request #16 from serengil/feat-task-1212-homomorphic-suppo…
Browse files Browse the repository at this point in the history
…rt-for-tensors

Feat task 1212 homomorphic support for tensors
  • Loading branch information
serengil authored Dec 20, 2023
2 parents 2e855c0 + 378da76 commit ba53704
Show file tree
Hide file tree
Showing 6 changed files with 407 additions and 81 deletions.
28 changes: 18 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ Even though fully homomorphic encryption (FHE) has become available in recent ti
- 📏 Generating smaller ciphertexts
- 🧠 Well-suited for memory-constrained environments
- ⚖️ Strikes a favorable balance for practical use cases
- 🔑 Supporting encryption and decryption of vectors
- 🗝️ Performing homomorphic addition, homomorphic element-wise multiplication and scalar multiplication on encrypted vectors

# Installation [![PyPI](https://img.shields.io/pypi/v/lightphe.svg)](https://pypi.org/project/lightphe/)

Expand Down Expand Up @@ -157,24 +159,30 @@ with pytest.raises(ValueError, match="Paillier is not homomorphic with respect t

However, if you tried to multiply ciphertexts with RSA, or xor ciphertexts with Goldwasser-Micali, these will be succeeded because those cryptosystems support those homomorphic operations.

# Encrypt & Decrypt Tensors
# Working with vectors

You can encrypt the output tensors of machine learning models with LightPHE.
You can encrypt the output vectors of machine learning models with LightPHE. These encrypted tensors come with homomorphic operation support.

```python
# build an additively homomorphic cryptosystem
cs = LightPHE(algorithm_name="Paillier")

# define plain tensor
tensor = [1.005, 2.005, 3.005, -4.005, 5.005]
# define plain tensors
t1 = [1.005, 2.05, -3.5, 4]
t2 = [5, 6.2, 7.002, 8.02]

# encrypt tensor
encrypted_tensors = cs.encrypt(tensor)
# encrypt tensors
c1 = cs.encrypt(t1)
c2 = cs.encrypt(t2)

# decrypt tensor
decrypted_tensors = cs.decrypt(encrypted_tensors)
# perform homomorphic addition
c3 = c1 + c2

# decrypt the addition tensor
t3 = cs.decrypt(c3)

for i, decrypted_tensor in enumerate(decrypted_tensors):
assert tensor[i] == decrypted_tensor
for i, tensor in enumerate(t3):
assert abs((t1[i] + t2[i]) - restored_tensor) < 0.5
```

# Contributing
Expand Down
85 changes: 44 additions & 41 deletions lightphe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from lightphe.models.Homomorphic import Homomorphic
from lightphe.models.Ciphertext import Ciphertext
from lightphe.models.Algorithm import Algorithm
from lightphe.models.Tensor import EncryptedTensor, EncryptedTensors
from lightphe.models.Tensor import Fraction, EncryptedTensor
from lightphe.cryptosystems.RSA import RSA
from lightphe.cryptosystems.ElGamal import ElGamal
from lightphe.cryptosystems.Paillier import Paillier
Expand All @@ -29,6 +29,7 @@ def __init__(
keys: Optional[dict] = None,
key_file: Optional[str] = None,
key_size: Optional[int] = None,
precision: int = 5,
):
"""
Build LightPHE class
Expand All @@ -39,8 +40,10 @@ def __init__(
keys (dict): optional private-public key pair
key_file (str): if keys are exported, you can load them into cryptosystem
key_size (int): key size in bits
precision (int): precision for homomorphic operations on tensors
"""
self.algorithm_name = algorithm_name
self.precision = precision

if key_file is not None:
keys = self.restore_keys(target_file=key_file)
Expand Down Expand Up @@ -105,7 +108,7 @@ def __build_cryptosystem(
raise ValueError(f"unimplemented algorithm - {algorithm_name}")
return cs

def encrypt(self, plaintext: Union[int, float, list]) -> Union[Ciphertext, EncryptedTensors]:
def encrypt(self, plaintext: Union[int, float, list]) -> Union[Ciphertext, EncryptedTensor]:
"""
Encrypt a plaintext with a built cryptosystem
Args:
Expand All @@ -126,7 +129,7 @@ def encrypt(self, plaintext: Union[int, float, list]) -> Union[Ciphertext, Encry
return Ciphertext(algorithm_name=self.algorithm_name, keys=self.cs.keys, value=ciphertext)

def decrypt(
self, ciphertext: Union[Ciphertext, EncryptedTensors]
self, ciphertext: Union[Ciphertext, EncryptedTensor]
) -> Union[int, List[int], List[float]]:
"""
Decrypt a ciphertext with a buit cryptosystem
Expand All @@ -138,51 +141,62 @@ def decrypt(
if self.cs.keys.get("private_key") is None:
raise ValueError("You must have private key to perform decryption")

if isinstance(ciphertext, EncryptedTensors):
if isinstance(ciphertext, EncryptedTensor):
# then this is encrypted tensor
return self.__decrypt_tensors(encrypted_tensor=ciphertext)

return self.cs.decrypt(ciphertext=ciphertext.value)

def __encrypt_tensors(self, tensor: list) -> EncryptedTensors:
def __encrypt_tensors(self, tensor: list) -> EncryptedTensor:
"""
Encrypt a given tensor
Args:
tensor (list of int or float)
Returns
encrypted tensor (list of encrypted tensor object)
"""
encrypted_tensor: List[EncryptedTensor] = []
encrypted_tensor: List[Fraction] = []
for m in tensor:
sign = 1 if m >= 0 else -1
# get rid of sign anyway
m = m * sign
sign_encrypted = self.cs.encrypt(plaintext=sign)
if isinstance(m, int):
dividend_encrypted = self.cs.encrypt(plaintext=m)
divisor_encrypted = self.cs.encrypt(plaintext=1)
c = EncryptedTensor(
dividend_encrypted = self.cs.encrypt(
plaintext=(m % self.cs.plaintext_modulo) * pow(10, self.precision)
)
abs_dividend_encrypted = self.cs.encrypt(
plaintext=(abs(m) % self.cs.plaintext_modulo) * pow(10, self.precision)
)
divisor_encrypted = self.cs.encrypt(plaintext=pow(10, self.precision))
c = Fraction(
dividend=dividend_encrypted,
divisor=divisor_encrypted,
sign=sign_encrypted,
abs_dividend=abs_dividend_encrypted,
sign=1 if m >= 0 else -1,
)
elif isinstance(m, float):
dividend, divisor = phe_utils.fractionize(value=m, modulo=self.cs.plaintext_modulo)
dividend, divisor = phe_utils.fractionize(
value=(m % self.cs.plaintext_modulo),
modulo=self.cs.plaintext_modulo,
precision=self.precision,
)
abs_dividend, _ = phe_utils.fractionize(
value=(abs(m) % self.cs.plaintext_modulo),
modulo=self.cs.plaintext_modulo,
precision=self.precision,
)
dividend_encrypted = self.cs.encrypt(plaintext=dividend)
abs_dividend_encrypted = self.cs.encrypt(plaintext=abs_dividend)
divisor_encrypted = self.cs.encrypt(plaintext=divisor)
c = EncryptedTensor(
c = Fraction(
dividend=dividend_encrypted,
divisor=divisor_encrypted,
sign=sign_encrypted,
abs_dividend=abs_dividend_encrypted,
sign=1 if m >= 0 else -1,
)
else:
raise ValueError(f"unimplemented type - {type(m)}")
encrypted_tensor.append(c)
return EncryptedTensors(encrypted_tensor=encrypted_tensor)
return EncryptedTensor(fractions=encrypted_tensor, cs=self.cs)

def __decrypt_tensors(
self, encrypted_tensor: EncryptedTensors
) -> Union[List[int], List[float]]:
def __decrypt_tensors(self, encrypted_tensor: EncryptedTensor) -> Union[List[int], List[float]]:
"""
Decrypt a given encrypted tensor
Args:
Expand All @@ -191,26 +205,15 @@ def __decrypt_tensors(
List of plain tensors
"""
plain_tensor = []
for c in encrypted_tensor.encrypted_tensor:
if isinstance(c, EncryptedTensor) is False:
raise ValueError("Ciphertext items must be EncryptedTensor")

encrypted_dividend = c.dividend
encrypted_divisor = c.divisor
encrypted_sign = c.sign

dividend = self.cs.decrypt(ciphertext=encrypted_dividend)
divisor = self.cs.decrypt(ciphertext=encrypted_divisor)
sign = self.cs.decrypt(ciphertext=encrypted_sign)

if sign == self.cs.plaintext_modulo - 1:
sign = -1
elif sign == 1:
sign = 1
else:
raise ValueError("this cannot be true!")

m = sign * (dividend / divisor)
for c in encrypted_tensor.fractions:
if isinstance(c, Fraction) is False:
raise ValueError("Ciphertext items must be type of Fraction")

sign = c.sign
abs_dividend = self.cs.decrypt(ciphertext=c.abs_dividend)
# dividend = self.cs.decrypt(ciphertext=c.dividend)
divisor = self.cs.decrypt(ciphertext=c.divisor)
m = sign * abs_dividend / divisor

plain_tensor.append(m)
return plain_tensor
Expand Down
35 changes: 29 additions & 6 deletions lightphe/commons/phe_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Union, Tuple
from typing import Union, Tuple, Optional
from decimal import Decimal, getcontext
from lightphe.commons.logger import Logger

logger = Logger(module="lightphe/commons/phe_utils.py")

# pylint: disable=no-else-return
# pylint: disable=no-else-return, no-else-break


def parse_int(value: Union[int, float], modulo: int) -> int:
Expand All @@ -22,10 +23,32 @@ def parse_int(value: Union[int, float], modulo: int) -> int:
return result


def fractionize(value: float, modulo: int) -> Tuple[int, int]:
decimal_places = len(str(value).split(".")[1])
scaling_factor = 10**decimal_places
integer_value = int(value * scaling_factor) % modulo
def fractionize(value: float, modulo: int, precision: Optional[int] = None) -> Tuple[int, int]:
getcontext().prec = 50

if precision is None:
decimal_places = len(str(value).split(".")[1])
scaling_factor = 10**decimal_places
else:
scaling_factor = 10**precision

while True:
integer_value = int(Decimal(value) * Decimal(scaling_factor)) % modulo

if precision is None:
break

if scaling_factor > 10**precision:
# If scaling factor is too large, discard excess part of integer value
integer_value = int(integer_value / (10 ** (scaling_factor - 10**precision)))
break
elif scaling_factor < 10 ** (precision - 1):
# If scaling factor is too small, multiply dividend and divisor 10 times
value *= 10
scaling_factor *= 10
else:
break

logger.debug(f"{integer_value}*{scaling_factor}^-1 mod {modulo}")
return integer_value, scaling_factor

Expand Down
10 changes: 1 addition & 9 deletions lightphe/models/Ciphertext.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,7 @@ def __rmul__(self, constant: Union[int, float]) -> "Ciphertext":
Returns
scalar multiplication of ciphertext
"""
if self.cs.keys.get("public_key") is None:
raise ValueError("You must have public key to perform scalar multiplication")

if isinstance(constant, float):
constant = phe_utils.parse_int(value=constant, modulo=self.cs.plaintext_modulo)

# Handle multiplication with a constant on the right
result = self.cs.multiply_by_contant(ciphertext=self.value, constant=constant)
return Ciphertext(algorithm_name=self.algorithm_name, keys=self.keys, value=result)
return self.__mul__(other=constant)

def __xor__(self, other: "Ciphertext") -> "Ciphertext":
"""
Expand Down
Loading

0 comments on commit ba53704

Please sign in to comment.