From 51e702a970a001315cab1f0c02c229ac237260e0 Mon Sep 17 00:00:00 2001 From: Daofeng Wu Date: Wed, 16 Oct 2024 14:12:31 +0900 Subject: [PATCH] feat(scraper): support streaming --- npiai/tools/web/scraper/__test__/bardeen.py | 10 +- .../web/scraper/__test__/column_inference.py | 6 +- npiai/tools/web/scraper/app.py | 130 +++++++++++------- 3 files changed, 94 insertions(+), 52 deletions(-) diff --git a/npiai/tools/web/scraper/__test__/bardeen.py b/npiai/tools/web/scraper/__test__/bardeen.py index 2bfdd879..a28d5c73 100644 --- a/npiai/tools/web/scraper/__test__/bardeen.py +++ b/npiai/tools/web/scraper/__test__/bardeen.py @@ -1,16 +1,18 @@ import asyncio +import json + from npiai.tools.web.scraper import Scraper from npiai.utils.test_utils import DebugContext async def main(): async with Scraper(headless=False, batch_size=10) as scraper: - await scraper.summarize( + stream = scraper.summarize_stream( ctx=DebugContext(), url="https://www.bardeen.ai/playbooks", ancestor_selector=".playbook_list", items_selector=".playbook_list .playbook_item-link", - output_file=".cache/bardeen.csv", + limit=42, output_columns=[ { "name": "Apps Involved", @@ -29,9 +31,11 @@ async def main(): "description": "The URL of the playbook", }, ], - limit=42, ) + async for items in stream: + print("Chunk:", json.dumps(items, indent=2)) + if __name__ == "__main__": asyncio.run(main()) diff --git a/npiai/tools/web/scraper/__test__/column_inference.py b/npiai/tools/web/scraper/__test__/column_inference.py index 161a566f..11e9626d 100644 --- a/npiai/tools/web/scraper/__test__/column_inference.py +++ b/npiai/tools/web/scraper/__test__/column_inference.py @@ -20,16 +20,18 @@ async def main(): print("Inferred columns:", json.dumps(columns, indent=2)) - await scraper.summarize( + stream = scraper.summarize_stream( ctx=DebugContext(), url=url, ancestor_selector=ancestor_selector, items_selector=items_selector, output_columns=columns, - output_file=".cache/bardeen.csv", limit=10, ) + async for items in stream: + print("Chunk:", json.dumps(items, indent=2)) + if __name__ == "__main__": asyncio.run(main()) diff --git a/npiai/tools/web/scraper/app.py b/npiai/tools/web/scraper/app.py index f180254c..6518e89e 100644 --- a/npiai/tools/web/scraper/app.py +++ b/npiai/tools/web/scraper/app.py @@ -81,8 +81,7 @@ def from_context(cls, ctx: Context) -> "Scraper": ) return cls() - @function - async def summarize( + async def summarize_stream( self, ctx: Context, url: str, @@ -90,11 +89,10 @@ async def summarize( ancestor_selector: str | None = None, items_selector: str | None = None, pagination_button_selector: str | None = None, - output_file: str | None = None, limit: int = -1, - ) -> str: + ): """ - Summarize the content of a webpage into a csv table. + Summarize the content of a webpage into a csv table represented as a stream of item objects. Args: ctx: NPi context. @@ -103,71 +101,109 @@ async def summarize( ancestor_selector: The selector of the ancestor element containing the items to summarize. If None, the 'body' element is used. items_selector: The selector of the items to summarize. If None, all the children of the ancestor element are used. pagination_button_selector: The selector of the pagination button (e.g., the "Next Page" button) to load more items. Used when the items are paginated. By default, the tool will scroll to load more items. - output_file: The file path to save the output. If None, the output is saved to 'scraper_output.json'. limit: The maximum number of items to summarize. If -1, all items are summarized. + + Returns: + A stream of items. Each item is a dictionary with keys corresponding to the column names and values corresponding to the column values. """ if limit == 0: - return "No items to summarize" + return await self.playwright.page.goto(url) if not ancestor_selector: ancestor_selector = "body" - if not output_file: - output_file = "scraper_output.csv" + count = 0 - os.makedirs(os.path.dirname(output_file), exist_ok=True) + while True: + remaining = min(self._batch_size, limit - count) if limit != -1 else -1 - with open(output_file, "w", newline="") as f: - column_names = [column["name"] for column in output_columns] - writer = csv.DictWriter(f, fieldnames=column_names) - writer.writeheader() - f.flush() + md = await self._get_md( + ctx=ctx, + ancestor_selector=ancestor_selector, + items_selector=items_selector, + limit=remaining, + ) - count = 0 + if not md: + break - while True: - remaining = min(self._batch_size, limit - count) if limit != -1 else -1 + items = await self._llm_summarize(ctx, md, output_columns) - md = await self._get_md( - ctx=ctx, - ancestor_selector=ancestor_selector, - items_selector=items_selector, - limit=remaining, - ) + await ctx.send_debug_message(f"[{self.name}] Summarized {len(items)} items") - if not md: - break + if not items: + break - items = await self._llm_summarize(ctx, md, output_columns) + items_slice = items[:remaining] if limit != -1 else items + count += len(items_slice) - await ctx.send_debug_message( - f"[{self.name}] Summarized {len(items)} items" - ) + yield items_slice - if not items: - break + await ctx.send_debug_message( + f"[{self.name}] Summarized {count} items in total" + ) - items_slice = items[:remaining] if limit != -1 else items - writer.writerows(items_slice) - f.flush() + if limit != -1 and count >= limit: + break - count += len(items_slice) + await self._load_more( + ctx, + ancestor_selector, + items_selector, + pagination_button_selector, + ) - await ctx.send_debug_message( - f"[{self.name}] Summarized {count} items in total" - ) + @function + async def summarize( + self, + ctx: Context, + url: str, + output_columns: List[Column], + ancestor_selector: str | None = None, + items_selector: str | None = None, + pagination_button_selector: str | None = None, + output_file: str | None = None, + limit: int = -1, + ) -> str: + """ + Summarize the content of a webpage into a csv table. + + Args: + ctx: NPi context. + url: The URL to open. + output_columns: The columns of the output table. If not provided, use the `infer_columns` function to infer the columns. + ancestor_selector: The selector of the ancestor element containing the items to summarize. If None, the 'body' element is used. + items_selector: The selector of the items to summarize. If None, all the children of the ancestor element are used. + pagination_button_selector: The selector of the pagination button (e.g., the "Next Page" button) to load more items. Used when the items are paginated. By default, the tool will scroll to load more items. + output_file: The file path to save the output. If None, the output is saved to 'scraper_output.json'. + limit: The maximum number of items to summarize. If -1, all items are summarized. + """ + os.makedirs(os.path.dirname(output_file), exist_ok=True) - if limit != -1 and count >= limit: - break + with open(output_file, "w", newline="") as f: + column_names = [column["name"] for column in output_columns] + writer = csv.DictWriter(f, fieldnames=column_names) + writer.writeheader() + f.flush() - await self._load_more( - ctx, - ancestor_selector, - items_selector, - pagination_button_selector, - ) + count = 0 + + stream = self.summarize_stream( + ctx=ctx, + url=url, + output_columns=output_columns, + ancestor_selector=ancestor_selector, + items_selector=items_selector, + pagination_button_selector=pagination_button_selector, + limit=limit, + ) + + async for items in stream: + writer.writerows(items) + count += len(items) + f.flush() return f"Saved {count} items to {output_file}"