forked from rygorous/ryg_rans
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rans64.h
318 lines (267 loc) · 10.2 KB
/
rans64.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
// 64-bit rANS encoder/decoder - public domain - Fabian 'ryg' Giesen 2014
//
// This uses 64-bit states (63-bit actually) which allows renormalizing
// by writing out a whole 32 bits at a time (b=2^32) while still
// retaining good precision and allowing for high probability resolution.
//
// The only caveat is that this version requires 64-bit arithmetic; in
// particular, the encoder approximation in the bottom half requires a
// fast way to obtain the top 64 bits of an unsigned 64*64 bit product.
//
// In short, as written, this code works on 64-bit targets only!
#ifndef RANS64_HEADER
#define RANS64_HEADER
#include <stdint.h>
#ifdef assert
#define Rans64Assert assert
#else
#define Rans64Assert(x)
#endif
// --------------------------------------------------------------------------
// This code needs support for 64-bit long multiplies with 128-bit result
// (or more precisely, the top 64 bits of a 128-bit result). This is not
// really portable functionality, so we need some compiler-specific hacks
// here.
#if defined(_MSC_VER)
#include <intrin.h>
static inline uint64_t Rans64MulHi(uint64_t a, uint64_t b)
{
return __umulh(a, b);
}
#elif defined(__GNUC__)
static inline uint64_t Rans64MulHi(uint64_t a, uint64_t b)
{
return (uint64_t) (((unsigned __int128)a * b) >> 64);
}
#else
#error Unknown/unsupported compiler!
#endif
// --------------------------------------------------------------------------
// L ('l' in the paper) is the lower bound of our normalization interval.
// Between this and our 32-bit-aligned emission, we use 63 (not 64!) bits.
// This is done intentionally because exact reciprocals for 63-bit uints
// fit in 64-bit uints: this permits some optimizations during encoding.
#define RANS64_L (1ull << 31) // lower bound of our normalization interval
// State for a rANS encoder. Yep, that's all there is to it.
typedef uint64_t Rans64State;
// Initialize a rANS encoder.
static inline void Rans64EncInit(Rans64State* r)
{
*r = RANS64_L;
}
// Encodes a single symbol with range start "start" and frequency "freq".
// All frequencies are assumed to sum to "1 << scale_bits", and the
// resulting bytes get written to ptr (which is updated).
//
// NOTE: With rANS, you need to encode symbols in *reverse order*, i.e. from
// beginning to end! Likewise, the output bytestream is written *backwards*:
// ptr starts pointing at the end of the output buffer and keeps decrementing.
static inline void Rans64EncPut(Rans64State* r, uint32_t** pptr, uint32_t start, uint32_t freq, uint32_t scale_bits)
{
Rans64Assert(freq != 0);
// renormalize (never needs to loop)
uint64_t x = *r;
uint64_t x_max = ((RANS64_L >> scale_bits) << 32) * freq; // this turns into a shift.
if (x >= x_max) {
*pptr -= 1;
**pptr = (uint32_t) x;
x >>= 32;
Rans64Assert(x < x_max);
}
// x = C(s,x)
*r = ((x / freq) << scale_bits) + (x % freq) + start;
}
// Flushes the rANS encoder.
static inline void Rans64EncFlush(Rans64State* r, uint32_t** pptr)
{
uint64_t x = *r;
*pptr -= 2;
(*pptr)[0] = (uint32_t) (x >> 0);
(*pptr)[1] = (uint32_t) (x >> 32);
}
// Initializes a rANS decoder.
// Unlike the encoder, the decoder works forwards as you'd expect.
static inline void Rans64DecInit(Rans64State* r, uint32_t** pptr)
{
uint64_t x;
x = (uint64_t) ((*pptr)[0]) << 0;
x |= (uint64_t) ((*pptr)[1]) << 32;
*pptr += 2;
*r = x;
}
// Returns the current cumulative frequency (map it to a symbol yourself!)
static inline uint32_t Rans64DecGet(Rans64State* r, uint32_t scale_bits)
{
return *r & ((1u << scale_bits) - 1);
}
// Advances in the bit stream by "popping" a single symbol with range start
// "start" and frequency "freq". All frequencies are assumed to sum to "1 << scale_bits",
// and the resulting bytes get written to ptr (which is updated).
static inline void Rans64DecAdvance(Rans64State* r, uint32_t** pptr, uint32_t start, uint32_t freq, uint32_t scale_bits)
{
uint64_t mask = (1ull << scale_bits) - 1;
// s, x = D(x)
uint64_t x = *r;
x = freq * (x >> scale_bits) + (x & mask) - start;
// renormalize
if (x < RANS64_L) {
x = (x << 32) | **pptr;
*pptr += 1;
Rans64Assert(x >= RANS64_L);
}
*r = x;
}
// --------------------------------------------------------------------------
// That's all you need for a full encoder; below here are some utility
// functions with extra convenience or optimizations.
// Encoder symbol description
// This (admittedly odd) selection of parameters was chosen to make
// RansEncPutSymbol as cheap as possible.
typedef struct {
uint64_t rcp_freq; // Fixed-point reciprocal frequency
uint32_t freq; // Symbol frequency
uint32_t bias; // Bias
uint32_t cmpl_freq; // Complement of frequency: (1 << scale_bits) - freq
uint32_t rcp_shift; // Reciprocal shift
} Rans64EncSymbol;
// Decoder symbols are straightforward.
typedef struct {
uint32_t start; // Start of range.
uint32_t freq; // Symbol frequency.
} Rans64DecSymbol;
// Initializes an encoder symbol to start "start" and frequency "freq"
static inline void Rans64EncSymbolInit(Rans64EncSymbol* s, uint32_t start, uint32_t freq, uint32_t scale_bits)
{
Rans64Assert(scale_bits <= 31);
Rans64Assert(start <= (1u << scale_bits));
Rans64Assert(freq <= (1u << scale_bits) - start);
// Say M := 1 << scale_bits.
//
// The original encoder does:
// x_new = (x/freq)*M + start + (x%freq)
//
// The fast encoder does (schematically):
// q = mul_hi(x, rcp_freq) >> rcp_shift (division)
// r = x - q*freq (remainder)
// x_new = q*M + bias + r (new x)
// plugging in r into x_new yields:
// x_new = bias + x + q*(M - freq)
// =: bias + x + q*cmpl_freq (*)
//
// and we can just precompute cmpl_freq. Now we just need to
// set up our parameters such that the original encoder and
// the fast encoder agree.
s->freq = freq;
s->cmpl_freq = ((1 << scale_bits) - freq);
if (freq < 2) {
// freq=0 symbols are never valid to encode, so it doesn't matter what
// we set our values to.
//
// freq=1 is tricky, since the reciprocal of 1 is 1; unfortunately,
// our fixed-point reciprocal approximation can only multiply by values
// smaller than 1.
//
// So we use the "next best thing": rcp_freq=~0, rcp_shift=0.
// This gives:
// q = mul_hi(x, rcp_freq) >> rcp_shift
// = mul_hi(x, (1<<64) - 1)) >> 0
// = floor(x - x/(2^64))
// = x - 1 if 1 <= x < 2^64
// and we know that x>0 (x=0 is never in a valid normalization interval).
//
// So we now need to choose the other parameters such that
// x_new = x*M + start
// plug it in:
// x*M + start (desired result)
// = bias + x + q*cmpl_freq (*)
// = bias + x + (x - 1)*(M - 1) (plug in q=x-1, cmpl_freq)
// = bias + 1 + (x - 1)*M
// = x*M + (bias + 1 - M)
//
// so we have start = bias + 1 - M, or equivalently
// bias = start + M - 1.
s->rcp_freq = ~0ull;
s->rcp_shift = 0;
s->bias = start + (1 << scale_bits) - 1;
} else {
// Alverson, "Integer Division using reciprocals"
// shift=ceil(log2(freq))
uint32_t shift = 0;
uint64_t x0, x1, t0, t1;
while (freq > (1u << shift))
shift++;
// long divide ((uint128) (1 << (shift + 63)) + freq-1) / freq
// by splitting it into two 64:64 bit divides (this works because
// the dividend has a simple form.)
x0 = freq - 1;
x1 = 1ull << (shift + 31);
t1 = x1 / freq;
x0 += (x1 % freq) << 32;
t0 = x0 / freq;
s->rcp_freq = t0 + (t1 << 32);
s->rcp_shift = shift - 1;
// With these values, 'q' is the correct quotient, so we
// have bias=start.
s->bias = start;
}
}
// Initialize a decoder symbol to start "start" and frequency "freq"
static inline void Rans64DecSymbolInit(Rans64DecSymbol* s, uint32_t start, uint32_t freq)
{
Rans64Assert(start <= (1 << 31));
Rans64Assert(freq <= (1 << 31) - start);
s->start = start;
s->freq = freq;
}
// Encodes a given symbol. This is faster than straight RansEnc since we can do
// multiplications instead of a divide.
//
// See RansEncSymbolInit for a description of how this works.
static inline void Rans64EncPutSymbol(Rans64State* r, uint32_t** pptr, Rans64EncSymbol const* sym, uint32_t scale_bits)
{
Rans64Assert(sym->x_max != 0); // can't encode symbol with freq=0
// renormalize
uint64_t x = *r;
uint64_t x_max = ((RANS64_L >> scale_bits) << 32) * sym->freq; // turns into a shift
if (x >= x_max) {
*pptr -= 1;
**pptr = (uint32_t) x;
x >>= 32;
}
// x = C(s,x)
uint64_t q = Rans64MulHi(x, sym->rcp_freq) >> sym->rcp_shift;
*r = x + sym->bias + q * sym->cmpl_freq;
}
// Equivalent to RansDecAdvance that takes a symbol.
static inline void Rans64DecAdvanceSymbol(Rans64State* r, uint32_t** pptr, Rans64DecSymbol const* sym, uint32_t scale_bits)
{
Rans64DecAdvance(r, pptr, sym->start, sym->freq, scale_bits);
}
// Advances in the bit stream by "popping" a single symbol with range start
// "start" and frequency "freq". All frequencies are assumed to sum to "1 << scale_bits".
// No renormalization or output happens.
static inline void Rans64DecAdvanceStep(Rans64State* r, uint32_t start, uint32_t freq, uint32_t scale_bits)
{
uint64_t mask = (1u << scale_bits) - 1;
// s, x = D(x)
uint64_t x = *r;
*r = freq * (x >> scale_bits) + (x & mask) - start;
}
// Equivalent to RansDecAdvanceStep that takes a symbol.
static inline void Rans64DecAdvanceSymbolStep(Rans64State* r, Rans64DecSymbol const* sym, uint32_t scale_bits)
{
Rans64DecAdvanceStep(r, sym->start, sym->freq, scale_bits);
}
// Renormalize.
static inline void Rans64DecRenorm(Rans64State* r, uint32_t** pptr)
{
// renormalize
uint64_t x = *r;
if (x < RANS64_L) {
x = (x << 32) | **pptr;
*pptr += 1;
Rans64Assert(x >= RANS64_L);
}
*r = x;
}
#endif // RANS64_HEADER