Skip to content

Commit

Permalink
remove mlm_prefix in pydantic models
Browse files Browse the repository at this point in the history
  • Loading branch information
rbavery committed Mar 7, 2024
1 parent 9ddff24 commit 0bed29b
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 33 deletions.
File renamed without changes.
File renamed without changes.
34 changes: 17 additions & 17 deletions stac_model/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def eurosat_resnet():
1231.58581042,
],
)
mlm_input = ModelInput(
input = ModelInput(
name="13 Band Sentinel-2 Batch",
bands=band_names,
input_array=input_array,
Expand All @@ -75,7 +75,7 @@ def eurosat_resnet():
statistics=stats,
pre_processing_function="https://github.com/microsoft/torchgeo/blob/545abe8326efc2848feae69d0212a15faba3eb00/torchgeo/datamodules/eurosat.py", # noqa: E501
)
mlm_runtime = Runtime(
runtime = Runtime(
framework="torch",
version="2.1.2+cu121",
asset=Asset(
Expand Down Expand Up @@ -107,28 +107,28 @@ def eurosat_resnet():
ClassObject(value=class_map[class_name], name=class_name)
for class_name in class_map
]
mlm_output = ModelOutput(
output = ModelOutput(
task="classification",
classification_classes=class_objects,
output_shape=[-1, 10],
result_array=[result_array],
)
ml_model_meta = MLModelProperties(
mlm_name="Resnet-18 Sentinel-2 ALL MOCO",
mlm_task="classification",
mlm_framework="pytorch",
mlm_framework_version="2.1.2+cu121",
mlm_file_size=43000000,
mlm_memory_size=1,
mlm_summary=(
name="Resnet-18 Sentinel-2 ALL MOCO",
task="classification",
framework="pytorch",
framework_version="2.1.2+cu121",
file_size=43000000,
memory_size=1,
summary=(
"Sourced from torchgeo python library,"
"identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO"
),
mlm_pretrained_source="EuroSat Sentinel-2",
mlm_total_parameters=11_700_000,
mlm_input=[mlm_input],
mlm_runtime=[mlm_runtime],
mlm_output=[mlm_output],
pretrained_source="EuroSat Sentinel-2",
total_parameters=11_700_000,
input=[input],
runtime=[runtime],
output=[output],
)
# TODO, this can't be serialized but pystac.item calls for a datetime
# in docs. start_datetime=datetime.strptime("1900-01-01", "%Y-%m-%d")
Expand All @@ -138,8 +138,8 @@ def eurosat_resnet():
geometry = None
bbox = [-90, -180, 90, 180]
name = (
"_".join(ml_model_meta.mlm_name.split(" ")).lower()
+ f"_{ml_model_meta.mlm_task}".lower()
"_".join(ml_model_meta.name.split(" ")).lower()
+ f"_{ml_model_meta.task}".lower()
)
item = pystac.Item(
id=name,
Expand Down
32 changes: 16 additions & 16 deletions stac_model/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,29 +38,29 @@
PREFIX = f"{get_args(SchemaName)[0]}:"


def mlm_prefix_replacer(field_name: str) -> str:
return field_name.replace("mlm_", "mlm:")
def mlm_prefix_adder(field_name: str) -> str:
return "mlm:" + field_name


class MLModelProperties(BaseModel):
mlm_name: str
mlm_task: TaskEnum
mlm_framework: str
mlm_framework_version: str
mlm_file_size: int
mlm_memory_size: int
mlm_input: List[ModelInput]
mlm_output: List[ModelOutput]
mlm_runtime: List[Runtime]
mlm_total_parameters: int
mlm_pretrained_source: str
mlm_summary: str
mlm_parameters: Optional[
name: str
task: TaskEnum
framework: str
framework_version: str
file_size: int
memory_size: int
input: List[ModelInput]
output: List[ModelOutput]
runtime: List[Runtime]
total_parameters: int
pretrained_source: str
summary: str
parameters: Optional[
Dict[str, Union[int, str, bool, List[Union[int, str, bool]]]]
] = None # noqa: E501

model_config = ConfigDict(
alias_generator=mlm_prefix_replacer, populate_by_name=True, extra="ignore"
alias_generator=mlm_prefix_adder, populate_by_name=True, extra="ignore"
)


Expand Down

0 comments on commit 0bed29b

Please sign in to comment.