Skip to content

Commit 4e6be3a

Browse files
feat(framework): add cython wrappers mode (#27770)
Co-authored-by: ivy-branch <[email protected]>
1 parent d755ad4 commit 4e6be3a

File tree

13 files changed

+375
-8
lines changed

13 files changed

+375
-8
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ with_time_logs/
2121
*.jpg
2222
*.jpeg
2323
*.gif
24+
*.so
2425
.hypothesis
2526
.array_api_tests_k_flag*
2627
internal_automation_tools/

ivy/__init__.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ class Node(str):
591591
nan_policy_stack = []
592592
dynamic_backend_stack = []
593593
warn_to_regex = {"all": "!.*", "ivy_only": "^(?!.*ivy).*$", "none": ".*"}
594-
594+
cython_wrappers_stack = []
595595

596596
# local
597597
import threading
@@ -741,7 +741,7 @@ class Node(str):
741741

742742
locks = {"backend_setter": threading.Lock()}
743743

744-
744+
from .wrappers import *
745745
from .func_wrapper import *
746746
from .data_classes.array import Array, add_ivy_array_instance_methods
747747
from .data_classes.array.conversions import *
@@ -776,6 +776,7 @@ class Node(str):
776776
choose_random_backend,
777777
unset_backend,
778778
)
779+
from . import wrappers
779780
from . import func_wrapper
780781
from .utils import assertions, exceptions, verbosity
781782
from .utils.backend import handler
@@ -961,6 +962,7 @@ def __deepcopy__(self, memo):
961962
"default_uint_dtype_stack": data_type.default_uint_dtype_stack,
962963
"nan_policy_stack": nan_policy_stack,
963964
"dynamic_backend_stack": dynamic_backend_stack,
965+
"cython_wrappers_stack": cython_wrappers_stack,
964966
})
965967

966968
_default_globals = copy.deepcopy(globals_vars)
@@ -1144,7 +1146,7 @@ def unset_nan_policy():
11441146
ivy.dynamic_backend = dynamic_backend_stack[-1] if dynamic_backend_stack else True
11451147

11461148

1147-
def set_dynamic_backend(flag):
1149+
def set_dynamic_backend(flag): # noqa: D209
11481150
"""Set the global dynamic backend setting to the provided flag (True or
11491151
False)"""
11501152
global dynamic_backend_stack
@@ -1166,6 +1168,37 @@ def unset_dynamic_backend():
11661168
ivy.__setattr__("dynamic_backend", flag, True)
11671169

11681170

1171+
# Cython wrappers
1172+
1173+
ivy.cython_wrappers_mode = cython_wrappers_stack[-1] if cython_wrappers_stack else False
1174+
1175+
1176+
@handle_exceptions
1177+
def set_cython_wrappers_mode(flag: bool = True) -> None:
1178+
"""Set the mode of whether to use cython wrappers for functions.
1179+
1180+
Parameter
1181+
---------
1182+
flag
1183+
boolean whether to use cython wrappers for functions
1184+
1185+
Examples
1186+
--------
1187+
>>> ivy.set_cython_wrappers_mode(False)
1188+
>>> ivy.cython_wrappers_mode
1189+
False
1190+
1191+
>>> ivy.set_cython_wrappers_mode(True)
1192+
>>> ivy.cython_wrappers_mode
1193+
True
1194+
"""
1195+
global cython_wrappers_stack
1196+
if flag not in [True, False]:
1197+
raise ValueError("cython_wrappers_mode must be a boolean value (True or False)")
1198+
cython_wrappers_stack.append(flag)
1199+
ivy.__setattr__("cython_wrappers_mode", flag, True)
1200+
1201+
11691202
# Context Managers
11701203

11711204

@@ -1438,6 +1471,7 @@ def cast_data_types(val=True):
14381471
"default_int_dtype",
14391472
"default_complex_dtype",
14401473
"default_uint_dtype",
1474+
"cython_wrappers_mode",
14411475
]
14421476

14431477

@@ -1479,7 +1513,7 @@ def set_logging_mode(self, mode):
14791513
logging.getLogger().setLevel(mode)
14801514
self.logging_mode_stack.append(mode)
14811515

