Skip to content

Commit

Permalink
Black'en code
Browse files Browse the repository at this point in the history
  • Loading branch information
regiellis committed Oct 17, 2024
1 parent a2b5ee9 commit db04701
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 52 deletions.
1 change: 0 additions & 1 deletion ecko_cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,3 @@ def load_environment_variables() -> None:

load_environment_variables()
os.environ[CAPTION_MODEL] = os.getenv(CAPTION_MODEL, "microsoft/Florence-2-large")

6 changes: 2 additions & 4 deletions ecko_cli/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def analyze_image(
progress.update(task, advance=1)
return None


class CaptionType(Enum):
DETAILED = "<MORE_DETAILED_CAPTION>"
OBJECT_DETECTION = "<OD>"
Expand Down Expand Up @@ -287,15 +288,12 @@ def generate_florence_description(
# Get task prompt as per input flags
task_prompt = get_task_prompt(is_object, is_anime, is_style)


inputs = processor(text=task_prompt, images=image, return_tensors="pt")
for key in inputs.keys():
inputs[key] = inputs[key].to(device)

if key == "pixel_values" and torch.cuda.is_available():
inputs[key] = inputs[key].to(
torch.float16
)
inputs[key] = inputs[key].to(torch.float16)

# Ensure that input_ids are Long
if "input_ids" in inputs:
Expand Down
99 changes: 56 additions & 43 deletions ecko_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
create_output_directory,
generate_caption_file,
create_table,
feedback_message
feedback_message,
)

from .images import ImageProcessor
Expand Down Expand Up @@ -138,7 +138,7 @@ def __init__(
button_secondary_background_fill_hover="*neutral_100",
button_secondary_text_color="*neutral_800",
button_secondary_border_color="*neutral_200",
#button_shadow="0 1px 2px 0 rgba(0, 0, 0, 0.05)",
# button_shadow="0 1px 2px 0 rgba(0, 0, 0, 0.05)",
block_label_background_fill="*neutral_50",
block_label_background_fill_dark="*neutral_800",
input_background_fill="white",
Expand All @@ -150,6 +150,7 @@ def __init__(
slider_color_dark="*primary_400",
)


playlogic = PlaylogicTheme()


Expand Down Expand Up @@ -259,61 +260,59 @@ def display_results_table(results: List[Dict[str, str]]) -> None:


def serve(dataset_path):
training_data_path = os.path.join(dataset_path, 'training_data')
dataset_file = os.path.join(training_data_path, 'dataset.json')
training_data_path = os.path.join(dataset_path, "training_data")
dataset_file = os.path.join(training_data_path, "dataset.json")

if not os.path.exists(dataset_file):
feedback_message("Dataset not found", "error")
raise typer.Exit()
with open(dataset_file, 'r') as f:

with open(dataset_file, "r") as f:
dataset = json.load(f)

image_count = len(dataset)

image_count = len(dataset)

def load_image_and_caption(index):
if index is None or not (0 <= index < image_count):
return None, "No image selected", None
item = dataset[index]
image_path = os.path.join(training_data_path, item['image'])
return Image.open(image_path), item['text'], index

image_path = os.path.join(training_data_path, item["image"])
return Image.open(image_path), item["text"], index

def update_caption(index, new_caption):
if index is None or not (0 <= index < image_count):
return "No image selected for update"
dataset[index]['text'] = new_caption

dataset[index]["text"] = new_caption

# Update all dataset files
for ext in ['json', 'jsonl', 'csv']:
file_path = os.path.join(training_data_path, f'dataset.{ext}')
for ext in ["json", "jsonl", "csv"]:
file_path = os.path.join(training_data_path, f"dataset.{ext}")
if os.path.exists(file_path):
if ext == 'json':
with open(file_path, 'w') as f:
if ext == "json":
with open(file_path, "w") as f:
json.dump(dataset, f, indent=2)
elif ext == 'jsonl':
with open(file_path, 'w') as f:
elif ext == "jsonl":
with open(file_path, "w") as f:
for item in dataset:
f.write(json.dumps(item) + '\n')
elif ext == 'csv':
with open(file_path, 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=['image', 'text'])
f.write(json.dumps(item) + "\n")
elif ext == "csv":
with open(file_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=["image", "text"])
writer.writeheader()
writer.writerows(dataset)

# Update the corresponding .txt file
image_filename = dataset[index]['image']
txt_filename = os.path.splitext(image_filename)[0] + '.txt'
image_filename = dataset[index]["image"]
txt_filename = os.path.splitext(image_filename)[0] + ".txt"
txt_path = os.path.join(training_data_path, txt_filename)
with open(txt_path, 'w') as f:
with open(txt_path, "w") as f:
f.write(new_caption)

return f"Caption updated for image {txt_filename[:-4]}"

def get_image_paths():
return [os.path.join(training_data_path, item['image']) for item in dataset]
return [os.path.join(training_data_path, item["image"]) for item in dataset]

with gr.Blocks(theme=playlogic) as demo:
gr.Markdown("## ECKO Editor")
Expand All @@ -324,31 +323,45 @@ def get_image_paths():
caption_input = gr.Textbox(label="Caption", lines=4)
update_button = gr.Button("Update Caption")
update_status = gr.Markdown("### Update status will appear here")
with gr.Row():

with gr.Row():
with gr.Column(scale=1):
gallery = gr.Gallery(value=get_image_paths(), columns=8, rows=4, label="Image Gallery", allow_preview=False)
gallery = gr.Gallery(
value=get_image_paths(),
columns=8,
rows=4,
label="Image Gallery",
allow_preview=False,
)
gr.Markdown(f"### Total Images: {image_count}")

selected_index = gr.State(None)

def select_image(evt: gr.SelectData):
return evt.index

gallery.select(select_image, outputs=[selected_index])
selected_index.change(load_image_and_caption, inputs=[selected_index], outputs=[image_output, caption_input, selected_index])

selected_index.change(
load_image_and_caption,
inputs=[selected_index],
outputs=[image_output, caption_input, selected_index],
)

update_button.click(
update_caption,
inputs=[selected_index, caption_input],
outputs=[update_status]
outputs=[update_status],
)

demo.launch(server_name="0.0.0.0")


@ecko_cli.command()
def ui(dataset_path: str = typer.Argument(..., help="Path to the directory containing dataset.json and images")):
def ui(
dataset_path: str = typer.Argument(
..., help="Path to the directory containing dataset.json and images"
)
):
"""Serve a Gradio interface for viewing image captions"""
serve(dataset_path)

Expand Down Expand Up @@ -488,11 +501,11 @@ def process_directory(
"status": "Failed to generate caption",
}
)

datasets = ["jsonl", "json"]
for dataset in datasets:
create_dataset_from_images(output_dir, "dataset", dataset)

display_results_table(results)


Expand Down
7 changes: 3 additions & 4 deletions ecko_cli/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,14 @@ def smart_resize(img, size):


def load_dataset(json_path):

if not os.path.exists(json_path):
feedback_message(f"Dataset file not found at {json_path}", type="warning")
return None
with open(json_path, 'r') as f:

with open(json_path, "r") as f:
return json.load(f)


def get_image_count(dataset_path):
return len(dataset_path)

0 comments on commit db04701

Please sign in to comment.