Skip to content

Commit

Permalink
more itemizations in assay-matrix ingestor
Browse files Browse the repository at this point in the history
  • Loading branch information
johnkerl committed Sep 2, 2022
1 parent 6733c13 commit ad10f3e
Showing 1 changed file with 52 additions and 3 deletions.
55 changes: 52 additions & 3 deletions apis/python/src/tiledbsc/assay_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,14 @@ def __init__(
* For reading from an already-populated SOMA, we wish to avoid cache-coherency issues.
"""
super().__init__(uri=uri, name=name, parent=parent)
s0 = self.timing_start("__init__", "total")

self.row_dim_name = row_dim_name
self.col_dim_name = col_dim_name
self.attr_name = "value"
self.row_dataframe = row_dataframe
self.col_dataframe = col_dataframe
self.timing_end(s0)

# ----------------------------------------------------------------
def shape(self) -> Tuple[int, int]:
Expand All @@ -69,12 +71,15 @@ def shape(self) -> Tuple[int, int]:
Note: currently implemented via data scan -- will be optimized for TileDB core 2.10.
"""
s1 = self.timing_start("shape", "total")
with self._open():
# These TileDB arrays are string-dimensioned sparse arrays so there is no '.shape'.
# Instead we compute it ourselves. See also:
num_rows = self.row_dataframe.shape()[0]
num_cols = self.col_dataframe.shape()[0]
return (num_rows, num_cols)
retval = (num_rows, num_cols)
self.timing_end(s1)
return retval

# ----------------------------------------------------------------
def dim_select(
Expand All @@ -89,7 +94,13 @@ def dim_select(
Either or both of the ID lists may be `None`, meaning, do not subselect along
that dimension. If both ID lists are `None`, the entire matrix is returned.
"""
s0 = self.timing_start("dim_select", "open")

s1 = self.timing_start("dim_select", "open")
with tiledb.open(self.uri, ctx=self._ctx) as A:
self.timing_end(s1)

s2 = self.timing_start("dim_select", "tiledb_query")
query = A.query(return_arrow=return_arrow)
if obs_ids is None:
if var_ids is None:
Expand All @@ -101,8 +112,14 @@ def dim_select(
df = query.df[obs_ids, :]
else:
df = query.df[obs_ids, var_ids]
self.timing_end(s2)

s3 = self.timing_start("dim_select", "set_index")
if not return_arrow:
df.set_index([self.row_dim_name, self.col_dim_name], inplace=True)
self.timing_end(s3)

self.timing_end(s0)
return df

# ----------------------------------------------------------------
Expand All @@ -126,15 +143,21 @@ def csr(
"""
Like `.df()` but returns results in `scipy.sparse.csr_matrix` format.
"""
return self._csr_or_csc("csr", obs_ids, var_ids)
s0 = self.timing_start("csr", "total")
retval = self._csr_or_csc("csr", obs_ids, var_ids)
self.timing_end(s0)
return retval

def csc(
self, obs_ids: Optional[Ids] = None, var_ids: Optional[Ids] = None
) -> sp.csc_matrix:
"""
Like `.df()` but returns results in `scipy.sparse.csc_matrix` format.
"""
return self._csr_or_csc("csc", obs_ids, var_ids)
s0 = self.timing_start("csc", "total")
retval = self._csr_or_csc("csc", obs_ids, var_ids)
self.timing_end(s0)
return retval

def _csr_or_csc(
self,
Expand Down Expand Up @@ -168,6 +191,7 @@ def from_matrix_and_dim_values(
`scipy.sparse.csr_matrix`, `scipy.sparse.csc_matrix`, `numpy.ndarray`, etc.
For ingest from `AnnData`, these should be `ann.obs_names` and `ann.var_names`.
"""
s0 = self.timing_start("from_matrix_and_dim_values", "total")

s = util.get_start_stamp()
log_io(
Expand Down Expand Up @@ -206,11 +230,14 @@ def from_matrix_and_dim_values(
util.format_elapsed(s, f"{self._indent}FINISH WRITING {self.uri}"),
)

self.timing_end(s0)

# ----------------------------------------------------------------
def _create_empty_array(self, matrix_dtype: np.dtype) -> None:
"""
Create a TileDB 2D sparse array with string dimensions and a single attribute.
"""
s0 = self.timing_start("_create_empty_array", "total")

level = self._soma_options.string_dim_zstd_level
dom = tiledb.Domain(
Expand Down Expand Up @@ -253,6 +280,7 @@ def _create_empty_array(self, matrix_dtype: np.dtype) -> None:
)

tiledb.Array.create(self.uri, sch, ctx=self._ctx)
self.timing_end(s0)

# ----------------------------------------------------------------
def ingest_data_whole(
Expand All @@ -269,6 +297,7 @@ def ingest_data_whole(
:param row_names: List of row names.
:param col_names: List of column names.
"""
s0 = self.timing_start("ingest_data_whole", "total")

assert len(row_names) == matrix.shape[0]
assert len(col_names) == matrix.shape[1]
Expand All @@ -279,6 +308,7 @@ def ingest_data_whole(

with tiledb.open(self.uri, mode="w", ctx=self._ctx) as A:
A[d0, d1] = mat_coo.data
self.timing_end(s0)

# ----------------------------------------------------------------
# Example: suppose this 4x3 is to be written in two chunks of two rows each
Expand Down Expand Up @@ -326,7 +356,9 @@ def ingest_data_rows_chunked(
:param row_names: List of row names.
:param col_names: List of column names.
"""
s0 = self.timing_start("ingest_data_rows_chunked", "total")

s1 = self.timing_start("ingest_data_rows_chunked", "sortprep")
assert len(row_names) == matrix.shape[0]
assert len(col_names) == matrix.shape[1]

Expand All @@ -346,13 +378,20 @@ def ingest_data_rows_chunked(
f"{self._indent}START ingest_data_rows_chunked",
)

self.timing_end(s1)

eta_tracker = util.ETATracker()
s2 = self.timing_start("ingest_data_rows_chunked", "open")
with tiledb.open(self.uri, mode="w", ctx=self._ctx) as A:
self.timing_end(s2)

nrow = len(sorted_row_names)

i = 0
while i < nrow:
t1 = time.time()

s3 = self.timing_start("ingest_data_rows_chunked", "chunkprep")
# Find a number of CSR rows which will result in a desired nnz for the chunk.
chunk_size = util._find_csr_chunk_size(
matrix, permutation, i, self._soma_options.goal_chunk_nnz
Expand All @@ -365,6 +404,7 @@ def ingest_data_rows_chunked(
# Write the chunk-COO to TileDB.
d0 = sorted_row_names[chunk_coo.row + i]
d1 = col_names[chunk_coo.col]
self.timing_end(s3)

if len(d0) == 0:
i = i2
Expand All @@ -390,7 +430,9 @@ def ingest_data_rows_chunked(
)

# Write a TileDB fragment
s4 = self.timing_start("ingest_data_rows_chunked", "tiledb-write")
A[d0, d1] = chunk_coo.data
self.timing_end(s4)

t2 = time.time()
chunk_seconds = t2 - t1
Expand All @@ -413,6 +455,7 @@ def ingest_data_rows_chunked(
f"{self._indent}FINISH __ingest_coo_data_string_dims_rows_chunked",
),
)
self.timing_end(s0)

# This method is very similar to ingest_data_rows_chunked. The code is largely repeated,
# and this is intentional. The algorithm here is non-trivial (among the most non-trivial
Expand All @@ -432,6 +475,7 @@ def ingest_data_cols_chunked(
:param row_names: List of row names.
:param col_names: List of column names.
"""
s0 = self.timing_start("ingest_data_cols_chunked", "total")

assert len(row_names) == matrix.shape[0]
assert len(col_names) == matrix.shape[1]
Expand Down Expand Up @@ -519,6 +563,7 @@ def ingest_data_cols_chunked(
f"{self._indent}FINISH __ingest_coo_data_string_dims_rows_chunked",
),
)
self.timing_end(s0)

# This method is very similar to ingest_data_rows_chunked. The code is largely repeated,
# and this is intentional. The algorithm here is non-trivial (among the most non-trivial
Expand All @@ -538,6 +583,7 @@ def ingest_data_dense_rows_chunked(
:param row_names: List of row names.
:param col_names: List of column names.
"""
s0 = self.timing_start("ingest_data_dense_rows_chunked", "total")

assert len(row_names) == matrix.shape[0]
assert len(col_names) == matrix.shape[1]
Expand Down Expand Up @@ -627,6 +673,7 @@ def ingest_data_dense_rows_chunked(
f"{self._indent}FINISH __ingest_coo_data_string_dims_dense_rows_chunked",
),
)
self.timing_end(s0)

# ----------------------------------------------------------------
def to_csr_matrix(self, row_labels: Labels, col_labels: Labels) -> sp.csr_matrix:
Expand All @@ -638,6 +685,7 @@ def to_csr_matrix(self, row_labels: Labels, col_labels: Labels) -> sp.csr_matrix
be in the same order as they were in any anndata object which was used to create the
TileDB storage.
"""
s0 = self.timing_start("to_csr_matrix", "total")

s = util.get_start_stamp()
log_io(None, f"{self._indent}START read {self.uri}")
Expand All @@ -649,4 +697,5 @@ def to_csr_matrix(self, row_labels: Labels, col_labels: Labels) -> sp.csr_matrix
util.format_elapsed(s, f"{self._indent}FINISH read {self.uri}"),
)

self.timing_end(s0)
return csr

0 comments on commit ad10f3e

Please sign in to comment.