diff --git a/examples/text_to_image.py b/examples/text_to_image.py index 8914f10..9ecd8fb 100644 --- a/examples/text_to_image.py +++ b/examples/text_to_image.py @@ -1,51 +1,87 @@ # This is my first effort at making an interpolated dictionary -# This will be an image generator based on parquet data https://huggingface.co/datasets/yuvalkirstain/pickapic_v2/tree/main/data - -import pandas as pd -from sentence_transformers import SentenceTransformer -from PIL import Image -import numpy as np -import io - - -model = SentenceTransformer("all-MiniLM-L6-v2") - -# Our sentences we like to encode -sentences = [ - "This framework generates embeddings for each input sentence", - "Sentences are passed as a list of string.", - "The quick brown fox jumps over the lazy dog.", -] - -# Sentences are encoded by calling model.encode() -sentence_embeddings = model.encode(sentences) - -# Print the embeddings -for sentence, embedding in zip(sentences, sentence_embeddings): - print("Sentence:", sentence) - print("Embedding:", len(embedding)) - print("") - -ans = pd.read_parquet("train-00000-of-00645-b66ac786bf6fb553.parquet") -data = ans.iloc[0] - -print('ans', data) -print(ans['caption'][0]) -jpg_0 = ans['jpg_0'][0] -jpg_1 = ans['jpg_1'][0] - -img = Image.open(io.BytesIO(jpg_0)) -arr = np.asarray(img) -print('arr', arr) - -for index, row in ans.iterrows() : - print('row', row) - -""" -from torchdata.datapipes.iter import FileLister -import torcharrow.dtypes as dt -DTYPE = dt.Struct([dt.Field("Values", dt.int32)]) -ource_dp = FileLister(".", masks="df*.parquet") -parquet_df_dp = source_dp.load_parquet_as_df(dtype=DTYPE) -list(parquet_df_dp)[0] -""" \ No newline at end of file +import os +from omegaconf import DictConfig, OmegaConf +import hydra +from high_order_layers_torch.layers import * +from high_order_layers_torch.networks import * +from pytorch_lightning import Trainer +import matplotlib.pyplot as plt +from high_order_implicit_representation.networks import Net +from pytorch_lightning.callbacks import LearningRateMonitor +from high_order_implicit_representation.rendering import ImageGenerator +from high_order_implicit_representation.single_image_dataset import ( + image_to_dataset, + ImageDataModule, +) +import logging + +logging.basicConfig() +logger = logging.getLogger(__name__) +logging.getLogger().setLevel(logging.DEBUG) + + +@hydra.main(config_path="../config", config_name="images_config") +def run_implicit_images(cfg: DictConfig): + + logger.info(OmegaConf.to_yaml(cfg)) + logger.info(f"Working directory {os.getcwd()}") + logger.info(f"Orig working directory {hydra.utils.get_original_cwd()}") + + root_dir = hydra.utils.get_original_cwd() + + if cfg.train is True: + full_path = [f"{root_dir}/{path}" for path in cfg.images] + data_module = ImageDataModule( + filenames=full_path, batch_size=cfg.batch_size, rotations=cfg.rotations + ) + image_generator = ImageGenerator( + filename=full_path[0], rotations=cfg.rotations, batch_size=cfg.batch_size + ) + lr_monitor = LearningRateMonitor(logging_interval="epoch") + trainer = Trainer( + max_epochs=cfg.max_epochs, + devices=cfg.gpus, + accelerator=cfg.accelerator, + callbacks=[lr_monitor, image_generator], + ) + model = Net(cfg) + trainer.fit(model, datamodule=data_module) + logger.info("testing") + + trainer.test(model, datamodule=data_module) + logger.info("finished testing") + logger.info("best check_point", trainer.checkpoint_callback.best_model_path) + else: + # plot some data + logger.info("evaluating result") + logger.info(f"cfg.checkpoint {cfg.checkpoint}") + checkpoint_path = f"{hydra.utils.get_original_cwd()}/{cfg.checkpoint}" + + logger.info(f"checkpoint_path {checkpoint_path}") + model = Net.load_from_checkpoint(checkpoint_path) + + model.eval() + image_dir = f"{hydra.utils.get_original_cwd()}/{cfg.images[0]}" + output, inputs, image = image_to_dataset(image_dir, rotations=cfg.rotations) + + y_hat_list = [] + for batch in range((len(inputs) + cfg.batch_size) // cfg.batch_size): + print("batch", batch) + res = model(inputs[batch * cfg.batch_size : (batch + 1) * cfg.batch_size]) + y_hat_list.append(res.cpu()) + + y_hat = torch.cat(y_hat_list) + + ans = y_hat.reshape(image.shape[0], image.shape[1], image.shape[2]) + ans = (ans + 1.0) / 2.0 + + f, axarr = plt.subplots(1, 2) + axarr[0].imshow(ans.detach().numpy()) + axarr[0].set_title("fit") + axarr[1].imshow(image) + axarr[1].set_title("original") + plt.show() + + +if __name__ == "__main__": + run_implicit_images() \ No newline at end of file