1482-
def unset_logging_mode(self):
1516+
def unset_logging_mode(self): # noqa: D209
14831517
"""Remove the most recently set logging mode, returning to the previous
14841518
one."""
14851519
if len(self.logging_mode_stack) > 1:

ivy/func_wrapper.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,6 +1002,29 @@ def _to_ivy_array(x):
10021002
return _temp_asarray_wrapper
10031003

10041004

1005+
# Download compiled cython wrapper wrapper
1006+
1007+
1008+
def download_cython_wrapper_wrapper(fn: Callable) -> Callable:
1009+
@functools.wraps(fn)
1010+
def _download_cython_wrapper_wrapper(*args, **kwargs):
1011+
"""Wrap the function to download compiled cython wrapper for the
1012+
function and re-wraps it with the downloaded wrapper.
1013+
1014+
Download the compiled cython wrapper by calling
1015+
ivy.wrappers.get_wrapper(func_name: str) and then wrap the
1016+
function with the downloaded wrapper.
1017+
"""
1018+
ivy.wrappers.download_cython_wrapper(fn.__name__)
1019+
ivy.wrappers.load_one_wrapper(fn.__name__)
1020+
ivy.functional.__dict__[fn.__name__] = getattr(
1021+
ivy.wrappers, fn.__name__ + "_wrapper"
1022+
)(fn)
1023+
return ivy.functional.__dict__[fn.__name__](*args, **kwargs)
1024+
1025+
return _download_cython_wrapper_wrapper
1026+
1027+
10051028
# Functions #
10061029

10071030

@@ -1046,6 +1069,12 @@ def _wrap_function(
10461069
)
10471070
return to_wrap
10481071
if isinstance(to_wrap, FunctionType):
1072+
if ivy.cython_wrappers_mode and ivy.wrappers.wrapper_exists(to_wrap.__name__):
1073+
if to_wrap.__name__ + "_wrapper" in ivy.wrappers.__all__:
1074+
to_wrap = getattr(ivy.wrappers, to_wrap.__name__ + "_wrapper")(to_wrap)
1075+
return to_wrap
1076+
else:
1077+
return download_cython_wrapper_wrapper(to_wrap)
10491078
# set attributes
10501079
for attr in original.__dict__.keys():
10511080
# private attribute or decorator

ivy/utils/_importlib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# expected. Import these modules along with Ivy initialization, as the import logic
1313
# assumes they exist in sys.modules.
1414

15-
MODULES_TO_SKIP = ["ivy.compiler", "ivy.engines"]
15+
MODULES_TO_SKIP = ["ivy.compiler", "ivy.engines", "ivy.wrappers"]
1616

1717
IS_COMPILING_WITH_BACKEND = False
1818

ivy/utils/backend/handler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,6 @@ def set_backend(backend: str, dynamic: bool = False):
358358
global ivy_original_dict
359359
if not backend_stack:
360360
ivy_original_dict = ivy.__dict__.copy()
361-
362361
_clear_current_sub_backends()
363362
if isinstance(backend, str):
364363
temp_stack = []

ivy/wrappers/__init__.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import os
2+
import sys
3+
import glob
4+
import importlib
5+
6+
dir_path = os.path.dirname(os.path.realpath(__file__))
7+
so_files = glob.glob(dir_path + "/*.so")
8+
sys.path.append(dir_path)
9+
10+
__all__ = []
11+
12+
for so_file in so_files:
13+
# if os.path.basename(so_file) != "add.so":
14+
# continue
15+
module_name = os.path.splitext(os.path.basename(so_file))[0]
16+
17+
locals()[module_name] = importlib.import_module(module_name)
18+
19+
if module_name + "_wrapper" in locals()[module_name].__dict__.keys():
20+
locals()[module_name + "_wrapper"] = getattr(
21+
locals()[module_name], module_name + "_wrapper"
22+
)
23+
__all__.append(module_name + "_wrapper")
24+
25+
del dir_path
26+
del so_files
27+
28+
import utils
29+
from utils import *

