Skip to content

Commit

Permalink
etl: keras transformer for images
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishek Gaikwad <[email protected]>
  • Loading branch information
gaikwadabhishek committed Jul 28, 2023
1 parent 72a99c4 commit ff276e7
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 0 deletions.
12 changes: 12 additions & 0 deletions transformers/keras_transformer/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
FROM python:slim

COPY requirements.txt requirements.txt
RUN pip3 install --upgrade -r requirements.txt

RUN mkdir /code
WORKDIR /code
COPY server.py server.py

ENV PYTHONUNBUFFERED 1

EXPOSE 80
9 changes: 9 additions & 0 deletions transformers/keras_transformer/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
REGISTRY_URL ?= docker.io/aistorage

all: build push

build:
docker build -t $(REGISTRY_URL)/transformer_keras:latest .

push:
docker push $(REGISTRY_URL)/transformer_keras:latest
59 changes: 59 additions & 0 deletions transformers/keras_transformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Keras Transformer - Image Data Augmentation and Preprocessing

The Keras Transformer is a powerful tool designed for image data preprocessing and data augmentation. Leveraging the `apply_transform` function from Keras (TensorFlow), this transformer allows users to define transformations by providing a JSON string with parameter-value pairs. Currently, the following parameters are supported:

| Parameter | Description |
|-------------------------|---------------------------------------------------------|
| 'theta' | Rotation angle in degrees. |
| 'tx' | Shift in the x direction. |
| 'ty' | Shift in the y direction. |
| 'shear' | Shear angle in degrees. |
| 'zx' | Zoom in the x direction. |
| 'zy' | Zoom in the y direction. |
| 'flip_horizontal' | Boolean. Enable horizontal flip. |
| 'flip_vertical' | Boolean. Enable vertical flip. |
| 'channel_shift_intensity' | Float. Channel shift intensity. |
| 'brightness' | Float. Brightness shift intensity. |

The image format (JPEG, PNG, etc.) of the images to be processed or stored is specified in the `spec.yaml`.

The transformer supports both `hpull`, `hpush` and `hrev` communication mechanisms for seamless integration.

