Skip to content

Commit

Permalink
feat: 🎸 add /is-valid endpoint (#134)
Browse files Browse the repository at this point in the history
fixes #129
  • Loading branch information
severo authored Feb 3, 2022
1 parent 48ac19e commit ec64b3b
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 2 deletions.
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,28 @@ Responses:
}
```

### /is-valid

> Tells if a dataset is valid. A dataset is considered valid if `/splits` and `/rows` for all the splits return a valid response. Note that stalled cache entries are considered valid.
Example: https://datasets-preview.huggingface.tech/is-valid?dataset=glue

Method: `GET`

Parameters:

- `dataset` (required): the dataset ID

Responses:

- `200`: JSON content which tells if the dataset is valid or not

```json
{
"valid": true
}
```

### /queue

> Give statistics about the content of the queue
Expand Down
6 changes: 5 additions & 1 deletion src/datasets_preview_backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
from datasets_preview_backend.routes.refresh_split import refresh_split_endpoint
from datasets_preview_backend.routes.rows import rows_endpoint
from datasets_preview_backend.routes.splits import splits_endpoint
from datasets_preview_backend.routes.valid import valid_datasets_endpoint
from datasets_preview_backend.routes.valid import (
is_valid_endpoint,
valid_datasets_endpoint,
)
from datasets_preview_backend.routes.webhook import webhook_endpoint


Expand All @@ -41,6 +44,7 @@ def create_app() -> Starlette:
Route("/cache-reports", endpoint=cache_reports_endpoint),
Route("/healthcheck", endpoint=healthcheck_endpoint),
Route("/hf_datasets", endpoint=hf_datasets_endpoint),
Route("/is-valid", endpoint=is_valid_endpoint),
Route("/queue", endpoint=queue_stats_endpoint),
Route("/queue-dump", endpoint=queue_dump_endpoint),
Route("/refresh-split", endpoint=refresh_split_endpoint, methods=["POST"]),
Expand Down
9 changes: 9 additions & 0 deletions src/datasets_preview_backend/io/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,15 @@ def is_dataset_valid_or_stalled(dataset: DbDataset) -> bool:
return all(split.status in [Status.VALID, Status.STALLED] for split in splits)


def is_dataset_name_valid_or_stalled(dataset_name: str) -> bool:
try:
dataset = DbDataset.objects(dataset_name=dataset_name).get()
return is_dataset_valid_or_stalled(dataset)
except DoesNotExist:
return False
# ^ can also raise MultipleObjectsReturned, which should not occur -> we let the exception raise


def get_valid_or_stalled_dataset_names() -> List[str]:
return [d.dataset_name for d in DbDataset.objects() if is_dataset_valid_or_stalled(d)]

Expand Down
20 changes: 19 additions & 1 deletion src/datasets_preview_backend/routes/valid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from starlette.responses import Response

from datasets_preview_backend.config import MAX_AGE_LONG_SECONDS
from datasets_preview_backend.io.cache import get_valid_or_stalled_dataset_names
from datasets_preview_backend.exceptions import Status400Error, StatusError
from datasets_preview_backend.io.cache import (
get_valid_or_stalled_dataset_names,
is_dataset_name_valid_or_stalled,
)
from datasets_preview_backend.routes._utils import get_response

logger = logging.getLogger(__name__)
Expand All @@ -18,3 +22,17 @@ async def valid_datasets_endpoint(_: Request) -> Response:
"created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
}
return get_response(content, 200, MAX_AGE_LONG_SECONDS)


async def is_valid_endpoint(request: Request) -> Response:
dataset_name = request.query_params.get("dataset")
logger.info(f"/is-valid, dataset={dataset_name}")
try:
if not isinstance(dataset_name, str):
raise Status400Error("Parameter 'dataset' is required")
content = {
"valid": is_dataset_name_valid_or_stalled(dataset_name),
}
return get_response(content, 200, MAX_AGE_LONG_SECONDS)
except StatusError as err:
return get_response(err.as_content(), err.status_code, MAX_AGE_LONG_SECONDS)
21 changes: 21 additions & 0 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,27 @@ def test_get_valid_datasets(client: TestClient) -> None:
assert "valid" in json


def test_get_is_valid(client: TestClient) -> None:
response = client.get("/is-valid")
assert response.status_code == 400

dataset = "acronym_identification"
split_full_names = refresh_dataset_split_full_names(dataset)
for split_full_name in split_full_names:
refresh_split(split_full_name["dataset_name"], split_full_name["config_name"], split_full_name["split_name"])
response = client.get("/is-valid", params={"dataset": "acronym_identification"})
assert response.status_code == 200
json = response.json()
assert "valid" in json
assert json["valid"] is True

response = client.get("/is-valid", params={"dataset": "doesnotexist"})
assert response.status_code == 200
json = response.json()
assert "valid" in json
assert json["valid"] is False


def test_get_healthcheck(client: TestClient) -> None:
response = client.get("/healthcheck")
assert response.status_code == 200
Expand Down

0 comments on commit ec64b3b

Please sign in to comment.