Skip to content

Commit

Permalink
Merge pull request #6 from NillionNetwork/feature/nada-algbra-0.3.0
Browse files Browse the repository at this point in the history
Feature/nada algbra 0.3.0
  • Loading branch information
mathias-nillion authored Jun 5, 2024
2 parents 051d98a + b28d97a commit ba8dc6b
Show file tree
Hide file tree
Showing 18 changed files with 145 additions and 85 deletions.
2 changes: 1 addition & 1 deletion examples/complex_model/src/my_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class MyOperations(nn.Module):

def forward(self, x: na.NadaArray) -> na.NadaArray:
"""Does some arbitrary operations for illustrative purposes"""
return (x * na.Rational(2)) - na.Rational(1)
return (x * na.rational(2)) - na.rational(1)


class MyModel(nn.Module):
Expand Down
15 changes: 10 additions & 5 deletions nada_ai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class ModelClient:
"""ML model client"""

def __init__(self, model: Any, state_dict: OrderedDict[str, np.ndarray]) -> None:
"""Initialization.
"""
Initialization.
Args:
model (Any): Model object to wrap around.
Expand All @@ -43,7 +44,8 @@ def __init__(self, model: Any, state_dict: OrderedDict[str, np.ndarray]) -> None

@classmethod
def from_torch(cls, model: nn.Module) -> "ModelClient":
"""Instantiates a model client from a PyTorch model.
"""
Instantiates a model client from a PyTorch model.
Args:
model (nn.Module): PyTorch nn.Module object.
Expand All @@ -56,7 +58,8 @@ def from_torch(cls, model: nn.Module) -> "ModelClient":

