diff --git a/weather_mv/loader_pipeline/ee.py b/weather_mv/loader_pipeline/ee.py index 69b5b637..e622bbbc 100644 --- a/weather_mv/loader_pipeline/ee.py +++ b/weather_mv/loader_pipeline/ee.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import contextlib import dataclasses +import datetime import json import logging import os @@ -24,9 +26,11 @@ import typing as t from multiprocessing import Process, Queue +import affine import apache_beam as beam import ee import numpy as np +import rasterio import xarray as xr from apache_beam.io.filesystems import FileSystems from apache_beam.io.gcp.gcsio import WRITE_CHUNK_SIZE @@ -34,6 +38,7 @@ from apache_beam.utils import retry from google.auth import compute_engine, default, credentials from google.auth.transport import requests +from pyproj import CRS from rasterio.io import MemoryFile from .sinks import ToDataSink, open_dataset, open_local, KwargsFactoryMixin @@ -243,7 +248,7 @@ def add_parser_arguments(cls, subparser: argparse.ArgumentParser): help='The GCS location where the asset files will be pushed.') subparser.add_argument('--ee_asset', type=str, required=True, default=None, help='The asset folder path in earth engine project where the asset files' - ' will be pushed.') + ' will be pushed.') subparser.add_argument('--ee_asset_type', type=str, choices=['IMAGE', 'TABLE'], default='IMAGE', help='The type of asset to ingest in the earth engine.') subparser.add_argument('--xarray_open_dataset_kwargs', type=json.loads, default='{}', @@ -332,16 +337,16 @@ def expand(self, paths): band_names_dict = json.load(f) if not self.dry_run: ( - paths - | 'FilterFiles' >> FilterFilesTransform.from_kwargs(**vars(self)) - | 'ReshuffleFiles' >> beam.Reshuffle() - | 'ConvertToAsset' >> ConvertToAsset.from_kwargs(band_names_dict=band_names_dict, **vars(self)) - | 'IngestIntoEE' >> IngestIntoEETransform.from_kwargs(**vars(self)) + paths + | 'FilterFiles' >> FilterFilesTransform.from_kwargs(**vars(self)) + | 'ReshuffleFiles' >> beam.Reshuffle() + | 'ConvertToAsset' >> ConvertToAsset.from_kwargs(band_names_dict=band_names_dict, **vars(self)) + | 'IngestIntoEE' >> IngestIntoEETransform.from_kwargs(**vars(self)) ) else: ( - paths - | 'Log Files' >> beam.Map(logger.info) + paths + | 'Log Files' >> beam.Map(logger.info) ) @@ -404,7 +409,7 @@ def process(self, uri: str) -> t.Iterator[str]: @dataclasses.dataclass class ConvertToAsset(beam.DoFn, beam.PTransform, KwargsFactoryMixin): - """Writes asset after extracting input data and uploads it to GCS. + """Original way or preparing an asset for Earth Engine ingestion. Attributes: ee_asset_type: The type of asset to ingest in the earth engine. Default: IMAGE. @@ -466,11 +471,11 @@ def convert_to_asset(self, queue: Queue, uri: str): transform=transform, compress='lzw') as f: for i, da in enumerate(data): - f.write(da, i+1) + f.write(da, i + 1) # Making the channel name EE-safe before adding it as a band name. - f.set_band_description(i+1, get_ee_safe_name(channel_names[i])) - f.update_tags(i+1, band_name=channel_names[i]) - f.update_tags(i+1, **da.attrs) + f.set_band_description(i + 1, get_ee_safe_name(channel_names[i])) + f.update_tags(i + 1, band_name=channel_names[i]) + f.update_tags(i + 1, **da.attrs) # Write attributes as tags in tiff. f.update_tags(**attrs) @@ -542,6 +547,164 @@ def expand(self, pcoll): return pcoll | beam.FlatMap(self.process) +class XarrayToAsset(ConvertToAsset): + """Convert Xarray data to an Earth Engine asset. + + Includes utility methods with good defaults to make loading Xarray-readable + data into Earth Engine easier. + + For nearly all use cases, to load (raster) data into Google Earth Engine, + all you need to do is overload `open_dataset()` (and maybe + `define_name_and_target_path()`). + + Overloading `apply()` is necessary in order to ingest tabular data as an + Earth Engine feature collection (see the base class implementation to + get a good idea for how to do this). + """ + + def open_dataset(self, local_path: str, uri: t.Optional[str] = None, **kwargs) -> xr.Dataset: + """Open raw weather data as an Xarray Dataset, using weather-mv's method by default.""" + with open_dataset(local_path, **kwargs) as ds: + return ds + + def define_name_and_target_path(self, uri: str) -> t.Tuple[str, str]: + """Define an EE-safe asset name and target path based on input bucket.""" + asset_name = get_ee_safe_name(uri) + + file_name = f'{asset_name}.tiff' + target_path = os.path.join( + self.asset_location, file_name + ) + + return asset_name, target_path + + def apply(self, uri: str) -> AssetData: + """Core XArray-to-COG transformation.""" + with self.open(uri) as ds: + asset_name, target_path = self.define_name_and_target_path(uri) + with self.to_cog(ds) as memfile: + self.upload(memfile, target_path) + return to_asset_data(ds, asset_name, target_path) + + @contextlib.contextmanager + def open(self, uri: str, **kwargs) -> t.Iterator[xr.Dataset]: + """Copies raw data to local VM before opening as an Xarray dataset.""" + try: + with open_local(uri) as local_path: + ds = self.open_dataset(local_path, uri=uri, **kwargs) + logger.info(f'Opened dataset size: {ds.nbytes}') + ( + beam.metrics.Metrics.counter('Success', 'ReadXArrayDataset') + .inc() + ) + yield ds + except Exception as e: + beam.metrics.Metrics.counter('Failure', 'ReadXArrayDataset').inc() + logger.error(f'Unable to open file {uri!r}: {e}') + raise + + @contextlib.contextmanager + def to_cog( + self, + ds: xr.Dataset, + channel_names: t.Optional[t.List[str]] = None, + attrs: t.Optional[t.Dict[str, t.Any]] = None, + width: t.Optional[int] = None, + height: t.Optional[int] = None, + dtype: t.Union[np.dtype, str, None] = None, + transform: t.Union[affine.Affine, t.List[float], None] = None, + crs: t.Union[str, CRS, None] = None, + compression: str = 'lzw', + nodata: float = np.nan, + ) -> t.Iterator[rasterio.MemoryFile]: + """Write a xarray.Dataset to an in-memory COG.""" + if channel_names is None: + channel_names = list(ds.keys()) + if attrs is None: + attrs = ds.attrs + if dtype is None: + dtype = attrs.get('dtype', list(ds.dtypes.values())[0]) + if crs is None: + if 'crs' not in attrs: + raise ValueError( + 'unknown CRS. Please specify CRS as a parameter or in dataset' + ' attrs.' + ) + crs = attrs['crs'] + if transform is None: + if 'transform' not in attrs: + raise ValueError( + 'unknown transform. Please specify transform as a parameter or in' + ' dataset attrs.' + ) + transform = attrs['transform'] + + shape = tuple(ds.dims.values()) + x_size = width or ds.dims.get('x', None) or shape[1] + y_size = height or ds.dims.get('y', None) or shape[0] + + with MemoryFile() as memfile: + with memfile.open(driver='COG', + dtype=dtype, + width=x_size, + height=y_size, + count=len(ds.data_vars), + nodata=nodata, + crs=crs, + transform=transform, + compress=compression) as f: + for i, (name, da) in enumerate(ds.items()): + f.write(da, i + 1) + # Making the channel name EE-safe setting the band name. + safe_name = get_ee_safe_name(name) + f.set_band_description(i + 1, safe_name) + f.update_tags(i + 1, band_name=safe_name) + if da.attrs: + f.update_tags(i + 1, **make_attrs_ee_compatible(da.attrs)) + + # Write attributes as tags in asset. + f.update_tags(**make_attrs_ee_compatible(attrs)) + + yield memfile + + def upload(self, src: t.IO, dst: str) -> None: + # TODO(alxr): Check to see if we can make this faster with gsutil. + # Here, we need to evaluate if it's worth writing the COG in memory... + with FileSystems().create(dst) as dst_: + shutil.copyfileobj(src, dst_, WRITE_CHUNK_SIZE) + + def convert_to_asset(self, queue: Queue, uri: str) -> None: + self.add_to_queue(queue, self.apply(uri)) + # Indicates the end of subprocess. + self.add_to_queue(queue, None) + + +def _to_epoch(time_: t.Union[datetime.datetime, str]) -> float: + if isinstance(time_, str): + time_ = datetime.datetime.fromisoformat(time_) + + if isinstance(time_, datetime.datetime): + time_ = time_.timestamp() + + return float(time_) + + +def to_asset_data(ds: xr.Dataset, name: str, target_path: str) -> AssetData: + """Parses xarray.Dataset and metadata to AssetData.""" + start_time, end_time = ( + _to_epoch(ds.attrs[key]) for key in ('start_time', 'end_time') + ) + channel_names = [str(k) for k in ds.keys()] + return AssetData( + name=name, + target_path=target_path, + channel_names=channel_names, + start_time=start_time, + end_time=end_time, + properties=ds.attrs, + ) + + class IngestIntoEETransform(SetupEarthEngine, KwargsFactoryMixin): """Ingests asset into earth engine and yields asset id. diff --git a/weather_mv/loader_pipeline/util.py b/weather_mv/loader_pipeline/util.py index 54850cf8..6582fc76 100644 --- a/weather_mv/loader_pipeline/util.py +++ b/weather_mv/loader_pipeline/util.py @@ -69,6 +69,8 @@ def make_attrs_ee_compatible(attrs: t.Dict) -> t.Dict: k = re.sub(r'[^a-zA-Z0-9-_]+', r'_', k) if type(v) not in [int, float]: + if isinstance(v, bytes): + v = v.decode('utf-8') v = str(v) if len(v) > 1024: v = f'{v[:1021]}...' # Since 1 char = 1 byte.