Skip to content

Commit

Permalink
Add intitial prediction.py script.
Browse files Browse the repository at this point in the history
  • Loading branch information
KelSolaar committed Jan 3, 2024
1 parent 81a5635 commit cc7b857
Show file tree
Hide file tree
Showing 6 changed files with 930 additions and 29 deletions.
2 changes: 1 addition & 1 deletion colour_checker_detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
colour.utilities.ANCILLARY_COLOUR_SCIENCE_PACKAGES[ # pyright: ignore
"colour-checker-detection"
] = _version
colour.utilities.ANCILLARY_RUNTIME_PACKAGES[ # pyright: ignore
colour.utilities.ANCILLARY_RUNTIME_PACKAGES[
"opencv"
] = cv2.__version__ # pyright: ignore

Expand Down
661 changes: 661 additions & 0 deletions colour_checker_detection/scripts/LICENSE

Large diffs are not rendered by default.

233 changes: 233 additions & 0 deletions colour_checker_detection/scripts/prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
#!/usr/bin/env python
"""
Colour Checker Detection - Prediction
=====================================
Defines the scripts for colour checker detection using segmentation based on
*Ultralytics YOLOv8* machine learning model:
- :attr:`colour_checker_detection.SETTINGS_SEGMENTATION_COLORCHECKER_CLASSIC`
- :attr:`colour_checker_detection.SETTINGS_SEGMENTATION_COLORCHECKER_SG`
- :func:`colour_checker_detection.colour_checkers_coordinates_segmentation`
- :func:`colour_checker_detection.extract_colour_checkers_segmentation`
- :func:`colour_checker_detection.detect_colour_checkers_segmentation`
Warnings
--------
This script is provided under the terms of the
*GNU Affero General Public License v3.0* as it uses the *Ultralytics YOLOv8*
API which is incompatible with the *BSD-3-Clause*.
"""

from __future__ import annotations

import logging
import os
from pathlib import Path

import click
import cv2
import numpy as np
from colour import read_image
from colour.hints import List, Literal, NDArray, Tuple
from ultralytics import YOLO
from ultralytics.utils.downloads import download

from colour_checker_detection.detection.segmentation import as_8_bit_BGR_image

__author__ = "Colour Developers"
__copyright__ = "Copyright 2024 Colour Developers"
__license__ = (
"GNU Affero General Public License v3.0 - "
"https://www.gnu.org/licenses/agpl-3.0.en.html"
)
__maintainer__ = "Colour Developers"
__email__ = "[email protected]"
__status__ = "Production"


__all__ = [
"ROOT_REPOSITORY",
"URL_BASE",
"URL_MODEL_FILE_DEFAULT",
"inference",
"segmentation",
]


logger = logging.getLogger(__name__)

ROOT_REPOSITORY: str = os.environ.get(
"COLOUR_SCIENCE__COLOUR_CHECKER_DETECTION__REPOSITORY",
os.path.join(
os.path.expanduser("~"),
".colour-science",
"colour-checker-detection",
),
)
"""Root of the local repository to download the hosted models to."""

URL_BASE: str = (
"https://huggingface.co/colour-science/colour-checker-detection-models"
)
"""URL of the remote repository to download the models from."""

URL_MODEL_FILE_DEFAULT: str = (
f"{URL_BASE}/resolve/main/models/colour-checker-detection-l-seg.pt"
)
"""URL for the default segmentation model."""


def inference(
source: str | Path | NDArray, model: YOLO, show: bool = False, **kwargs
) -> List[Tuple[NDArray, NDArray, NDArray]]:
"""
Run the inference on the provided source.
Parameters
----------
source
Source of the image to make predictions on. Accepts all source types
accepted by the *YOLOv8* model.
model
The model to use for the inference.
show
Whether to show the inference results on the image.
Other Parameters
----------------
\\**kwargs : dict, optional
Keywords arguments for the *YOLOv8* segmentation method.
Returns
-------
:class:`list`
Inference results.
"""

data = []

for result in model(source, show=show, **kwargs):
show and cv2.waitKey(0) == ord("n")

data_boxes = result.boxes.data
data_masks = result.masks.data

for i in range(data_boxes.shape[0]):
data.append(
(
data_boxes[i, 4].cpu().numpy(),
data_boxes[i, 5].cpu().numpy(),
data_masks[i].data.cpu().numpy(),
)
)

if np.any(data[-1][-1]):
logging.debug(
'Found a "%s" class object with "%s" confidence.',
data[-1][1],
data[-1][0],
)
else:
logging.warning("No objects were detected!")

return data


@click.command()
@click.option(
"--input",
required=True,
type=click.Path(exists=True),
help="Input file to run the segmentation model on.",
)
@click.option(
"--output",
help="Output file to write the segmentation results to.",
)
@click.option(
"--model",
"model",
type=click.Path(exists=True),
help='Segmentation model file to load. Default to the "colour-science" model '
'hosted on "HuggingFace". It will be downloaded if not cached already.',
)
@click.option(
"--show/--no-show",
default=False,
help="Whether to show the segmentation results.",
)
@click.option(
"--logging-level",
"logging_level",
default="INFO",
type=click.Choice(
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
),
help="Set the logging level.",
)
def segmentation(
input: str, # noqa: A002
output: str | None = None,
model: str | None = None,
show: bool = False,
logging_level: Literal[
"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"
] = "INFO",
) -> NDArray:
"""
Run the segmentation model on the given input file and save the results to
given output file.
Parameters
----------
input
Input file to run the segmentation model on.
output
Output file to write the segmentation results to.
model
Segmentation model file to load. Default to the *colour-science8 model
hosted on *HuggingFace*. It will be downloaded if not cached already.
show
Whether to show the segmentation results.
logging_level
Set the logging level.
Returns
-------
:class:`numpy.ndarray`
Inference results.
"""

logging.getLogger().setLevel(getattr(logging, logging_level.upper()))

if model is None:
model = os.path.join(
ROOT_REPOSITORY, os.path.basename(URL_MODEL_FILE_DEFAULT)
)
logging.debug('Using "%s" default model.', model)
if not os.path.exists(model):
logging.info('Downloading "%s" model...', URL_MODEL_FILE_DEFAULT)
download(URL_MODEL_FILE_DEFAULT, ROOT_REPOSITORY)

if input.endswith((".npy", ".npz")):
logging.debug('Reading "%s" serialised array...', input)
source = np.load(input)
else:
logging.debug('Reading "%s" image...', input)
source = as_8_bit_BGR_image(read_image(input))

results = np.array(inference(source, YOLO(model), show), dtype=object)

if output is None:
output = f"{input}.npz"

np.savez(output, results)

return results


if __name__ == "__main__":
logging.basicConfig()

segmentation()
8 changes: 4 additions & 4 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@ fonttools==4.47.0 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
imageio==2.33.1 ; python_version >= "3.9" and python_version < "3.13"
imagesize==1.4.1 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.0.0 ; python_version >= "3.9" and python_version < "3.10"
importlib-metadata==7.0.1 ; python_version >= "3.9" and python_version < "3.10"
importlib-resources==6.1.1 ; python_version >= "3.9" and python_version < "3.10"
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13"
kiwisolver==1.4.5 ; python_version >= "3.9" and python_version < "3.13"
latexcodec==2.0.1 ; python_version >= "3.9" and python_version < "3.13"
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13"
matplotlib==3.8.2 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
opencv-python==4.8.1.78 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.3 ; python_version >= "3.9" and python_version < "3.13"
opencv-python==4.9.0.80 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
pybtex==0.24.0 ; python_version >= "3.9" and python_version < "3.13"
pybtex-docutils==1.0.3 ; python_version >= "3.9" and python_version < "3.13"
pydata-sphinx-theme==0.14.4 ; python_version >= "3.9" and python_version < "3.13"
Expand Down
15 changes: 11 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ name = "colour-checker-detection"
version = "0.1.6"
description = "Colour checker detection with Python"
license = "BSD-3-Clause"
authors = [ "Colour Developers <[email protected]>" ]
maintainers = [ "Colour Developers <[email protected]>" ]
authors = ["Colour Developers <[email protected]>"]
maintainers = ["Colour Developers <[email protected]>"]
readme = 'README.rst'
repository = "https://github.com/colour-science/colour-checker-detection"
homepage = "https://www.colour-science.org/"
Expand Down Expand Up @@ -43,6 +43,9 @@ exclude = [
"colour_checker_detection/resources/colour-checker-detection-tests-datasets/colour_checker_detection/detection/*",
]

[tool.poetry.scripts]
colour-checker-detection-segmentation = 'colour_checker_detection.scripts.prediction:segmentation'

[tool.poetry.dependencies]
python = ">= 3.9, < 3.13"
colour-science = ">= 0.4.4"
Expand All @@ -55,6 +58,10 @@ typing-extensions = ">= 4, < 5"
[tool.poetry.group.optional.dependencies]
matplotlib = ">= 3.5, != 3.5.0, != 3.5.1"

[tool.poetry.group.ultralytics.dependencies]
click = { version = ">= 8, < 9", python = "< 3.12" }
ultralytics = { version = ">= 8, < 9", python = "< 3.12" }

[tool.poetry.group.dev.dependencies]
coverage = "!= 6.3"
coveralls = "*"
Expand Down Expand Up @@ -86,7 +93,7 @@ exclude = '''
'''

[tool.flynt]
line_length=999
line_length = 999

[tool.isort]
ensure_newline_before_comments = true
Expand Down Expand Up @@ -203,5 +210,5 @@ convention = "numpy"
"utilities/unicode_to_ascii.py" = ["RUF001"]

[build-system]
requires = [ "poetry_core>=1.0.0" ]
requires = ["poetry_core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
Loading

0 comments on commit cc7b857

Please sign in to comment.