Skip to content

Commit 7386298

Browse files
committed
Better array printing
1 parent c0f7d97 commit 7386298

14 files changed

+589
-102
lines changed

CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ if (${SKBUILD})
8888
set(module_name "_librapid")
8989

9090
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build shared libraries" FORCE)
91+
set(LIBRAPID_QUIET ON) # Disable warnings for a cleaner output.
9192

9293
message(STATUS "[ LIBRAPID ] Cloning PyBind11")
9394
FetchContent_Declare(

librapid/bindings/generators/arrayGenerator.py

+177-3
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def generateCppArrayType(config):
2828

2929

3030
def generateCppArrayViewType(config):
31-
return f"lrc::array::GeneralArrayView<{generateCppArrayType(config)}>"
31+
return f"lrc::array::GeneralArrayView<{generateCppArrayType(config)} &, lrc::Shape>"
3232

3333

3434
def generateFunctionsForArray(config):
@@ -64,7 +64,6 @@ def generateFunctionsForArray(config):
6464
"""
6565
)
6666
)
67-
6867
methods += [
6968
# Shape
7069
function.Function(
@@ -111,6 +110,107 @@ def generateFunctionsForArray(config):
111110
]
112111
),
113112

113+
# Get item
114+
function.Function(
115+
name="__getitem__",
116+
args=[
117+
argument.Argument(
118+
name="self",
119+
type=generateCppArrayType(config),
120+
const=False,
121+
ref=True
122+
),
123+
argument.Argument(
124+
name="index",
125+
type="int64_t"
126+
)
127+
],
128+
op="""
129+
return self[index];
130+
"""
131+
),
132+
133+
# Set item (GeneralArrayView)
134+
function.Function(
135+
name="__setitem__",
136+
args=[
137+
argument.Argument(
138+
name="self",
139+
type=generateCppArrayType(config),
140+
const=False,
141+
ref=True
142+
),
143+
argument.Argument(
144+
name="index",
145+
type="int64_t"
146+
),
147+
argument.Argument(
148+
name="other",
149+
type=generateCppArrayViewType(config),
150+
const=True,
151+
ref=True
152+
)
153+
],
154+
op="""
155+
self[index] = other;
156+
return self;
157+
"""
158+
),
159+
160+
# Set item (Array)
161+
function.Function(
162+
name="__setitem__",
163+
args=[
164+
argument.Argument(
165+
name="self",
166+
type=generateCppArrayType(config),
167+
const=False,
168+
ref=True
169+
),
170+
argument.Argument(
171+
name="index",
172+
type="int64_t"
173+
),
174+
argument.Argument(
175+
name="other",
176+
type=generateCppArrayType(config),
177+
const=True,
178+
ref=True
179+
)
180+
],
181+
op="""
182+
self[index] = other;
183+
return self;
184+
"""
185+
),
186+
187+
# Set item (Scalar)
188+
function.Function(
189+
name="__setitem__",
190+
args=[
191+
argument.Argument(
192+
name="self",
193+
type=generateCppArrayType(config),
194+
const=False,
195+
ref=True
196+
),
197+
argument.Argument(
198+
name="index",
199+
type="int64_t"
200+
),
201+
argument.Argument(
202+
name="other",
203+
type=config["scalar"],
204+
const=True,
205+
ref=True
206+
)
207+
],
208+
op="""
209+
self[index] = other;
210+
return self;
211+
"""
212+
),
213+
114214
# Addition
115215
function.Function(
116216
name="__add__",
@@ -133,6 +233,72 @@ def generateFunctionsForArray(config):
133233
"""
134234
),
135235

236+
# Subtraction
237+
function.Function(
238+
name="__sub__",
239+
args=[
240+
argument.Argument(
241+
name="self",
242+
type=generateCppArrayType(config),
243+
const=True,
244+
ref=True
245+
),
246+
argument.Argument(
247+
name="other",
248+
type=generateCppArrayType(config),
249+
const=True,
250+
ref=True
251+
)
252+
],
253+
op="""
254+
return (self - other).eval();
255+
"""
256+
),
257+
258+
# Multiplication
259+
function.Function(
260+
name="__mul__",
261+
args=[
262+
argument.Argument(
263+
name="self",
264+
type=generateCppArrayType(config),
265+
const=True,
266+
ref=True
267+
),
268+
argument.Argument(
269+
name="other",
270+
type=generateCppArrayType(config),
271+
const=True,
272+
ref=True
273+
)
274+
],
275+
op="""
276+
return (self * other).eval();
277+
"""
278+
),
279+
280+
# Division
281+
function.Function(
282+
name="__div__",
283+
args=[
284+
argument.Argument(
285+
name="self",
286+
type=generateCppArrayType(config),
287+
const=True,
288+
ref=True
289+
),
290+
argument.Argument(
291+
name="other",
292+
type=generateCppArrayType(config),
293+
const=True,
294+
ref=True
295+
)
296+
],
297+
op="""
298+
return (self / other).eval();
299+
"""
300+
),
301+
136302
# String representation
137303
function.Function(
138304
name="__str__",
@@ -161,7 +327,15 @@ def generateFunctionsForArray(config):
161327
)
162328
],
163329
op=f"""
164-
return fmt::format("<librapid.{config['name']} ~ {{}}>", self.shape());
330+
std::string thisStr = fmt::format("{{}}", self);
331+
std::string padded;
332+
for (const auto &c : thisStr) {{
333+
padded += c;
334+
if (c == '\\n') {{
335+
padded += std::string(16, ' ');
336+
}}
337+
}}
338+
return fmt::format("<librapid.Array {{}} dtype={config['scalar']} backend={config['backend']}>", padded);
165339
"""
166340
),
167341

0 commit comments

Comments
 (0)