Skip to content

Commit

Permalink
Merge pull request #5 from chenwanqq/rerank
Browse files Browse the repository at this point in the history
add rerank model support
  • Loading branch information
chenwanqq authored Apr 9, 2024
2 parents 5d5608f + 82984f0 commit e191934
Show file tree
Hide file tree
Showing 10 changed files with 236 additions and 5 deletions.
10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ license = "MIT"
description = "A unofficial Rust library for the Ernie API"
homepage = "https://github.com/chenwanqq/erniebot-rs"
repository = "https://github.com/chenwanqq/erniebot-rs"
version = "0.1.1"
version = "0.2.0"
edition = "2021"
exclude = [".github/",".vscode/",".gitignore"]

Expand All @@ -17,12 +17,12 @@ strum_macros = "0.26.1"
serde = {version = "1.0.197", features = ["derive"]}
serde_json = "1.0.113"
url = "2.5.0"
reqwest = {version = "0.11.6", features = ["json","blocking"]}
reqwest = {version = "0.12.3", features = ["json","blocking"]}
thiserror = "1.0.57"
json_value_merge = "2.0"
reqwest-eventsource = "0.5.0"
reqwest-eventsource = "0.6.0"
tokio = { version = "1.36.0", features = ["full"] }
tokio-stream = "0.1.14"
base64 = "0.21.7"
image = "0.24.9"
base64 = "0.22.0"
image = "0.25.1"
schemars = "0.8"
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

Unofficial Baidu Ernie(Wenxin Yiyan, Qianfan) Rust SDK, currently supporting three modules: chat, text embedding (embedding), and text-to-image generation (text2image).

**update in 2024/04/09**: Add support for the bce-reranker-base-v1 rerank model

## Installation

Add the following to your Cargo.toml file:
Expand Down
2 changes: 2 additions & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

非官方的百度千帆大模型(文心一言,或者是Ernie,随便啦)SDK, 目前支持对话(chat),文本嵌入(embedding)以及文生图(text2image)三个模块。

**2024/04/09更新**: 添加对bce-reranker-base-v1重排序模型的支持

## 安装

