Skip to content

Commit

Permalink
refactor(scrapers): make init_data() an abstract method
Browse files Browse the repository at this point in the history
  • Loading branch information
idiotWu committed Feb 3, 2025
1 parent 65643b6 commit a9a956c
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 5 deletions.
6 changes: 3 additions & 3 deletions npiai/tools/scrapers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ class BaseScraper(FunctionTool, ABC):
infer_prompt: str = DEFAULT_COLUMN_INFERENCE_PROMPT

@abstractmethod
async def next_items(self, ctx: Context, count: int) -> List[SourceItem] | None: ...
async def init_data(self, ctx: Context): ...

async def init_data(self, ctx: Context):
pass
@abstractmethod
async def next_items(self, ctx: Context, count: int) -> List[SourceItem] | None: ...

async def summarize_stream(
self,
Expand Down
4 changes: 4 additions & 0 deletions npiai/tools/scrapers/instagram/comments_scraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ def __init__(
media_pk = client.media_pk_from_url(url)
self._media_id = client.media_id(media_pk)

async def init_data(self, ctx: Context):
self._pagination_code = None
self._remaining_comments = []

async def next_items(self, ctx: Context, count: int) -> List[SourceItem] | None:
all_comments = self._fetch_more_comments(count)

Expand Down
5 changes: 4 additions & 1 deletion npiai/tools/scrapers/instagram/media_scraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def __init__(
# see: https://github.com/subzeroid/instagrapi/issues/1802#issuecomment-1944589499
self._user_id = client.user_info_by_username_v1(username).pk

async def init_data(self, ctx: Context):
self._pagination_code = None

async def next_items(self, ctx: Context, count: int) -> List[SourceItem] | None:
async with self._load_media_lock:
all_media, pagination_code = self._client.user_medias_paginated(
Expand Down Expand Up @@ -85,7 +88,7 @@ def _parse_media(self, media: Media) -> SourceItem:
data=res,
)

def _get_media_type(self, media: Media) -> str:
def _get_media_type(self, media: Media) -> str | None:
if media.media_type == 1:
return "photo"

Expand Down
7 changes: 6 additions & 1 deletion npiai/tools/scrapers/youtube/comments_scraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@ class YouTubeCommentsScraper(BaseScraper):
description = "Scrape comments from a YouTube video"
system_prompt = "You are a YouTube comments scraper tasked with extracting comments from the given video."

_url: str
_downloader: YoutubeCommentDownloader
_comments_generator: Generator[dict, None, None]

def __init__(self, url: str):
super().__init__()
self._url = url
self._downloader = YoutubeCommentDownloader()

async def init_data(self, ctx: Context):
self._comments_generator = self._downloader.get_comments_from_url(
url, language="en"
youtube_url=self._url,
language="en",
)

async def next_items(self, ctx: Context, count: int) -> List[SourceItem] | None:
Expand Down

0 comments on commit a9a956c

Please sign in to comment.