Skip to content

PseudoCode Snippets

Juan Emmanuel Johnson edited this page Feb 20, 2024 · 2 revisions
# create extra dimensions
da = da.assign(band=ds.band_id, time=ds.time)
# select within time range
ds = ds.sel(time=ds.time.dt.hour.isin(range(0, 24, time_resolution_hours)))
# convert to numpy array
ds: np.ndarray = (
        ds.to_array(dim="variable").transpose("variable", "time", "latitude", "longitude").data
    )
# get nearest lat-lon index
def get_nearest_latlon_index(ds: xr.Dataset, lat: float, lon: float) -> tuple[float, float]:
    """Get the nearest index into lat/lon to the given lat/lon"""
    lat_index = np.abs(ds.latitude.data - lat).argmin()
    lon_index = np.abs(ds.longitude.data - lon).argmin()
    return lat_index, lon_index

lat_index, lon_index = get_nearest_latlon_index(ds, lat, lon)
    ds = ds.isel(
        latitude=slice(lat_index - image_size // 2, lat_index + image_size // 2),
        longitude=slice(lon_index - image_size // 2, lon_index + image_size // 2),
    )

Plot a map where you highlight a country or region

        # Create the figure and axes with PlateCarree projection
        ax_low = plt.axes(projection=ccrs.PlateCarree())

        # Plot coastlines
        ax_low.coastlines()

        # Plot country boundaries for Brazil with thicker lines
        brazil = cfeature.NaturalEarthFeature(
            category='cultural',
            name='admin_1_states_provinces_lines',
            scale='50m',
            facecolor='none',
            edgecolor='white',  # Change the edge color to white
            linewidth=0.5       # Increase the linewidth to make it thicker
        )
        ax_low.add_feature(brazil)

        # Plot country boundaries for other countries with thinner lines
        other_countries = cfeature.NaturalEarthFeature(
            category='cultural',
            name='admin_0_countries',
            scale='50m',
            facecolor='none',
            edgecolor='white',    # Use grey color for other countries
            linewidth=0.7        # Use thinner lines for other countries
        )
        ax_low.add_feature(other_countries)

        # Plot your data on the map
        data.isel(time=time).plot(ax=ax_low, transform=ccrs.PlateCarree(), vmin=vmin, vmax=vmax, cmap=cmap)

        # Add gridlines and labels with grey color
        gl = ax_low.gridlines(draw_labels=True, color='grey', linewidth=0.3)

        # Show the plot
        plt.show()

Create an xarray dataset in the safest way possible

   def create_xarray(data: np.array, start_date: pd.Timestamp = pd.Timestamp('2007-01-01'), bbox=[-35, -75, 5, -35]):
        num_time_steps, image_size, _ = data.shape
        
        date_range = xr.cftime_range(start=start_date, periods=num_time_steps, freq="1M")

        # Create the time_coords using xr.DataArray with date values
        time_coords = xr.DataArray(date_range, dims=("time",), attrs={"units": "months"})

        # Create latitude and longitude arrays
        latitude_values = np.linspace(bbox[0], bbox[2], image_size)
        longitude_values = np.linspace(bbox[1], bbox[3], image_size)

        # Create coordinate arrays using xarray DataArray
        latitude_coords = xr.DataArray(latitude_values, dims=("latitude",), attrs={"units": "degrees_north"})
        longitude_coords = xr.DataArray(longitude_values, dims=("longitude",), attrs={"units": "degrees_east"})

        # Create a DataArray with the input data and coordinate values
        data_array = xr.DataArray(
            data,
            dims=("time", "latitude", "longitude"),
            coords={"time": time_coords, "latitude": latitude_coords, "longitude": longitude_coords}
        )

        return data_array

A minimum example to load many files and clean

data = xr.open_mfdataset(f'{data_dir}/*.nc', combine='by_coords')
        data = data * 86400
        data = data.resample(time='1MS').sum()
        data = data.assign_coords(lon=(((data.lon + 180) % 360) - 180))
        data = data.roll(lon=int(len(data['lon']) / 2), roll_coords=True)
        data = data.sel(time=slice(time_init, time_end), lat=slice(bbox[0], bbox[2]), lon=slice(bbox[1], bbox[3]))
        data = data.rename({'lat': 'latitude', 'lon': 'longitude'})

A minimal example to load a file and convert it to a raster

data = xr.open_dataset(data_dir)
data = data.sel(time=slice(time_init, time_end), latitude=slice(bbox[0], bbox[2]), longitude=slice(bbox[1], bbox[3]))
data = data.rename({'precip': 'pr'})

data = data.rio.write_crs("EPSG:4326")
upscale_factor = 5

            new_width = data.rio.width // upscale_factor
            new_height = data.rio.height // upscale_factor

            data = data.rio.reproject(
                data.rio.crs,
                shape=(new_height, new_width),
                resampling=Resampling.average,
            )
            data = data.rename({'y': 'latitude', 'x': 'longitude'})

Create Slices for custom windows of xarray data

def slicing(data: xr.Dataset, size: int):
    """Slicing the data into smaller pieces and convert to numpy arrays.

    Args:
        data (xarray.Dataset): CMIP6 data for the given time period and bounding box.
        size (int): Size of the slices.

    Returns:
        data_slices (List[float, float]): List of slices.
    """
    data_slices = []
    for year in range(0, len(data['time'])):
        for latitude in range(0, len(data['latitude']), size):
            for longitude in range(0, len(data['longitude']), size):
                data_slices.append(data.pr.isel(time=year, latitude=slice(latitude, latitude+size), longitude=slice(longitude, longitude+size)).values)
    
    return np.array(data_slices)