diff --git a/transformers/keras_transformer/Dockerfile b/transformers/keras_transformer/Dockerfile new file mode 100644 index 0000000..c71b238 --- /dev/null +++ b/transformers/keras_transformer/Dockerfile @@ -0,0 +1,12 @@ +FROM python:slim + +COPY requirements.txt requirements.txt +RUN pip3 install --upgrade -r requirements.txt + +RUN mkdir /code +WORKDIR /code +COPY server.py server.py + +ENV PYTHONUNBUFFERED 1 + +EXPOSE 80 diff --git a/transformers/keras_transformer/Makefile b/transformers/keras_transformer/Makefile new file mode 100644 index 0000000..67e5016 --- /dev/null +++ b/transformers/keras_transformer/Makefile @@ -0,0 +1,9 @@ +REGISTRY_URL ?= docker.io/aistorage + +all: build push + +build: + docker build -t $(REGISTRY_URL)/transformer_keras:latest . + +push: + docker push $(REGISTRY_URL)/transformer_keras:latest diff --git a/transformers/keras_transformer/README.md b/transformers/keras_transformer/README.md new file mode 100644 index 0000000..997ee15 --- /dev/null +++ b/transformers/keras_transformer/README.md @@ -0,0 +1,59 @@ +# Keras Transformer - Image Data Augmentation and Preprocessing + +The Keras Transformer is a powerful tool designed for image data preprocessing and data augmentation. Leveraging the `apply_transform` function from Keras (TensorFlow), this transformer allows users to define transformations by providing a JSON string with parameter-value pairs. Currently, the following parameters are supported: + +| Parameter | Description | +|-------------------------|---------------------------------------------------------| +| 'theta' | Rotation angle in degrees. | +| 'tx' | Shift in the x direction. | +| 'ty' | Shift in the y direction. | +| 'shear' | Shear angle in degrees. | +| 'zx' | Zoom in the x direction. | +| 'zy' | Zoom in the y direction. | +| 'flip_horizontal' | Boolean. Enable horizontal flip. | +| 'flip_vertical' | Boolean. Enable vertical flip. | +| 'channel_shift_intensity' | Float. Channel shift intensity. | +| 'brightness' | Float. Brightness shift intensity. | + +The image format (JPEG, PNG, etc.) of the images to be processed or stored is specified in the `spec.yaml`. + +The transformer supports both `hpull`, `hpush` and `hrev` communication mechanisms for seamless integration. + +> For more information on communication mechanisms, please refer to [this link](https://github.com/NVIDIA/aistore/blob/master/docs/etl.md#communication-mechanisms). + +## Parameters +Only two parameters need to be updated in the `pod.yaml` file. + +| Argument | Description | Default Value | +| ----------- | --------------------------------------------------------------------- | ------------- | +| `TRANSFORM` | Specify a JSON string with operations to be performed | `` | +| `FORMAT`| To process/store images in which image format (PNG, JPEG,etc) | `JPEG` | + +Please ensure to adjust these parameters according to your specific requirements. + +### Initializing ETL with AIStore CLI + +The following steps demonstrate how to initialize the `Keras Transformer` with using the [AIStore CLI](https://github.com/NVIDIA/aistore/blob/master/docs/cli.md): + +```!bash +$ cd transformers/keras_transformer + +$ # Set values for FORMAT and TRANSFORM +$ export FORMAT="JPEG" +$ export TRANSFORM='{"theta":40, "brightness":0.8, "zx":0.9, "zy":0.9}' + +$ # Mention communication type b/w target and container +$ export COMMUNICATION_TYPE = 'hpull://' + +# Substitute env variables in spec file +$ envsubst < pod.yaml > init_spec.yaml + +$ # Initialize ETL +$ ais etl init spec --from-file init_spec.yaml --name + +$ # Transform and retrieve objects from the bucket using this ETL +$ # For inline transformation +$ ais etl object ais://src/.JPEG dst.JPEG +$ # Or, for offline (bucket-to-bucket) transformation +$ ais etl bucket ais://src-bck ais://dst-bck --ext="{JPEG:JPEG}" +``` \ No newline at end of file diff --git a/transformers/keras_transformer/pod.yaml b/transformers/keras_transformer/pod.yaml new file mode 100644 index 0000000..cea63eb --- /dev/null +++ b/transformers/keras_transformer/pod.yaml @@ -0,0 +1,32 @@ +# https://github.com/NVIDIA/ais-etl/blob/master/transformers/keras_transformer/README.md +apiVersion: v1 +kind: Pod +metadata: + name: transformer-compress + annotations: + # Values `communication_type` can take are ["hpull://", "hrev://", "hpush://", "io://"]. + # Visit https://github.com/NVIDIA/aistore/blob/master/docs/etl.md#communication-mechanisms + communication_type: ${COMMUNICATION_TYPE:-"\"hpull://\""} + wait_timeout: 5m +spec: + containers: + - name: server + image: gaikwadabhishek/transformer_keras:latest + imagePullPolicy: Always + ports: + - name: default + containerPort: 80 + command: ['/code/server.py', '--listen', '0.0.0.0', '--port', '80'] + env: + - name: FORMAT + # expected values - PNG, JPEG, etc + value: ${FORMAT:-"JPEG"} + - name: TRANSFORM + # MANDATORY: expected json string parameter-value paris. + # https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator#apply_transform + # e.g. '{"theta":40, "brightness":0.8, "zx":0.9, "zy":0.9}' + value: ${TRANSFORM} + readinessProbe: + httpGet: + path: /health + port: default diff --git a/transformers/keras_transformer/requirements.txt b/transformers/keras_transformer/requirements.txt new file mode 100644 index 0000000..c405b85 --- /dev/null +++ b/transformers/keras_transformer/requirements.txt @@ -0,0 +1,5 @@ +requests +pillow +scipy +keras +tensorflow \ No newline at end of file diff --git a/transformers/keras_transformer/server.py b/transformers/keras_transformer/server.py new file mode 100755 index 0000000..e46bfc6 --- /dev/null +++ b/transformers/keras_transformer/server.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python +# +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# + +import os +import json +import logging +import requests +from http.server import HTTPServer, BaseHTTPRequestHandler +from socketserver import ThreadingMixIn +import io +from keras.preprocessing.image import ( + ImageDataGenerator, + load_img, + array_to_img, + img_to_array, +) + +# Constants +FORMAT = os.getenv("FORMAT", "JPEG") +ARG_TYPE = os.getenv("ARG_TYPE", "bytes") + +# Environment Variables +host_target = os.environ.get("AIS_TARGET_URL") +TRANSFORM = os.environ.get("TRANSFORM") +if not host_target: + raise EnvironmentError("AIS_TARGET_URL environment variable missing") +if not TRANSFORM: + raise EnvironmentError( + "TRANSFORM environment variable missing. Check documentation for examples (link)" + ) +transform_dict = json.loads(TRANSFORM) + + +class Handler(BaseHTTPRequestHandler): + def log_request(self, code="-", size="-"): + """Override log_request to not log successful requests.""" + pass + + def _set_headers(self): + """Set standard headers for responses.""" + self.send_response(200) + self.send_header("Content-Type", "application/octet-stream") + self.end_headers() + + def transform(self, data: bytes) -> bytes: + """Process image data as bytes using the specified transformation.""" + try: + img = load_img(io.BytesIO(data)) + img = img_to_array(img) + datagen = ImageDataGenerator() + img = datagen.apply_transform(x=img, transform_parameters=transform_dict) + img = array_to_img(img) + buf = io.BytesIO() + img.save(buf, format=FORMAT) + return buf.getvalue() + except Exception as e: + logging.error("Error processing data: %s", str(e)) + raise + + def do_PUT(self): + """PUT handler supports `hpush` operation.""" + try: + content_length = int(self.headers["Content-Length"]) + post_data = self.rfile.read(content_length) + processed_data = self.transform(post_data) + if processed_data is not None: + self._set_headers() + self.wfile.write(processed_data) + else: + self.send_response(500) + self.end_headers() + self.wfile.write(b"Data processing failed") + except Exception as e: + logging.error("Error processing PUT request: %s", str(e)) + self.send_response(500) + self.end_headers() + self.wfile.write(b"Data processing failed") + + def do_GET(self): + """GET handler supports `hpull` operation.""" + try: + if self.path == "/health": + self._set_headers() + self.wfile.write(b"OK") + return + + query_path = host_target + self.path + + if ARG_TYPE == "url": # need this for webdataset + result = self.transform(query_path) + else: + input_bytes = requests.get(query_path).content + result = self.transform(input_bytes) + + if result is not None: + self._set_headers() + self.wfile.write(result) + else: + self.send_response(500) + self.end_headers() + self.wfile.write(b"Data processing failed") + except Exception as e: + logging.error("Error processing GET request: %s", str(e)) + self.send_response(500) + self.end_headers() + self.wfile.write(b"Data processing failed") + + +class ThreadedHTTPServer(ThreadingMixIn, HTTPServer): + """Handle requests in a separate thread.""" + + +def run(addr="0.0.0.0", port=80): + server = ThreadedHTTPServer((addr, port), Handler) + logging.info(f"Starting HTTP server on {addr}:{port}") + server.serve_forever() + + +if __name__ == "__main__": + run(addr="0.0.0.0", port=80)