Skip to content

Commit

Permalink
Merge pull request #11 from zauberzeug/nodelib-12
Browse files Browse the repository at this point in the history
adapt to nodelib 0.12
  • Loading branch information
denniswittich authored Nov 26, 2024
2 parents 1c2b4f8 + 6c5bcaf commit f6d2226
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 112 deletions.
17 changes: 10 additions & 7 deletions trainer/app_code/model_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,22 @@ def get_all_weightfiles(training_path: Path) -> List[Path]:
return weightfiles


def _epoch_from_weightfile(weightfile: Path) -> int:
number = weightfile.name[5:-3]
if number == '':
def epoch_from_weightfile(weightfile: Path) -> int:
try:
number = weightfile.name[5:-3]
if number == '':
return 0
return int(number)
except ValueError:
return 0
return int(number)


def delete_older_epochs(training_path: Path, weightfile: Path):
all_weightfiles = get_all_weightfiles(training_path)

target_epoch = _epoch_from_weightfile(weightfile)
target_epoch = epoch_from_weightfile(weightfile)
for f in all_weightfiles:
if _epoch_from_weightfile(f) < target_epoch:
if epoch_from_weightfile(f) < target_epoch:
_try_remove(f)
delete_json_for_weightfile(f)

Expand All @@ -53,6 +56,6 @@ def _try_remove(file: Path):
def get_new(training_path: Path) -> Union[Path, None]:
all_weightfiles = get_all_weightfiles(training_path)
if all_weightfiles:
all_weightfiles.sort(key=_epoch_from_weightfile)
all_weightfiles.sort(key=epoch_from_weightfile)
return all_weightfiles[-1]
return None
140 changes: 88 additions & 52 deletions trainer/app_code/tests/test_yolov5.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,21 @@
import os
import shutil
from pathlib import Path
from typing import Dict
from typing import Dict, List, Tuple
from uuid import uuid4

import pytest
from learning_loop_node.data_classes import (Category, Context, Hyperparameter,
Training, TrainingData)
from learning_loop_node.data_classes import (Category, Context, TrainerState,
Training)
from learning_loop_node.data_exchanger import DataExchanger
from learning_loop_node.helpers.misc import create_image_folder
from learning_loop_node.loop_communication import LoopCommunicator
from learning_loop_node.trainer.downloader import TrainingsDownloader
from learning_loop_node.trainer.executor import Executor
from ruamel.yaml import YAML

from .. import model_files, yolov5_format
from ..yolov5_format import update_hyps
from ..yolov5_format import set_hyperparameters_in_file
from ..yolov5_trainer import Yolov5TrainerLogic

# pylint: disable=protected-access,unused-argument
Expand All @@ -37,15 +38,22 @@ class TestWithLoop:
"""This test environment sets up the environment vars and
a test project in the loop which is used for testing."""

async def test_training_creates_model(self, use_training_dir, data_exchanger: DataExchanger, glc: LoopCommunicator):
@pytest.mark.usefixtures('use_training_dir')
async def test_training_creates_model(self, data_exchanger: DataExchanger, glc: LoopCommunicator):
"""Test if training creates a model"""

