-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
86 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,51 +1,87 @@ | ||
# This is my first effort at making an interpolated dictionary | ||
# This will be an image generator based on parquet data https://huggingface.co/datasets/yuvalkirstain/pickapic_v2/tree/main/data | ||
|
||
import pandas as pd | ||
from sentence_transformers import SentenceTransformer | ||
from PIL import Image | ||
import numpy as np | ||
import io | ||
|
||
|
||
model = SentenceTransformer("all-MiniLM-L6-v2") | ||
|
||
# Our sentences we like to encode | ||
sentences = [ | ||
"This framework generates embeddings for each input sentence", | ||
"Sentences are passed as a list of string.", | ||
"The quick brown fox jumps over the lazy dog.", | ||
] | ||
|
||
# Sentences are encoded by calling model.encode() | ||
sentence_embeddings = model.encode(sentences) | ||
|
||
# Print the embeddings | ||
for sentence, embedding in zip(sentences, sentence_embeddings): | ||
print("Sentence:", sentence) | ||
print("Embedding:", len(embedding)) | ||
print("") | ||
|
||
ans = pd.read_parquet("train-00000-of-00645-b66ac786bf6fb553.parquet") | ||
data = ans.iloc[0] | ||
|
||
print('ans', data) | ||
print(ans['caption'][0]) | ||
jpg_0 = ans['jpg_0'][0] | ||
jpg_1 = ans['jpg_1'][0] | ||
|
||
img = Image.open(io.BytesIO(jpg_0)) | ||
arr = np.asarray(img) | ||
print('arr', arr) | ||
|
||
for index, row in ans.iterrows() : | ||
print('row', row) | ||
|
||
""" | ||
from torchdata.datapipes.iter import FileLister | ||
import torcharrow.dtypes as dt | ||
DTYPE = dt.Struct([dt.Field("Values", dt.int32)]) | ||
ource_dp = FileLister(".", masks="df*.parquet") | ||
parquet_df_dp = source_dp.load_parquet_as_df(dtype=DTYPE) | ||
list(parquet_df_dp)[0] | ||
""" | ||
import os | ||
from omegaconf import DictConfig, OmegaConf | ||
import hydra | ||
from high_order_layers_torch.layers import * | ||
from high_order_layers_torch.networks import * | ||
from pytorch_lightning import Trainer | ||
import matplotlib.pyplot as plt | ||
from high_order_implicit_representation.networks import Net | ||
from pytorch_lightning.callbacks import LearningRateMonitor | ||
from high_order_implicit_representation.rendering import ImageGenerator | ||
from high_order_implicit_representation.single_image_dataset import ( | ||
image_to_dataset, | ||
ImageDataModule, | ||
) | ||
import logging | ||
|
||
logging.basicConfig() | ||
logger = logging.getLogger(__name__) | ||
logging.getLogger().setLevel(logging.DEBUG) | ||
|
||
|
||
@hydra.main(config_path="../config", config_name="images_config") | ||
def run_implicit_images(cfg: DictConfig): | ||
|
||
logger.info(OmegaConf.to_yaml(cfg)) | ||
logger.info(f"Working directory {os.getcwd()}") | ||
logger.info(f"Orig working directory {hydra.utils.get_original_cwd()}") | ||
|
||
root_dir = hydra.utils.get_original_cwd() | ||
|
||
if cfg.train is True: | ||
full_path = [f"{root_dir}/{path}" for path in cfg.images] | ||
data_module = ImageDataModule( | ||
filenames=full_path, batch_size=cfg.batch_size, rotations=cfg.rotations | ||
) | ||
image_generator = ImageGenerator( | ||
filename=full_path[0], rotations=cfg.rotations, batch_size=cfg.batch_size | ||
) | ||
lr_monitor = LearningRateMonitor(logging_interval="epoch") | ||
trainer = Trainer( | ||
max_epochs=cfg.max_epochs, | ||
devices=cfg.gpus, | ||
accelerator=cfg.accelerator, | ||
callbacks=[lr_monitor, image_generator], | ||
) | ||
model = Net(cfg) | ||
trainer.fit(model, datamodule=data_module) | ||
logger.info("testing") | ||
|
||
trainer.test(model, datamodule=data_module) | ||
logger.info("finished testing") | ||
logger.info("best check_point", trainer.checkpoint_callback.best_model_path) | ||
else: | ||
# plot some data | ||
logger.info("evaluating result") | ||
logger.info(f"cfg.checkpoint {cfg.checkpoint}") | ||
checkpoint_path = f"{hydra.utils.get_original_cwd()}/{cfg.checkpoint}" | ||
|
||
logger.info(f"checkpoint_path {checkpoint_path}") | ||
model = Net.load_from_checkpoint(checkpoint_path) | ||
|
||
model.eval() | ||
image_dir = f"{hydra.utils.get_original_cwd()}/{cfg.images[0]}" | ||
output, inputs, image = image_to_dataset(image_dir, rotations=cfg.rotations) | ||
|
||
y_hat_list = [] | ||
for batch in range((len(inputs) + cfg.batch_size) // cfg.batch_size): | ||
print("batch", batch) | ||
res = model(inputs[batch * cfg.batch_size : (batch + 1) * cfg.batch_size]) | ||
y_hat_list.append(res.cpu()) | ||
|
||
y_hat = torch.cat(y_hat_list) | ||
|
||
ans = y_hat.reshape(image.shape[0], image.shape[1], image.shape[2]) | ||
ans = (ans + 1.0) / 2.0 | ||
|
||
f, axarr = plt.subplots(1, 2) | ||
axarr[0].imshow(ans.detach().numpy()) | ||
axarr[0].set_title("fit") | ||
axarr[1].imshow(image) | ||
axarr[1].set_title("original") | ||
plt.show() | ||
|
||
|
||
if __name__ == "__main__": | ||
run_implicit_images() |