Skip to content

Commit c0f7d97

Browse files
committed
General Array View
1 parent 53f3232 commit c0f7d97

File tree

5 files changed

+9096
-10687
lines changed

5 files changed

+9096
-10687
lines changed

librapid/bindings/generators/argument.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(self, *args, **kwargs):
2121
- default
2222
- const
2323
- ref
24+
- move
2425
- pointer
2526
- noConvert
2627
- returnPolicy
@@ -34,6 +35,7 @@ def __init__(self, *args, **kwargs):
3435
self.default = kwargs.get("default", None)
3536
self.const = kwargs.get("const", True)
3637
self.ref = kwargs.get("ref", True)
38+
self.move = kwargs.get("move", False)
3739
self.pointer = kwargs.get("pointer", False)
3840
self.noConvert = kwargs.get("noConvert", False)
3941

@@ -48,9 +50,11 @@ def __init__(self, *args, **kwargs):
4850
self.const = args[i]
4951
elif i == 4 and self.ref is None:
5052
self.ref = args[i]
51-
elif i == 5 and self.pointer is None:
53+
elif i == 5 and self.move is None:
54+
self.move = args[i]
55+
elif i == 6 and self.pointer is None:
5256
self.pointer = args[i]
53-
elif i == 6 and self.noConvert is None:
57+
elif i == 7 and self.noConvert is None:
5458
self.noConvert = args[i]
5559
else:
5660
raise ValueError("Too many arguments")
@@ -71,7 +75,7 @@ def param(self):
7175
return f"py::kwargs kwargs"
7276
else:
7377
isPrimitiveType = isPrimitive(self.type)
74-
return f"{'const ' if self.const and not isPrimitiveType else ''}{self.type} {'&' if self.ref and not isPrimitiveType else ''}{'*' if self.pointer else ''}{self.name}"
78+
return f"{'const ' if self.const and not isPrimitiveType else ''}{self.type} {'&' if self.ref and not isPrimitiveType else ''}{"&&" if self.move and not self.ref else ""}{'*' if self.pointer else ''}{self.name}"
7579

