Skip to content

Commit c2ca303

Browse files
author
Behrooz
committed
Add main code
1 parent 6c62f4e commit c2ca303

File tree

1 file changed

+140
-0
lines changed

1 file changed

+140
-0
lines changed

projector.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import os
2+
from typing import Iterable
3+
4+
import numpy as np
5+
import tensorflow as tf
6+
from matplotlib import pyplot as plt
7+
from tensorboard.plugins import projector
8+
9+
10+
class Projector:
11+
def __init__(
12+
self,
13+
images: np.array,
14+
labels: Iterable,
15+
log_dir: str,
16+
data_name: str):
17+
"""
18+
Create all necessary artifacts and configs for Tensorboard Projector.
19+
20+
Parameters
21+
----------
22+
images : np.array
23+
An n-d array of images (NumberOFImages X Width X Height).
24+
labels : Iterable
25+
A one-to-one asociated labes to images. It can be a list, numpy array, or any iterable.
26+
log_dir : str
27+
The location all the artifacts are being saved. The directory to which Tensorboard is directd.
28+
`Tensorboard --logdir "log_dir"`
29+
data_name : str
30+
The name of the dataset, which is appended to the name of all artifacts.
31+
"""
32+
self.log_dir = log_dir
33+
self.images = images
34+
self.labels = labels
35+
self.data_name = data_name
36+
37+
self.n_images = None
38+
self.image_width = None
39+
self.image_height = None
40+
self.points = None
41+
42+
if self.images:
43+
self.convert_images_to_points()
44+
45+
def convert_images_to_points(self):
46+
"""
47+
Convert images array to high-dimentional data points (NumberOFImages X NumberOfDimensions).
48+
"""
49+
self.n_images, self.image_width, self.image_height = self.images.shape
50+
self.points = np.reshape(
51+
self.images, (-1, self.image_width * self.image_height))
52+
53+
def save_points(self):
54+
"""
55+
Save high-dimensional data points into a model checkpoint.
56+
"""
57+
points_filename = os.path.join(self.log_dir, f'images_{self.data_name}.ckpt')
58+
points_tensor = tf.Variable(self.points, name=self.data_name)
59+
ckpt = tf.train.Checkpoint(**{self.data_name: points_tensor})
60+
ckpt.save(points_filename)
61+
print('> Images are saved in {}'.format(points_filename))
62+
63+
def save_labels(self):
64+
"""
65+
Save labels into a metadata tab-separated-value file.
66+
"""
67+
meta_filename = os.path.join(
68+
self.log_dir, f'metadata_{self.data_name}.tsv')
69+
with open(meta_filename, 'w') as metadata_file:
70+
for row in self.labels:
71+
metadata_file.write(f'{row}\n')
72+
print('> Metadata file is saved in {}'.format(meta_filename))
73+
74+
def write_sprite_image(self):
75+
"""
76+
Create and write a sprite image, a single PNG file containing all images (possibly downsampled).
77+
"""
78+
# Calculate number of plot
79+
n_plots = int(np.ceil(np.sqrt(self.n_images)))
80+
81+
# Preallocate the sprite image
82+
sprite_image = np.ones(
83+
(self.image_height * n_plots, self.image_width * n_plots))
84+
85+
for i in range(n_plots):
86+
for j in range(n_plots):
87+
img_idx = i * n_plots + j
88+
if img_idx < self.n_images:
89+
img = self.images[img_idx]
90+
sprite_image[i * self.image_height: (i + 1) * self.image_height,
91+
j * self.image_width: (j + 1) * self.image_width] = img
92+
93+
sprite_filename = os.path.join(
94+
self.log_dir, f'sprite_{self.data_name}.png')
95+
plt.imsave(sprite_filename, sprite_image, cmap='gray')
96+
print('> Sprite image saved in {}'.format(sprite_filename))
97+
98+
def create_config(self, with_sprite=True):
99+
"""
100+
Create a congfig files that defines image tensor name, path to metadata file, path to the sprite image,
101+
and the size of individual image whithin the sprite image.
102+
103+
Parameters
104+
----------
105+
with_sprite : bool, optional
106+
If to save sprite or not, by default True
107+
"""
108+
config = projector.ProjectorConfig()
109+
embedding = config.embeddings.add()
110+
embedding.tensor_name = f'{self.data_name}/.ATTRIBUTES/VARIABLE_VALUE'
111+
embedding.metadata_path = f'metadata_{self.data_name}.tsv'
112+
if with_sprite:
113+
embedding.sprite.image_path = f'sprite_{self.data_name}.png'
114+
embedding.sprite.single_image_dim.extend(
115+
[self.image_width, self.image_height])
116+
projector.visualize_embeddings(self.log_dir, config)
117+
118+
def make(self):
119+
self.save_points()
120+
self.save_labels()
121+
self.write_sprite_image()
122+
self.create_config()
123+
124+
125+
if __name__ == "__main__":
126+
fashion_mnist = tf.keras.datasets.fashion_mnist
127+
(train_images, train_labels), (test_images,
128+
test_labels) = fashion_mnist.load_data()
129+
130+
log_dir = '/Users/behrooz/workspace/unsupervised/logs/projector3'
131+
data_name = 'fmnist_with_image'
132+
labels = train_labels[:1000]
133+
images = train_images[:1000]
134+
135+
proj = Projector(
136+
images=images,
137+
labels=labels,
138+
log_dir=log_dir,
139+
data_name=data_name)
140+
proj.make()

0 commit comments

Comments
 (0)