Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tiled arbitrary grids #10

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
further improve typing
chrishavlin committed Oct 11, 2024
commit f5c368d9614519dd094808a1f3ef571aa70aa254
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -119,4 +119,4 @@ enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
warn_unreachable = true
disallow_untyped_defs = false
disallow_incomplete_defs = false
disable_error_code = ["import-untyped", "import-not-found", "no-untyped-call"]
disable_error_code = ["import-untyped", "import-not-found"]
11 changes: 11 additions & 0 deletions yt_experiments/tiled_grid/tests/test_tiled_grid.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pytest
import unyt
from numpy.testing import assert_equal
from yt.testing import fake_amr_ds, requires_module
@@ -78,6 +79,16 @@ def test_arbitrary_grid_oct():
assert level_arrays[ilev].shape == expected_levels[ilev]


def test_missing_ds():
with pytest.raises(ValueError, match="Please provide a dataset"):
_ = YTTiledArbitraryGrid(
unyt.unyt_array([0, 0, 0], "m"),
unyt.unyt_array([1, 1, 1], "m"),
(20, 20, 20),
5,
)


@requires_module("xarray")
def test_arbitrary_grid_to_xarray():
import xarray as xr
60 changes: 43 additions & 17 deletions yt_experiments/tiled_grid/tiled_grid.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,10 @@
from yt.data_objects.construction_data_containers import YTArbitraryGrid
from yt.data_objects.static_output import Dataset

_GridInfo = tuple[
npt.NDArray, npt.NDArray, unyt.unyt_array, unyt.unyt_array, Any, npt.NDArray
]


def _validate_edge(edge: npt.ArrayLike, ds: Dataset):
if not isinstance(edge, unyt.unyt_array):
@@ -62,6 +66,9 @@ def __init__(

"""

if ds is None:
raise ValueError("Please provide a dataset via the ds keyword argument")

self.ds = ds
self.left_edge = _validate_edge(left_edge, ds)
self.right_edge = _validate_edge(right_edge, ds)
@@ -86,7 +93,7 @@ def __init__(
self._left_cell_center = self.left_edge + self.dds / 2.0
self._right_cell_center = self.right_edge - self.dds / 2.0

def __repr__(self):
def __repr__(self) -> str:
nm = self.__class__.__name__
shape = tuple(self.dims)
n_chunks = tuple(self.nchunks)
@@ -97,13 +104,13 @@ def __repr__(self):
)
return msg

def _get_grid_by_ijk(self, ijk_grid):
def _get_grid_by_ijk(self, ijk_grid: npt.NDArray[int]) -> _GridInfo:
chunksizes = self.chunks

le_index = []
re_index = []
le_val = self.ds.domain_left_edge.copy()
re_val = self.ds.domain_right_edge.copy()
le_val: unyt.unyt_array = self.ds.domain_left_edge.copy()
re_val: unyt.unyt_array = self.ds.domain_right_edge.copy()

for idim in range(self._ndim):
chunk_i = ijk_grid[idim]
@@ -122,29 +129,29 @@ def _get_grid_by_ijk(self, ijk_grid):
le_index[2] : re_index[2],
]

le_index = np.array(le_index, dtype=int)
re_index = np.array(re_index, dtype=int)
le_index_ = np.array(le_index, dtype=int)
re_index_ = np.array(re_index, dtype=int)
shape = chunksizes

return le_index, re_index, le_val, re_val, slc, shape
return le_index_, re_index_, le_val, re_val, slc, shape

def _get_grid(self, igrid: int):
def _get_grid(self, igrid: int) -> _GridInfo:
# get grid extent of a **single** grid
ijk_grid = np.unravel_index(igrid, self.nchunks)
return self._get_grid_by_ijk(ijk_grid)

def _coord_array(self, idim):
def _coord_array(self, idim: int) -> npt.NDArray:
LE = self._left_cell_center[idim]
RE = self._right_cell_center[idim]
N = self.dims[idim]
return np.mgrid[LE : RE : N * 1j]

def to_xarray(self, field, *, output_array=None):
def to_xarray(
self, field: tuple[str, str], *, output_array: npt.ArrayLike | None = None
) -> Any:

import xarray as xr

# ToDo: import from on_demand_imports

vals = self.to_array(field, output_array=output_array)

dims = self.ds.coordinates.axis_order
@@ -162,7 +169,13 @@ def to_xarray(self, field, *, output_array=None):
)
return xr_ds

def single_grid_values(self, igrid, field, *, ops=None):
def single_grid_values(
self,
igrid: int,
field: tuple[str, str],
*,
ops: list[Callable[[npt.NDArray], npt.NDArray]] | None = None,
) -> tuple[npt.NDArray, Any]:
"""
Get the values for a field for a single grid chunk as in-memory array.

@@ -308,7 +321,9 @@ def __init__(

self.levels: list[YTTiledArbitraryGrid] = levels

def _validate_levels(self, levels):
def _validate_levels(
self, levels: Sequence[int | tuple[int, int, int] | npt.ArrayLike]
):

for ilev in range(1, self.n_levels):
res = np.prod(levels[ilev])
@@ -321,7 +336,7 @@ def _validate_levels(self, levels):
)
raise ValueError(msg)

def __repr__(self):
def __repr__(self) -> str:
return (
f"{self.__class__.__name__} with {self.n_levels} levels and base resolution "
f"{self.base_resolution}"
@@ -330,7 +345,11 @@ def __repr__(self):
def base_resolution(self) -> tuple[int, int, int]:
return tuple(self[0].dims)

def to_arrays(self, field, output_arrays=None):
def to_arrays(
self,
field: tuple[str, str],
output_arrays: list[npt.ArrayLike | None] | None = None,
) -> list[npt.ArrayLike]:
if output_arrays is None:
output_arrays = [None for _ in range(len(self.levels))]

@@ -390,7 +409,14 @@ def _validate_factor(
return np.asarray(input_factor, dtype=int)


def _get_filled_grid(le, re, shp, field, ds, field_parameters):
def _get_filled_grid(
le: npt.NDArray,
re: npt.NDArray,
shp: npt.NDArray,
field: tuple[str, str],
ds: Dataset,
field_parameters: Any,
) -> npt.NDArray:
grid = YTArbitraryGrid(le, re, shp, ds=ds, field_parameters=field_parameters)
vals = grid[field]
return vals