Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dumb but fast writer #120

Merged
merged 1 commit into from
Nov 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 50 additions & 64 deletions src/structs/vpx_bool_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@ use crate::metrics::{Metrics, ModelComponent};
use crate::structs::branch::Branch;
use crate::structs::simple_hash::SimpleHash;

// MAX_STREAM_BITS should be a multiple of 8 larger than 8,
// and (MAX_STREAM_BITS + 1 bit of carry + 1 bit of divider)
// should fit into 64 bits of `low_value`
const MAX_STREAM_BITS: i32 = 56; //48; //40;// 32;// 24;// 16;//

pub struct VPXBoolWriter<W> {
low_value: u64,
range: u32,
Expand Down Expand Up @@ -68,23 +63,23 @@ impl<W: Write> VPXBoolWriter<W> {
#[inline(always)]
pub fn put(
&mut self,
value: bool,
bit: bool,
branch: &mut Branch,
tmp_value: &mut u64,
tmp_range: &mut u32,
mut tmp_value: u64,
mut tmp_range: u32,
_cmp: ModelComponent,
) -> Result<()> {
) -> (u64, u32) {
#[cfg(feature = "detailed_tracing")]
{
// used to detect divergences between the C++ and rust versions
self.hash.hash(branch.get_u64());
self.hash.hash(*tmp_value);
self.hash.hash(*tmp_range);
self.hash.hash(tmp_value);
self.hash.hash(tmp_range);

let hashed_value = self.hash.get();
//if hashedValue == 0xe35c28fd
{
print!("({0}:{1:x})", value as u8, hashed_value);
print!("({0}:{1:x})", bit as u8, hashed_value);
if hashed_value % 8 == 0 {
println!();
}
Expand All @@ -93,52 +88,48 @@ impl<W: Write> VPXBoolWriter<W> {

let probability = branch.get_probability() as u32;

let split = 1 + (((*tmp_range - 1) * probability) >> 8);
let split = 1 + (((tmp_range - 1) * probability) >> 8);

let mut shift;
branch.record_and_update_bit(value);
branch.record_and_update_bit(bit);

if value {
*tmp_value += split as u64;
*tmp_range -= split;
if bit {
tmp_value += split as u64;
tmp_range -= split;
} else {
*tmp_range = split;
tmp_range = split;
}

shift = (*tmp_range as u8).leading_zeros() as i32;
let shift = (tmp_range as u8).leading_zeros();

#[cfg(feature = "compression_stats")]
{
self.model_statistics
.record_compression_stats(_cmp, 1, i64::from(shift));
}

*tmp_range <<= shift;
tmp_range <<= shift;
tmp_value <<= shift;

// check whether we have more than MAX_STREAM_BITS stream bits after shift
let stream_bits = 64 - (*tmp_value).leading_zeros() as i32 - 2;
let count = shift + stream_bits - MAX_STREAM_BITS;
if count >= 0 {
// check carry
*tmp_value <<= MAX_STREAM_BITS - stream_bits;
if (*tmp_value & (1 << MAX_STREAM_BITS)) != 0 {
// check whether we cannot put next bit into stream
if tmp_value & (u64::MAX << 57) != 0 {
let mut stream_bits = 64 - tmp_value.leading_zeros() - 2;
// 62 >= stream_bits >= 56

if tmp_value & (1 << stream_bits) != 0 {
self.carry();
}
// write all full bytes
let mut sh = MAX_STREAM_BITS - 8;
while sh > 0 {
self.buffer.push((*tmp_value >> sh) as u8);
sh -= 8;

for _stream_bytes in 0..6 {
stream_bits -= 8;
self.buffer.push((tmp_value >> stream_bits) as u8);
}
*tmp_value &= (1 << 8) - 1; // exclude written bytes
*tmp_value |= 1 << 9; // restore divider bit

shift = count;
tmp_value &= (1 << stream_bits) - 1;
tmp_value |= 1 << (stream_bits + 1);
// 14 >= stream_bits >= 8
}

*tmp_value <<= shift;

Ok(())
// 55 >= stream_bits >= 8
(tmp_value, tmp_range)
}

/// Safe as: at the stream beginning initially put `false` ensure that carry cannot get out
Expand Down Expand Up @@ -176,13 +167,13 @@ impl<W: Write> VPXBoolWriter<W> {

loop {
let cur_bit = (v & (1 << index)) != 0;
self.put(
(tmp_value, tmp_range) = self.put(
cur_bit,
&mut branches[serialized_so_far],
&mut tmp_value,
&mut tmp_range,
tmp_value,
tmp_range,
cmp,
)?;
);

if index == 0 {
break;
Expand Down Expand Up @@ -213,13 +204,13 @@ impl<W: Write> VPXBoolWriter<W> {

let mut i: i32 = (num_bits - 1) as i32;
while i >= 0 {
self.put(
(tmp_value, tmp_range) = self.put(
(bits & (1 << i)) != 0,
&mut branches[i as usize],
&mut tmp_value,
&mut tmp_range,
tmp_value,
tmp_range,
cmp,
)?;
);
i -= 1;
}

Expand All @@ -244,13 +235,7 @@ impl<W: Write> VPXBoolWriter<W> {
for i in 0..A {
let cur_bit = v != i;

self.put(
cur_bit,
&mut branches[i],
&mut tmp_value,
&mut tmp_range,
cmp,
)?;
(tmp_value, tmp_range) = self.put(cur_bit, &mut branches[i], tmp_value, tmp_range, cmp);
if !cur_bit {
break;
}
Expand All @@ -272,7 +257,7 @@ impl<W: Write> VPXBoolWriter<W> {
let mut tmp_value = self.low_value;
let mut tmp_range = self.range;

self.put(value, branch, &mut tmp_value, &mut tmp_range, _cmp)?;
(tmp_value, tmp_range) = self.put(value, branch, tmp_value, tmp_range, _cmp);

self.low_value = tmp_value;
self.range = tmp_range;
Expand All @@ -284,20 +269,21 @@ impl<W: Write> VPXBoolWriter<W> {
// opposite to initial Lepton implementation that writes down all the buffer.
pub fn finish(&mut self) -> Result<()> {
let mut tmp_value = self.low_value;
let stream_bits = 64 - tmp_value.leading_zeros() as i32 - 2;
let stream_bits = 64 - tmp_value.leading_zeros() - 2;
// 55 >= stream_bits >= 8

tmp_value <<= MAX_STREAM_BITS - stream_bits;
if (tmp_value & (1 << MAX_STREAM_BITS)) != 0 {
tmp_value <<= 63 - stream_bits;
if tmp_value & (1 << 63) != 0 {
self.carry();
}

let mut shift = MAX_STREAM_BITS - 8;
let mut stream_bytes = (stream_bits + 7) >> 3;
while stream_bytes > 0 {
self.buffer.push((tmp_value >> shift) as u8);
let mut shift = 63;
for _stream_bytes in 0..(stream_bits + 7) >> 3 {
shift -= 8;
stream_bytes -= 1;
self.buffer.push((tmp_value >> shift) as u8);
}
// check that no stream bits remain in the buffer
debug_assert!(!(u64::MAX << shift) & tmp_value == 0);

self.writer.write_all(&self.buffer[..])?;
Ok(())
Expand Down
Loading