Skip to content

Commit

Permalink
Updated the docstring, generalised the code for every data type, wrot…
Browse files Browse the repository at this point in the history
…e tests for the function
  • Loading branch information
Ishaanj18 committed Oct 3, 2023
1 parent 0630abe commit 679866f
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 22 deletions.
40 changes: 40 additions & 0 deletions tests/test_get_keepbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import xarray as xr

import xbitinfo as xb
from xbitinfo.xbitinfo import get_keepbits


@pytest.fixture
Expand All @@ -29,3 +30,42 @@ def test_get_keepbits_inflevel_dim(rasm_info_per_bit, inflevel):
if isinstance(inflevel, (int, float)):
inflevel = [inflevel]
assert (keepbits.inflevel == inflevel).all()


def test_get_keepbits_informationFilter():
ds = xr.tutorial.load_dataset("air_temperature")
info = xb.get_bitinformation(ds, dim="lat")
var = info["air"]
for i in range(var.size):
if i >= 19 and i <= 24:
var[i] = 0.05
keepbits_dataset = get_keepbits(
info,
inflevel=[0.90],
information_filter="On",
**{"threshold": 0.7, "tolerance": 0.001}
)
keepbits = keepbits_dataset["air"].values
assert keepbits == 5


def test_get_keepbits_informationFilter_1():
ds = xr.tutorial.load_dataset("air_temperature")
info = xb.get_bitinformation(ds, dim="lat")
keepbitsOff_dataset = get_keepbits(
info,
inflevel=[0.99],
information_filter="Off",
**{"threshold": 0.7, "tolerance": 0.001}
)
keepbits_Off = keepbitsOff_dataset["air"].values

keepbitsOn_dataset = get_keepbits(
info,
inflevel=[0.99],
information_filter="On",
**{"threshold": 0.7, "tolerance": 0.001}
)
keepbits_On = keepbitsOn_dataset["air"].values

assert keepbits_Off == keepbits_On
111 changes: 89 additions & 22 deletions xbitinfo/xbitinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,43 +376,106 @@ def load_bitinformation(label):
raise FileNotFoundError(f"No bitinformation could be found at {label+'.json'}")


def get_trueKeepbits(info_per_bit, bitdim, threshold, tolerance, bit_vars):
def get_cdf_without_artificial_information(
info_per_bit, bitdim, threshold, tolerance, bit_vars
):
"""
Calculate a Cumulative Distribution Function (CDF) with artificial information removal.
This function calculates a modified CDF for a given set of bit information and variable dimensions,
removing artificial information while preserving the desired threshold of information content.
Parameters:
-----------
info_per_bit : :py:class: 'xarray.Dataset'
Information content of each bit. This is the output from :py:func:`xbitinfo.xbitinfo.get_bitinformation`.
bitdim : str
The dimension representing the bit information.
threshold : float
Determines the percentage of total information above which the keepbits should lie.
tolerance : float
The tolerance is the value below which gradient starts becoming constant
bit_vars : list
List of variable names of the dataset.
Returns:
--------
xarray.Dataset
A modified CDF dataset with artificial information removed.
Example:
--------
>>> ds = xr.tutorial.load_dataset("air_temperature")
>>> info_per_bit = xb.get_bitinformation(ds)
>>> get_keepbits(
... info,
... inflevel=[0.99],
... information_filter="On",
... **{"threshold": 0.7, "tolerance": 0.001}
... )
>>> get_cdf_without_artificial_information(
... info_per_bit, bitdim, threshold, tolerance, bit_vars
... )
<xarray.Dataset>
Dimensions: (dim: 3, bit32: 32)
Coordinates:
* dim (dim) <U4 'lat' 'lon' 'time'
Dimensions without coordinates: bit32
Data variables:
air (dim, bit32) float64 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0
"""
coordinates = info_per_bit.coords
coordinates_array = coordinates["dim"].values
flag_scalar_value = False
if coordinates_array.ndim == 0:
value = coordinates_array.item()
flag_scalar_value = True
coordinates_array = np.array([value])

cdf = _cdf_from_info_per_bit(info_per_bit, bitdim)
dimensions = info_per_bit.dims
dim_size = dimensions["dim"]
for var_name in bit_vars:
for i in range(dim_size):
infoArray = info_per_bit[var_name].isel(dim=i)
for dimension in coordinates_array:
if flag_scalar_value:
infoArray = info_per_bit[var_name]
else:
infoArray = info_per_bit[var_name].sel(dim=dimension)
# total sum of information along a dimension
infSum = sum(infoArray).item()

# sum of first nine bits
first_Ninebits_sum = sum(infoArray[:9]).item()
if int(bitdim[3:]) == 16:
sign_and_exponent = 6

cdf_array = cdf[var_name].isel(dim=i)
gradient_array = np.diff(cdf_array.values)
if int(bitdim[3:]) == 32:
sign_and_exponent = 9

for i in range(9, len(gradient_array) - 1):
first_Ninebits_sum = first_Ninebits_sum + infoArray[i].item()
if int(bitdim[3:]) == 64:
sign_and_exponent = 12

# sum of sign and exponent bits
SignExpSum = sum(infoArray[:sign_and_exponent]).item()
if flag_scalar_value:
cdf_array = cdf[var_name]
else:
cdf_array = cdf[var_name].sel(dim=dimension)

gradient_array = np.diff(cdf_array.values)
CurrentBit_Sum = SignExpSum
for i in range(sign_and_exponent, len(gradient_array) - 1):
CurrentBit_Sum = CurrentBit_Sum + infoArray[i].item()
if (
gradient_array[i]
) < tolerance and first_Ninebits_sum >= threshold * infSum:
) < tolerance and CurrentBit_Sum >= threshold * infSum:
infbits = i
break

for i in range(infbits + 1, len(cdf_array) - 1):
for i in range(infbits + 1, len(cdf_array)):
cdf_array[i] = 0

return cdf


def get_keepbits(
info_per_bit,
inflevel=0.99,
information_filter="Off",
threshold=0.7,
tolerance=0.001,
):
def get_keepbits(info_per_bit, inflevel=0.99, information_filter="Off", **kwargs):
"""Get the number of mantissa bits to keep. To be used in :py:func:`xbitinfo.bitround.xr_bitround` and :py:func:`xbitinfo.bitround.jl_bitround`.
Parameters
Expand Down Expand Up @@ -476,8 +539,12 @@ def get_keepbits(
bit_vars = [v for v in info_per_bit.data_vars if bitdim in info_per_bit[v].dims]
if bit_vars != []:
if information_filter == "On":
cdf = get_trueKeepbits(
info_per_bit[bit_vars], bitdim, threshold, tolerance, bit_vars
cdf = get_cdf_without_artificial_information(
info_per_bit[bit_vars],
bitdim,
kwargs["threshold"],
kwargs["tolerance"],
bit_vars,
)
else:
cdf = _cdf_from_info_per_bit(info_per_bit[bit_vars], bitdim)
Expand Down

0 comments on commit 679866f

Please sign in to comment.