-
Notifications
You must be signed in to change notification settings - Fork 1
/
nfloat.t
320 lines (276 loc) · 8.36 KB
/
nfloat.t
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
-- SPDX-FileCopyrightText: 2024 René Hiemstra <[email protected]>
-- SPDX-FileCopyrightText: 2024 Torsten Keßler <[email protected]>
--
-- SPDX-License-Identifier: MIT
require "terralibext"
local uname = io.popen("uname", "r"):read("*a")
-- Wrap FLINT without inlines
local flint = terralib.includec("flint/nfloat.h", {"-DNFLOAT_INLINES_C=1"})
local gr = terralib.includec("flint/gr.h", {"-DGR_INLINES_C=1"})
if uname == "Darwin\n" then
terralib.linklibrary("libflint.dylib")
elseif uname == "Linux\n" then
terralib.linklibrary("libflint.so")
else
error("Not implemented for this OS.")
end
local base = require("base")
local mathfun = require("mathfuns")
local concepts = require("concepts")
local suffix = {64, 128, 192, 256, 384, 512, 1024, 2048, 4096}
local float_type = {}
local context = {}
for _, N in pairs(suffix) do
float_type[N] = flint[string.format("nfloat%d_struct", N)]
-- Meta information on fixed precision floats is stored in a context.
-- Mathematically, they represent rings.
-- Here, we store them as global variables in a table such that
-- each float type has exactly one context it will use.
context[N] = global(flint.gr_ctx_t)
local ctx = context[N]:get()
-- Call clean_context() to release memory allocated by nfloat_ctx_init
flint.nfloat_ctx_init(ctx, N, 0)
end
local unary_math = {
"abs",
"sqrt",
"floor",
"ceil",
"exp",
"expm1",
"log",
"log1p",
"sin",
"cos",
"tan",
"sinh",
"cosh",
"tanh",
"gamma",
}
local binary_math = {
"pow",
}
concepts.NFloat = terralib.types.newstruct("NFloat")
concepts.Base(concepts.NFloat)
concepts.NFloat.traits.precision = concepts.traittag
--extract the exponent of an nfloat
local exponent = macro(function(value)
return quote
var tmp = [&uint32](&value.data.head)
in
@tmp
end
end)
--extract the sign of an nfloat
local sign = macro(function(value)
return quote
var sign = [&uint32](&value.data.head[1])
var s = 1
if @sign % 2 == 1 then s = -1 end
in
s
end
end)
--extract the significant 64-bit part of the mantissa of an nfloat
local significant_part_mantissa = macro(function(value)
local M = value:gettype().type.traits.precision / 64
return quote
var n = [&uint64](&value.data.d[M-1])
in
@n
end
end)
--shift significat 64-bit part of mantissa
local terra shiftandscale(n : uint64, e : int)
var res = n
var k = 0
while n > 0 do
k = k + 1
n = n << 1
end
return mathfun.ldexp(double(res >> 64 - k), e-k)
end
local FixedFloat = terralib.memoize(function(N)
local ctype = float_type[N]
assert(ctype, "No support for precision " .. N .. " in FixedFloat")
local ctx = context[N]:get()
local struct nfloat {
data: ctype
}
function nfloat.metamethods.__typename()
return string.format("FixedFloat(%d)", N)
end
base.AbstractBase(nfloat)
--type traits
nfloat.traits.precision = N
local terra new()
var data: ctype
flint.nfloat_init(&data, ctx)
return nfloat {data}
end
local terra from_double(x: double)
var f = new()
flint.nfloat_set_d(&f.data, x, ctx)
return f
end
local terra from_str(s: rawstring)
var f = new()
flint.nfloat_set_str(&f.data, s, ctx)
return f
end
local from = terralib.overloadedfunction("from", {from_double, from_str})
local to_str = macro(function(x)
local digits = math.floor(N * (math.log(2) / math.log(10)))
return quote
var str: rawstring
-- TODO: Fix memory leak
-- defer flint.flint_free(str)
gr.gr_get_str_n(&str, &x.data, [digits], ctx)
in
str
end
end)
function nfloat.metamethods.__cast(from, to, exp)
if to == nfloat then
if from:isarithmetic() then
return `from_double(exp)
elseif from:ispointer() and from.type == int8 then
return `from_str(exp)
else
error("Cannot cast from " .. from .. " to " .. to)
end
end
error("Unknown type")
end
local binary = {
__add = flint.nfloat_add,
__mul = flint.nfloat_mul,
__sub = flint.nfloat_sub,
__div = flint.nfloat_div,
}
for key, method in pairs(binary) do
nfloat.metamethods[key] = terra(self: nfloat, other:nfloat)
var res = new()
[method](&res.data, &self.data, &other.data, ctx)
return res
end
end
local terra fmod(value : nfloat, modulus : nfloat)
var tmp = new()
flint.nfloat_div(&tmp.data, &value.data, &modulus.data, ctx)
flint.nfloat_floor(&tmp, &tmp, ctx)
flint.nfloat_mul(&tmp.data, &tmp.data, &modulus.data, ctx)
flint.nfloat_sub(&tmp.data, &value.data, &tmp.data, ctx)
return tmp
end
mathfun["fmod"]:adddefinition(fmod)
nfloat.metamethods.__mod = terra(self: nfloat, other: nfloat)
return fmod(self, other)
end
local unary = {
__unm = flint.nfloat_neg,
}
for key, method in pairs(unary) do
nfloat.metamethods[key] = terra(self: nfloat)
var res = new()
[method](&res.data, &self.data, ctx)
return res
end
end
local function cmp(sign)
local terra impl(self: &ctype, other: &ctype, ctx: flint.gr_ctx_t)
var res = 0
flint.nfloat_cmp(&res, self, other, ctx)
return res == sign
end
return impl
end
local boolean = {
__eq = cmp(0),
__lt = cmp(-1),
__gt = cmp(1),
}
for key, method in pairs(boolean) do
nfloat.metamethods[key] = terra(self: nfloat, other: nfloat)
return [method](&self.data, &other.data, ctx)
end
end
nfloat.metamethods.__le = terra(self: nfloat, other: nfloat)
return self < other or self == other
end
nfloat.metamethods.__ge = terra(self: nfloat, other: nfloat)
return self > other or self == other
end
local terra round(value : nfloat)
value = value + 0.5
flint.nfloat_floor(&value, &value, ctx)
return value
end
mathfun["round"]:adddefinition(round)
local terra pi()
var res = new()
flint.nfloat_pi(&res, ctx)
return res
end
terra nfloat:truncatetodouble()
var m = significant_part_mantissa(self)
var e = exponent(self)
var s = sign(self)
return s * shiftandscale(m, e)
end
for _, func in pairs(unary_math) do
local name = "nfloat_" .. func
local terra impl(x: nfloat)
var y: nfloat
flint.[name](&y.data, &x.data, ctx)
return y
end
mathfun[func]:adddefinition(impl)
end
for _, func in pairs(binary_math) do
local name = "nfloat_" .. func
local terra impl(x: nfloat, y: nfloat)
var z: nfloat
flint.[name](&z.data, &x.data, &y.data, ctx)
return z
end
mathfun[func]:adddefinition(impl)
end
mathfun.min:adddefinition(terra(x : nfloat, y : nfloat)
return terralib.select(x < y, x, y)
end)
mathfun.max:adddefinition(terra(x : nfloat, y : nfloat)
return terralib.select(x > y, x, y)
end)
mathfun.conj:adddefinition(terra(x: nfloat) return x end)
mathfun.real:adddefinition(terra(x: nfloat) return x end)
mathfun.imag:adddefinition(terra(x: nfloat) return [nfloat](0) end)
do
local terra impl(x: nfloat, y: nfloat, z: nfloat)
return x * y + z
end
mathfun.fusedmuladd:adddefinition(impl)
end
for k, v in pairs({from = from, tostr = to_str, pi = pi}) do
nfloat.staticmethods[k] = v
end
for _, C in pairs({"NFloat", "Real", "Float", "Number"}) do
concepts[C].friends[nfloat] = true
end
return nfloat
end)
local terra clean_context()
escape
for _, N in pairs(suffix) do
local val = context[N]:get()
emit quote
gr.gr_ctx_clear(val)
end
end
end
end
return {
FixedFloat = FixedFloat,
clean_context = clean_context
}