ivy/wrappers/utils.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import os
2+
import logging
3+
import json
4+
from urllib import request
5+
import importlib
6+
import ivy
7+
8+
folder_path = os.sep.join(__file__.split(os.sep)[:-3])
9+
wrappers_path = os.path.join(folder_path, "wrappers.json")
10+
wrappers = json.loads(open(wrappers_path).read())
11+
wrapers_dir = os.path.join(folder_path, "ivy/wrappers")
12+
13+
14+
def download_cython_wrapper(func_name: str):
15+
"""Get the wrapper for the given function name."""
16+
if func_name + ".so" not in wrappers["ivy"]["functional"]:
17+
logging.warn(f"Wrapper for {func_name} not found.")
18+
return False
19+
try:
20+
response = request.urlopen(
21+
"https://raw.githubusercontent.com/unifyai"
22+
+ "/binaries/cython_wrappers/wrappers/"
23+
+ func_name
24+
+ ".so"
25+
)
26+
os.makedirs(wrapers_dir, exist_ok=True)
27+
with open(os.path.join(wrapers_dir, func_name + ".so"), "wb") as f:
28+
f.write(response.read())
29+
print("Downloaded wrapper for " + func_name)
30+
return True
31+
except request.HTTPError:
32+
logging.warn(f"Unable to download wrapper for {func_name}.")
33+
return False
34+
35+
36+
def wrapper_exists(func_name: str):
37+
"""Check if the wrapper for the given function name exists."""
38+
return func_name + ".so" in wrappers["ivy"]["functional"]
39+
40+
41+
def load_one_wrapper(func_name: str):
42+
"""Load the wrapper for the given function name."""
43+
module_name = func_name
44+
dir_path = os.path.dirname(os.path.realpath(__file__))
45+
# check if file exists
46+
if os.path.isfile(os.path.join(dir_path, module_name + ".so")):
47+
ivy.wrappers.__dict__[module_name] = importlib.import_module(module_name)
48+
ivy.wrappers.__dict__[module_name + "_wrapper"] = getattr(
49+
ivy.wrappers.__dict__[module_name], module_name + "_wrapper"
50+
)
51+
ivy.wrappers.__all__.append(module_name + "_wrapper")
52+
return True
53+
else:
54+
return False

ivy_tests/test_ivy/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,10 @@ def process_cl_flags(config) -> Dict[str, bool]:
292292
False,
293293
getopt("--with-transpile"),
294294
),
295+
"test_cython_wrapper": (
296+
getopt("--skip-cython-wrapper-testing"),
297+
getopt("--with-cython-wrapper-testing"),
298+
),
295299
}
296300

297301
# whether to skip gt testing or not
@@ -358,6 +362,8 @@ def pytest_addoption(parser):
358362
default=None,
359363
help="Print test items in my custom format",
360364
)
365+
parser.addoption("--skip-cython-wrapper-testing", action="store_true")
366+
parser.addoption("--with-cython-wrapper-testing", action="store_true")
361367

362368

363369
def pytest_collection_finish(session):

ivy_tests/test_ivy/helpers/function_testing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,11 @@ def test_function_backend_computation(
154154
test_flags.container[0] for _ in range(total_num_arrays)
155155
]
156156

157+
if test_flags.test_cython_wrapper:
158+
ivy.set_cython_wrappers_mode(True)
159+
else:
160+
ivy.set_cython_wrappers_mode(False)
161+
157162
with BackendHandler.update_backend(fw) as ivy_backend:
158163
# Update variable flags to be compatible with float dtype and with_out args
159164
test_flags.as_variable = [

ivy_tests/test_ivy/helpers/pipeline_helper.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@ class BackendHandlerMode(Enum):
1111

1212

1313
class WithBackendContext:
14-
def __init__(self, backend) -> None:
14+
def __init__(self, backend, cached=True) -> None:
1515
self.backend = backend
16+
self.cached = cached
1617

1718
def __enter__(self):
18-
return ivy.with_backend(self.backend)
19+
return ivy.with_backend(self.backend, cached=self.cached)
1920

2021
def __exit__(self, exc_type, exc_val, exc_tb):
2122
return

0 commit comments

Comments
 (0)