diff --git a/bananalyzer/data/fetch_schemas.py b/bananalyzer/data/fetch_schemas.py index 4ab8bc9b..209bddf3 100644 --- a/bananalyzer/data/fetch_schemas.py +++ b/bananalyzer/data/fetch_schemas.py @@ -1,68 +1,94 @@ +from typing import Dict, Type + +from pydantic import BaseModel, Field + """ -Mapping of fetch_id to fetch schema to avoid duplicate schemas in examples.json +This file contains mapping of fetch_id to fetch schema to avoid duplicate schemas in examples.json """ -fetch_goals = { - "job_posting": "Return the provided information about the job posting. For salaries, provide the range as ${lower} - ${upper} if available, otherwise just provide ${salary}", -} - -fetch_schemas = { - "contact": { - "name": "string", - "website": "string", - "phone": "string", - "fax": "string", - "address": "string", - "type": "string", # What kind of location / person is it? May not be available - }, - "job_posting": { - "job_id": "string", - "job_title": "string", - "job_category": "string", - "date_posted": "string", - "location": "string", - "job_description": "string", - "roles_and_responsibilities": "string", - "qualifications": "string", - "preferred_qualifications": "string", - "benefits": "string", - "salary": "string", - }, - "manufacturing_commerce": { - "mpn": "string", - "alias_mpns": ["string"], - "manufacturer": "string", - "classifications": ["string"], - "description": "string", - "hero_image": "string", - "series": "string", - "lifecycle_status": "string", - "country_of_origin": "string", - "aecq_status": "string", - "reach_status": "string", - "rohs_status": "string", - "export_control_class_number": "string", - "packaging": "string", - "power_rating": "string", - "voltage_rating": "string", - "mount_type": "string", - "moisture_sensitivity_level": "string", - "tolerance": "string", - "inductance": "string", - "capacitance": "string", - "resistance": "string", - "min_operating_temperature": "string", - "max_operating_temperature": "string", - "leadfree": "string", - "termination_type": "string", - "num_terminations": "int", - "specs": [{"label": "string", "value": "string"}], - "product_change_notification_documents": [ - {"url": "string", "filename": "string"} - ], - "reach_compliance_documents": [{"url": "string", "filename": "string"}], - "rohs_compliance_documents": [{"url": "string", "filename": "string"}], - "datasheets": [{"url": "string", "filename": "string"}], - "specsheets": [{"url": "string", "filename": "string"}], - "suggested_alternative_mpns": ["string"], - }, -} + + +class ContactSchema(BaseModel): + name: str + website: str = Field( + description="An external link to the website if the website provides a link" + ) + phone: str + fax: str = Field(description="Fax number of the location") + address: str + type: str = Field( + description="The type of clinic the location: Hospital, Clinic, etc." + ) + + +class JobPostingSchema(BaseModel): + job_id: str + job_title: str + job_category: str + date_posted: str + location: str + job_description: str + roles_and_responsibilities: str + qualifications: str + preferred_qualifications: str + benefits: str + salary: str + + +class Specification(BaseModel): + label: str + value: str + + +class Document(BaseModel): + url: str + filename: str + + +class ManufacturingCommerceSchema(BaseModel): + mpn: str + alias_mpns: list[str] = Field(description="Other MPNs that this part is known by") + manufacturer: str + classifications: list[str] + description: str + hero_image: str + series: str + lifecycle_status: str + country_of_origin: str + aecq_status: str + reach_status: str + rohs_status: str + export_control_class_number: str + packaging: str + power_rating: str + voltage_rating: str + mount_type: str + moisture_sensitivity_level: str + tolerance: str + inductance: str + capacitance: str + resistance: str + min_operating_temperature: str + max_operating_temperature: str + leadfree: str + termination_type: str + num_terminations: int + specs: list[Specification] + product_change_notification_documents: list[Document] + reach_compliance_documents: list[Document] + rohs_compliance_documents: list[Document] + datasheets: list[Document] + specsheets: list[Document] + suggested_alternative_mpns: list[str] + + +def get_fetch_schema(fetch_id: str) -> Type[BaseModel]: + fetch_schemas: Dict[str, Type[BaseModel]] = { + "contact": ContactSchema, + "job_posting": JobPostingSchema, + "manufacturing_commerce": ManufacturingCommerceSchema, + } + + if fetch_id not in fetch_schemas: + raise ValueError(f"Invalid fetch_id: {fetch_id}") + + return fetch_schemas[fetch_id] diff --git a/bananalyzer/data/schemas.py b/bananalyzer/data/schemas.py index fcd5c42c..96456f5d 100644 --- a/bananalyzer/data/schemas.py +++ b/bananalyzer/data/schemas.py @@ -3,7 +3,7 @@ from playwright.async_api import Page from pydantic import BaseModel, Field, model_validator -from bananalyzer.data.fetch_schemas import fetch_schemas +from bananalyzer.data.fetch_schemas import get_fetch_schema from bananalyzer.runner.evals import ( validate_end_url_match, validate_field_match, @@ -96,5 +96,5 @@ def set_goal_if_fetch_id_provided(cls, values: Dict[str, Any]) -> Dict[str, Any] if goal is not None: raise ValueError("goal must not be provided if fetch_id is provided") - values["goal"] = fetch_schemas[fetch_id] + values["goal"] = get_fetch_schema(fetch_id).model_json_schema() return values diff --git a/tests/test_example_eval.py b/tests/test_example_eval.py index ac8be5c2..bcbf11c1 100644 --- a/tests/test_example_eval.py +++ b/tests/test_example_eval.py @@ -5,7 +5,7 @@ from pydantic import ValidationError from pytest_mock import MockFixture -from bananalyzer.data.fetch_schemas import fetch_schemas +from bananalyzer.data.fetch_schemas import get_fetch_schema from bananalyzer.data.schemas import Eval, Example from bananalyzer.runner.evals import format_new_lines @@ -147,6 +147,7 @@ def test_fetch_with_fetch_id_and_goal_should_raise_validation_error() -> None: def test_fetch_with_fetch_id_and_no_goal_sets_default_goal() -> None: - example_data = create_default_example({"fetch_id": "job_posting", "goal": None}) + example_data = create_default_example({"fetch_id": "contact", "goal": None}) example = Example(**example_data) - assert example.goal == fetch_schemas["job_posting"] + print(get_fetch_schema("contact").model_json_schema()) + assert example.goal == get_fetch_schema("contact").model_json_schema()