From a9a956c0449fad2f9b5952a577efe86299569b84 Mon Sep 17 00:00:00 2001 From: Daofeng Wu Date: Mon, 3 Feb 2025 19:45:07 +0900 Subject: [PATCH] refactor(scrapers): make init_data() an abstract method --- npiai/tools/scrapers/base.py | 6 +++--- npiai/tools/scrapers/instagram/comments_scraper.py | 4 ++++ npiai/tools/scrapers/instagram/media_scraper.py | 5 ++++- npiai/tools/scrapers/youtube/comments_scraper.py | 7 ++++++- 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/npiai/tools/scrapers/base.py b/npiai/tools/scrapers/base.py index e689959f..d004ce34 100644 --- a/npiai/tools/scrapers/base.py +++ b/npiai/tools/scrapers/base.py @@ -45,10 +45,10 @@ class BaseScraper(FunctionTool, ABC): infer_prompt: str = DEFAULT_COLUMN_INFERENCE_PROMPT @abstractmethod - async def next_items(self, ctx: Context, count: int) -> List[SourceItem] | None: ... + async def init_data(self, ctx: Context): ... - async def init_data(self, ctx: Context): - pass + @abstractmethod + async def next_items(self, ctx: Context, count: int) -> List[SourceItem] | None: ... async def summarize_stream( self, diff --git a/npiai/tools/scrapers/instagram/comments_scraper.py b/npiai/tools/scrapers/instagram/comments_scraper.py index 5d235a2b..b50bb99d 100644 --- a/npiai/tools/scrapers/instagram/comments_scraper.py +++ b/npiai/tools/scrapers/instagram/comments_scraper.py @@ -33,6 +33,10 @@ def __init__( media_pk = client.media_pk_from_url(url) self._media_id = client.media_id(media_pk) + async def init_data(self, ctx: Context): + self._pagination_code = None + self._remaining_comments = [] + async def next_items(self, ctx: Context, count: int) -> List[SourceItem] | None: all_comments = self._fetch_more_comments(count) diff --git a/npiai/tools/scrapers/instagram/media_scraper.py b/npiai/tools/scrapers/instagram/media_scraper.py index 713990c8..dee3518c 100644 --- a/npiai/tools/scrapers/instagram/media_scraper.py +++ b/npiai/tools/scrapers/instagram/media_scraper.py @@ -38,6 +38,9 @@ def __init__( # see: https://github.com/subzeroid/instagrapi/issues/1802#issuecomment-1944589499 self._user_id = client.user_info_by_username_v1(username).pk + async def init_data(self, ctx: Context): + self._pagination_code = None + async def next_items(self, ctx: Context, count: int) -> List[SourceItem] | None: async with self._load_media_lock: all_media, pagination_code = self._client.user_medias_paginated( @@ -85,7 +88,7 @@ def _parse_media(self, media: Media) -> SourceItem: data=res, ) - def _get_media_type(self, media: Media) -> str: + def _get_media_type(self, media: Media) -> str | None: if media.media_type == 1: return "photo" diff --git a/npiai/tools/scrapers/youtube/comments_scraper.py b/npiai/tools/scrapers/youtube/comments_scraper.py index 85aaa601..967d505b 100644 --- a/npiai/tools/scrapers/youtube/comments_scraper.py +++ b/npiai/tools/scrapers/youtube/comments_scraper.py @@ -11,14 +11,19 @@ class YouTubeCommentsScraper(BaseScraper): description = "Scrape comments from a YouTube video" system_prompt = "You are a YouTube comments scraper tasked with extracting comments from the given video." + _url: str _downloader: YoutubeCommentDownloader _comments_generator: Generator[dict, None, None] def __init__(self, url: str): super().__init__() + self._url = url self._downloader = YoutubeCommentDownloader() + + async def init_data(self, ctx: Context): self._comments_generator = self._downloader.get_comments_from_url( - url, language="en" + youtube_url=self._url, + language="en", ) async def next_items(self, ctx: Context, count: int) -> List[SourceItem] | None: