Skip to content

Commit

Permalink
Merge pull request #8 from reworkd/schema_description
Browse files Browse the repository at this point in the history
🍌 Add more schema information to fetch models
  • Loading branch information
KhoomeiK authored Nov 28, 2023
2 parents cecf3cf + 914a53b commit f472675
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 71 deletions.
158 changes: 92 additions & 66 deletions bananalyzer/data/fetch_schemas.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,94 @@
from typing import Dict, Type

from pydantic import BaseModel, Field

"""
Mapping of fetch_id to fetch schema to avoid duplicate schemas in examples.json
This file contains mapping of fetch_id to fetch schema to avoid duplicate schemas in examples.json
"""
fetch_goals = {
"job_posting": "Return the provided information about the job posting. For salaries, provide the range as ${lower} - ${upper} if available, otherwise just provide ${salary}",
}

fetch_schemas = {
"contact": {
"name": "string",
"website": "string",
"phone": "string",
"fax": "string",
"address": "string",
"type": "string", # What kind of location / person is it? May not be available
},
"job_posting": {
"job_id": "string",
"job_title": "string",
"job_category": "string",
"date_posted": "string",
"location": "string",
"job_description": "string",
"roles_and_responsibilities": "string",
"qualifications": "string",
"preferred_qualifications": "string",
"benefits": "string",
"salary": "string",
},
"manufacturing_commerce": {
"mpn": "string",
"alias_mpns": ["string"],
"manufacturer": "string",
"classifications": ["string"],
"description": "string",
"hero_image": "string",
"series": "string",
"lifecycle_status": "string",
"country_of_origin": "string",
"aecq_status": "string",
"reach_status": "string",
"rohs_status": "string",
"export_control_class_number": "string",
"packaging": "string",
"power_rating": "string",
"voltage_rating": "string",
"mount_type": "string",
"moisture_sensitivity_level": "string",
"tolerance": "string",
"inductance": "string",
"capacitance": "string",
"resistance": "string",
"min_operating_temperature": "string",
"max_operating_temperature": "string",
"leadfree": "string",
"termination_type": "string",
"num_terminations": "int",
"specs": [{"label": "string", "value": "string"}],
"product_change_notification_documents": [
{"url": "string", "filename": "string"}
],
"reach_compliance_documents": [{"url": "string", "filename": "string"}],
"rohs_compliance_documents": [{"url": "string", "filename": "string"}],
"datasheets": [{"url": "string", "filename": "string"}],
"specsheets": [{"url": "string", "filename": "string"}],
"suggested_alternative_mpns": ["string"],
},
}


class ContactSchema(BaseModel):
name: str
website: str = Field(
description="An external link to the website if the website provides a link"
)
phone: str
fax: str = Field(description="Fax number of the location")
address: str
type: str = Field(
description="The type of clinic the location: Hospital, Clinic, etc."
)


class JobPostingSchema(BaseModel):
job_id: str
job_title: str
job_category: str
date_posted: str
location: str
job_description: str
roles_and_responsibilities: str
qualifications: str
preferred_qualifications: str
benefits: str
salary: str


class Specification(BaseModel):
label: str
value: str


class Document(BaseModel):
url: str
filename: str


class ManufacturingCommerceSchema(BaseModel):
mpn: str
alias_mpns: list[str] = Field(description="Other MPNs that this part is known by")
manufacturer: str
classifications: list[str]
description: str
hero_image: str
series: str
lifecycle_status: str
country_of_origin: str
aecq_status: str
reach_status: str
rohs_status: str
export_control_class_number: str
packaging: str
power_rating: str
voltage_rating: str
mount_type: str
moisture_sensitivity_level: str
tolerance: str
inductance: str
capacitance: str
resistance: str
min_operating_temperature: str
max_operating_temperature: str
leadfree: str
termination_type: str
num_terminations: int
specs: list[Specification]
product_change_notification_documents: list[Document]
reach_compliance_documents: list[Document]
rohs_compliance_documents: list[Document]
datasheets: list[Document]
specsheets: list[Document]
suggested_alternative_mpns: list[str]


def get_fetch_schema(fetch_id: str) -> Type[BaseModel]:
fetch_schemas: Dict[str, Type[BaseModel]] = {
"contact": ContactSchema,
"job_posting": JobPostingSchema,
"manufacturing_commerce": ManufacturingCommerceSchema,
}

if fetch_id not in fetch_schemas:
raise ValueError(f"Invalid fetch_id: {fetch_id}")

return fetch_schemas[fetch_id]
4 changes: 2 additions & 2 deletions bananalyzer/data/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from playwright.async_api import Page
from pydantic import BaseModel, Field, model_validator

from bananalyzer.data.fetch_schemas import fetch_schemas
from bananalyzer.data.fetch_schemas import get_fetch_schema
from bananalyzer.runner.evals import (
validate_end_url_match,
validate_field_match,
Expand Down Expand Up @@ -96,5 +96,5 @@ def set_goal_if_fetch_id_provided(cls, values: Dict[str, Any]) -> Dict[str, Any]
if goal is not None:
raise ValueError("goal must not be provided if fetch_id is provided")

values["goal"] = fetch_schemas[fetch_id]
values["goal"] = get_fetch_schema(fetch_id).model_json_schema()
return values
7 changes: 4 additions & 3 deletions tests/test_example_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import ValidationError
from pytest_mock import MockFixture

from bananalyzer.data.fetch_schemas import fetch_schemas
from bananalyzer.data.fetch_schemas import get_fetch_schema
from bananalyzer.data.schemas import Eval, Example
from bananalyzer.runner.evals import format_new_lines

Expand Down Expand Up @@ -147,6 +147,7 @@ def test_fetch_with_fetch_id_and_goal_should_raise_validation_error() -> None:


def test_fetch_with_fetch_id_and_no_goal_sets_default_goal() -> None:
example_data = create_default_example({"fetch_id": "job_posting", "goal": None})
example_data = create_default_example({"fetch_id": "contact", "goal": None})
example = Example(**example_data)
assert example.goal == fetch_schemas["job_posting"]
print(get_fetch_schema("contact").model_json_schema())
assert example.goal == get_fetch_schema("contact").model_json_schema()

0 comments on commit f472675

Please sign in to comment.