Skip to content

Commit

Permalink
fix(scrape/infer_column): force structured output
Browse files Browse the repository at this point in the history
  • Loading branch information
idiotWu committed Oct 16, 2024
1 parent dac9ea5 commit 9759093
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 23 deletions.
9 changes: 9 additions & 0 deletions npiai/tools/web/scraper/__test__/column_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ async def main():

print("Inferred columns:", columns)

await scraper.summarize(
ctx=DebugContext(),
url="https://www.bardeen.ai/playbooks",
ancestor_selector=".playbook_list",
items_selector=".playbook_list .playbook_item-link",
output_columns=columns,
limit=10,
)


if __name__ == "__main__":
asyncio.run(main())
51 changes: 28 additions & 23 deletions npiai/tools/web/scraper/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from npiai import function, BrowserTool, Context
from npiai.core import NavigatorAgent
from npiai.utils import is_cloud_env, parse_json_response
from npiai.utils import is_cloud_env, llm_tool_call


class NonBase64ImageConverter(MarkdownConverter):
Expand Down Expand Up @@ -180,32 +180,37 @@ async def infer_columns(
limit=10,
)

messages = [
ChatCompletionSystemMessageParam(
role="system",
content=dedent(
"""
Imagine you are summarizing the content of a webpage into a table. Find the common nature of the provided items and suggest the columns for the output table. Respond with the columns in a list format: ['column1', 'column2', ...]
"""
),
),
ChatCompletionUserMessageParam(
role="user",
content=md,
),
]
def callback(columns: List[str]):
"""
Callback with the inferred columns.
response = await ctx.llm.completion(
messages=messages,
max_tokens=4096,
Args:
columns: The inferred columns.
"""
return columns

res = await llm_tool_call(
llm=ctx.llm,
tool=callback,
messages=[
ChatCompletionSystemMessageParam(
role="system",
content=dedent(
"""
Imagine you are summarizing the content of a webpage into a table. Find the common nature of the provided items and suggest the columns for the output table. Respond with the columns in a list format: ['column1', 'column2', ...]
"""
),
),
ChatCompletionUserMessageParam(
role="user",
content=md,
),
],
)
content = response.choices[0].message.content

await ctx.send_debug_message(
f"[{self.name}] Columns inference response: {content}"
)
await ctx.send_debug_message(f"[{self.name}] Columns inference response: {res}")

return parse_json_response(content)
return callback(**res.model_dump())

async def _get_md(
self,
Expand Down

0 comments on commit 9759093

Please sign in to comment.