-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Abhishek Gaikwad <[email protected]>
- Loading branch information
1 parent
72a99c4
commit ff276e7
Showing
6 changed files
with
239 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
FROM python:slim | ||
|
||
COPY requirements.txt requirements.txt | ||
RUN pip3 install --upgrade -r requirements.txt | ||
|
||
RUN mkdir /code | ||
WORKDIR /code | ||
COPY server.py server.py | ||
|
||
ENV PYTHONUNBUFFERED 1 | ||
|
||
EXPOSE 80 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
REGISTRY_URL ?= docker.io/aistorage | ||
|
||
all: build push | ||
|
||
build: | ||
docker build -t $(REGISTRY_URL)/transformer_keras:latest . | ||
|
||
push: | ||
docker push $(REGISTRY_URL)/transformer_keras:latest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Keras Transformer - Image Data Augmentation and Preprocessing | ||
|
||
The Keras Transformer is a powerful tool designed for image data preprocessing and data augmentation. Leveraging the `apply_transform` function from Keras (TensorFlow), this transformer allows users to define transformations by providing a JSON string with parameter-value pairs. Currently, the following parameters are supported: | ||
|
||
| Parameter | Description | | ||
|-------------------------|---------------------------------------------------------| | ||
| 'theta' | Rotation angle in degrees. | | ||
| 'tx' | Shift in the x direction. | | ||
| 'ty' | Shift in the y direction. | | ||
| 'shear' | Shear angle in degrees. | | ||
| 'zx' | Zoom in the x direction. | | ||
| 'zy' | Zoom in the y direction. | | ||
| 'flip_horizontal' | Boolean. Enable horizontal flip. | | ||
| 'flip_vertical' | Boolean. Enable vertical flip. | | ||
| 'channel_shift_intensity' | Float. Channel shift intensity. | | ||
| 'brightness' | Float. Brightness shift intensity. | | ||
|
||
The image format (JPEG, PNG, etc.) of the images to be processed or stored is specified in the `spec.yaml`. | ||
|
||
The transformer supports both `hpull`, `hpush` and `hrev` communication mechanisms for seamless integration. | ||
|
||
> For more information on communication mechanisms, please refer to [this link](https://github.com/NVIDIA/aistore/blob/master/docs/etl.md#communication-mechanisms). | ||
## Parameters | ||
Only two parameters need to be updated in the `pod.yaml` file. | ||
|
||
| Argument | Description | Default Value | | ||
| ----------- | --------------------------------------------------------------------- | ------------- | | ||
| `TRANSFORM` | Specify a JSON string with operations to be performed | `` | | ||
| `FORMAT`| To process/store images in which image format (PNG, JPEG,etc) | `JPEG` | | ||
|
||
Please ensure to adjust these parameters according to your specific requirements. | ||
|
||
### Initializing ETL with AIStore CLI | ||
|
||
The following steps demonstrate how to initialize the `Keras Transformer` with using the [AIStore CLI](https://github.com/NVIDIA/aistore/blob/master/docs/cli.md): | ||
|
||
```!bash | ||
$ cd transformers/keras_transformer | ||
$ # Set values for FORMAT and TRANSFORM | ||
$ export FORMAT="JPEG" | ||
$ export TRANSFORM='{"theta":40, "brightness":0.8, "zx":0.9, "zy":0.9}' | ||
$ # Mention communication type b/w target and container | ||
$ export COMMUNICATION_TYPE = 'hpull://' | ||
# Substitute env variables in spec file | ||
$ envsubst < pod.yaml > init_spec.yaml | ||
$ # Initialize ETL | ||
$ ais etl init spec --from-file init_spec.yaml --name <etl-name> | ||
$ # Transform and retrieve objects from the bucket using this ETL | ||
$ # For inline transformation | ||
$ ais etl object <etl-name> ais://src/<image-name>.JPEG dst.JPEG | ||
$ # Or, for offline (bucket-to-bucket) transformation | ||
$ ais etl bucket <etl-name> ais://src-bck ais://dst-bck --ext="{JPEG:JPEG}" | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# https://github.com/NVIDIA/ais-etl/blob/master/transformers/keras_transformer/README.md | ||
apiVersion: v1 | ||
kind: Pod | ||
metadata: | ||
name: transformer-compress | ||
annotations: | ||
# Values `communication_type` can take are ["hpull://", "hrev://", "hpush://", "io://"]. | ||
# Visit https://github.com/NVIDIA/aistore/blob/master/docs/etl.md#communication-mechanisms | ||
communication_type: ${COMMUNICATION_TYPE:-"\"hpull://\""} | ||
wait_timeout: 5m | ||
spec: | ||
containers: | ||
- name: server | ||
image: gaikwadabhishek/transformer_keras:latest | ||
imagePullPolicy: Always | ||
ports: | ||
- name: default | ||
containerPort: 80 | ||
command: ['/code/server.py', '--listen', '0.0.0.0', '--port', '80'] | ||
env: | ||
- name: FORMAT | ||
# expected values - PNG, JPEG, etc | ||
value: ${FORMAT:-"JPEG"} | ||
- name: TRANSFORM | ||
# MANDATORY: expected json string parameter-value paris. | ||
# https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator#apply_transform | ||
# e.g. '{"theta":40, "brightness":0.8, "zx":0.9, "zy":0.9}' | ||
value: ${TRANSFORM} | ||
readinessProbe: | ||
httpGet: | ||
path: /health | ||
port: default |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
requests | ||
pillow | ||
scipy | ||
keras | ||
tensorflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
#!/usr/bin/env python | ||
# | ||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
|
||
import os | ||
import json | ||
import logging | ||
import requests | ||
from http.server import HTTPServer, BaseHTTPRequestHandler | ||
from socketserver import ThreadingMixIn | ||
import io | ||
from keras.preprocessing.image import ( | ||
ImageDataGenerator, | ||
load_img, | ||
array_to_img, | ||
img_to_array, | ||
) | ||
|
||
# Constants | ||
FORMAT = os.getenv("FORMAT", "JPEG") | ||
ARG_TYPE = os.getenv("ARG_TYPE", "bytes") | ||
|
||
# Environment Variables | ||
host_target = os.environ.get("AIS_TARGET_URL") | ||
TRANSFORM = os.environ.get("TRANSFORM") | ||
if not host_target: | ||
raise EnvironmentError("AIS_TARGET_URL environment variable missing") | ||
if not TRANSFORM: | ||
raise EnvironmentError( | ||
"TRANSFORM environment variable missing. Check documentation for examples (link)" | ||
) | ||
transform_dict = json.loads(TRANSFORM) | ||
|
||
|
||
class Handler(BaseHTTPRequestHandler): | ||
def log_request(self, code="-", size="-"): | ||
"""Override log_request to not log successful requests.""" | ||
pass | ||
|
||
def _set_headers(self): | ||
"""Set standard headers for responses.""" | ||
self.send_response(200) | ||
self.send_header("Content-Type", "application/octet-stream") | ||
self.end_headers() | ||
|
||
def transform(self, data: bytes) -> bytes: | ||
"""Process image data as bytes using the specified transformation.""" | ||
try: | ||
img = load_img(io.BytesIO(data)) | ||
img = img_to_array(img) | ||
datagen = ImageDataGenerator() | ||
img = datagen.apply_transform(x=img, transform_parameters=transform_dict) | ||
img = array_to_img(img) | ||
buf = io.BytesIO() | ||
img.save(buf, format=FORMAT) | ||
return buf.getvalue() | ||
except Exception as e: | ||
logging.error("Error processing data: %s", str(e)) | ||
raise | ||
|
||
def do_PUT(self): | ||
"""PUT handler supports `hpush` operation.""" | ||
try: | ||
content_length = int(self.headers["Content-Length"]) | ||
post_data = self.rfile.read(content_length) | ||
processed_data = self.transform(post_data) | ||
if processed_data is not None: | ||
self._set_headers() | ||
self.wfile.write(processed_data) | ||
else: | ||
self.send_response(500) | ||
self.end_headers() | ||
self.wfile.write(b"Data processing failed") | ||
except Exception as e: | ||
logging.error("Error processing PUT request: %s", str(e)) | ||
self.send_response(500) | ||
self.end_headers() | ||
self.wfile.write(b"Data processing failed") | ||
|
||
def do_GET(self): | ||
"""GET handler supports `hpull` operation.""" | ||
try: | ||
if self.path == "/health": | ||
self._set_headers() | ||
self.wfile.write(b"OK") | ||
return | ||
|
||
query_path = host_target + self.path | ||
|
||
if ARG_TYPE == "url": # need this for webdataset | ||
result = self.transform(query_path) | ||
else: | ||
input_bytes = requests.get(query_path).content | ||
result = self.transform(input_bytes) | ||
|
||
if result is not None: | ||
self._set_headers() | ||
self.wfile.write(result) | ||
else: | ||
self.send_response(500) | ||
self.end_headers() | ||
self.wfile.write(b"Data processing failed") | ||
except Exception as e: | ||
logging.error("Error processing GET request: %s", str(e)) | ||
self.send_response(500) | ||
self.end_headers() | ||
self.wfile.write(b"Data processing failed") | ||
|
||
|
||
class ThreadedHTTPServer(ThreadingMixIn, HTTPServer): | ||
"""Handle requests in a separate thread.""" | ||
|
||
|
||
def run(addr="0.0.0.0", port=80): | ||
server = ThreadedHTTPServer((addr, port), Handler) | ||
logging.info(f"Starting HTTP server on {addr}:{port}") | ||
server.serve_forever() | ||
|
||
|
||
if __name__ == "__main__": | ||
run(addr="0.0.0.0", port=80) |