Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add rerank model support #5

Merged
merged 1 commit into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ license = "MIT"
description = "A unofficial Rust library for the Ernie API"
homepage = "https://github.com/chenwanqq/erniebot-rs"
repository = "https://github.com/chenwanqq/erniebot-rs"
version = "0.1.1"
version = "0.2.0"
edition = "2021"
exclude = [".github/",".vscode/",".gitignore"]

Expand All @@ -17,12 +17,12 @@ strum_macros = "0.26.1"
serde = {version = "1.0.197", features = ["derive"]}
serde_json = "1.0.113"
url = "2.5.0"
reqwest = {version = "0.11.6", features = ["json","blocking"]}
reqwest = {version = "0.12.3", features = ["json","blocking"]}
thiserror = "1.0.57"
json_value_merge = "2.0"
reqwest-eventsource = "0.5.0"
reqwest-eventsource = "0.6.0"
tokio = { version = "1.36.0", features = ["full"] }
tokio-stream = "0.1.14"
base64 = "0.21.7"
image = "0.24.9"
base64 = "0.22.0"
image = "0.25.1"
schemars = "0.8"
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

Unofficial Baidu Ernie(Wenxin Yiyan, Qianfan) Rust SDK, currently supporting three modules: chat, text embedding (embedding), and text-to-image generation (text2image).

**update in 2024/04/09**: Add support for the bce-reranker-base-v1 rerank model

## Installation

Add the following to your Cargo.toml file:
Expand Down
2 changes: 2 additions & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

非官方的百度千帆大模型(文心一言,或者是Ernie,随便啦)SDK, 目前支持对话(chat),文本嵌入(embedding)以及文生图(text2image)三个模块。

**2024/04/09更新**: 添加对bce-reranker-base-v1重排序模型的支持

## 安装