`Cargo.toml`文件中添加以下内容:
Expand Down
2 changes: 2 additions & 0 deletions examples/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,7 @@ fn test_async_embedding() {

fn main() {
test_embedding();
//sleep to avoid qps
std::thread::sleep(std::time::Duration::from_secs(1));
test_async_embedding();
}
36 changes: 36 additions & 0 deletions examples/rerank.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use erniebot_rs::reranker::{RerankerEndpoint, RerankerModel};
use tokio::runtime::Runtime;

fn test_reranker() {
let reranker = RerankerEndpoint::new(RerankerModel::BceRerankerBaseV1).unwrap();
let query = "你好".to_string();
let documents = vec![
"你好".to_string(),
"你叫什么名字".to_string(),
"你是谁".to_string(),
];
let reranker_response = reranker.invoke(query, documents, None, None).unwrap();
let reranker_results = reranker_response.get_reranker_response().unwrap();
println!("{},{:?}", reranker_results.len(), reranker_results);
}

fn test_async_reranker() {
let reranker = RerankerEndpoint::new(RerankerModel::BceRerankerBaseV1).unwrap();
let query = "你好".to_string();
let documents = vec![
"你好".to_string(),
"你叫什么名字".to_string(),
"你是谁".to_string(),
];
let rt = Runtime::new().unwrap();
let reranker_response = rt
.block_on(reranker.ainvoke(query, documents, None, None))
.unwrap();
let reranker_results = reranker_response.get_reranker_response().unwrap();
println!("{},{:?}", reranker_results.len(), reranker_results);
}

fn main() {
test_reranker();
test_async_reranker();
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod chat;
/// Toolset to interact with embedding model in Qianfan platform
pub mod embedding;
pub mod errors;
pub mod reranker;
/// Toolset to interact with text2image model in Qianfan platform
pub mod text2image;
pub mod utils;
99 changes: 99 additions & 0 deletions src/reranker/endpoint.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use super::model::RerankerModel;
use super::response::RerankerResponse;
use crate::errors::ErnieError;
use crate::utils::{build_url, get_access_token};
use json_value_merge::Merge;
use serde_json::Value;
use url::Url;

static RERANKER_BASE_URL: &str =
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/reranker/";

/** ChatEndpoint is a struct that represents the chat endpoint of erniebot API
*/
pub struct RerankerEndpoint {
url: Url,
access_token: String,
}

impl RerankerEndpoint {
// create a new embedding instance using pre-defined model
pub fn new(model: RerankerModel) -> Result<Self, ErnieError> {
Ok(RerankerEndpoint {
url: build_url(RERANKER_BASE_URL, model.to_string().as_str())?,
access_token: get_access_token()?,
})
}
/// sync invoke
pub fn invoke(
&self,
query: String,
documents: Vec<String>,
top_n: Option<u64>,
user_id: Option<String>,
) -> Result<RerankerResponse, ErnieError> {
let mut body = serde_json::json!({
"query": query,
"documents": documents,
});
if let Some(top_n) = top_n {
body.merge(&serde_json::json!({"top_n": top_n}));
}
if let Some(user_id) = user_id {
body.merge(&serde_json::json!({"user_id": user_id}));
}
let client = reqwest::blocking::Client::new();
let response: Value = client
.post(self.url.as_str())
.query(&[("access_token", self.access_token.as_str())])
.json(&body)
.send()
.map_err(|e| ErnieError::InvokeError(e.to_string()))?
.json()
.map_err(|e| ErnieError::InvokeError(e.to_string()))?;

//if error_code key in response, means RemoteAPIError
if response.get("error_code").is_some() {
return Err(ErnieError::RemoteAPIError(response.to_string()));
}

Ok(RerankerResponse::new(response))
}
///async invoke
pub async fn ainvoke(
&self,
query: String,
documents: Vec<String>,
top_n: Option<u64>,
user_id: Option<String>,
) -> Result<RerankerResponse, ErnieError> {
let mut body = serde_json::json!({
"query": query,
"documents": documents,
});
if let Some(top_n) = top_n {
body.merge(&serde_json::json!({"top_n": top_n}));
}
if let Some(user_id) = user_id {
body.merge(&serde_json::json!({"user_id": user_id}));
}
let client = reqwest::Client::new();
let response: Value = client
.post(self.url.as_str())
.query(&[("access_token", self.access_token.as_str())])
.json(&body)
.send()
.await
.map_err(|e| ErnieError::InvokeError(e.to_string()))?
.json()
.await
.map_err(|e| ErnieError::InvokeError(e.to_string()))?;

//if error_code key in response, means RemoteAPIError
if response.get("error_code").is_some() {
return Err(ErnieError::RemoteAPIError(response.to_string()));
}

Ok(RerankerResponse::new(response))
}
}
7 changes: 7 additions & 0 deletions src/reranker/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mod endpoint;
mod model;
mod response;

pub use endpoint::RerankerEndpoint;
pub use model::RerankerModel;
pub use response::RerankerResponse;
11 changes: 11 additions & 0 deletions src/reranker/model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use serde::{Deserialize, Serialize};
use strum_macros::{Display, EnumString};

#[derive(Debug, Default, Clone, Serialize, Deserialize, EnumString, Display, PartialEq, Eq)]
#[non_exhaustive]
pub enum RerankerModel {
#[default]
#[strum(serialize = "bce_reranker_base")]
#[serde(rename = "bce_reranker_base")]
BceRerankerBaseV1,
}
71 changes: 71 additions & 0 deletions src/reranker/response.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use crate::errors::ErnieError;
use serde::{Deserialize, Serialize};
use serde_json::value;

/// Response is using for non-stream response
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RerankerResponse {
raw_response: value::Value,
}

impl RerankerResponse {
pub fn new(raw_response: value::Value) -> Self {
RerankerResponse { raw_response }
}

pub fn get_raw_response(&self) -> &value::Value {
&self.raw_response
}

pub fn get(&self, key: &str) -> Option<&value::Value> {
self.raw_response.get(key)
}

pub fn get_mut(&mut self, key: &str) -> Option<&mut value::Value> {
self.raw_response.get_mut(key)
}

/// get the result of reranker response
pub fn get_reranker_response(&self) -> Result<Vec<RerankData>, ErnieError> {
match self.raw_response.get("results") {
Some(data) => {
let data_array = data
.as_array()
.ok_or(ErnieError::GetResponseError(
"reranker results is not an array".to_string(),
))?
.clone();
let results = data_array
.into_iter()
.map(|x| {
serde_json::from_value(x)
.map_err(|e| ErnieError::GetResponseError(e.to_string()))
})
.collect::<Result<Vec<RerankData>, ErnieError>>()?;
Ok(results)
}
None => Err(ErnieError::GetResponseError(
"reranker results is not found".to_string(),
)),
}
}
/// get tokens used by prompt
pub fn get_prompt_tokens(&self) -> Option<u64> {
let usage = self.get("usage")?.as_object()?;
let prompt_tokens = usage.get("prompt_tokens")?.as_u64()?;
Some(prompt_tokens)
}
/// get tokens used by completion
pub fn get_total_tokens(&self) -> Option<u64> {
let usage = self.get("usage")?.as_object()?;
let total_tokens = usage.get("total_tokens")?.as_u64()?;
Some(total_tokens)
}
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RerankData {
document: String,
relevance_score: f64,
index: u64,
}

0 comments on commit e191934

Please sign in to comment.