Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🍌 Add more schema information to fetch models #8

Merged
merged 5 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 92 additions & 66 deletions bananalyzer/data/fetch_schemas.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,94 @@
from typing import Dict, Type

from pydantic import BaseModel, Field

"""
Mapping of fetch_id to fetch schema to avoid duplicate schemas in examples.json
This file contains mapping of fetch_id to fetch schema to avoid duplicate schemas in examples.json
"""
fetch_goals = {
"job_posting": "Return the provided information about the job posting. For salaries, provide the range as ${lower} - ${upper} if available, otherwise just provide ${salary}",
}

fetch_schemas = {
"contact": {
"name": "string",
"website": "string",
"phone": "string",
"fax": "string",
"address": "string",
"type": "string", # What kind of location / person is it? May not be available
},
"job_posting": {
"job_id": "string",
"job_title": "string",
"job_category": "string",
"date_posted": "string",
"location": "string",
"job_description": "string",
"roles_and_responsibilities": "string",
"qualifications": "string",
"preferred_qualifications": "string",
"benefits": "string",
"salary": "string",
},
"manufacturing_commerce": {
"mpn": "string",
"alias_mpns": ["string"],
"manufacturer": "string",
"classifications": ["string"],
"description": "string",
"hero_image": "string",
"series": "string",
"lifecycle_status": "string",
"country_of_origin": "string",
"aecq_status": "string",
"reach_status": "string",
"rohs_status": "string",
"export_control_class_number": "string",
"packaging": "string",
"power_rating": "string",
"voltage_rating": "string",
"mount_type": "string",
"moisture_sensitivity_level": "string",
"tolerance": "string",
"inductance": "string",
"capacitance": "string",
"resistance": "string",
"min_operating_temperature": "string",
"max_operating_temperature": "string",
"leadfree": "string",
"termination_type": "string",
"num_terminations": "int",
"specs": [{"label": "string", "value": "string"}],
"product_change_notification_documents": [
{"url": "string", "filename": "string"}
],
"reach_compliance_documents": [{"url": "string", "filename": "string"}],
"rohs_compliance_documents": [{"url": "string", "filename": "string"}],
"datasheets": [{"url": "string", "filename": "string"}],
"specsheets": [{"url": "string", "filename": "string"}],
"suggested_alternative_mpns": ["string"],
},
}


class ContactSchema(BaseModel):
name: str
website: str = Field(
description="An external link to the website if the website provides a link"
)
phone: str
fax: str = Field(description="Fax number of the location")
address: str
type: str = Field(
description="The type of clinic the location: Hospital, Clinic, etc."
)


class JobPostingSchema(BaseModel):
job_id: str
job_title: str
job_category: str
date_posted: str
location: str
job_description: str
roles_and_responsibilities: str
qualifications: str
preferred_qualifications: str
benefits: str
salary: str


class Specification(BaseModel):
label: str
value: str


class Document(BaseModel):
url: str
filename: str


class ManufacturingCommerceSchema(BaseModel):
mpn: str
alias_mpns: list[str] = Field(description="Other MPNs that this part is known by")
manufacturer: str
classifications: list[str]
description: str
hero_image: str
series: str
lifecycle_status: str
country_of_origin: str
aecq_status: str
reach_status: str
rohs_status: str
export_control_class_number: str
packaging: str
power_rating: str
voltage_rating: str
mount_type: str
moisture_sensitivity_level: str
tolerance: str
inductance: str
capacitance: str
resistance: str
min_operating_temperature: str
max_operating_temperature: str
leadfree: str
termination_type: str
num_terminations: int
specs: list[Specification]
product_change_notification_documents: list[Document]
reach_compliance_documents: list[Document]
rohs_compliance_documents: list[Document]
datasheets: list[Document]
specsheets: list[Document]
suggested_alternative_mpns: list[str]


def get_fetch_schema(fetch_id: str) -> Type[BaseModel]:
fetch_schemas: Dict[str, Type[BaseModel]] = {
"contact": ContactSchema,
"job_posting": JobPostingSchema,
"manufacturing_commerce": ManufacturingCommerceSchema,
}

if fetch_id not in fetch_schemas:
raise ValueError(f"Invalid fetch_id: {fetch_id}")

return fetch_schemas[fetch_id]
4 changes: 2 additions & 2 deletions bananalyzer/data/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from playwright.async_api import Page
from pydantic import BaseModel, Field, model_validator

from bananalyzer.data.fetch_schemas import fetch_schemas
from bananalyzer.data.fetch_schemas import get_fetch_schema
from bananalyzer.runner.evals import (
validate_end_url_match,
validate_field_match,
Expand Down Expand Up @@ -96,5 +96,5 @@ def set_goal_if_fetch_id_provided(cls, values: Dict[str, Any]) -> Dict[str, Any]
if goal is not None:
raise ValueError("goal must not be provided if fetch_id is provided")

values["goal"] = fetch_schemas[fetch_id]
values["goal"] = get_fetch_schema(fetch_id).model_json_schema()
return values
7 changes: 4 additions & 3 deletions tests/test_example_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import ValidationError
from pytest_mock import MockFixture

from bananalyzer.data.fetch_schemas import fetch_schemas
from bananalyzer.data.fetch_schemas import get_fetch_schema
from bananalyzer.data.schemas import Eval, Example
from bananalyzer.runner.evals import format_new_lines

Expand Down Expand Up @@ -147,6 +147,7 @@ def test_fetch_with_fetch_id_and_goal_should_raise_validation_error() -> None:


def test_fetch_with_fetch_id_and_no_goal_sets_default_goal() -> None:
example_data = create_default_example({"fetch_id": "job_posting", "goal": None})
example_data = create_default_example({"fetch_id": "contact", "goal": None})
example = Example(**example_data)
assert example.goal == fetch_schemas["job_posting"]
print(get_fetch_schema("contact").model_json_schema())
assert example.goal == get_fetch_schema("contact").model_json_schema()
Loading