Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 176 additions & 13 deletions weather_mv/loader_pipeline/ee.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
import dataclasses
import datetime
import json
import logging
import os
Expand All @@ -24,16 +26,19 @@
import typing as t
from multiprocessing import Process, Queue

import affine
import apache_beam as beam
import ee
import numpy as np
import rasterio
import xarray as xr
from apache_beam.io.filesystems import FileSystems
from apache_beam.io.gcp.gcsio import WRITE_CHUNK_SIZE
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.utils import retry
from google.auth import compute_engine, default, credentials
from google.auth.transport import requests
from pyproj import CRS
from rasterio.io import MemoryFile

from .sinks import ToDataSink, open_dataset, open_local, KwargsFactoryMixin
Expand Down Expand Up @@ -243,7 +248,7 @@ def add_parser_arguments(cls, subparser: argparse.ArgumentParser):
help='The GCS location where the asset files will be pushed.')
subparser.add_argument('--ee_asset', type=str, required=True, default=None,
help='The asset folder path in earth engine project where the asset files'
' will be pushed.')
' will be pushed.')
subparser.add_argument('--ee_asset_type', type=str, choices=['IMAGE', 'TABLE'], default='IMAGE',
help='The type of asset to ingest in the earth engine.')
subparser.add_argument('--xarray_open_dataset_kwargs', type=json.loads, default='{}',
Expand Down Expand Up @@ -332,16 +337,16 @@ def expand(self, paths):
band_names_dict = json.load(f)
if not self.dry_run:
(
paths
| 'FilterFiles' >> FilterFilesTransform.from_kwargs(**vars(self))
| 'ReshuffleFiles' >> beam.Reshuffle()
| 'ConvertToAsset' >> ConvertToAsset.from_kwargs(band_names_dict=band_names_dict, **vars(self))
| 'IngestIntoEE' >> IngestIntoEETransform.from_kwargs(**vars(self))
paths
| 'FilterFiles' >> FilterFilesTransform.from_kwargs(**vars(self))
| 'ReshuffleFiles' >> beam.Reshuffle()
| 'ConvertToAsset' >> ConvertToAsset.from_kwargs(band_names_dict=band_names_dict, **vars(self))
| 'IngestIntoEE' >> IngestIntoEETransform.from_kwargs(**vars(self))
)
else:
(
paths
| 'Log Files' >> beam.Map(logger.info)
paths
| 'Log Files' >> beam.Map(logger.info)
)


Expand Down Expand Up @@ -404,7 +409,7 @@ def process(self, uri: str) -> t.Iterator[str]:

@dataclasses.dataclass
class ConvertToAsset(beam.DoFn, beam.PTransform, KwargsFactoryMixin):
"""Writes asset after extracting input data and uploads it to GCS.
"""Original way or preparing an asset for Earth Engine ingestion.

Attributes:
ee_asset_type: The type of asset to ingest in the earth engine. Default: IMAGE.
Expand Down Expand Up @@ -466,11 +471,11 @@ def convert_to_asset(self, queue: Queue, uri: str):
transform=transform,
compress='lzw') as f:
for i, da in enumerate(data):
f.write(da, i+1)
f.write(da, i + 1)
# Making the channel name EE-safe before adding it as a band name.
f.set_band_description(i+1, get_ee_safe_name(channel_names[i]))
f.update_tags(i+1, band_name=channel_names[i])
f.update_tags(i+1, **da.attrs)
f.set_band_description(i + 1, get_ee_safe_name(channel_names[i]))
f.update_tags(i + 1, band_name=channel_names[i])
f.update_tags(i + 1, **da.attrs)
# Write attributes as tags in tiff.
f.update_tags(**attrs)

Expand Down Expand Up @@ -542,6 +547,164 @@ def expand(self, pcoll):
return pcoll | beam.FlatMap(self.process)


class XarrayToAsset(ConvertToAsset):
"""Convert Xarray data to an Earth Engine asset.

Includes utility methods with good defaults to make loading Xarray-readable
data into Earth Engine easier.

For nearly all use cases, to load (raster) data into Google Earth Engine,
all you need to do is overload `open_dataset()` (and maybe
`define_name_and_target_path()`).

Overloading `apply()` is necessary in order to ingest tabular data as an
Earth Engine feature collection (see the base class implementation to
get a good idea for how to do this).
"""

def open_dataset(self, local_path: str, uri: t.Optional[str] = None, **kwargs) -> xr.Dataset:
"""Open raw weather data as an Xarray Dataset, using weather-mv's method by default."""
with open_dataset(local_path, **kwargs) as ds:
return ds

def define_name_and_target_path(self, uri: str) -> t.Tuple[str, str]:
"""Define an EE-safe asset name and target path based on input bucket."""
asset_name = get_ee_safe_name(uri)

file_name = f'{asset_name}.tiff'
target_path = os.path.join(
self.asset_location, file_name
)

return asset_name, target_path

