Skip to content

Commit

Permalink
Dumb but fast writer (#120)
Browse files Browse the repository at this point in the history
The same performance as "Bool writer like reader #118"
  • Loading branch information
Melirius authored Nov 25, 2024
1 parent 06b5d7b commit 2545e7e
Showing 1 changed file with 50 additions and 64 deletions.
114 changes: 50 additions & 64 deletions src/structs/vpx_bool_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@ use crate::metrics::{Metrics, ModelComponent};
use crate::structs::branch::Branch;
use crate::structs::simple_hash::SimpleHash;

// MAX_STREAM_BITS should be a multiple of 8 larger than 8,
// and (MAX_STREAM_BITS + 1 bit of carry + 1 bit of divider)
// should fit into 64 bits of `low_value`
const MAX_STREAM_BITS: i32 = 56; //48; //40;// 32;// 24;// 16;//

pub struct VPXBoolWriter<W> {
low_value: u64,
range: u32,
Expand Down Expand Up @@ -68,23 +63,23 @@ impl<W: Write> VPXBoolWriter<W> {
#[inline(always)]
pub fn put(
&mut self,
value: bool,
bit: bool,
branch: &mut Branch,
tmp_value: &mut u64,
tmp_range: &mut u32,
mut tmp_value: u64,
mut tmp_range: u32,
_cmp: ModelComponent,
) -> Result<()> {
) -> (u64, u32) {
#[cfg(feature = "detailed_tracing")]
{
// used to detect divergences between the C++ and rust versions
self.hash.hash(branch.get_u64());
self.hash.hash(*tmp_value);
self.hash.hash(*tmp_range);
self.hash.hash(tmp_value);
self.hash.hash(tmp_range);

let hashed_value = self.hash.get();
//if hashedValue == 0xe35c28fd
{
print!("({0}:{1:x})", value as u8, hashed_value);
print!("({0}:{1:x})", bit as u8, hashed_value);
if hashed_value % 8 == 0 {
println!();
}
Expand All @@ -93,52 +88,48 @@ impl<W: Write> VPXBoolWriter<W> {

let probability = branch.get_probability() as u32;

let split = 1 + (((*tmp_range - 1) * probability) >> 8);
let split = 1 + (((tmp_range - 1) * probability) >> 8);

let mut shift;
branch.record_and_update_bit(value);
branch.record_and_update_bit(bit);

if value {
*tmp_value += split as u64;
*tmp_range -= split;
if bit {
tmp_value += split as u64;
tmp_range -= split;
} else {
*tmp_range = split;
tmp_range = split;
}

shift = (*tmp_range as u8).leading_zeros() as i32;
let shift = (tmp_range as u8).leading_zeros();

#[cfg(feature = "compression_stats")]
{
self.model_statistics
.record_compression_stats(_cmp, 1, i64::from(shift));
}

*tmp_range <<= shift;
tmp_range <<= shift;
tmp_value <<= shift;

// check whether we have more than MAX_STREAM_BITS stream bits after shift
let stream_bits = 64 - (*tmp_value).leading_zeros() as i32 - 2;
let count = shift + stream_bits - MAX_STREAM_BITS;
if count >= 0 {
// check carry
*tmp_value <<= MAX_STREAM_BITS - stream_bits;
if (*tmp_value & (1 << MAX_STREAM_BITS)) != 0 {
// check whether we cannot put next bit into stream
if tmp_value & (u64::MAX << 57) != 0 {
let mut stream_bits = 64 - tmp_value.leading_zeros() - 2;
// 62 >= stream_bits >= 56

if tmp_value & (1 << stream_bits) != 0 {
self.carry();
}
// write all full bytes
let mut sh = MAX_STREAM_BITS - 8;
while sh > 0 {
self.buffer.push((*tmp_value >> sh) as u8);
sh -= 8;

for _stream_bytes in 0..6 {
stream_bits -= 8;
self.buffer.push((tmp_value >> stream_bits) as u8);
}
*tmp_value &= (1 << 8) - 1; // exclude written bytes
*tmp_value |= 1 << 9; // restore divider bit

shift = count;
tmp_value &= (1 << stream_bits) - 1;
tmp_value |= 1 << (stream_bits + 1);
// 14 >= stream_bits >= 8
}

*tmp_value <<= shift;

Ok(())
// 55 >= stream_bits >= 8
(tmp_value, tmp_range)
}

/// Safe as: at the stream beginning initially put `false` ensure that carry cannot get out
Expand Down Expand Up @@ -176,13 +167,13 @@ impl<W: Write> VPXBoolWriter<W> {

loop {
let cur_bit = (v & (1 << index)) != 0;
self.put(
(tmp_value, tmp_range) = self.put(
cur_bit,
&mut branches[serialized_so_far],
&mut tmp_value,
&mut tmp_range,
tmp_value,
tmp_range,
cmp,
)?;
);

if index == 0 {
break;
Expand Down Expand Up @@ -213,13 +204,13 @@ impl<W: Write> VPXBoolWriter<W> {

let mut i: i32 = (num_bits - 1) as i32;
while i >= 0 {
self.put(
(tmp_value, tmp_range) = self.put(
(bits & (1 << i)) != 0,
&mut branches[i as usize],
&mut tmp_value,
&mut tmp_range,
tmp_value,
tmp_range,
cmp,
)?;
);
i -= 1;
}

Expand All @@ -244,13 +235,7 @@ impl<W: Write> VPXBoolWriter<W> {
for i in 0..A {
let cur_bit = v != i;

self.put(
cur_bit,
&mut branches[i],
&mut tmp_value,
&mut tmp_range,
cmp,
)?;
(tmp_value, tmp_range) = self.put(cur_bit, &mut branches[i], tmp_value, tmp_range, cmp);
if !cur_bit {
break;
}
Expand All @@ -272,7 +257,7 @@ impl<W: Write> VPXBoolWriter<W> {
let mut tmp_value = self.low_value;
let mut tmp_range = self.range;

self.put(value, branch, &mut tmp_value, &mut tmp_range, _cmp)?;
(tmp_value, tmp_range) = self.put(value, branch, tmp_value, tmp_range, _cmp);

self.low_value = tmp_value;
self.range = tmp_range;
Expand All @@ -284,20 +269,21 @@ impl<W: Write> VPXBoolWriter<W> {
// opposite to initial Lepton implementation that writes down all the buffer.
pub fn finish(&mut self) -> Result<()> {
let mut tmp_value = self.low_value;
let stream_bits = 64 - tmp_value.leading_zeros() as i32 - 2;
let stream_bits = 64 - tmp_value.leading_zeros() - 2;
// 55 >= stream_bits >= 8

tmp_value <<= MAX_STREAM_BITS - stream_bits;
if (tmp_value & (1 << MAX_STREAM_BITS)) != 0 {
tmp_value <<= 63 - stream_bits;
if tmp_value & (1 << 63) != 0 {
self.carry();
}

let mut shift = MAX_STREAM_BITS - 8;
let mut stream_bytes = (stream_bits + 7) >> 3;
while stream_bytes > 0 {
self.buffer.push((tmp_value >> shift) as u8);
let mut shift = 63;
for _stream_bytes in 0..(stream_bits + 7) >> 3 {
shift -= 8;
stream_bytes -= 1;
self.buffer.push((tmp_value >> shift) as u8);
}
// check that no stream bits remain in the buffer
debug_assert!(!(u64::MAX << shift) & tmp_value == 0);

self.writer.write_all(&self.buffer[..])?;
Ok(())
Expand Down

0 comments on commit 2545e7e

Please sign in to comment.