Skip to content

Commit

Permalink
update cli example. still getting datetime serialization issue
Browse files Browse the repository at this point in the history
  • Loading branch information
rbavery committed Feb 28, 2024
1 parent 4fc2e8e commit 30269d4
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 172 deletions.
163 changes: 5 additions & 158 deletions example.json
Original file line number Diff line number Diff line change
@@ -1,159 +1,6 @@
{
"mlm:name": "Resnet-18 Sentinel-2 ALL MOCO",
"mlm:task": "classification",
"mlm:framework": "pytorch",
"mlm:framework_version": "2.1.2+cu121",
"mlm:file_size": 1,
"mlm:memory_size": 1,
"mlm:input": [
{
"name": "13 Band Sentinel-2 Batch",
"bands": [
"B01",
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B8A",
"B09",
"B10",
"B11",
"B12"
],
"input_array": {
"shape": [
-1,
13,
64,
64
],
"dim_order": "bchw",
"data_type": "float32"
},
"norm_by_channel": true,
"norm_type": "z_score",
"statistics": {
"mean": [
1354.40546513,
1118.24399958,
1042.92983953,
947.62620298,
1199.47283961,
1999.79090914,
2369.22292565,
2296.82608323,
732.08340178,
12.11327804,
1819.01027855,
1118.92391149,
2594.14080798
],
"stddev": [
245.71762908,
333.00778264,
395.09249139,
593.75055589,
566.4170017,
861.18399006,
1086.63139075,
1117.98170791,
404.91978886,
4.77584468,
1002.58768311,
761.30323499,
1231.58581042
]
},
"pre_processing_function": "https://github.com/microsoft/torchgeo/blob/545abe8326efc2848feae69d0212a15faba3eb00/torchgeo/datamodules/eurosat.py"
}
],
"mlm:output": [
{
"task": "classification",
"result_array": [
{
"shape": [
-1,
10
],
"dim_names": [
"batch",
"class"
],
"data_type": "float32"
}
],
"classification_classes": [
{
"value": 0,
"name": "Annual Crop",
"nodata": false
},
{
"value": 1,
"name": "Forest",
"nodata": false
},
{
"value": 2,
"name": "Herbaceous Vegetation",
"nodata": false
},
{
"value": 3,
"name": "Highway",
"nodata": false
},
{
"value": 4,
"name": "Industrial Buildings",
"nodata": false
},
{
"value": 5,
"name": "Pasture",
"nodata": false
},
{
"value": 6,
"name": "Permanent Crop",
"nodata": false
},
{
"value": 7,
"name": "Residential Buildings",
"nodata": false
},
{
"value": 8,
"name": "River",
"nodata": false
},
{
"value": 9,
"name": "SeaLake",
"nodata": false
}
]
}
],
"mlm:runtime": [
{
"asset": {
"href": "."
},
"source_code": {
"href": "."
},
"accelerator": "cuda",
"accelerator_constrained": false,
"hardware_summary": "Unknown"
}
],
"mlm:total_parameters": 11700000,
"mlm:pretrained_source": "EuroSat Sentinel-2",
"mlm:summary": "Sourced from torchgeo python library,identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO"
}
"type": "Feature",
"stac_version": "1.0.0",
"id": "resnet-18_sentinel-2_all_moco_classification",
"properties": {
"start_datetime":
8 changes: 3 additions & 5 deletions stac_model/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import typer
from rich.console import Console

import json
from stac_model import __version__
from stac_model.examples import eurosat_resnet

Expand Down Expand Up @@ -35,10 +35,8 @@ def main(
) -> None:
"""Generate example spec."""
ml_model_meta = eurosat_resnet()
json_str = ml_model_meta.model_dump_json(indent=2, exclude_none=True, by_alias=True)
with open("example.json", "w") as file:
file.write(json_str)
print(ml_model_meta.model_dump_json(indent=2, exclude_none=True, by_alias=True))
with open("example.json", "w") as json_file:
json.dump(ml_model_meta.item.to_dict(), json_file, indent=4)
print("Example model metadata written to ./example.json.")
return ml_model_meta

Expand Down
15 changes: 6 additions & 9 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,15 @@


@pytest.fixture
def metadata_json():
def mlmodel_metadata_item():
from stac_model.examples import eurosat_resnet

model_metadata_stac_item = eurosat_resnet()
return model_metadata_stac_item

def test_model_metadata_to_dict(mlmodel_metadata_item):
assert mlmodel_metadata_item.item.to_dict()

def test_model_metadata_to_dict(metadata_json):
assert metadata_json.to_dict()


def test_model_metadata_json_operations(metadata_json):
from stac_model.schema import MLModelExtension

assert MLModelExtension(metadata_json.to_dict())
def test_validate_model_metadata(mlmodel_metadata_item):
import pystac
assert pystac.read_dict(mlmodel_metadata_item.item.to_dict())

0 comments on commit 30269d4

Please sign in to comment.