-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6 from NillionNetwork/feature/nada-algbra-0.3.0
Feature/nada algbra 0.3.0
- Loading branch information
Showing
18 changed files
with
145 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,57 @@ | ||
"""NN activations logic""" | ||
|
||
from typing import Union | ||
import nada_algebra as na | ||
from nada_ai.nn.module import Module | ||
from nada_dsl import Integer | ||
from nada_dsl import Integer, NadaType, SecretInteger, SecretBoolean, PublicBoolean | ||
|
||
|
||
class ReLU(Module): | ||
"""ReLU layer implementation""" | ||
|
||
def forward(self, x: na.NadaArray) -> na.NadaArray: | ||
"""Forward pass. | ||
""" | ||
Forward pass. | ||
Args: | ||
x (na.NadaArray): Input array. | ||
Returns: | ||
na.NadaArray: Module output. | ||
""" | ||
dtype = type(x.item(0)) | ||
if dtype in (na.Rational, na.SecretRational): | ||
mask = x.applypyfunc( | ||
lambda a: (a > na.Rational(0)).if_else( | ||
Integer(1 << a.log_scale), Integer(0) | ||
) | ||
) | ||
mask = mask.applypyfunc(lambda a: na.SecretRational(value=a)) | ||
if x.dtype in (na.Rational, na.SecretRational): | ||
mask = x.apply(self._rational_relu) | ||
else: | ||
mask = x.applypyfunc( | ||
lambda a: (a > Integer(0)).if_else(Integer(1), Integer(0)) | ||
) | ||
result = x * mask | ||
return result | ||
mask = x.apply(self._relu) | ||
|
||
return x * mask | ||
|
||
@staticmethod | ||
def _rational_relu( | ||
value: Union[na.Rational, na.SecretRational] | ||
) -> na.SecretRational: | ||
""" | ||
Element-wise ReLU logic for rational values. | ||
Args: | ||
value (Union[na.Rational, na.SecretRational]): Input rational. | ||
Returns: | ||
na.SecretRational: ReLU output rational. | ||
""" | ||
above_zero: Union[PublicBoolean, SecretBoolean] = value > na.rational(0) | ||
return above_zero.if_else(na.rational(1), na.rational(0)) | ||
|
||
@staticmethod | ||
def _relu(value: NadaType) -> SecretInteger: | ||
""" | ||
Element-wise ReLU logic for NadaType values. | ||
Args: | ||
value (NadaType): Input nada value. | ||
Returns: | ||
SecretInteger: Output nada value. | ||
""" | ||
above_zero: Union[PublicBoolean, SecretBoolean] = value > Integer(0) | ||
return above_zero.if_else(Integer(1), Integer(0)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.