Skip to content

Commit

Permalink
Fix defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Jan 24, 2025
1 parent 8d596d5 commit 1dadeac
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 36 deletions.
69 changes: 36 additions & 33 deletions ndonnx/_data_types/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,47 +63,50 @@
nutf8: NUtf8 = NUtf8()


_canonical_names = {
bool: "bool",
float32: "float32",
float64: "float64",
int8: "int8",
int16: "int16",
int32: "int32",
int64: "int64",
uint8: "uint8",
uint16: "uint16",
uint32: "uint32",
uint64: "uint64",
utf8: "utf8",
}


def canonical_name(dtype: CoreType) -> str:
"""Return the canonical name of the data type."""
if dtype == bool:
return "bool"
elif dtype == float32:
return "float32"
elif dtype == float64:
return "float64"
elif dtype == int8:
return "int8"
elif dtype == int16:
return "int16"
elif dtype == int32:
return "int32"
elif dtype == int64:
return "int64"
elif dtype == uint8:
return "uint8"
elif dtype == uint16:
return "uint16"
elif dtype == uint32:
return "uint32"
elif dtype == uint64:
return "uint64"
elif dtype == utf8:
return "utf8"
if dtype in _canonical_names:
return _canonical_names[dtype]
else:
raise ValueError(f"Unknown data type: {dtype}")


_kinds = {
bool: ("bool",),
int8: ("signed integer", "integer", "numeric"),
int16: ("signed integer", "integer", "numeric"),
int32: ("signed integer", "integer", "numeric"),
int64: ("signed integer", "integer", "numeric"),
uint8: ("unsigned integer", "integer", "numeric"),
uint16: ("unsigned integer", "integer", "numeric"),
uint32: ("unsigned integer", "integer", "numeric"),
uint64: ("unsigned integer", "integer", "numeric"),
float32: ("floating", "numeric"),
float64: ("floating", "numeric"),
}


def kinds(dtype: CoreType) -> tuple[str, ...]:
"""Return the kinds of the data type."""
if dtype in (bool,):
return ("bool",)
if dtype in (int8, int16, int32, int64):
return ("signed integer", "integer", "numeric")
if dtype in (uint8, uint16, uint32, uint64):
return ("unsigned integer", "integer", "numeric")
if dtype in (float32, float64):
return ("floating", "numeric")
if dtype in (utf8,):
if dtype in _kinds:
return _kinds[dtype]
elif dtype in (utf8,):
raise ValueError(f"We don't get define a kind for {dtype}")
else:
raise ValueError(f"Unknown data type: {dtype}")
15 changes: 12 additions & 3 deletions ndonnx/_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,29 @@ def devices(self) -> list:
def dtypes(
self, *, device=None, kind: str | tuple[str, ...] | None = None
) -> dict[str, ndx.CoreType]:
# We don't care for device and don't use it.
# We don't care for device since we are writing ONNX graphs.
# We would rather not give users the impression that their arrays
# are tied to a specific device when serializing an ONNX graph as
# such a concept does not exist in the ONNX .
out: dict[str, ndx.CoreType] = {}
for dtype in self._all_array_api_types:
if kind is None or ndx.isdtype(dtype, kind):
out[canonical_name(dtype)] = dtype
return out

def default_dtypes(
self, *, device=None, kind: str | tuple[str, ...] | None
) -> dict[str, ndx.CoreType]:
self,
*,
device=None,
) -> dict[str, ndx.CoreType | None]:
# See comment in `dtypes` method regarding device.
return {
"real floating": ndx.float64,
"integral": ndx.int64,
"indexing": ndx.int64,
# We don't support complex numbers yet due to immaturity in the ONNX ecoystem, so "complex floating" is meaningless.
# The standard requires this key to be present.
"complex floating": None,
}


Expand Down

0 comments on commit 1dadeac

Please sign in to comment.