Skip to content

Commit

Permalink
Fix byte order detection for np.uint8 and other types for which byt…
Browse files Browse the repository at this point in the history
…e order is not defined. (#160)

* Fix byte order detection for `np.uint8` and other types for which byte order is not defined.

* Fixing np.uint8 endianness for `save`.

Co-authored-by: KOLANICH <[email protected]>
  • Loading branch information
Narsil and KOLANICH authored Jan 16, 2023
1 parent e8b3ab6 commit 02e5707
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
20 changes: 14 additions & 6 deletions bindings/python/py_src/safetensors/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ def save(tensor_dict: Dict[str, np.ndarray], metadata: Optional[Dict[str, str]]
```
"""
for tensor in tensor_dict.values():
byte_order = _byte_order(tensor)
if byte_order != "<":
if not _is_little_endian(tensor):
raise ValueError("Safetensor format only accepts little endian")
flattened = {k: {"dtype": v.dtype.name, "shape": v.shape, "data": v.tobytes()} for k, v in tensor_dict.items()}
serialized = serialize(flattened, metadata=metadata)
Expand Down Expand Up @@ -68,6 +67,9 @@ def save_file(tensor_dict: Dict[str, np.ndarray], filename: str, metadata: Optio
save(tensors, "model.safetensors")
```
"""
for tensor in tensor_dict.values():
if not _is_little_endian(tensor):
raise ValueError("Safetensor format only accepts little endian")
flattened = {k: {"dtype": v.dtype.name, "shape": v.shape, "data": v.tobytes()} for k, v in tensor_dict.items()}
serialize_file(flattened, filename, metadata=metadata)

Expand Down Expand Up @@ -158,11 +160,17 @@ def _view2np(safeview) -> Dict[str, np.ndarray]:
return result


def _byte_order(tensor: np.ndarray) -> str:
def _is_little_endian(tensor: np.ndarray) -> str:
byteorder = tensor.dtype.byteorder
if byteorder == "=":
if sys.byteorder == "little":
return "<"
return True
else:
return ">"
return byteorder
return False
elif byteorder == "|":
return True
elif byteorder == "<":
return True
elif byteorder == ">":
return False
raise ValueError(f"Unexpected byte order {byteorder}")
4 changes: 4 additions & 0 deletions bindings/python/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,15 @@ def test_numpy_example(self):
tensors = {"a": np.zeros((2, 2)), "b": np.zeros((2, 3), dtype=np.uint8)}

save_file(tensors, "./out.safetensors")
out = save(tensors)

# Now loading
loaded = load_file("./out.safetensors")
self.assertTensorEqual(tensors, loaded, np.allclose)

loaded = load(out)
self.assertTensorEqual(tensors, loaded, np.allclose)

def test_torch_example(self):
tensors = {
"a": torch.zeros((2, 2)),
Expand Down

0 comments on commit 02e5707

Please sign in to comment.