> For more information on communication mechanisms, please refer to [this link](https://github.com/NVIDIA/aistore/blob/master/docs/etl.md#communication-mechanisms).
## Parameters
Only two parameters need to be updated in the `pod.yaml` file.

| Argument | Description | Default Value |
| ----------- | --------------------------------------------------------------------- | ------------- |
| `TRANSFORM` | Specify a JSON string with operations to be performed | `` |
| `FORMAT`| To process/store images in which image format (PNG, JPEG,etc) | `JPEG` |

Please ensure to adjust these parameters according to your specific requirements.

### Initializing ETL with AIStore CLI

The following steps demonstrate how to initialize the `Keras Transformer` with using the [AIStore CLI](https://github.com/NVIDIA/aistore/blob/master/docs/cli.md):

```!bash
$ cd transformers/keras_transformer
$ # Set values for FORMAT and TRANSFORM
$ export FORMAT="JPEG"
$ export TRANSFORM='{"theta":40, "brightness":0.8, "zx":0.9, "zy":0.9}'
$ # Mention communication type b/w target and container
$ export COMMUNICATION_TYPE = 'hpull://'
# Substitute env variables in spec file
$ envsubst < pod.yaml > init_spec.yaml
$ # Initialize ETL
$ ais etl init spec --from-file init_spec.yaml --name <etl-name>
$ # Transform and retrieve objects from the bucket using this ETL
$ # For inline transformation
$ ais etl object <etl-name> ais://src/<image-name>.JPEG dst.JPEG
$ # Or, for offline (bucket-to-bucket) transformation
$ ais etl bucket <etl-name> ais://src-bck ais://dst-bck --ext="{JPEG:JPEG}"
```
32 changes: 32 additions & 0 deletions transformers/keras_transformer/pod.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# https://github.com/NVIDIA/ais-etl/blob/master/transformers/keras_transformer/README.md
apiVersion: v1
kind: Pod
metadata:
name: transformer-compress
annotations:
# Values `communication_type` can take are ["hpull://", "hrev://", "hpush://", "io://"].
# Visit https://github.com/NVIDIA/aistore/blob/master/docs/etl.md#communication-mechanisms
communication_type: ${COMMUNICATION_TYPE:-"\"hpull://\""}
wait_timeout: 5m
spec:
containers:
- name: server
image: gaikwadabhishek/transformer_keras:latest
imagePullPolicy: Always
ports:
- name: default
containerPort: 80
command: ['/code/server.py', '--listen', '0.0.0.0', '--port', '80']
env:
- name: FORMAT
# expected values - PNG, JPEG, etc
value: ${FORMAT:-"JPEG"}
- name: TRANSFORM
# MANDATORY: expected json string parameter-value paris.
# https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator#apply_transform
# e.g. '{"theta":40, "brightness":0.8, "zx":0.9, "zy":0.9}'
value: ${TRANSFORM}
readinessProbe:
httpGet:
path: /health
port: default
5 changes: 5 additions & 0 deletions transformers/keras_transformer/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
requests
pillow
scipy
keras
tensorflow
122 changes: 122 additions & 0 deletions transformers/keras_transformer/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#!/usr/bin/env python
#
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#

import os
import json
import logging
import requests
from http.server import HTTPServer, BaseHTTPRequestHandler
from socketserver import ThreadingMixIn
import io
from keras.preprocessing.image import (
ImageDataGenerator,
load_img,
array_to_img,
img_to_array,
)

# Constants
FORMAT = os.getenv("FORMAT", "JPEG")
ARG_TYPE = os.getenv("ARG_TYPE", "bytes")

# Environment Variables
host_target = os.environ.get("AIS_TARGET_URL")
TRANSFORM = os.environ.get("TRANSFORM")
if not host_target:
raise EnvironmentError("AIS_TARGET_URL environment variable missing")
if not TRANSFORM:
raise EnvironmentError(
"TRANSFORM environment variable missing. Check documentation for examples (link)"
)
transform_dict = json.loads(TRANSFORM)


class Handler(BaseHTTPRequestHandler):
def log_request(self, code="-", size="-"):
"""Override log_request to not log successful requests."""
pass

def _set_headers(self):
"""Set standard headers for responses."""
self.send_response(200)
self.send_header("Content-Type", "application/octet-stream")
self.end_headers()

def transform(self, data: bytes) -> bytes:
"""Process image data as bytes using the specified transformation."""
try:
img = load_img(io.BytesIO(data))
img = img_to_array(img)
datagen = ImageDataGenerator()
img = datagen.apply_transform(x=img, transform_parameters=transform_dict)
img = array_to_img(img)
buf = io.BytesIO()
img.save(buf, format=FORMAT)
return buf.getvalue()
except Exception as e:
logging.error("Error processing data: %s", str(e))
raise

def do_PUT(self):
"""PUT handler supports `hpush` operation."""
try:
content_length = int(self.headers["Content-Length"])
post_data = self.rfile.read(content_length)
processed_data = self.transform(post_data)
if processed_data is not None:
self._set_headers()
self.wfile.write(processed_data)
else:
self.send_response(500)
self.end_headers()
self.wfile.write(b"Data processing failed")
except Exception as e:
logging.error("Error processing PUT request: %s", str(e))
self.send_response(500)
self.end_headers()
self.wfile.write(b"Data processing failed")

def do_GET(self):
"""GET handler supports `hpull` operation."""
try:
if self.path == "/health":
self._set_headers()
self.wfile.write(b"OK")
return

query_path = host_target + self.path

if ARG_TYPE == "url": # need this for webdataset
result = self.transform(query_path)
else:
input_bytes = requests.get(query_path).content
result = self.transform(input_bytes)

if result is not None:
self._set_headers()
self.wfile.write(result)
else:
self.send_response(500)
self.end_headers()
self.wfile.write(b"Data processing failed")
except Exception as e:
logging.error("Error processing GET request: %s", str(e))
self.send_response(500)
self.end_headers()
self.wfile.write(b"Data processing failed")


class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
"""Handle requests in a separate thread."""


def run(addr="0.0.0.0", port=80):
server = ThreadedHTTPServer((addr, port), Handler)
logging.info(f"Starting HTTP server on {addr}:{port}")
server.serve_forever()


if __name__ == "__main__":
run(addr="0.0.0.0", port=80)

0 comments on commit ff276e7

Please sign in to comment.