forked from 42lux/CaptainCaption
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
372 lines (293 loc) · 14 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
import base64
import datetime
import glob
import io
import os
import sys
import traceback
from concurrent.futures import ThreadPoolExecutor
from tkinter import filedialog, Tk
import gradio as gr
import numpy as np
from PIL import Image
from gradio import Warning
from openai import OpenAI
MAX_IMAGE_WIDTH = 2048
IMAGE_FORMAT = "JPEG"
def load_images_and_text(folder_path):
image_files = glob.glob(os.path.join(folder_path, "*.jpg")) + glob.glob(os.path.join(folder_path, "*.png"))
image_files.sort()
images = []
texts = []
for img_path in image_files:
txt_path = os.path.splitext(img_path)[0] + ".txt"
if os.path.exists(txt_path):
with open(txt_path, 'r', encoding='utf-8') as f:
text = f.read()
images.append(img_path)
texts.append(text)
return images, texts
def save_edited_text(image_path, new_text):
txt_path = os.path.splitext(image_path)[0] + ".txt"
with open(txt_path, 'w', encoding='utf-8') as f:
f.write(new_text)
return f"Saved changes for {os.path.basename(image_path)}"
def generate_description(api_key, image, prompt, detail, max_tokens):
try:
img = Image.fromarray(image) if isinstance(image, np.ndarray) else Image.open(image)
img = scale_image(img)
buffered = io.BytesIO()
img.save(buffered, format=IMAGE_FORMAT)
img_base64 = base64.b64encode(buffered.getvalue()).decode()
client = OpenAI(api_key=api_key)
payload = {
"model": "gpt-4o-mini",
"messages": [{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{img_base64}", "detail": detail}}
]
}],
"max_tokens": max_tokens
}
response = client.chat.completions.create(**payload)
return response.choices[0].message.content
except Exception as e:
with open("error_log.txt", 'a') as log_file:
log_file.write(str(e) + '\n')
log_file.write(traceback.format_exc() + '\n')
return f"Error: {str(e)}"
history = []
columns = ["Time", "Prompt", "GPT4-Vision Caption"]
def clear_fields():
global history
history = []
return "", []
def update_history(prompt, response):
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
history.append({"Time": timestamp, "Prompt": prompt, "GPT4-Vision Caption": response})
return [[entry[column] for column in columns] for entry in history]
def scale_image(img):
if img.width > MAX_IMAGE_WIDTH:
ratio = MAX_IMAGE_WIDTH / img.width
new_height = int(img.height * ratio)
return img.resize((MAX_IMAGE_WIDTH, new_height), Image.Resampling.LANCZOS)
return img
def get_dir(file_path):
dir_path, file_name = os.path.split(file_path)
return dir_path, file_name
def get_folder_path(folder_path=''):
current_folder_path = folder_path
initial_dir, initial_file = get_dir(folder_path)
root = Tk()
root.wm_attributes('-topmost', 1)
root.withdraw()
if sys.platform == 'darwin':
root.call('wm', 'attributes', '.', '-topmost', True)
folder_path = filedialog.askdirectory(initialdir=initial_dir)
root.destroy()
if folder_path == '':
folder_path = current_folder_path
return folder_path
is_processing = True
def process_folder(api_key, folder_path, prompt, detail, max_tokens, pre_prompt="", post_prompt="",
progress=gr.Progress(), num_workers=4):
global is_processing
is_processing = True
if not os.path.isdir(folder_path):
return f"No such directory: {folder_path}"
file_list = [f for f in os.listdir(folder_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
progress(0)
def process_file(file):
global is_processing
if not is_processing:
return "Processing canceled by user"
image_path = os.path.join(folder_path, file)
txt_path = os.path.join(folder_path, os.path.splitext(file)[0] + ".txt")
# Check if the *.txt file already exists
if os.path.exists(txt_path):
print(f'File {txt_path} already exists. Skipping.')
return # Exit the function
description = generate_description(api_key, image_path, prompt, detail, max_tokens)
# If file doesn't exist, write to it
with open(txt_path, 'w', encoding='utf-8') as f:
f.write(pre_prompt + ", " + description + " " + post_prompt)
with ThreadPoolExecutor(max_workers=num_workers) as executor:
for i, _ in enumerate(executor.map(process_file, file_list), 1):
progress((i, len(file_list)))
if not is_processing:
break
is_processing = False
return f"Processed {len(file_list)} images in folder {folder_path}"
with gr.Blocks() as app:
with gr.Accordion("Configuration", open=False):
api_key_input = gr.Textbox(label="OpenAI API Key", placeholder="Enter your API key here", type="password",
info="The OpenAI API is rate limited to 20 requests per second. A big dataset can take a long time to tag.")
with gr.Tab("Prompt Engineering"):
image_input = gr.Image(label="Upload Image")
with gr.Row():
prompt_input = gr.Textbox(scale=6, label="Prompt",
value="What’s in this image? Provide a description without filler text like: The image depicts...",
interactive=True)
detail_level = gr.Radio(["high", "low", "auto"], scale=2, label="Detail", value="auto")
max_tokens_input = gr.Number(scale=0, value=300, label="Max Tokens")
submit_button = gr.Button("Generate Caption")
output = gr.Textbox(label="GPT4-Vision Caption")
history_table = gr.Dataframe(headers=columns)
clear_button = gr.Button("Clear")
clear_button.click(clear_fields, inputs=[], outputs=[output, history_table])
with gr.Tab("GPT4-Vision Tagging"):
with gr.Row():
folder_path_dataset = gr.Textbox(scale=8, label="Dataset Folder Path", placeholder="/home/user/dataset",
interactive=True,
info="The folder path select button is a bit of hack if it doesn't work you can just copy and paste the path to your dataset.")
folder_button = gr.Button(
'📂', elem_id='open_folder_small'
)
folder_button.click(
get_folder_path,
outputs=folder_path_dataset,
show_progress="hidden",
)
with gr.Row():
prompt_input_dataset = gr.Textbox(scale=6, label="Prompt",
value="What’s in this image? Provide a description without filler text like: The image depicts...",
interactive=True)
detail_level_dataset = gr.Radio(["high", "low", "auto"], scale=2, label="Detail", value="auto")
max_tokens_input_dataset = gr.Number(scale=0, value=300, label="Max Tokens")
with gr.Row():
pre_prompt_input = gr.Textbox(scale=6, label="Prefix", placeholder="(Optional)",
info="Will be added at the front of the caption.", interactive=True)
post_prompt_input = gr.Textbox(scale=6, label="Postfix", placeholder="(Optional)",
info="Will be added at the end of the caption.", interactive=True)
with gr.Row():
worker_slider = gr.Slider(minimum=1, maximum=4, value=2, step=1, scale=2, label="Number of Workers")
submit_button_dataset = gr.Button("Generate Captions", scale=3)
cancel_button = gr.Button("Cancel", scale=3)
with gr.Row():
processing_results_output = gr.Textbox(label="Processing Results")
with gr.Tab("View and Edit Captions"):
with gr.Row():
folder_path_view = gr.Textbox(label="Dataset Folder Path", placeholder="/home/user/dataset", scale=8)
folder_button = gr.Button('📂', elem_id='open_folder_small', scale=1)
load_button = gr.Button("Load Images and Captions")
with gr.Row():
image_output = gr.Gallery(label="Image", show_label=False, elem_id="preview_gallery", columns=[1], rows=[1], height="420px", allow_preview=True)
text_output = gr.Textbox(label="Caption", lines=5, interactive=True)
save_button = gr.Button("Save Changes")
with gr.Row():
prev_button = gr.Button("Previous")
next_button = gr.Button("Next")
status_output = gr.Textbox(label="Status")
gallery = gr.Gallery(label="Image Gallery", show_label=False, elem_id="gallery", columns=[5], rows=[2], height="auto", allow_preview=False)
current_index = gr.State(0)
images_list = gr.State([])
texts_list = gr.State([])
folder_button.click(
get_folder_path,
outputs=folder_path_view,
show_progress="hidden",
)
def save_caption(index, images, texts, new_text):
if 0 <= index < len(images):
image_path = images[index]
save_edited_text(image_path, new_text)
texts[index] = new_text
return texts, "Changes saved successfully"
return texts, "Error: Invalid image index"
save_button.click(
save_caption,
inputs=[current_index, images_list, texts_list, text_output],
outputs=[texts_list, status_output]
)
def load_data_and_display_first(folder_path):
images, texts = load_images_and_text(folder_path)
if images and texts:
img_path = images[0]
txt = texts[0]
return (
0, images, texts,
[(img_path, os.path.basename(img_path))],
txt,
f"Loaded {len(images)} images",
[(img, os.path.basename(img)) for img in images]
)
return 0, [], [], [], "", "No images found in the specified folder", []
def update_display(index, images, texts):
if 0 <= index < len(images):
img_path = images[index]
txt = texts[index]
return [(img_path, os.path.basename(img_path))], txt
return [], ""
def nav_previous(current, images, texts):
new_index = max(0, current - 1)
preview_gallery, txt = update_display(new_index, images, texts)
return new_index, preview_gallery, txt
def nav_next(current, images, texts):
new_index = min(len(images) - 1, current + 1)
preview_gallery, txt = update_display(new_index, images, texts)
return new_index, preview_gallery, txt
def gallery_select(evt: gr.SelectData, images, texts):
index = evt.index
preview_gallery, txt = update_display(index, images, texts)
return index, preview_gallery, txt
load_button.click(
load_data_and_display_first,
inputs=[folder_path_view],
outputs=[current_index, images_list, texts_list, image_output, text_output, status_output, gallery]
)
prev_button.click(
nav_previous,
inputs=[current_index, images_list, texts_list],
outputs=[current_index, image_output, text_output]
)
next_button.click(
nav_next,
inputs=[current_index, images_list, texts_list],
outputs=[current_index, image_output, text_output]
)
gallery.select(
gallery_select,
inputs=[images_list, texts_list],
outputs=[current_index, image_output, text_output]
)
def cancel_processing():
global is_processing
is_processing = False
return "Processing canceled"
cancel_button.click(cancel_processing, inputs=[], outputs=[processing_results_output])
def on_click(api_key, image, prompt, detail, max_tokens):
if not api_key.strip():
raise Warning("Please enter your OpenAI API key.")
if image is None:
raise Warning("Please upload an image.")
description = generate_description(api_key, image, prompt, detail, max_tokens)
new_history = update_history(prompt, description)
return description, new_history
submit_button.click(on_click, inputs=[api_key_input, image_input, prompt_input, detail_level, max_tokens_input],
outputs=[output, history_table])
def on_click_folder(api_key, folder_path, prompt, detail, max_tokens, pre_prompt, post_prompt, worker_slider_local):
if not api_key.strip():
raise Warning("Please enter your OpenAI API key.")
if not folder_path.strip():
raise Warning("Please enter the folder path.")
result = process_folder(api_key, folder_path, prompt, detail, max_tokens, pre_prompt, post_prompt,
num_workers=worker_slider_local)
return result
submit_button_dataset.click(
on_click_folder,
inputs=[
api_key_input,
folder_path_dataset,
prompt_input_dataset,
detail_level_dataset,
max_tokens_input_dataset,
pre_prompt_input,
post_prompt_input,
worker_slider
],
outputs=[processing_results_output]
)
app.launch()