Skip to content

Commit

Permalink
Added is_owner to online conversion (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 authored and ptoupas committed Dec 12, 2024
1 parent 887ed47 commit b181785
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 78 deletions.
4 changes: 3 additions & 1 deletion modelconverter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import pkg_resources
from luxonis_ml.utils import PUT_FILE_REGISTRY

from .hub import *
from .hub import convert

__version__ = "0.3.1"

__all__ = ["convert"]


def load_put_file_plugins() -> None:
"""Registers any external put file plugins."""
Expand Down
21 changes: 6 additions & 15 deletions modelconverter/cli/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,6 @@ class Status(str, Enum):
),
]

TeamIDOption = Annotated[
Optional[str],
typer.Option(help="The team ID", show_default=False),
]

RepositoryUrlOption = Annotated[
Optional[str],
typer.Option(help="The repository URL", show_default=False),
Expand Down Expand Up @@ -294,9 +289,12 @@ class Status(str, Enum):
str, typer.Argument(help="Name of the model", show_default=False)
]

UserIDOption = Annotated[
Optional[str],
typer.Option(help="The user ID", show_default=False),
IsOwnerOption = Annotated[
bool,
typer.Option(
help="Whether the user is the owner of the resource",
show_default=False,
),
]

ArchitectureIDOption = Annotated[
Expand Down Expand Up @@ -408,13 +406,6 @@ class Status(str, Enum):
typer.Option(help="The project ID", show_default=False),
]

FilterPublicEntityByTeamIDOption = Annotated[
Optional[bool],
typer.Option(
help="Whether to filter public entity by team ID", show_default=False
),
]

LuxonisOnlyOption = Annotated[
bool,
typer.Option(help="Whether Luxonis only models", show_default=False),
Expand Down
31 changes: 13 additions & 18 deletions modelconverter/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def hub_ls(
**kwargs,
) -> List[Dict[str, Any]]:
rename = rename or {}
data = Request.get(f"{endpoint}/", params=kwargs).json()
data = Request.get(f"{endpoint}/", params=kwargs)
table = Table(row_styles=["yellow", "cyan"], box=ROUNDED)
for key in keys:
table.add_column(rename.get(key, key), header_style="magenta i")
Expand Down Expand Up @@ -285,7 +285,7 @@ def slug_to_id(
"is_public": is_public,
"slug": slug,
}
data = Request.get(f"{endpoint}/", params=params).json()
data = Request.get(f"{endpoint}/", params=params)
if data:
return data[0]["id"]
raise ValueError(f"Model with slug '{slug}' not found.")
Expand All @@ -307,7 +307,7 @@ def request_info(
resource_id = get_resource_id(identifier, endpoint)

try:
return Request.get(f"{endpoint}/{resource_id}/").json()
return Request.get(f"{endpoint}/{resource_id}/")
except HTTPError:
typer.echo(f"Resource with ID '{resource_id}' not found.")
exit(1)
Expand All @@ -333,26 +333,21 @@ def get_variant_name(


def get_version_number(model_id: str) -> str:
versions = Request.get(
"modelVersions/", params={"model_id": model_id}
).json()
versions = Request.get("modelVersions/", params={"model_id": model_id})
if not versions:
version = "0.1.0"
else:
max_version = Version(versions[0]["version"])
for v in versions[1:]:
max_version = max(max_version, Version(v["version"]))
max_version = str(max_version)
version_numbers = max_version.split(".")
version_numbers[-1] = str(int(version_numbers[-1]) + 1)
version = ".".join(version_numbers)
return version
return "0.1.0"
max_version = Version(versions[0]["version"])
for v in versions[1:]:
max_version = max(max_version, Version(v["version"]))
max_version = str(max_version)
version_numbers = max_version.split(".")
version_numbers[-1] = str(int(version_numbers[-1]) + 1)
return ".".join(version_numbers)


def wait_for_export(run_id: str) -> None:
def _get_run(run_id: str) -> Dict[str, Any]:
run = Request.dag_get(f"runs/{run_id}").json()
return run
return Request.dag_get(f"runs/{run_id}")

def _clean_logs(logs: str) -> str:
pattern = r"\[.*?\] \{.*?\} INFO - \[base\] logs:\s*"
Expand Down
46 changes: 17 additions & 29 deletions modelconverter/hub/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
DescriptionOption,
DescriptionShortOption,
DomainOption,
FilterPublicEntityByTeamIDOption,
HashOption,
HubVersionOption,
HubVersionOptionRequired,
IdentifierArgument,
IsOwnerOption,
IsPublicOption,
JSONOption,
LicenseTypeOption,
Expand Down Expand Up @@ -58,8 +58,6 @@
TagsOption,
TargetPrecisionOption,
TasksOption,
TeamIDOption,
UserIDOption,
VariantSlugOption,
VersionOption,
get_configs,
Expand Down Expand Up @@ -140,14 +138,12 @@ def login(

@model.command(name="ls")
def model_ls(
team_id: TeamIDOption = None,
tasks: TasksOption = None,
user_id: UserIDOption = None,
license_type: LicenseTypeOption = None,
is_public: IsPublicOption = None,
is_owner: IsOwnerOption = True,
slug: SlugOption = None,
project_id: ProjectIDOption = None,
filter_public_entity_by_team_id: FilterPublicEntityByTeamIDOption = None,
luxonis_only: LuxonisOnlyOption = False,
limit: LimitOption = 50,
sort: SortOption = "updated",
Expand All @@ -156,14 +152,12 @@ def model_ls(
"""Lists models."""
return hub_ls(
"models",
team_id=team_id,
tasks=[task for task in tasks] if tasks else [],
user_id=user_id,
license_type=license_type,
is_public=is_public,
is_owner=is_owner,
slug=slug,
project_id=project_id,
filter_public_entity_by_team_id=filter_public_entity_by_team_id,
luxonis_only=luxonis_only,
limit=limit,
sort=sort,
Expand Down Expand Up @@ -224,7 +218,7 @@ def model_create(
"links": links or [],
}
try:
res = Request.post("models", json=data).json()
res = Request.post("models", json=data)
except requests.HTTPError as e:
if (
e.response is not None
Expand All @@ -248,24 +242,22 @@ def model_delete(identifier: IdentifierArgument):

@variant.command(name="ls")
def variant_ls(
team_id: TeamIDOption = None,
user_id: UserIDOption = None,
model_id: ModelIDOption = None,
slug: SlugOption = None,
variant_slug: VariantSlugOption = None,
version: HubVersionOption = None,
is_public: IsPublicOption = None,
is_owner: IsOwnerOption = True,
limit: LimitOption = 50,
sort: SortOption = "updated",
order: OrderOption = Order.DESC,
) -> List[Dict[str, Any]]:
"""Lists model versions."""
return hub_ls(
"modelVersions",
team_id=team_id,
user_id=user_id,
model_id=model_id,
is_public=is_public,
is_owner=is_owner,
slug=slug,
variant_slug=variant_slug,
version=version,
Expand Down Expand Up @@ -324,7 +316,7 @@ def variant_create(
"tags": tags or [],
}
try:
res = Request.post("modelVersions", json=data).json()
res = Request.post("modelVersions", json=data)
except requests.HTTPError as e:
if str(e).startswith("{'detail': 'Unique constraint error."):
raise ValueError(
Expand All @@ -348,8 +340,6 @@ def variant_delete(identifier: IdentifierArgument):
@instance.command(name="ls")
def instance_ls(
platforms: PlatformsOption = None,
team_id: TeamIDOption = None,
user_id: UserIDOption = None,
model_id: ModelIDOption = None,
variant_id: ModelVersionIDOption = None,
model_type: ModelTypeOption = None,
Expand All @@ -359,6 +349,7 @@ def instance_ls(
hash: HashOption = None,
status: StatusOption = None,
is_public: IsPublicOption = None,
is_owner: IsOwnerOption = True,
compression_level: CompressionLevelOption = None,
optimization_level: OptimizationLevelOption = None,
slug: SlugOption = None,
Expand All @@ -382,9 +373,8 @@ def instance_ls(
status=status,
compression_level=compression_level,
optimization_level=optimization_level,
team_id=team_id,
user_id=user_id,
is_public=is_public,
is_owner=is_owner,
slug=slug,
limit=limit,
sort=sort,
Expand Down Expand Up @@ -434,7 +424,7 @@ def instance_download(
dest = Path(output_dir) if output_dir else None
model_instance_id = get_resource_id(identifier, "modelInstances")
downloaded_path = None
urls = Request.get(f"modelInstances/{model_instance_id}/download").json()
urls = Request.get(f"modelInstances/{model_instance_id}/download")
if not urls:
raise ValueError("No files to download")

Expand All @@ -445,9 +435,9 @@ def instance_download(
filename = unquote(Path(urlparse(url).path).name)
if dest is None:
dest = Path(
Request.get(f"modelInstances/{model_instance_id}")
.json()
.get("slug", model_instance_id)
Request.get(f"modelInstances/{model_instance_id}").get(
"slug", model_instance_id
)
)
dest.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -487,7 +477,7 @@ def instance_create(
"quantization_data": quantization_data,
"is_deployable": is_deployable,
}
res = Request.post("modelInstances", json=data).json()
res = Request.post("modelInstances", json=data)
print(f"Model instance '{res['name']}' created with ID '{res['id']}'")
if not silent:
instance_info(res["id"])
Expand All @@ -506,16 +496,14 @@ def instance_delete(identifier: IdentifierArgument):
def config(identifier: IdentifierArgument):
"""Prints the configuration of a model instance."""
model_instance_id = get_resource_id(identifier, "modelInstances")
res = Request.get(f"modelInstances/{model_instance_id}/config")
print(res.json())
print(Request.get(f"modelInstances/{model_instance_id}/config"))


@instance.command()
def files(identifier: IdentifierArgument):
"""Prints the configuration of a model instance."""
model_instance_id = get_resource_id(identifier, "modelInstances")
res = Request.get(f"modelInstances/{model_instance_id}/files")
print(res.json())
print(Request.get(f"modelInstances/{model_instance_id}/files"))


@instance.command()
Expand Down Expand Up @@ -548,7 +536,7 @@ def _export(
res = Request.post(
f"modelInstances/{model_instance_id}/export/{target.lower()}",
json=json,
).json()
)
print(
f"Model instance '{name}' created for {target} export with ID '{res['id']}'"
)
Expand Down
40 changes: 25 additions & 15 deletions modelconverter/hub/hub_requests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from json import JSONDecodeError
from typing import Dict, Final, Optional
from typing import Any, Dict, Final, Optional

import requests
from requests import HTTPError, Response
Expand All @@ -8,26 +8,36 @@


class Request:
URL: Final[str] = f"{environ.HUBAI_URL.rstrip('/')}/api/v1"
URL: Final[str] = f"{environ.HUBAI_URL.rstrip('/')}/models/api/v1"
DAG_URL: Final[str] = URL.replace("models", "dags")
HEADERS: Final[Dict[str, str]] = {
"accept": "application/json",
"Authorization": f"Bearer {environ.HUBAI_API_KEY}",
}

@staticmethod
def _process_response(response: Response) -> Any:
return Request._get_json(Request._check_response(response))

@staticmethod
def _check_response(response: Response) -> Response:
if response.status_code >= 400:
try:
json = response.json()
raise HTTPError(json, response=response)
except JSONDecodeError as e:
raise HTTPError(response.text) from e
raise HTTPError(Request._get_json(response), response=response)
return response

@staticmethod
def get(endpoint: str = "", **kwargs) -> requests.Response:
return Request._check_response(
def _get_json(response: Response) -> Any:
try:
return response.json()
except JSONDecodeError as e:
raise HTTPError(
f"Unexpected response from the server:\n{response.text}",
response=response,
) from e

@staticmethod
def get(endpoint: str = "", **kwargs) -> Any:
return Request._process_response(
requests.get(
Request._get_url(endpoint),
headers=Request.HEADERS,
Expand All @@ -36,8 +46,8 @@ def get(endpoint: str = "", **kwargs) -> requests.Response:
)

@staticmethod
def dag_get(endpoint: str = "", **kwargs) -> requests.Response:
return Request._check_response(
def dag_get(endpoint: str = "", **kwargs) -> Any:
return Request._process_response(
requests.get(
Request._get_url(endpoint, Request.DAG_URL),
headers=Request.HEADERS,
Expand All @@ -46,19 +56,19 @@ def dag_get(endpoint: str = "", **kwargs) -> requests.Response:
)

@staticmethod
def post(endpoint: str = "", **kwargs) -> requests.Response:
def post(endpoint: str = "", **kwargs) -> Any:
headers = Request.HEADERS
if "headers" in kwargs:
headers = {**Request.HEADERS, **kwargs.pop("headers")}
return Request._check_response(
return Request._process_response(
requests.post(
Request._get_url(endpoint), headers=headers, **kwargs
)
)

@staticmethod
def delete(endpoint: str = "", **kwargs) -> requests.Response:
return Request._check_response(
def delete(endpoint: str = "", **kwargs) -> Any:
return Request._process_response(
requests.delete(
Request._get_url(endpoint), headers=Request.HEADERS, **kwargs
)
Expand Down

0 comments on commit b181785

Please sign in to comment.