-
-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
961d1f3
commit ae19d89
Showing
7 changed files
with
1,051 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from pathlib import Path | ||
from typing import Optional | ||
|
||
import matplotlib.pyplot as plt | ||
import PIL | ||
from matplotlib import patches | ||
from transformers import AutoModelForTokenClassification, AutoProcessor | ||
|
||
from datasets import load_dataset | ||
|
||
|
||
def display_image_bounding_boxes_cord( | ||
image: PIL.Image, | ||
bboxes: list[tuple[int, int, int, int]], | ||
ner_tags: Optional[list[str]] = None, | ||
colors: Optional[list[str]] = None, | ||
output_path: Optional[Path] = None, | ||
): | ||
"""Display an image with bounding boxes. | ||
:param image: the image | ||
:param bboxes: the bounding boxes | ||
""" | ||
fig, ax = plt.subplots() | ||
ax.imshow(image) | ||
|
||
for i, bbox in enumerate(bboxes): | ||
color = colors[ner_tags[i]] if ner_tags else "r" | ||
x_min, y_min, x_max, y_max = bbox | ||
x_min *= image.width / 1000 | ||
y_min *= image.height / 1000 | ||
x_max *= image.width / 1000 | ||
y_max *= image.height / 1000 | ||
rect = patches.Rectangle( | ||
(x_min, y_min), | ||
x_max - x_min, | ||
y_max - y_min, | ||
linewidth=1, | ||
edgecolor=color, | ||
facecolor="none", | ||
) | ||
ax.add_patch(rect) | ||
|
||
if output_path: | ||
fig.savefig(output_path) | ||
plt.close() | ||
else: | ||
plt.show() | ||
|
||
|
||
processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False) | ||
model = AutoModelForTokenClassification.from_pretrained( | ||
"microsoft/layoutlmv3-base", num_labels=7 | ||
) | ||
dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train") | ||
|
||
|
||
example = dataset[0] | ||
image = example["image"] | ||
words = example["tokens"] | ||
boxes = example["bboxes"] | ||
word_labels = example["ner_tags"] | ||
|
||
display_image_bounding_boxes_cord(image, boxes) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from pathlib import Path | ||
from typing import Optional | ||
|
||
import matplotlib.patches as patches | ||
import matplotlib.pyplot as plt | ||
import PIL | ||
import tqdm | ||
|
||
from datasets import DatasetDict | ||
|
||
|
||
def display_image_bounding_boxes( | ||
image: PIL.Image, | ||
bboxes: list[tuple[int, int, int, int]], | ||
ner_tags: Optional[list[str]] = None, | ||
colors: Optional[list[str]] = None, | ||
output_path: Optional[Path] = None, | ||
): | ||
"""Display an image with bounding boxes. | ||
:param image: the image | ||
:param bboxes: the bounding boxes | ||
""" | ||
fig, ax = plt.subplots() | ||
ax.imshow(image) | ||
|
||
for i, bbox in enumerate(bboxes): | ||
color = colors[ner_tags[i]] if ner_tags else None | ||
x_min, y_min, x_max, y_max = bbox | ||
rect = patches.Rectangle( | ||
(x_min, y_min), | ||
x_max - x_min, | ||
y_max - y_min, | ||
linewidth=1, | ||
edgecolor=color, | ||
facecolor="none", | ||
) | ||
ax.add_patch(rect) | ||
|
||
if output_path: | ||
fig.savefig(output_path) | ||
plt.close() | ||
else: | ||
plt.show() | ||
|
||
|
||
base_ds = DatasetDict.load_from_disk("datasets/ingredient-detection-layout-dataset-v1") | ||
|
||
|
||
# Useful for debugging (checking if the image with bounding boxes is correct) | ||
for split_name in ("train", "test"): | ||
root_path = Path( | ||
f"datasets/ingredient-detection-layout-dataset-v1/_output_images/{split_name}" | ||
) | ||
root_path.mkdir(parents=True, exist_ok=True) | ||
|
||
for item in tqdm.tqdm(base_ds[split_name], desc="dataset items"): | ||
if item["offsets"]: | ||
barcode = item["meta"]["barcode"] | ||
image_id = item["meta"]["image_id"] | ||
output_path = root_path / f"{barcode}_{image_id}.png" | ||
display_image_bounding_boxes( | ||
item["image"], | ||
item["bboxes"], | ||
item["ner_tags"], | ||
output_path=output_path, | ||
colors=["r", "g", "b"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
"""Generate an image-text dataset compatible with LayoutLMv3 inputs from the | ||
NER-like dataset.""" | ||
|
||
import logging | ||
from io import BytesIO | ||
from typing import Optional | ||
from urllib.parse import urlparse | ||
|
||
import PIL | ||
import requests | ||
from openfoodfacts import OCRResult | ||
from openfoodfacts.ocr import Word | ||
from openfoodfacts.utils import get_logger, http_session | ||
from PIL import Image | ||
from requests.exceptions import ConnectionError, SSLError, Timeout | ||
|
||
import datasets | ||
from datasets import DatasetDict, load_dataset | ||
|
||
logger = get_logger() | ||
|
||
|
||
class ImageLoadingException(Exception): | ||
"""Exception raised by `get_image_from_url` when image cannot be fetched | ||
from URL or if loading failed. | ||
""" | ||
|
||
pass | ||
|
||
|
||
def _get_image_from_url( | ||
image_url: str, | ||
error_raise: bool = True, | ||
session: Optional[requests.Session] = None, | ||
) -> Optional[requests.Response]: | ||
auth = ( | ||
("off", "off") | ||
if urlparse(image_url).netloc.endswith("openfoodfacts.net") | ||
else None | ||
) | ||
try: | ||
if session: | ||
r = session.get(image_url, auth=auth) | ||
else: | ||
r = requests.get(image_url, auth=auth) | ||
except (ConnectionError, SSLError, Timeout) as e: | ||
error_message = "Cannot download image %s" | ||
if error_raise: | ||
raise ImageLoadingException(error_message % image_url) from e | ||
logger.info(error_message, image_url, exc_info=e) | ||
return None | ||
|
||
if not r.ok: | ||
error_message = "Cannot download image %s: HTTP %s" | ||
error_args = (image_url, r.status_code) | ||
if error_raise: | ||
raise ImageLoadingException(error_message % error_args) | ||
logger.log( | ||
logging.INFO if r.status_code < 500 else logging.WARNING, | ||
error_message, | ||
*error_args, | ||
) | ||
return None | ||
|
||
return r | ||
|
||
|
||
def get_image_from_url(image_url: str, error_raise: bool = True): | ||
s3_url = image_url.replace( | ||
"https://static.openfoodfacts.org/images/products/", | ||
"https://openfoodfacts-images.s3.eu-west-3.amazonaws.com/data/", | ||
) | ||
r = _get_image_from_url(s3_url, error_raise=error_raise) | ||
|
||
if r is None: | ||
logger.info( | ||
"Cannot download image from S3 (%s), falling back to Open Food Facts server", | ||
s3_url, | ||
) | ||
r = _get_image_from_url(image_url, error_raise=error_raise) | ||
|
||
if r is None: | ||
return None | ||
|
||
content_bytes = r.content | ||
try: | ||
return Image.open(BytesIO(content_bytes)) | ||
except PIL.UnidentifiedImageError: | ||
error_message = f"Cannot identify image {image_url}" | ||
if error_raise: | ||
raise ImageLoadingException(error_message) | ||
logger.info(error_message) | ||
except PIL.Image.DecompressionBombError: | ||
error_message = f"Decompression bomb error for image {image_url}" | ||
if error_raise: | ||
raise ImageLoadingException(error_message) | ||
logger.info(error_message) | ||
|
||
|
||
def generate_layoutlm_dataset_item(item): | ||
"""Generate a LayoutLMv3 dataset item from a NER-like dataset item. | ||
:param item: the NER-like dataset item | ||
:return: the LayoutLMv3 dataset item | ||
""" | ||
text = item["text"] | ||
offsets = item["offsets"] | ||
ocr_url = item["meta"]["url"] | ||
image_url = ocr_url.replace(".json", ".jpg") | ||
image = get_image_from_url(image_url, error_raise=False) | ||
new_item = { | ||
"text": text, | ||
"offsets": offsets, | ||
"meta": item["meta"], | ||
"words": [], | ||
"bboxes": [], | ||
"ner_tags": [], | ||
} | ||
|
||
if image is None: | ||
logger.info("Cannot load image from %s", image_url) | ||
return None | ||
|
||
if image.mode != "RGB": | ||
image = image.convert("RGB") | ||
|
||
ocr_result = OCRResult.from_url(ocr_url, http_session, error_raise=False) | ||
if ocr_result is None: | ||
logger.info("Cannot load OCR result from %s", ocr_url) | ||
return None | ||
|
||
first_words = set() | ||
selected_words = set() | ||
for i, (start_idx, end_idx) in enumerate(offsets): | ||
words: Optional[list[Word]] = ocr_result.get_words_from_indices( | ||
start_idx, end_idx | ||
) | ||
if words is None: | ||
logger.info( | ||
"Cannot get word indices #{%d} (%s) from OCR result %s", | ||
i, | ||
(start_idx, end_idx), | ||
ocr_url, | ||
) | ||
continue | ||
if len(words) == 0: | ||
raise ValueError("Empty word list") | ||
|
||
first_words.add(words[0]) | ||
selected_words |= set(words) | ||
|
||
width, height = image.size | ||
for page in ocr_result.full_text_annotation.pages: | ||
for block in page.blocks: | ||
for paragraph in block.paragraphs: | ||
for word in paragraph.words: | ||
if word in first_words: | ||
ner_tag = "B-ING" | ||
elif word in selected_words: | ||
ner_tag = "I-ING" | ||
else: | ||
ner_tag = "O" | ||
new_item["words"].append(word.text) | ||
new_item["ner_tags"].append(ner_tag) | ||
y_min = min([vertex[1] for vertex in word.bounding_poly.vertices]) | ||
x_min = min([vertex[0] for vertex in word.bounding_poly.vertices]) | ||
y_max = max([vertex[1] for vertex in word.bounding_poly.vertices]) | ||
x_max = max([vertex[0] for vertex in word.bounding_poly.vertices]) | ||
# Normalize bounding box coordinates: make sure that the | ||
# coordinates don't overflow and are normalized between 0 | ||
# and 999 | ||
new_item["bboxes"].append( | ||
( | ||
max(0, min(999, int(x_min * 1000 / width))), | ||
max(0, min(999, int(y_min * 1000 / height))), | ||
max(0, min(999, int(x_max * 1000 / width))), | ||
max(0, min(999, int(y_max * 1000 / height))), | ||
) | ||
) | ||
new_item["image"] = image | ||
|
||
return new_item | ||
|
||
|
||
DATASET_VERSION = "v6" | ||
|
||
DATASET_URLS = { | ||
"train": f"https://static.openfoodfacts.org/data/datasets/ingredient_detection_dataset-alpha-{DATASET_VERSION}_train.jsonl.gz", | ||
"test": f"https://static.openfoodfacts.org/data/datasets/ingredient_detection_dataset-alpha-{DATASET_VERSION}_test.jsonl.gz", | ||
} | ||
|
||
base_ds = load_dataset("json", data_files=DATASET_URLS) | ||
|
||
# Useful for debugging (checking if the image with bounding boxes is correct) | ||
# for split_name in ("train", "test"): | ||
# for item in tqdm.tqdm(base_ds[split_name], desc="dataset items"): | ||
# new_item = generate_layoutlm_dataset_item(item) | ||
|
||
# if new_item["offsets"]: | ||
# display_image_bounding_boxes( | ||
# new_item["image"], new_item["bboxes"], new_item["ner_tags"] | ||
# ) | ||
|
||
features = datasets.Features( | ||
{ | ||
"ner_tags": datasets.Sequence( | ||
datasets.features.ClassLabel(names=["O", "B-ING", "I-ING"]) | ||
), | ||
"words": datasets.Sequence(datasets.Value("string")), | ||
"bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))), | ||
"image": datasets.features.Image(), | ||
"text": datasets.features.Value("string"), | ||
"offsets": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))), | ||
"meta": { | ||
"barcode": datasets.Value("string"), | ||
"image_id": datasets.Value("string"), | ||
"url": datasets.Value("string"), | ||
"id": datasets.Value("string"), | ||
"in_test_split": datasets.Value("bool"), | ||
}, | ||
} | ||
) | ||
new_ds_train = base_ds["train"].map( | ||
generate_layoutlm_dataset_item, | ||
features=features, | ||
remove_columns=["marked_text", "tokens"], | ||
) | ||
new_ds_test = base_ds["test"].map( | ||
generate_layoutlm_dataset_item, | ||
features=features, | ||
remove_columns=["marked_text", "tokens"], | ||
) | ||
|
||
new_ds = DatasetDict({"train": new_ds_train, "test": new_ds_test}) | ||
new_ds.save_to_disk("datasets/ingredient-detection-layout-dataset-v1") | ||
new_ds.push_to_hub("raphael0202/ingredient-detection-layout-dataset") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
accelerate | ||
seqeval | ||
transformers[torch] | ||
datasets | ||
openfoodfacts==0.1.11 |
Oops, something went wrong.