@classmethod
def from_sklearn(cls, model: sklearn.base.BaseEstimator) -> "ModelClient":
"""Instantiates a model client from a Sklearn estimator.
"""
Instantiates a model client from a Sklearn estimator.
Args:
model (sklearn.base.BaseEstimator): Sklearn estimator object.
Expand Down Expand Up @@ -100,7 +103,8 @@ def export_state_as_secrets(
name: str,
nada_type: _NillionType = na.SecretRational,
) -> Dict[str, _NillionType]:
"""Exports model state as a Dict of Nillion secret types.
"""
Exports model state as a Dict of Nillion secret types.
Args:
name (str): Name to be used to store state secrets in the network.
Expand All @@ -123,7 +127,8 @@ def export_state_as_secrets(
return state_secrets

def __ensure_numpy(self, array_like: Any) -> np.ndarray:
"""Ensures an array-like input is converted to a NumPy array.
"""
Ensures an array-like input is converted to a NumPy array.
Args:
array_like (Any): Some array-like input.
Expand Down
54 changes: 39 additions & 15 deletions nada_ai/nn/activations.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,57 @@
"""NN activations logic"""

from typing import Union
import nada_algebra as na
from nada_ai.nn.module import Module
from nada_dsl import Integer
from nada_dsl import Integer, NadaType, SecretInteger, SecretBoolean, PublicBoolean


class ReLU(Module):
"""ReLU layer implementation"""

def forward(self, x: na.NadaArray) -> na.NadaArray:
"""Forward pass.
"""
Forward pass.
Args:
x (na.NadaArray): Input array.
Returns:
na.NadaArray: Module output.
"""
dtype = type(x.item(0))
if dtype in (na.Rational, na.SecretRational):
mask = x.applypyfunc(
lambda a: (a > na.Rational(0)).if_else(
Integer(1 << a.log_scale), Integer(0)
)
)
mask = mask.applypyfunc(lambda a: na.SecretRational(value=a))
if x.dtype in (na.Rational, na.SecretRational):
mask = x.apply(self._rational_relu)
else:
mask = x.applypyfunc(
lambda a: (a > Integer(0)).if_else(Integer(1), Integer(0))
)
result = x * mask
return result
mask = x.apply(self._relu)

return x * mask

@staticmethod
def _rational_relu(
value: Union[na.Rational, na.SecretRational]
) -> na.SecretRational:
"""
Element-wise ReLU logic for rational values.
Args:
value (Union[na.Rational, na.SecretRational]): Input rational.
Returns:
na.SecretRational: ReLU output rational.
"""
above_zero: Union[PublicBoolean, SecretBoolean] = value > na.rational(0)
return above_zero.if_else(na.rational(1), na.rational(0))

@staticmethod
def _relu(value: NadaType) -> SecretInteger:
"""
Element-wise ReLU logic for NadaType values.
Args:
value (NadaType): Input nada value.
Returns:
SecretInteger: Output nada value.
"""
above_zero: Union[PublicBoolean, SecretBoolean] = value > Integer(0)
return above_zero.if_else(Integer(1), Integer(0))
48 changes: 29 additions & 19 deletions nada_ai/nn/layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""NN layers logic"""

from typing import Iterable, Optional, Union
from typing import Iterable, Union
import numpy as np
import nada_algebra as na
from nada_ai.nn.module import Module
Expand All @@ -16,7 +16,8 @@ class Linear(Module):
def __init__(
self, in_features: int, out_features: int, include_bias: bool = True
) -> None:
"""Linear (or fully-connected) layer.
"""
Linear (or fully-connected) layer.
Args:
in_features (int): Number of input features.
Expand All @@ -27,7 +28,8 @@ def __init__(
self.bias = Parameter(out_features) if include_bias else None

def forward(self, x: na.NadaArray) -> na.NadaArray:
"""Forward pass.
"""
Forward pass.
Args:
x (na.NadaArray): Input array.
Expand All @@ -48,18 +50,19 @@ def __init__(
in_channels: int,
out_channels: int,
kernel_size: _ShapeLike,
padding: Optional[_ShapeLike] = 0,
stride: Optional[_ShapeLike] = 1,
padding: _ShapeLike = 0,
stride: _ShapeLike = 1,
include_bias: bool = True,
) -> None:
"""2D-convolutional operator.
"""
2D-convolutional operator.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (_ShapeLike): Size of convolution kernel.
padding (Optional[_ShapeLike]): Padding length. Defaults to 0.
stride (Optional[_ShapeLike]): Stride length. Defaults to 1.
padding (_ShapeLike, optional): Padding length. Defaults to 0.
stride (_ShapeLike, optional): Stride length. Defaults to 1.
include_bias (bool, optional): Whether or not to include a bias term. Defaults to True.
"""
if isinstance(kernel_size, int):
Expand All @@ -80,7 +83,8 @@ def __init__(
self.bias = Parameter(out_channels) if include_bias else None

def forward(self, x: na.NadaArray) -> na.NadaArray:
"""Forward pass.
"""
Forward pass.
Args:
x (na.NadaArray): Input array.
Expand All @@ -98,6 +102,7 @@ def forward(self, x: na.NadaArray) -> na.NadaArray:
out_channels, _, kernel_rows, kernel_cols = self.weight.shape

if any(pad > 0 for pad in self.padding):
# TODO: avoid side-step to NumPy
padded_input = np.pad(
x.inner,
[
Expand Down Expand Up @@ -158,15 +163,16 @@ class AvgPool2d(Module):
def __init__(
self,
kernel_size: _ShapeLike,
stride: Optional[_ShapeLike] = None,
padding: Optional[_ShapeLike] = 0,
stride: _ShapeLike = None,
padding: _ShapeLike = 0,
) -> None:
"""2D-average pooling layer.
"""
2D-average pooling layer.
Args:
kernel_size (_ShapeLike): Size of pooling kernel.
stride (Optional[_ShapeLike]): Stride length. Defaults to the size of the pooling kernel.
padding (Optional[_ShapeLike]): Padding length. Defaults to 0.
stride (_ShapeLike, optional): Stride length. Defaults to the size of the pooling kernel.
padding (_ShapeLike, optional): Padding length. Defaults to 0.
"""
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
Expand All @@ -183,7 +189,8 @@ def __init__(
self.stride = stride

def forward(self, x: na.NadaArray) -> na.NadaArray:
"""Forward pass.
"""
Forward pass.
Args:
x (na.NadaArray): Input array.
Expand All @@ -198,9 +205,10 @@ def forward(self, x: na.NadaArray) -> na.NadaArray:
unbatched = True

batch_size, channels, input_height, input_width = x.shape
dtype = type(x.item(0))
dtype = x.dtype

if any(pad > 0 for pad in self.padding):
# TODO: avoid side-step to NumPy
padded_input = np.pad(
x.inner,
(
Expand Down Expand Up @@ -239,7 +247,7 @@ def forward(self, x: na.NadaArray) -> na.NadaArray:
pool_region = padded_input[b, c, start_h:end_h, start_w:end_w]

if dtype in (na.Rational, na.SecretRational):
pool_size = na.Rational(pool_region.size)
pool_size = na.rational(pool_region.size)
else:
pool_size = Integer(pool_region.size)

Expand All @@ -255,7 +263,8 @@ class Flatten(Module):
"""Flatten layer implementation"""

def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
"""Flatten operator.
"""
Flatten operator.
Args:
start_dim (int, optional): Flatten start dimension. Defaults to 1.
Expand All @@ -265,7 +274,8 @@ def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
self.end_dim = end_dim

def forward(self, x: na.NadaArray) -> na.NadaArray:
"""Forward pass.
"""
Forward pass.
Args:
x (na.NadaArray): Input array.
Expand Down
6 changes: 4 additions & 2 deletions nada_ai/nn/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ class LinearRegression(Module):
"""Linear regression implementation"""

def __init__(self, in_features: int, include_bias: bool = True) -> None:
"""Initialization.
"""
Initialization.
Args:
in_features (int): Number of input features to regression.
Expand All @@ -19,7 +20,8 @@ def __init__(self, in_features: int, include_bias: bool = True) -> None:
self.intercept = Parameter(1) if include_bias else None

def forward(self, x: na.NadaArray) -> na.NadaArray:
"""Forward pass.
"""
Forward pass.
Args:
x (na.NadaArray): Input array.
Expand Down
47 changes: 33 additions & 14 deletions nada_ai/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,33 @@ class Module(ABC):
"""Generic neural network module"""

@abstractmethod
def forward(self, x: na.NadaArray) -> na.NadaArray:
"""Forward pass"""
def forward(self, x: na.NadaArray, *args, **kwargs) -> na.NadaArray:
"""
Forward pass.
Args:
x (na.NadaArray): Input array.
Returns:
na.NadaArray: Output array.
"""
...

def __call__(self, x: na.NadaArray) -> na.NadaArray:
"""All calls get passed to forward method"""
return self.forward(x)
def __call__(self, x: na.NadaArray, *args, **kwargs) -> na.NadaArray:
"""
Proxy for forward pass.
Args:
x (na.NadaArray): Input array.
Returns:
na.NadaArray: Output array.
"""
return self.forward(x, *args, **kwargs)

def __named_parameters(self, prefix: str) -> Iterator[Tuple[str, Parameter]]:
"""Recursively generates all parameters in Module, its submodules, their submodules, etc.
"""
Recursively generates all parameters in Module, its submodules, their submodules, etc.
Args:
prefix (str): Named parameter prefix. Parameter names have a "."-delimited trace of
Expand All @@ -53,15 +70,17 @@ def __named_parameters(self, prefix: str) -> Iterator[Tuple[str, Parameter]]:
yield from value.__named_parameters(prefix=name)

def named_parameters(self) -> Iterator[Tuple[str, Parameter]]:
"""Generates all parameters in Module, its submodules, their submodules, etc.
"""
Generates all parameters in Module, its submodules, their submodules, etc.
Yields:
Iterator[Tuple[str, Parameter]]: Iterator over named parameters.
"""
yield from self.__named_parameters(prefix="")

def __numel(self) -> Iterator[int]:
"""Recursively generates number of elements in each Parameter in the module.
"""
Recursively generates number of elements in each Parameter in the module.
Yields:
Iterator[int]: Number of elements in each Parameter.
Expand All @@ -73,7 +92,8 @@ def __numel(self) -> Iterator[int]:
yield from value.__numel()

def numel(self) -> int:
"""Returns total number of elements in the module.
"""
Returns total number of elements in the module.
Returns:
int: Total number of elements.
Expand All @@ -86,16 +106,15 @@ def load_state_from_network(
party: Party,
nada_type: _NadaInteger = na.SecretRational,
) -> None:
"""Loads the model state from the Nillion network.
"""
Loads the model state from the Nillion network.
Args:
name (str): Name to be used to find state secrets in the network.
party (Party): Party that provided the model state in the network.
nada_type (_NadaInteger, optional): NadaType to interpret the state values as. Defaults to na.SecretRational.
"""
for param_name, param in self.named_parameters():
param_state = na.array(
param.shape, party, f"{name}_{param_name}", nada_type
)

state_name = f"{name}_{param_name}"
param_state = na.array(param.shape, party, state_name, nada_type)
param.load_state(param_state)
Loading

0 comments on commit ba8dc6b

Please sign in to comment.