@@ -28,7 +28,7 @@ def generateCppArrayType(config):
28
28
29
29
30
30
def generateCppArrayViewType (config ):
31
- return f"lrc::array::GeneralArrayView<{ generateCppArrayType (config )} >"
31
+ return f"lrc::array::GeneralArrayView<{ generateCppArrayType (config )} &, lrc::Shape >"
32
32
33
33
34
34
def generateFunctionsForArray (config ):
@@ -64,7 +64,6 @@ def generateFunctionsForArray(config):
64
64
"""
65
65
)
66
66
)
67
-
68
67
methods += [
69
68
# Shape
70
69
function .Function (
@@ -111,6 +110,107 @@ def generateFunctionsForArray(config):
111
110
]
112
111
),
113
112
113
+ # Get item
114
+ function .Function (
115
+ name = "__getitem__" ,
116
+ args = [
117
+ argument .Argument (
118
+ name = "self" ,
119
+ type = generateCppArrayType (config ),
120
+ const = False ,
121
+ ref = True
122
+ ),
123
+ argument .Argument (
124
+ name = "index" ,
125
+ type = "int64_t"
126
+ )
127
+ ],
128
+ op = """
129
+ return self[index];
130
+ """
131
+ ),
132
+
133
+ # Set item (GeneralArrayView)
134
+ function .Function (
135
+ name = "__setitem__" ,
136
+ args = [
137
+ argument .Argument (
138
+ name = "self" ,
139
+ type = generateCppArrayType (config ),
140
+ const = False ,
141
+ ref = True
142
+ ),
143
+ argument .Argument (
144
+ name = "index" ,
145
+ type = "int64_t"
146
+ ),
147
+ argument .Argument (
148
+ name = "other" ,
149
+ type = generateCppArrayViewType (config ),
150
+ const = True ,
151
+ ref = True
152
+ )
153
+ ],
154
+ op = """
155
+ self[index] = other;
156
+ return self;
157
+ """
158
+ ),
159
+
160
+ # Set item (Array)
161
+ function .Function (
162
+ name = "__setitem__" ,
163
+ args = [
164
+ argument .Argument (
165
+ name = "self" ,
166
+ type = generateCppArrayType (config ),
167
+ const = False ,
168
+ ref = True
169
+ ),
170
+ argument .Argument (
171
+ name = "index" ,
172
+ type = "int64_t"
173
+ ),
174
+ argument .Argument (
175
+ name = "other" ,
176
+ type = generateCppArrayType (config ),
177
+ const = True ,
178
+ ref = True
179
+ )
180
+ ],
181
+ op = """
182
+ self[index] = other;
183
+ return self;
184
+ """
185
+ ),
186
+
187
+ # Set item (Scalar)
188
+ function .Function (
189
+ name = "__setitem__" ,
190
+ args = [
191
+ argument .Argument (
192
+ name = "self" ,
193
+ type = generateCppArrayType (config ),
194
+ const = False ,
195
+ ref = True
196
+ ),
197
+ argument .Argument (
198
+ name = "index" ,
199
+ type = "int64_t"
200
+ ),
201
+ argument .Argument (
202
+ name = "other" ,
203
+ type = config ["scalar" ],
204
+ const = True ,
205
+ ref = True
206
+ )
207
+ ],
208
+ op = """
209
+ self[index] = other;
210
+ return self;
211
+ """
212
+ ),
213
+
114
214
# Addition
115
215
function .Function (
116
216
name = "__add__" ,
@@ -133,6 +233,72 @@ def generateFunctionsForArray(config):
133
233
"""
134
234
),
135
235
236
+ # Subtraction
237
+ function .Function (
238
+ name = "__sub__" ,
239
+ args = [
240
+ argument .Argument (
241
+ name = "self" ,
242
+ type = generateCppArrayType (config ),
243
+ const = True ,
244
+ ref = True
245
+ ),
246
+ argument .Argument (
247
+ name = "other" ,
248
+ type = generateCppArrayType (config ),
249
+ const = True ,
250
+ ref = True
251
+ )
252
+ ],
253
+ op = """
254
+ return (self - other).eval();
255
+ """
256
+ ),
257
+
258
+ # Multiplication
259
+ function .Function (
260
+ name = "__mul__" ,
261
+ args = [
262
+ argument .Argument (
263
+ name = "self" ,
264
+ type = generateCppArrayType (config ),
265
+ const = True ,
266
+ ref = True
267
+ ),
268
+ argument .Argument (
269
+ name = "other" ,
270
+ type = generateCppArrayType (config ),
271
+ const = True ,
272
+ ref = True
273
+ )
274
+ ],
275
+ op = """
276
+ return (self * other).eval();
277
+ """
278
+ ),
279
+
280
+ # Division
281
+ function .Function (
282
+ name = "__div__" ,
283
+ args = [
284
+ argument .Argument (
285
+ name = "self" ,
286
+ type = generateCppArrayType (config ),
287
+ const = True ,
288
+ ref = True
289
+ ),
290
+ argument .Argument (
291
+ name = "other" ,
292
+ type = generateCppArrayType (config ),
293
+ const = True ,
294
+ ref = True
295
+ )
296
+ ],
297
+ op = """
298
+ return (self / other).eval();
299
+ """
300
+ ),
301
+
136
302
# String representation
137
303
function .Function (
138
304
name = "__str__" ,
@@ -161,7 +327,15 @@ def generateFunctionsForArray(config):
161
327
)
162
328
],
163
329
op = f"""
164
- return fmt::format("<librapid.{ config ['name' ]} ~ {{}}>", self.shape());
330
+ std::string thisStr = fmt::format("{{}}", self);
331
+ std::string padded;
332
+ for (const auto &c : thisStr) {{
333
+ padded += c;
334
+ if (c == '\\ n') {{
335
+ padded += std::string(16, ' ');
336
+ }}
337
+ }}
338
+ return fmt::format("<librapid.Array {{}} dtype={ config ['scalar' ]} backend={ config ['backend' ]} >", padded);
165
339
"""
166
340
),
167
341
0 commit comments