Skip to content

Commit

Permalink
FIX Remove aria2c dependency from HuggingFace Target (#530)
Browse files Browse the repository at this point in the history
  • Loading branch information
nina-msft authored Nov 11, 2024
1 parent d8c32d1 commit 7e9a658
Show file tree
Hide file tree
Showing 6 changed files with 264 additions and 162 deletions.
129 changes: 90 additions & 39 deletions doc/code/orchestrators/use_huggingface_chat_target.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
"id": "d623d73a",
"id": "066bb566",
"metadata": {
"lines_to_next_cell": 2
},
Expand All @@ -17,7 +17,7 @@
" - This notebook supports the following **instruct models** that follow a structured chat template. These are examples, and more instruct models are available on Hugging Face:\n",
" - `HuggingFaceTB/SmolLM-360M-Instruct`\n",
" - `microsoft/Phi-3-mini-4k-instruct`\n",
" \n",
"\n",
" - `...`\n",
"\n",
"2. **Excluded Models**:\n",
Expand All @@ -37,63 +37,116 @@
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0d61a68",
"metadata": {},
"outputs": [],
"execution_count": 1,
"id": "940f8d8a",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-11T22:41:35.643730Z",
"iopub.status.busy": "2024-11-11T22:41:35.643730Z",
"iopub.status.idle": "2024-11-11T22:43:23.863745Z",
"shell.execute_reply": "2024-11-11T22:43:23.862727Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running model: HuggingFaceTB/SmolLM-135M-Instruct\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Average response time for HuggingFaceTB/SmolLM-135M-Instruct: 37.12 seconds\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[22m\u001b[39mConversation ID: 5223e15e-f21c-4d15-88af-8c02d6558182\n",
"\u001b[1m\u001b[34muser: What is 4*4? Give me the solution.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[22m\u001b[33massistant: What a great question!\n",
"\n",
"The number 4*4 is a special number because it can be expressed as a product of two numbers,\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[22m\u001b[39mConversation ID: b0238d3e-ce2e-48c3-a5e1-eaebf2c58e6f\n",
"\u001b[1m\u001b[34muser: What is 3*3? Give me the solution.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[22m\u001b[33massistant: What a great question!\n",
"\n",
"The number 3*3 is a fascinating number that has been a subject of fascination for mathematicians and computer scientists for\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"HuggingFaceTB/SmolLM-135M-Instruct: 37.12 seconds\n"
]
}
],
"source": [
"import time\n",
"from pyrit.prompt_target import HuggingFaceChatTarget \n",
"from pyrit.prompt_target import HuggingFaceChatTarget\n",
"from pyrit.orchestrator import PromptSendingOrchestrator\n",
"\n",
"# models to test\n",
"model_id = \"HuggingFaceTB/SmolLM-135M-Instruct\" \n",
"model_id = \"HuggingFaceTB/SmolLM-135M-Instruct\"\n",
"\n",
"# List of prompts to send\n",
"prompt_list = [\n",
" \"What is 3*3? Give me the solution.\",\n",
" \"What is 4*4? Give me the solution.\"\n",
" ]\n",
"prompt_list = [\"What is 3*3? Give me the solution.\", \"What is 4*4? Give me the solution.\"]\n",
"\n",
"# Dictionary to store average response times\n",
"model_times = {}\n",
" \n",
"\n",
"print(f\"Running model: {model_id}\")\n",
" \n",
"\n",
"try:\n",
" # Initialize HuggingFaceChatTarget with the current model\n",
" target = HuggingFaceChatTarget(\n",
" model_id=model_id, \n",
" use_cuda=False, \n",
" tensor_format=\"pt\",\n",
" max_new_tokens=30 \n",
" )\n",
" \n",
" target = HuggingFaceChatTarget(model_id=model_id, use_cuda=False, tensor_format=\"pt\", max_new_tokens=30)\n",
"\n",
" # Initialize the orchestrator\n",
" orchestrator = PromptSendingOrchestrator(\n",
" prompt_target=target,\n",
" verbose=False\n",
" )\n",
" \n",
" orchestrator = PromptSendingOrchestrator(prompt_target=target, verbose=False)\n",
"\n",
" # Record start time\n",
" start_time = time.time()\n",
" \n",
"\n",
" # Send prompts asynchronously\n",
" responses = await orchestrator.send_prompts_async(prompt_list=prompt_list) # type: ignore\n",
" \n",
" responses = await orchestrator.send_prompts_async(prompt_list=prompt_list) # type: ignore\n",
"\n",
" # Record end time\n",
" end_time = time.time()\n",
" \n",
"\n",
" # Calculate total and average response time\n",
" total_time = end_time - start_time\n",
" avg_time = total_time / len(prompt_list)\n",
" model_times[model_id] = avg_time\n",
" \n",
"\n",
" print(f\"Average response time for {model_id}: {avg_time:.2f} seconds\\n\")\n",
" \n",
"\n",
" # Print the conversations\n",
" await orchestrator.print_conversations() # type: ignore\n",
" \n",
" await orchestrator.print_conversations() # type: ignore\n",
"\n",
"except Exception as e:\n",
" print(f\"An error occurred with model {model_id}: {e}\\n\")\n",
" model_times[model_id] = None\n",
Expand All @@ -108,14 +161,12 @@
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
"main_language": "python",
"notebook_metadata_filter": "-all"
"cell_metadata_filter": "-all"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "pyrit-dev",
"language": "python",
"name": "python3"
"name": "pyrit-dev"
},
"language_info": {
"codemirror_mode": {
Expand Down
112 changes: 112 additions & 0 deletions pyrit/common/download_hf_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import asyncio
import logging
import os
import httpx
from pathlib import Path

from huggingface_hub import HfApi


logger = logging.getLogger(__name__)


def get_available_files(model_id: str, token: str):
"""Fetches available files for a model from the Hugging Face repository."""
api = HfApi()
try:
model_info = api.model_info(model_id, token=token)
available_files = [file.rfilename for file in model_info.siblings]

# Perform simple validation: raise a ValueError if no files are available
if not len(available_files):
raise ValueError(f"No available files found for the model: {model_id}")

return available_files
except Exception as e:
logger.info(f"Error fetching model files for {model_id}: {e}")
return []


async def download_specific_files(model_id: str, file_patterns: list, token: str, cache_dir: Path):
"""
Downloads specific files from a Hugging Face model repository.
If file_patterns is None, downloads all files.
Returns:
List of URLs for the downloaded files.
"""
os.makedirs(cache_dir, exist_ok=True)

available_files = get_available_files(model_id, token)
# If no file patterns are provided, download all available files
if file_patterns is None:
files_to_download = available_files
logger.info(f"Downloading all files for model {model_id}.")
else:
# Filter files based on the patterns provided
files_to_download = [file for file in available_files if any(pattern in file for pattern in file_patterns)]
if not files_to_download:
logger.info(f"No files matched the patterns provided for model {model_id}.")
return

# Generate download URLs directly
base_url = f"https://huggingface.co/{model_id}/resolve/main/"
urls = [base_url + file for file in files_to_download]

# Download the files
await download_files(urls, token, cache_dir)


async def download_chunk(url, headers, start, end, client):
"""Download a chunk of the file with a specified byte range."""
range_header = {"Range": f"bytes={start}-{end}", **headers}
response = await client.get(url, headers=range_header)
response.raise_for_status()
return response.content


async def download_file(url, token, download_dir, num_splits):
"""Download a file in multiple segments (splits) using byte-range requests."""
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient(follow_redirects=True) as client:
# Get the file size to determine chunk size
response = await client.head(url, headers=headers)
response.raise_for_status()
file_size = int(response.headers["Content-Length"])
chunk_size = file_size // num_splits

# Prepare tasks for each chunk
tasks = []
file_name = url.split("/")[-1]
file_path = Path(download_dir, file_name)

for i in range(num_splits):
start = i * chunk_size
end = start + chunk_size - 1 if i < num_splits - 1 else file_size - 1
tasks.append(download_chunk(url, headers, start, end, client))

# Download all chunks concurrently
chunks = await asyncio.gather(*tasks)

# Write chunks to the file in order
with open(file_path, "wb") as f:
for chunk in chunks:
f.write(chunk)
logger.info(f"Downloaded {file_name} to {file_path}")


async def download_files(urls: list[str], token: str, download_dir: Path, num_splits=3, parallel_downloads=4):
"""Download multiple files with parallel downloads and segmented downloading."""

# Limit the number of parallel downloads
semaphore = asyncio.Semaphore(parallel_downloads)

async def download_with_limit(url):
async with semaphore:
await download_file(url, token, download_dir, num_splits)

# Run downloads concurrently, but limit to parallel_downloads at a time
await asyncio.gather(*(download_with_limit(url) for url in urls))
94 changes: 0 additions & 94 deletions pyrit/common/download_hf_model_with_aria2.py

This file was deleted.

Loading

0 comments on commit 7e9a658

Please sign in to comment.