Skip to content

Commit

Permalink
Merge pull request #150 from lovit/feature/149
Browse files Browse the repository at this point in the history
[#149] Pipeline - task template
  • Loading branch information
lovit authored Feb 4, 2025
2 parents cfbb394 + a027da7 commit d734c4a
Show file tree
Hide file tree
Showing 12 changed files with 128 additions and 7 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ repos:
hooks:
- id: pyright
entry: pyright
additional_dependencies: []
additional_dependencies:
- dacite==1.9.1
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.4
hooks:
Expand Down
4 changes: 4 additions & 0 deletions examples/dummy/pipeline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pipeline:
- name: Dummy
args:
name: soynlp
1 change: 1 addition & 0 deletions examples/dummy/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
soynlp pipeline -c pipeline.yaml
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ soynlp = "soynlp.__main__:main"

[tool.poetry.dependencies]
python = "~3.12.8"
dacite = "^1.9.1"

[tool.poetry.group.dev.dependencies]
pytest = "^8.3.0"
Expand Down
7 changes: 4 additions & 3 deletions soynlp/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from importlib.metadata import metadata
from typing import Callable

from soynlp.dummy import dummy
from soynlp.pipeline.pipeline import Pipeline


def main():
Expand All @@ -18,8 +18,9 @@ def main():
parser.set_defaults(func=lambda: parser.print_usage())
subparsers = parser.add_subparsers()

sp_dummy = subparsers.add_parser("dummy", help="dummy function")
sp_dummy.set_defaults(func=dummy)
sp_pipeline = subparsers.add_parser("pipeline", help="Run pipeline with config")
sp_pipeline.add_argument("-c", "--config_file", type=str, required=True, help="Config file path")
sp_pipeline.set_defaults(func=Pipeline.run)

args = parser.parse_args()
func = args.func
Expand Down
25 changes: 25 additions & 0 deletions soynlp/configs/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from dataclasses import dataclass

import dacite
import yaml


@dataclass
class TaskConfig:
name: str
args: dict


@dataclass
class Config:
pipeline: list[TaskConfig]


def from_yaml(path: str) -> Config:
with open(path) as file:
data = yaml.full_load(file)
return from_dict(data) # type: ignore


def from_dict(data: dict) -> Config:
return dacite.from_dict(data_class=Config, data=data) # type: ignore
2 changes: 0 additions & 2 deletions soynlp/dummy.py

This file was deleted.

32 changes: 32 additions & 0 deletions soynlp/pipeline/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import importlib

import dacite

from soynlp.configs.config import Config, from_yaml
from soynlp.pipeline.tasks import Task


class Pipeline:
def __call__(self, config: Config):
tasks: list[Task] = self._load_tasks(config)
parameters: dict = {}

for task in tasks:
parameters |= task(parameters)
return parameters

def _load_tasks(self, config: Config) -> list[Task]:
tasks = []
for task_config in config.pipeline:
task_module = importlib.import_module("soynlp.pipeline.tasks")
task_class: type[Task] = getattr(task_module, f"{task_config.name}Task")
task_args = dacite.from_dict(data_class=task_class.args(), data=task_config.args) # type: ignore
task = task_class(task_args)
tasks.append(task)
return tasks

@classmethod
def run(cls, config_file: str):
config = from_yaml(config_file)
pipeline = Pipeline()
print(pipeline(config))
5 changes: 5 additions & 0 deletions soynlp/pipeline/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from soynlp.pipeline.tasks.dummy import DummyTask # noqa F401
from soynlp.pipeline.tasks.task import Task # noqa F401


__all__ = ("Task", "DummyTask")
14 changes: 14 additions & 0 deletions soynlp/pipeline/tasks/dummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from dataclasses import dataclass

from soynlp.pipeline.tasks.task import Task, TaskArgs


@dataclass
class DummyTaskArgs(TaskArgs):
name: str = "Dummy Task"


class DummyTask(Task[DummyTaskArgs]):
def __call__(self, parameters: dict) -> dict:
print(f"Called in DummyTask({self._args.name})")
return parameters
25 changes: 25 additions & 0 deletions soynlp/pipeline/tasks/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, TypeVar, get_args


@dataclass
class TaskArgs:
pass


TaskArgsType = TypeVar("TaskArgsType", bound=TaskArgs)


class Task(ABC, Generic[TaskArgsType]):
def __init__(self, args: TaskArgsType):
self._args = args

@classmethod
def args(cls) -> type[TaskArgs]:
generic_type = cls.__orig_bases__[0] # type: ignore[attr-defined]
return get_args(generic_type)[0]

@abstractmethod
def __call__(self, parameters: dict) -> dict:
raise NotImplementedError

0 comments on commit d734c4a

Please sign in to comment.