From 2eb8715c273f8e7cd38ee8e5a1ad3bbce611475c Mon Sep 17 00:00:00 2001 From: Panagiotis Karagiannis Date: Thu, 7 Dec 2023 17:39:11 +0100 Subject: [PATCH] added canvas option to upload image --- drawing_to_fsd_layout/canvas_image.py | 31 +++++++++++++++++++++++++++ requirements.txt | 1 + streamlit_app.py | 21 +++++++++++++----- 3 files changed, 48 insertions(+), 5 deletions(-) create mode 100644 drawing_to_fsd_layout/canvas_image.py diff --git a/drawing_to_fsd_layout/canvas_image.py b/drawing_to_fsd_layout/canvas_image.py new file mode 100644 index 0000000..ed8fb9a --- /dev/null +++ b/drawing_to_fsd_layout/canvas_image.py @@ -0,0 +1,31 @@ +from drawing_to_fsd_layout.image_processing import Image +from streamlit_drawable_canvas import st_canvas +import streamlit as st +import numpy as np + +def show_canvas_warning(): + st.warning("You need to draw a track in order to continue.") + st.stop() + +def get_canvas_image() -> Image: + + stroke_width = st.slider("Stroke width", 1, 25, 3) + + canvas_result = st_canvas( + stroke_width=stroke_width, + drawing_mode='freedraw', + key="canvas", + ) + + if canvas_result.image_data is None: + show_canvas_warning() + + # by default canvas changes the alpha channel, not the rgb channels + image_data = canvas_result.image_data[:, :, [-1]] + image_data = np.broadcast_to(image_data, image_data.shape[:-1] + (3,)) + image_data = 255 - image_data + + if np.all(image_data == 255): + show_canvas_warning() + + return image_data diff --git a/requirements.txt b/requirements.txt index 3992745..85478bf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ scikit-image==0.19.3 scipy==1.10.0 streamlit==1.15.2 altair<5 +streamlit-drawable-canvas==0.9.3 \ No newline at end of file diff --git a/streamlit_app.py b/streamlit_app.py index 2c08b1f..bde506c 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -25,6 +25,7 @@ rotate, ) from drawing_to_fsd_layout.spline_fit import SplineFitterFactory +from drawing_to_fsd_layout.canvas_image import get_canvas_image class UploadType(str, Enum): @@ -63,8 +64,10 @@ def load_example_image() -> np.ndarray: return io.imread("media/before.png") -def image_upload_widget() -> np.ndarray: - mode = st.radio("Image upload", ["Upload", "Example Image"], horizontal=True) +def image_upload_widget() -> tuple[np.ndarray, bool]: + mode = st.radio("Image upload", ["Upload", "Canvas", "Example Image"], horizontal=True) + + should_show_image = True if mode == "Upload": upload_type = UploadType[st.radio("Upload type", [x.name for x in UploadType])] @@ -91,8 +94,14 @@ def image_upload_widget() -> np.ndarray: elif mode == "Example Image": image = load_example_image() + + elif mode == "Canvas": + # the canvas already shows the image, so we don't need to show it again + should_show_image = False + image = get_canvas_image() - return image + assert image is not None + return image, should_show_image def plot_contours( @@ -117,8 +126,10 @@ def main() -> None: ) st.markdown("## Upload image") - image = image_upload_widget() - st.image(image, caption="Uploaded image") + image, should_show_image = image_upload_widget() + if should_show_image: + + st.image(image, caption="Uploaded image") with st.spinner("Preprocessing image"): preprocessed_image = load_image_and_preprocess(image)