def apply(self, uri: str) -> AssetData:
"""Core XArray-to-COG transformation."""
with self.open(uri) as ds:
asset_name, target_path = self.define_name_and_target_path(uri)
with self.to_cog(ds) as memfile:
self.upload(memfile, target_path)
return to_asset_data(ds, asset_name, target_path)

@contextlib.contextmanager
def open(self, uri: str, **kwargs) -> t.Iterator[xr.Dataset]:
"""Copies raw data to local VM before opening as an Xarray dataset."""
try:
with open_local(uri) as local_path:
ds = self.open_dataset(local_path, uri=uri, **kwargs)
logger.info(f'Opened dataset size: {ds.nbytes}')
(
beam.metrics.Metrics.counter('Success', 'ReadXArrayDataset')
.inc()
)
yield ds
except Exception as e:
beam.metrics.Metrics.counter('Failure', 'ReadXArrayDataset').inc()
logger.error(f'Unable to open file {uri!r}: {e}')
raise

@contextlib.contextmanager
def to_cog(
self,
ds: xr.Dataset,
channel_names: t.Optional[t.List[str]] = None,
attrs: t.Optional[t.Dict[str, t.Any]] = None,
width: t.Optional[int] = None,
height: t.Optional[int] = None,
dtype: t.Union[np.dtype, str, None] = None,
transform: t.Union[affine.Affine, t.List[float], None] = None,
crs: t.Union[str, CRS, None] = None,
compression: str = 'lzw',
nodata: float = np.nan,
) -> t.Iterator[rasterio.MemoryFile]:
"""Write a xarray.Dataset to an in-memory COG."""
if channel_names is None:
channel_names = list(ds.keys())
if attrs is None:
attrs = ds.attrs
if dtype is None:
dtype = attrs.get('dtype', list(ds.dtypes.values())[0])
if crs is None:
if 'crs' not in attrs:
raise ValueError(
'unknown CRS. Please specify CRS as a parameter or in dataset'
' attrs.'
)
crs = attrs['crs']
if transform is None:
if 'transform' not in attrs:
raise ValueError(
'unknown transform. Please specify transform as a parameter or in'
' dataset attrs.'
)
transform = attrs['transform']

shape = tuple(ds.dims.values())
x_size = width or ds.dims.get('x', None) or shape[1]
y_size = height or ds.dims.get('y', None) or shape[0]

with MemoryFile() as memfile:
with memfile.open(driver='COG',
dtype=dtype,
width=x_size,
height=y_size,
count=len(ds.data_vars),
nodata=nodata,
crs=crs,
transform=transform,
compress=compression) as f:
for i, (name, da) in enumerate(ds.items()):
f.write(da, i + 1)
# Making the channel name EE-safe setting the band name.
safe_name = get_ee_safe_name(name)
f.set_band_description(i + 1, safe_name)
f.update_tags(i + 1, band_name=safe_name)
if da.attrs:
f.update_tags(i + 1, **make_attrs_ee_compatible(da.attrs))

# Write attributes as tags in asset.
f.update_tags(**make_attrs_ee_compatible(attrs))

yield memfile

def upload(self, src: t.IO, dst: str) -> None:
# TODO(alxr): Check to see if we can make this faster with gsutil.
# Here, we need to evaluate if it's worth writing the COG in memory...
with FileSystems().create(dst) as dst_:
shutil.copyfileobj(src, dst_, WRITE_CHUNK_SIZE)

def convert_to_asset(self, queue: Queue, uri: str) -> None:
self.add_to_queue(queue, self.apply(uri))
# Indicates the end of subprocess.
self.add_to_queue(queue, None)


def _to_epoch(time_: t.Union[datetime.datetime, str]) -> float:
if isinstance(time_, str):
time_ = datetime.datetime.fromisoformat(time_)

if isinstance(time_, datetime.datetime):
time_ = time_.timestamp()

return float(time_)


def to_asset_data(ds: xr.Dataset, name: str, target_path: str) -> AssetData:
"""Parses xarray.Dataset and metadata to AssetData."""
start_time, end_time = (
_to_epoch(ds.attrs[key]) for key in ('start_time', 'end_time')
)
channel_names = [str(k) for k in ds.keys()]
return AssetData(
name=name,
target_path=target_path,
channel_names=channel_names,
start_time=start_time,
end_time=end_time,
properties=ds.attrs,
)


class IngestIntoEETransform(SetupEarthEngine, KwargsFactoryMixin):
"""Ingests asset into earth engine and yields asset id.

Expand Down
2 changes: 2 additions & 0 deletions weather_mv/loader_pipeline/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def make_attrs_ee_compatible(attrs: t.Dict) -> t.Dict:
k = re.sub(r'[^a-zA-Z0-9-_]+', r'_', k)

if type(v) not in [int, float]:
if isinstance(v, bytes):
v = v.decode('utf-8')
v = str(v)
if len(v) > 1024:
v = f'{v[:1021]}...' # Since 1 char = 1 byte.
Expand Down