project_folder = os.getcwd()
images_folder = create_image_folder(project_folder)
categories, image_data = await download_training_data(images_folder, data_exchanger, glc)
training = Training(id=str(uuid4()),
project_folder=os.getcwd(),
training_folder=os.getcwd() + '/training',
images_folder=os.getcwd() + '/images',
base_model_uuid_or_name='model.pt',
context=Context(project='pytest_yolo5det', organization='zauberzeug'))
training.data = await create_training_data(training, data_exchanger, glc)
project_folder=project_folder,
training_folder=project_folder + '/training',
images_folder=images_folder,
model_variant='',
context=Context(project='pytest_yolo5det', organization='zauberzeug'),
categories=categories, hyperparameters={}, training_number=1,
training_state=TrainerState.Initialized.value,
image_data=image_data)
yolov5_format.create_file_structure(training)
executor = Executor(os.getcwd())
# from https://github.com/WongKinYiu/yolor#training
Expand All @@ -59,19 +67,25 @@ async def test_training_creates_model(self, use_training_dir, data_exchanger: Da
best = training.training_folder + '/result/weights/best.pt'
assert os.path.isfile(best)

async def test_parse_progress_from_log(self, use_training_dir, data_exchanger: DataExchanger, glc: LoopCommunicator):
@pytest.mark.usefixtures('use_training_dir')
async def test_parse_progress_from_log(self, data_exchanger: DataExchanger, glc: LoopCommunicator):
"""Test if progress is parsed correctly from log"""
trainer = Yolov5TrainerLogic()
trainer.epochs = 2
project_folder = os.getcwd()
images_folder = create_image_folder(project_folder)
categories, image_data = await download_training_data(images_folder, data_exchanger, glc)
trainer._training = Training(
id=str(uuid4()),
project_folder=os.getcwd(),
training_folder=os.getcwd() + '/training',
images_folder=os.getcwd() + '/images',
base_model_uuid_or_name='model.pt',
project_folder=project_folder,
training_folder=project_folder + '/training',
images_folder=images_folder,
model_variant='',
context=Context(project='pytest_yolo5det', organization='zauberzeug'),
categories=categories, hyperparameters={}, training_number=1,
training_state=TrainerState.Initialized.value,
image_data=image_data,
)
trainer.training.data = await create_training_data(trainer.training, data_exchanger, glc)
yolov5_format.create_file_structure(trainer.training)

trainer._executor = Executor(os.getcwd())
Expand All @@ -93,16 +107,21 @@ async def test_parse_progress_from_log(self, use_training_dir, data_exchanger: D
@pytest.mark.environment(organization='', project='', mode='DETECTION')
class TestWithDetection:

async def test_create_file_structure_box_size(self, use_training_dir):
@pytest.mark.usefixtures('use_training_dir')
async def test_create_file_structure_box_size(self):
categories = [Category(name='point_category_1', id='uuid_of_class_1'),
Category(name='point_category_2', id='uuid_of_class_2', point_size=30)]
image_data = [{'set': 'train', 'id': 'image_1', 'width': 100, 'height': 100, 'box_annotations': [],
'point_annotations': [{'category_id': 'uuid_of_class_1', 'x': 50, 'y': 60},
{'category_id': 'uuid_of_class_2', 'x': 60, 'y': 70}]}]
trainer = Yolov5TrainerLogic()
trainer._training = Training(id='someid', context=Context(organization='o', project='p'),
project_folder='./', images_folder='./', training_folder='./')
trainer.training.data = TrainingData(image_data=image_data, categories=categories)
trainer._training = Training(
id='someid', context=Context(organization='o', project='p'),
project_folder='./', images_folder='./', training_folder='./',
image_data=image_data, categories=categories, hyperparameters={},
model_variant='', training_number=1,
training_state=TrainerState.Initialized.value)

yolov5_format.create_file_structure(trainer.training)

with open('./train/image_1.txt', 'r') as f:
Expand All @@ -111,13 +130,17 @@ async def test_create_file_structure_box_size(self, use_training_dir):
assert '0 0.500000 0.600000 0.200000 0.200000' in lines[0]
assert '1 0.600000 0.700000 0.300000 0.300000' in lines[1]

async def test_new_model_discovery(self, use_training_dir):
@pytest.mark.usefixtures('use_training_dir')
async def test_new_model_discovery(self):
"""This test also triggers the creation of a wts file"""
trainer = Yolov5TrainerLogic()
trainer._training = Training(id='someid', context=Context(organization='o', project='p'),
project_folder='./', images_folder='./', training_folder='./')
trainer.training.data = TrainingData(image_data=[], categories=[
Category(name='class_a', id='uuid_of_class_a', type='box')])
trainer._training = Training(
id='someid', context=Context(organization='o', project='p'),
project_folder='./', images_folder='./', training_folder='./', image_data=[],
categories=[Category(name='class_a', id='uuid_of_class_a', type='box')],
hyperparameters={}, model_variant='', training_number=1,
training_state=TrainerState.Initialized.value)

assert trainer._get_new_best_training_state() is None, 'should not find any models'

model_path = 'result/weights/published/latest.pt'
Expand Down Expand Up @@ -152,12 +175,15 @@ async def test_new_model_discovery(self, use_training_dir):
# 'yolov5_pytorch': ['/tmp/model.pt', '/tmp/test_training/hyp.yaml'],
# 'yolov5_wts': ['/tmp/model.wts']}

def test_newest_model_is_used(self, use_training_dir):
@pytest.mark.usefixtures('use_training_dir')
def test_newest_model_is_used(self):
trainer = Yolov5TrainerLogic()
trainer._training = Training(id='someid', context=Context(organization='o', project='p'),
project_folder='./', images_folder='./', training_folder='./')
trainer.training.data = TrainingData(image_data=[], categories=[
Category(name='class_a', id='uuid_of_class_a', type='box')])
trainer._training = Training(
id='someid', context=Context(organization='o', project='p'),
project_folder='./', images_folder='./', training_folder='./', image_data=[],
categories=[Category(name='class_a', id='uuid_of_class_a', type='box')],
hyperparameters={}, model_variant='', training_number=1,
training_state=TrainerState.Initialized.value)

# create some models.
mock_epoch(10, {})
Expand All @@ -168,12 +194,16 @@ def test_newest_model_is_used(self, use_training_dir):
assert 'epoch10.pt' not in new_model.meta_information['weightfile']
assert 'epoch200.pt' in new_model.meta_information['weightfile']

def test_old_model_files_are_deleted_on_publish(self, use_training_dir):
@pytest.mark.usefixtures('use_training_dir')
def test_old_model_files_are_deleted_on_publish(self):
trainer = Yolov5TrainerLogic()
trainer._training = Training(id='someid', context=Context(organization='o', project='p'),
project_folder='./', images_folder='./', training_folder='./')
trainer.training.data = TrainingData(image_data=[], categories=[
Category(name='class_a', id='uuid_of_class_a', type='box')])
trainer._training = Training(
id='someid', context=Context(organization='o', project='p'),
project_folder='./', images_folder='./', training_folder='./', image_data=[],
categories=[Category(name='class_a', id='uuid_of_class_a', type='box')],
hyperparameters={}, model_variant='', training_number=1,
training_state=TrainerState.Initialized.value)

assert trainer._get_new_best_training_state() is None, 'should not find any models'

mock_epoch(1, {'class_a': {'fp': 0, 'tp': 1, 'fn': 0}})
Expand All @@ -196,12 +226,15 @@ def test_old_model_files_are_deleted_on_publish(self, use_training_dir):
_, _, files = next(os.walk("result/weights"))
assert len(files) == 0

def test_newer_model_files_are_kept_during_deleting(self, use_training_dir):
@pytest.mark.usefixtures('use_training_dir')
def test_newer_model_files_are_kept_during_deleting(self):
trainer = Yolov5TrainerLogic()
trainer._training = Training(id='someid', context=Context(organization='o', project='p'),
project_folder='./', images_folder='./', training_folder='./')
trainer.training.data = TrainingData(image_data=[], categories=[
Category(name='class_a', id='uuid_of_class_a', type='box')])
trainer._training = Training(
id='someid', context=Context(organization='o', project='p'),
project_folder='./', images_folder='./', training_folder='./', image_data=[],
categories=[Category(name='class_a', id='uuid_of_class_a', type='box')],
hyperparameters={}, model_variant='', training_number=1,
training_state=TrainerState.Initialized.value)

# create some models.
mock_epoch(10, {})
Expand All @@ -217,10 +250,13 @@ def test_newer_model_files_are_kept_during_deleting(self, use_training_dir):
assert len(all_model_files) == 1
assert 'epoch201.pt' in str(all_model_files[0]), 'Epoch201 is not yed synced. It should not be deleted.'

async def test_clear_training_data(self, use_training_dir):
@pytest.mark.usefixtures('use_training_dir')
async def test_clear_training_data(self):
trainer = Yolov5TrainerLogic()
trainer._training = Training(id='someid', context=Context(organization='o', project='p'),
project_folder='./', images_folder='./', training_folder='./')
project_folder='./', images_folder='./', training_folder='./',
categories=[], hyperparameters={}, model_variant='',
image_data=[], training_number=1, training_state=TrainerState.Initialized.value)
os.makedirs(f'{trainer.training.training_folder}/result/weights/', exist_ok=True)
os.makedirs(f'{trainer.training.training_folder}/result/weights/published/', exist_ok=True)

Expand Down Expand Up @@ -255,30 +291,30 @@ def assert_yaml_content(yaml_path, **kwargs):
assert content[key] == value

shutil.copy('app_code/tests/test_data/hyp.yaml', '/tmp')
hyperparameter = Hyperparameter(resolution=600, flip_rl=True, flip_ud=True)
hyperparameter = {'resolution': 600,
'flip_rl': True,
'flip_ud': True}

assert_yaml_content('/tmp/hyp.yaml', fliplr=0, flipud=0)
update_hyps('/tmp/hyp.yaml', hyperparameter)
set_hyperparameters_in_file('/tmp/hyp.yaml', hyperparameter)
assert_yaml_content('/tmp/hyp.yaml', fliplr=0.5, flipud=0.5)

# =======================================================================================================================
# ---------------------------------------------- HELPERS ----------------------------------------------------------------
# =======================================================================================================================


async def create_training_data(training: Training, data_exchanger: DataExchanger, glc: LoopCommunicator) -> TrainingData:
training_data = TrainingData()
async def download_training_data(images_folder: str, data_exchanger: DataExchanger, glc: LoopCommunicator
) -> Tuple[List[Category], List[Dict]]:

image_data, _ = await TrainingsDownloader(data_exchanger).download_training_data(training.images_folder)
logging.info(f'got {len(image_data)} images')
image_data, _ = await TrainingsDownloader(data_exchanger).download_training_data(images_folder)

response = await glc.get(f"/{os.environ['LOOP_ORGANIZATION']}/projects/{os.environ['LOOP_PROJECT']}/data")
assert response.status_code != 401, 'Authentification error - did you set LOOP_USERNAME and LOOP_PASSWORD in your environment?'
assert response.status_code == 200
data = response.json()
training_data.categories = Category.from_list(data['categories'])
training_data.image_data = image_data
return training_data
categories = Category.from_list(data['categories'])
return categories, image_data


def mock_epoch(number: int, confusion_matrix: Dict):
Expand Down
Loading

0 comments on commit f6d2226

Please sign in to comment.