7680
def declaration(self):
7781
if self.default is None:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
import argument
2+
import function
3+
import class_
4+
import module
5+
import file
6+
7+
import itertools
8+
9+
# The set of Array types we support in Python
10+
arrayTypes = []
11+
12+
for scalar in [("int32_t", "Int32"),
13+
("int64_t", "Int64"),
14+
("float", "Float"),
15+
("double", "Double"),
16+
("lrc::Complex<float>", "ComplexFloat"),
17+
("lrc::Complex<double>", "ComplexDouble")]:
18+
for backend in ["CPU"]: # ["CPU", "OpenCL", "CUDA"]:
19+
arrayTypes.append({
20+
"scalar": scalar[0],
21+
"backend": backend,
22+
"name": f"GeneralArrayView{scalar[1]}{backend}"
23+
})
24+
25+
26+
def generateCppArrayType(config):
27+
return f"lrc::Array<{config['scalar']}, lrc::backend::{config['backend']}>"
28+
29+
30+
def generateCppArrayViewType(config):
31+
return f"lrc::array::GeneralArrayView<{generateCppArrayType(config)} &>"
32+
33+
34+
def generateFunctionsForGeneralArrayView(config):
35+
methods = [
36+
# From an existing Array type
37+
# function.Function(
38+
# name="__init__",
39+
# args=[
40+
# argument.Argument(
41+
# name="array",
42+
# type=generateCppArrayType(config),
43+
# const=False,
44+
# ref=False,
45+
# move=True
46+
# )
47+
# ],
48+
# op=f"""
49+
# return lrc::createGeneralArrayView(array);
50+
# """
51+
# ),
52+
53+
# Reference an existing GeneralArrayView
54+
# function.Function(
55+
# name="__init__",
56+
# args=[
57+
# argument.Argument(
58+
# name="arrView",
59+
# type=generateCppArrayViewType(config),
60+
# const=True,
61+
# ref=True
62+
# )
63+
# ],
64+
# op=f"""
65+
# return {generateCppArrayViewType(config)}(arrView);
66+
# """
67+
# ),
68+
69+
# Create a new GeneralArrayView
70+
function.Function(
71+
name="createFromArray",
72+
args=[
73+
argument.Argument(
74+
name="array",
75+
type=generateCppArrayType(config),
76+
const=False,
77+
ref=True
78+
)
79+
],
80+
op=f"""
81+
return lrc::createGeneralArrayView(array);
82+
""",
83+
static=True
84+
),
85+
86+
# Addition
87+
function.Function(
88+
name="__add__",
89+
args=[
90+
argument.Argument(
91+
name="self",
92+
type=generateCppArrayViewType(config),
93+
const=True,
94+
ref=True
95+
),
96+
argument.Argument(
97+
name="other",
98+
type=generateCppArrayViewType(config),
99+
const=True,
100+
ref=True
101+
)
102+
],
103+
op="""
104+
return (self + other).eval();
105+
"""
106+
),
107+
108+
# String representation
109+
function.Function(
110+
name="__str__",
111+
args=[
112+
argument.Argument(
113+
name="self",
114+
type=generateCppArrayViewType(config),
115+
const=True,
116+
ref=True
117+
)
118+
],
119+
op="""
120+
return fmt::format("{}", self);
121+
"""
122+
),
123+
124+
# String representation
125+
function.Function(
126+
name="__repr__",
127+
args=[
128+
argument.Argument(
129+
name="self",
130+
type=generateCppArrayViewType(config),
131+
const=True,
132+
ref=True
133+
)
134+
],
135+
op=f"""
136+
return fmt::format("<librapid.{config['name']} ~ {{}}>", self.shape());
137+
"""
138+
),
139+
140+
# Format (__format__)
141+
function.Function(
142+
name="__format__",
143+
args=[
144+
argument.Argument(
145+
name="self",
146+
type=generateCppArrayViewType(config),
147+
const=True,
148+
ref=True
149+
),
150+
argument.Argument(
151+
name="formatSpec",
152+
type="std::string",
153+
const=True,
154+
ref=True
155+
)
156+
],
157+
op="""
158+
std::string format = fmt::format("{{:{}}}", formatSpec);
159+
return fmt::format(fmt::runtime(format), self);
160+
"""
161+
)
162+
]
163+
164+
return methods, []
165+
166+
167+
def generateGeneralArrayViewModule(config):
168+
generalArrayViewClass = class_.Class(
169+
name=config["name"],
170+
type=generateCppArrayViewType(config)
171+
)
172+
173+
methods, functions = generateFunctionsForGeneralArrayView(config)
174+
generalArrayViewClass.functions.extend(methods)
175+
176+
includeGuard = None
177+
if config["backend"] == "CUDA":
178+
includeGuard = "defined(LIBRAPID_HAS_CUDA)"
179+
elif config["backend"] == "OpenCL":
180+
includeGuard = "defined(LIBRAPID_HAS_OPENCL)"
181+
182+
generalArrayViewModule = module.Module(
183+
name=f"librapid.GeneralArrayView.{config['name']}",
184+
includeGuard=includeGuard
185+
)
186+
generalArrayViewModule.addClass(generalArrayViewClass)
187+
generalArrayViewModule.functions.extend(functions)
188+
189+
return generalArrayViewModule
190+
191+
192+
def writeGeneralArrayView(root, config):
193+
fileType = file.File(
194+
path=f"{root}/GeneralArrayView_{config['name']}.cpp"
195+
)
196+
197+
fileType.modules.append(generateGeneralArrayViewModule(config))
198+
199+
interfaceFunctions = fileType.write()
200+
# Run clang-format if possible
201+
try:
202+
import subprocess
203+
204+
subprocess.run(["clang-format", "-i", fileType.path])
205+
except Exception as e:
206+
print("Unable to run clang-format:", e)
207+
208+
return interfaceFunctions
209+
210+
211+
def write(root):
212+
interfaces = []
213+
for config in arrayTypes:
214+
interfaces.extend(writeGeneralArrayView(root, config))
215+
return interfaces

librapid/bindings/generators/main.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33

44
import shapeGenerator
55
import arrayGenerator
6+
import generalArrayViewGenerator
67

78
outputDir = "../python/generated"
89

910
boilerplate = textwrap.dedent(f"""
1011
#pragma once
1112
12-
#define LIBRAPID_ASSERT
13+
#ifndef LIBRAPID_DEBUG
14+
#define LIBRAPID_DEBUG
15+
#endif
1316
1417
#include <librapid/librapid.hpp>
1518
#include <pybind11/pybind11.h>
@@ -30,6 +33,7 @@ def main():
3033

3134
interfaceFunctions += shapeGenerator.write(outputDir)
3235
interfaceFunctions += arrayGenerator.write(outputDir)
36+
interfaceFunctions += generalArrayViewGenerator.write(outputDir)
3337

3438
with open(f"{outputDir}/librapidPython.hpp", "w") as f:
3539
f.write(boilerplate)

librapid/include/librapid/array/assignOps.hpp

-2
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,6 @@ namespace librapid {
266266
*/
267267

268268
#if defined(LIBRAPID_HAS_OPENCL)
269-
270269
namespace opencl {
271270
template<typename T, typename std::enable_if_t<typetraits::TypeInfo<T>::type !=
272271
::librapid::detail::LibRapidType::Scalar,
@@ -347,7 +346,6 @@ namespace librapid {
347346
#endif // LIBRAPID_HAS_OPENCL
348347

349348
#if defined(LIBRAPID_HAS_CUDA)
350-
351349
namespace cuda {
352350
template<typename T, typename std::enable_if_t<typetraits::TypeInfo<T>::type !=
353351
::librapid::detail::LibRapidType::Scalar,

0 commit comments

Comments
 (0)