Skip to content

Commit

Permalink
Adding stub for text_to_image
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed May 10, 2024
1 parent c1f3433 commit a3e083b
Showing 1 changed file with 86 additions and 50 deletions.
136 changes: 86 additions & 50 deletions examples/text_to_image.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,87 @@
# This is my first effort at making an interpolated dictionary
# This will be an image generator based on parquet data https://huggingface.co/datasets/yuvalkirstain/pickapic_v2/tree/main/data

import pandas as pd
from sentence_transformers import SentenceTransformer
from PIL import Image
import numpy as np
import io


model = SentenceTransformer("all-MiniLM-L6-v2")

# Our sentences we like to encode
sentences = [
"This framework generates embeddings for each input sentence",
"Sentences are passed as a list of string.",
"The quick brown fox jumps over the lazy dog.",
]

# Sentences are encoded by calling model.encode()
sentence_embeddings = model.encode(sentences)

# Print the embeddings
for sentence, embedding in zip(sentences, sentence_embeddings):
print("Sentence:", sentence)
print("Embedding:", len(embedding))
print("")

ans = pd.read_parquet("train-00000-of-00645-b66ac786bf6fb553.parquet")
data = ans.iloc[0]

print('ans', data)
print(ans['caption'][0])
jpg_0 = ans['jpg_0'][0]
jpg_1 = ans['jpg_1'][0]

img = Image.open(io.BytesIO(jpg_0))
arr = np.asarray(img)
print('arr', arr)

for index, row in ans.iterrows() :
print('row', row)

"""
from torchdata.datapipes.iter import FileLister
import torcharrow.dtypes as dt
DTYPE = dt.Struct([dt.Field("Values", dt.int32)])
ource_dp = FileLister(".", masks="df*.parquet")
parquet_df_dp = source_dp.load_parquet_as_df(dtype=DTYPE)
list(parquet_df_dp)[0]
"""
import os
from omegaconf import DictConfig, OmegaConf
import hydra
from high_order_layers_torch.layers import *
from high_order_layers_torch.networks import *
from pytorch_lightning import Trainer
import matplotlib.pyplot as plt
from high_order_implicit_representation.networks import Net
from pytorch_lightning.callbacks import LearningRateMonitor
from high_order_implicit_representation.rendering import ImageGenerator
from high_order_implicit_representation.single_image_dataset import (
image_to_dataset,
ImageDataModule,
)
import logging

logging.basicConfig()
logger = logging.getLogger(__name__)
logging.getLogger().setLevel(logging.DEBUG)


@hydra.main(config_path="../config", config_name="images_config")
def run_implicit_images(cfg: DictConfig):

logger.info(OmegaConf.to_yaml(cfg))
logger.info(f"Working directory {os.getcwd()}")
logger.info(f"Orig working directory {hydra.utils.get_original_cwd()}")

root_dir = hydra.utils.get_original_cwd()

if cfg.train is True:
full_path = [f"{root_dir}/{path}" for path in cfg.images]
data_module = ImageDataModule(
filenames=full_path, batch_size=cfg.batch_size, rotations=cfg.rotations
)
image_generator = ImageGenerator(
filename=full_path[0], rotations=cfg.rotations, batch_size=cfg.batch_size
)
lr_monitor = LearningRateMonitor(logging_interval="epoch")
trainer = Trainer(
max_epochs=cfg.max_epochs,
devices=cfg.gpus,
accelerator=cfg.accelerator,
callbacks=[lr_monitor, image_generator],
)
model = Net(cfg)
trainer.fit(model, datamodule=data_module)
logger.info("testing")

trainer.test(model, datamodule=data_module)
logger.info("finished testing")
logger.info("best check_point", trainer.checkpoint_callback.best_model_path)
else:
# plot some data
logger.info("evaluating result")
logger.info(f"cfg.checkpoint {cfg.checkpoint}")
checkpoint_path = f"{hydra.utils.get_original_cwd()}/{cfg.checkpoint}"

logger.info(f"checkpoint_path {checkpoint_path}")
model = Net.load_from_checkpoint(checkpoint_path)

model.eval()
image_dir = f"{hydra.utils.get_original_cwd()}/{cfg.images[0]}"
output, inputs, image = image_to_dataset(image_dir, rotations=cfg.rotations)

y_hat_list = []
for batch in range((len(inputs) + cfg.batch_size) // cfg.batch_size):
print("batch", batch)
res = model(inputs[batch * cfg.batch_size : (batch + 1) * cfg.batch_size])
y_hat_list.append(res.cpu())

y_hat = torch.cat(y_hat_list)

ans = y_hat.reshape(image.shape[0], image.shape[1], image.shape[2])
ans = (ans + 1.0) / 2.0

f, axarr = plt.subplots(1, 2)
axarr[0].imshow(ans.detach().numpy())
axarr[0].set_title("fit")
axarr[1].imshow(image)
axarr[1].set_title("original")
plt.show()


if __name__ == "__main__":
run_implicit_images()

0 comments on commit a3e083b

Please sign in to comment.