在`Cargo.toml`文件中添加以下内容:
Expand Down
2 changes: 2 additions & 0 deletions examples/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,7 @@ fn test_async_embedding() {

fn main() {
test_embedding();
//sleep to avoid qps
std::thread::sleep(std::time::Duration::from_secs(1));
test_async_embedding();
}
36 changes: 36 additions & 0 deletions examples/rerank.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use erniebot_rs::reranker::{RerankerEndpoint, RerankerModel};
use tokio::runtime::Runtime;

fn test_reranker() {
let reranker = RerankerEndpoint::new(RerankerModel::BceRerankerBaseV1).unwrap();
let query = "你好".to_string();
let documents = vec![
"你好".to_string(),
"你叫什么名字".to_string(),
"你是谁".to_string(),
];
let reranker_response = reranker.invoke(query, documents, None, None).unwrap();
let reranker_results = reranker_response.get_reranker_response().unwrap();
println!("{},{:?}", reranker_results.len(), reranker_results);
}

fn test_async_reranker() {
let reranker = RerankerEndpoint::new(RerankerModel::BceRerankerBaseV1).unwrap();
let query = "你好".to_string();
let documents = vec![
"你好".to_string(),
"你叫什么名字".to_string(),
"你是谁".to_string(),
];
let rt = Runtime::new().unwrap();
let reranker_response = rt
.block_on(reranker.ainvoke(query, documents, None, None))
.unwrap();
let reranker_results = reranker_response.get_reranker_response().unwrap();
println!("{},{:?}", reranker_results.len(), reranker_results);
}

fn main() {
test_reranker();
test_async_reranker();
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod chat;
/// Toolset to interact with embedding model in Qianfan platform
pub mod embedding;
pub mod errors;
pub mod reranker;
/// Toolset to interact with text2image model in Qianfan platform
pub mod text2image;
pub mod utils;
99 changes: 99 additions & 0 deletions src/reranker/endpoint.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use super::model::RerankerModel;
use super::response::RerankerResponse;
use crate::errors::ErnieError;
use crate::utils::{build_url, get_access_token};
use json_value_merge::Merge;
use serde_json::Value;
use url::Url;

static RERANKER_BASE_URL: &str =
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/reranker/";

/** ChatEndpoint is a struct that represents the chat endpoint of erniebot API
*/
pub struct RerankerEndpoint {
url: Url,
access_token: String,
}

impl RerankerEndpoint {
// create a new embedding instance using pre-defined model
pub fn new(model: RerankerModel) -> Result<Self, ErnieError> {
Ok(RerankerEndpoint {
url: build_url(RERANKER_BASE_URL, model.to_string().as_str())?,
access_token: get_access_token()?,
})
}
/// sync invoke
pub fn invoke(
&self,
query: String,
documents: Vec<String>,
top_n: Option<u64>,
user_id: Option<String>,
) -> Result<RerankerResponse, ErnieError> {
let mut body = serde_json::json!({
"query": query,
"documents": documents,
});
if let Some(top_n) = top_n {
body.merge(&serde_json::json!({"top_n": top_n}));
}
if let Some(user_id) = user_id {
body.merge(&serde_json::json!({"user_id": user_id}));
}
let client = reqwest::blocking::Client::new();
let response: Value = client
.post(self.url.as_str())
.query(&[("access_token", self.access_token.as_str())])
.json(&body)
.send()
.map_err(|e| ErnieError::InvokeError(e.to_string()))?
.json()
.map_err(|e| ErnieError::InvokeError(e.to_string()))?;

//if error_code key in response, means RemoteAPIError
if response.get("error_code").is_some() {
return Err(ErnieError::RemoteAPIError(response.to_string()));
}

Ok(RerankerResponse::new(response))
}
///async invoke
pub async fn ainvoke(
&self,
query: String,
documents: Vec<String>,
top_n: Option<u64>,
user_id: Option<String>,
) -> Result<RerankerResponse, ErnieError> {
let mut body = serde_json::json!({
"query": query,
"documents": documents,
});
if let Some(top_n) = top_n {
body.merge(&serde_json::json!({"top_n": top_n}));
}
if let Some(user_id) = user_id {
body.merge(&serde_json::json!({"user_id": user_id}));
}
let client = reqwest::Client::new();
let response: Value = client
.post(self.url.as_str())
.query(&[("access_token", self.access_token.as_str())])
.json(&body)
.send()
.await
.map_err(|e| ErnieError::InvokeError(e.to_string()))?
.json()
.await
.map_err(|e| ErnieError::InvokeError(e.to_string()))?;

//if error_code key in response, means RemoteAPIError
if response.get("error_code").is_some() {
return Err(ErnieError::RemoteAPIError(response.to_string()));
}

Ok(RerankerResponse::new(response))
}
}
7 changes: 7 additions & 0 deletions src/reranker/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mod endpoint;
mod model;
mod response;

pub use endpoint::RerankerEndpoint;
pub use model::RerankerModel;
pub use response::RerankerResponse;
11 changes: 11 additions & 0 deletions src/reranker/model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use serde::{Deserialize, Serialize};
use strum_macros::{Display, EnumString};

#[derive(Debug, Default, Clone, Serialize, Deserialize, EnumString, Display, PartialEq, Eq)]
#[non_exhaustive]
pub enum RerankerModel {
#[default]
#[strum(serialize = "bce_reranker_base")]
#[serde(rename = "bce_reranker_base")]
BceRerankerBaseV1,
}
71 changes: 71 additions & 0 deletions src/reranker/response.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use crate::errors::ErnieError;
use serde::{Deserialize, Serialize};
use serde_json::value;

/// Response is using for non-stream response
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RerankerResponse {
raw_response: value::Value,
}

impl RerankerResponse {
pub fn new(raw_response: value::Value) -> Self {
RerankerResponse { raw_response }
}

pub fn get_raw_response(&self) -> &value::Value {
&self.raw_response
}

pub fn get(&self, key: &str) -> Option<&value::Value> {
self.raw_response.get(key)
}

pub fn get_mut(&mut self, key: &str) -> Option<&mut value::Value> {
self.raw_response.get_mut(key)
}

/// get the result of reranker response
pub fn get_reranker_response(&self) -> Result<Vec<RerankData>, ErnieError> {
match self.raw_response.get("results") {
Some(data) => {
let data_array = data
.as_array()
.ok_or(ErnieError::GetResponseError(
"reranker results is not an array".to_string(),
))?
.clone();
let results = data_array
.into_iter()
.map(|x| {
serde_json::from_value(x)
.map_err(|e| ErnieError::GetResponseError(e.to_string()))
})
.collect::<Result<Vec<RerankData>, ErnieError>>()?;
Ok(results)
}
None => Err(ErnieError::GetResponseError(
"reranker results is not found".to_string(),
)),
}
}
/// get tokens used by prompt
pub fn get_prompt_tokens(&self) -> Option<u64> {
let usage = self.get("usage")?.as_object()?;
let prompt_tokens = usage.get("prompt_tokens")?.as_u64()?;
Some(prompt_tokens)
}
/// get tokens used by completion
pub fn get_total_tokens(&self) -> Option<u64> {
let usage = self.get("usage")?.as_object()?;
let total_tokens = usage.get("total_tokens")?.as_u64()?;
Some(total_tokens)
}
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RerankData {
document: String,
relevance_score: f64,
index: u64,
}
Loading