Skip to content

Commit

Permalink
py(core): extend CLI to Push command
Browse files Browse the repository at this point in the history
  • Loading branch information
tarilabs committed Jul 29, 2024
1 parent a7ec3e5 commit 791fe55
Show file tree
Hide file tree
Showing 8 changed files with 270 additions and 156 deletions.
13 changes: 13 additions & 0 deletions omlmd/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Using this to scope CLI targets
import click
from omlmd.helpers import Helper
from omlmd.model_metadata import deserialize_mdfile

@click.group()
def cli():
Expand Down Expand Up @@ -29,8 +30,20 @@ def config(target):
def crawl(targets):
"""Crawls configuration for the given list of OCI Artifact for ML model and metadata."""
click.echo(Helper().crawl(targets))

@click.command()
@click.argument('target', required=True)
@click.argument('path', required=True, type=click.Path())
@click.option('-m', '--metadata', required=True, type=click.Path())
def push(target, path, metadata):
"""Pushes an OCI Artifact containing ML model and metadata, supplying metadata from file as necessary"""
import logging
logging.basicConfig(level=logging.DEBUG)
md = deserialize_mdfile(metadata)
click.echo(Helper().push(target, path, **md))

cli.add_command(pull)
cli.add_command(get)
get.add_command(config)
cli.add_command(crawl)
cli.add_command(push)
3 changes: 2 additions & 1 deletion omlmd/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def download_file(uri):
class Helper:
def __init__(self, registry: Optional[OMLMDRegistry] = None):
if registry is None:
self._registry = OMLMDRegistry(insecure=True)
self._registry = OMLMDRegistry(insecure=True) # TODO: this is a bit limiting when used from CLI, to be refactored
else:
self._registry = registry

Expand Down Expand Up @@ -61,6 +61,7 @@ def push(
"model_metadata.omlmd.yaml:application/x-config",
]
try:
# print(target, files, model_metadata.to_annotations_dict())
return self._registry.push(
target=target,
files=files,
Expand Down
16 changes: 16 additions & 0 deletions omlmd/model_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,19 @@ def to_yaml(self) -> str:
def from_yaml(yaml_str: str) -> 'ModelMetadata':
data = yaml.safe_load(yaml_str)
return ModelMetadata(**data)


def deserialize_mdfile(file):
with open(file, 'r') as file:
content = file.read()
try:
return json.loads(content)
except json.JSONDecodeError:
pass

try:
return yaml.safe_load(content)
except yaml.YAMLError:
pass

raise ValueError(f"The file at {file} is neither a valid JSON nor a valid YAML file.")
329 changes: 175 additions & 154 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ nbconvert = "^7.16.4"
click = "^8.1.7"

[tool.poetry.group.dev.dependencies]
pytest = "^8.2.2"
pytest = "^8.3.2"
pytest-mock = "^3.14.0"
jq = "^1.7.0"
scikit-learn = "^1.5.0"
ipykernel = "^6.29.4"
Expand Down
7 changes: 7 additions & 0 deletions tests/data/md.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"name": "Model Example",
"description": "Lorem ipsum",
"author": "John Doe",
"accuracy": 0.987,
"license": "Apache-2.0"
}
29 changes: 29 additions & 0 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from omlmd.helpers import Helper
from omlmd.model_metadata import deserialize_mdfile
import tempfile
import json

def test_call_push_using_md_from_file(mocker):
helper = Helper()
mocker.patch.object(helper, "push", return_value=None)

md = {
"name": "mnist",
"description": "Lorem ipsum",
"author": "John Doe",
"accuracy": .987
}
with tempfile.NamedTemporaryFile(delete=True, mode="w") as f:
f.write(json.dumps(md))
f.flush()
md = deserialize_mdfile(f.name)

helper.push("localhost:8080/mmortari/ml-iris:v1", "some-file", **md)
helper.push.assert_called_once_with(
"localhost:8080/mmortari/ml-iris:v1",
"some-file",
name="mnist",
description="Lorem ipsum",
author="John Doe",
accuracy=0.987
)
26 changes: 26 additions & 0 deletions tests/test_omlmd.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from omlmd.model_metadata import ModelMetadata
from omlmd.model_metadata import deserialize_mdfile
import tempfile
import json
import yaml

def test_dry_run_model_metadata_json_yaml_conversions():
metadata = ModelMetadata(name="Example Model", author="John Doe")
Expand All @@ -16,3 +20,25 @@ def test_dry_run_model_metadata_json_yaml_conversions():

assert metadata == metadata_from_json
assert metadata == metadata_from_yaml


def test_deserialize_file_json():
md_dict = ModelMetadata(name="Example Model", author="John Doe", model_format_name="onnx", model_format_version="1", customProperties={"accuracy": .987}).to_dict()
json_str = json.dumps(md_dict)

with tempfile.NamedTemporaryFile(delete=True, mode="w") as f:
f.write(json_str)
f.flush()
metadata_from_json = deserialize_mdfile(f.name)
assert md_dict == metadata_from_json


def test_deserialize_file_yaml():
md_dict = ModelMetadata(name="Example Model", author="John Doe", model_format_name="onnx", model_format_version="1", customProperties={"accuracy": .987}).to_dict()
yaml_str = yaml.dump(md_dict)

with tempfile.NamedTemporaryFile(delete=True, mode="w") as f:
f.write(yaml_str)
f.flush()
metadata_from_yaml = deserialize_mdfile(f.name)
assert md_dict == metadata_from_yaml

0 comments on commit 791fe55

Please sign in to comment.