diff --git a/shinkai-bin/shinkai-node/src/llm_provider/providers/ollama.rs b/shinkai-bin/shinkai-node/src/llm_provider/providers/ollama.rs index 1f1d9fc9f..54a6c1bc9 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/providers/ollama.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/providers/ollama.rs @@ -1,7 +1,7 @@ use crate::llm_provider::execution::chains::inference_chain_trait::{FunctionCall, LLMInferenceResponse}; use crate::llm_provider::llm_stopper::LLMStopper; use crate::llm_provider::providers::shared::ollama_api::{ - ollama_prepare_messages, ollama_conversation_prepare_messages_with_tooling, OllamaAPIStreamingResponse, + ollama_conversation_prepare_messages_with_tooling, ollama_prepare_messages, OllamaAPIStreamingResponse, }; use crate::managers::model_capabilities_manager::PromptResultEnum; @@ -105,8 +105,11 @@ impl LLMService for Ollama { } } - // Conditionally add functions to the payload if tools_json is not empty - if !tools_json.is_empty() { + // Add this line to extract the use_tools flag from the config + let use_tools = config.as_ref().map_or(true, |c| c.use_tools.unwrap_or(true)); + + // Conditionally add tools to the payload if use_tools is true + if use_tools && !tools_json.is_empty() { payload["tools"] = serde_json::Value::Array(tools_json); } diff --git a/shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs b/shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs index 89862c43f..2351619ef 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs @@ -4,9 +4,7 @@ use std::sync::Arc; use super::super::error::LLMProviderError; use super::shared::openai_api::{openai_prepare_messages, MessageContent, OpenAIResponse}; use super::LLMService; -use crate::llm_provider::execution::chains::inference_chain_trait::{ - LLMInferenceResponse, FunctionCall, -}; +use crate::llm_provider::execution::chains::inference_chain_trait::{FunctionCall, LLMInferenceResponse}; use crate::llm_provider::llm_stopper::LLMStopper; use crate::managers::model_capabilities_manager::PromptResultEnum; use async_trait::async_trait; @@ -100,8 +98,11 @@ impl LLMService for OpenAI { "max_tokens": result.remaining_tokens, }); - // Conditionally add functions to the payload if tools_json is not empty - if !tools_json.is_empty() { + // Add this line to extract the use_tools flag from the config + let use_tools = config.as_ref().map_or(true, |c| c.use_tools.unwrap_or(true)); + + // Conditionally add tools to the payload if use_tools is true + if use_tools && !tools_json.is_empty() { payload["functions"] = serde_json::Value::Array(tools_json); } diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_jobs.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_jobs.rs index 3c02f0457..ba6bd62ed 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_jobs.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_jobs.rs @@ -704,6 +704,8 @@ impl Node { top_p: None, stream: None, other_model_params: None, + use_tools: None, + web_search: None, }); let _ = res.send(Ok(config)).await; Ok(()) diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_prompts.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_prompts.rs index cc254a4de..d204820b6 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_prompts.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_prompts.rs @@ -179,7 +179,7 @@ impl Node { let start_time = Instant::now(); // Perform the internal search using LanceShinkaiDb - match lance_db.read().await.prompt_vector_search(&query, 5).await { + match lance_db.read().await.prompt_vector_search(&query, 10).await { Ok(prompts) => { // Set embeddings to None before returning let prompts_without_embeddings: Vec = prompts diff --git a/shinkai-libs/shinkai-lancedb/Cargo.toml b/shinkai-libs/shinkai-lancedb/Cargo.toml index 0a415f877..afb954fd3 100644 --- a/shinkai-libs/shinkai-lancedb/Cargo.toml +++ b/shinkai-libs/shinkai-lancedb/Cargo.toml @@ -11,7 +11,7 @@ uuid = { version = "1.6.1", features = ["v4"] } shinkai_tools_primitives = { workspace = true } shinkai_vector_resources = { workspace = true } shinkai_message_primitives = { workspace = true } -shinkai_tools_runner = { version = "0.7.14", features = ["built-in-tools"] } +shinkai_tools_runner = { version = "0.7.15", features = ["built-in-tools"] } regex = "1" base64 = "0.22.0" lancedb = "0.10.0" diff --git a/shinkai-libs/shinkai-lancedb/src/lance_db/shinkai_prompt_db.rs b/shinkai-libs/shinkai-lancedb/src/lance_db/shinkai_prompt_db.rs index 6865746e6..a89231fc5 100644 --- a/shinkai-libs/shinkai-lancedb/src/lance_db/shinkai_prompt_db.rs +++ b/shinkai-libs/shinkai-lancedb/src/lance_db/shinkai_prompt_db.rs @@ -90,7 +90,7 @@ impl LanceShinkaiDb { Ok(()) } - fn convert_batch_to_prompt(batch: &RecordBatch) -> Option { + fn convert_batch_to_prompt(batch: &RecordBatch) -> Vec { let name_array = batch .column_by_name(ShinkaiPromptSchema::name_field()) .unwrap() @@ -134,13 +134,15 @@ impl LanceShinkaiDb { .downcast_ref::() .unwrap(); - if name_array.len() > 0 { - let embedding = if vector_array.is_null(0) { + let mut prompts = Vec::new(); + + for i in 0..name_array.len() { + let embedding = if vector_array.is_null(i) { None } else { Some( vector_array - .value(0) + .value(i) .as_any() .downcast_ref::() .unwrap() @@ -149,18 +151,18 @@ impl LanceShinkaiDb { ) }; - Some(CustomPrompt { - name: name_array.value(0).to_string(), - prompt: prompt_array.value(0).to_string(), - is_system: is_system_array.value(0), - is_enabled: is_enabled_array.value(0), - version: version_array.value(0).to_string(), - is_favorite: is_favorite_array.value(0), + prompts.push(CustomPrompt { + name: name_array.value(i).to_string(), + prompt: prompt_array.value(i).to_string(), + is_system: is_system_array.value(i), + is_enabled: is_enabled_array.value(i), + version: version_array.value(i).to_string(), + is_favorite: is_favorite_array.value(i), embedding, - }) - } else { - None + }); } + + prompts } pub async fn get_prompt(&self, name: &str) -> Result, ShinkaiLanceDBError> { @@ -188,7 +190,8 @@ impl LanceShinkaiDb { .map_err(|e| ShinkaiLanceDBError::DatabaseError(e.to_string()))?; for batch in results { - if let Some(prompt) = Self::convert_batch_to_prompt(&batch) { + let prompts = Self::convert_batch_to_prompt(&batch); + for prompt in prompts { return Ok(Some(prompt)); } } @@ -281,11 +284,8 @@ impl LanceShinkaiDb { let mut prompts = Vec::new(); let mut res = query; while let Some(Ok(batch)) = res.next().await { - for _i in 0..batch.num_rows() { - if let Some(prompt) = Self::convert_batch_to_prompt(&batch) { - prompts.push(prompt); - } - } + let prompts_batch = Self::convert_batch_to_prompt(&batch); + prompts.extend(prompts_batch); } Ok(prompts) @@ -311,7 +311,12 @@ impl LanceShinkaiDb { let fts_query_builder = self .prompt_table .query() - .full_text_search(FullTextSearchQuery::new(query.to_owned())) + .full_text_search(FullTextSearchQuery { + columns: vec!["prompt".to_string()], + query: query.to_owned(), + limit: Some(num_results as i64), + wand_factor: Some(1.0), + }) .select(Select::columns(&[ ShinkaiPromptSchema::name_field(), ShinkaiPromptSchema::prompt_field(), @@ -357,16 +362,14 @@ impl LanceShinkaiDb { let mut fts_res = fts_query; while let Some(Ok(batch)) = fts_res.next().await { - if let Some(prompt) = Self::convert_batch_to_prompt(&batch) { - fts_results.push(prompt); - } + let prompts = Self::convert_batch_to_prompt(&batch); + fts_results.extend(prompts); } let mut vector_res = vector_query; while let Some(Ok(batch)) = vector_res.next().await { - if let Some(prompt) = Self::convert_batch_to_prompt(&batch) { - vector_results.push(prompt); - } + let prompts = Self::convert_batch_to_prompt(&batch); + vector_results.extend(prompts); } // Merge results using interleave and remove duplicates @@ -399,6 +402,21 @@ impl LanceShinkaiDb { } } + // Continue to add results from the remaining iterator if needed + while combined_results.len() < num_results as usize { + if let Some(fts_item) = fts_iter.next() { + if seen.insert(fts_item.name.clone()) { + combined_results.push(fts_item); + } + } else if let Some(vector_item) = vector_iter.next() { + if seen.insert(vector_item.name.clone()) { + combined_results.push(vector_item); + } + } else { + break; + } + } + Ok(combined_results) } @@ -422,11 +440,8 @@ impl LanceShinkaiDb { let mut prompts = Vec::new(); let mut res = query; while let Some(Ok(batch)) = res.next().await { - for _i in 0..batch.num_rows() { - if let Some(prompt) = Self::convert_batch_to_prompt(&batch) { - prompts.push(prompt); - } - } + let prompts_batch = Self::convert_batch_to_prompt(&batch); + prompts.extend(prompts_batch); } Ok(prompts) diff --git a/shinkai-libs/shinkai-message-primitives/src/schemas/job_config.rs b/shinkai-libs/shinkai-message-primitives/src/schemas/job_config.rs index 6163af911..855b39803 100644 --- a/shinkai-libs/shinkai-message-primitives/src/schemas/job_config.rs +++ b/shinkai-libs/shinkai-message-primitives/src/schemas/job_config.rs @@ -13,4 +13,6 @@ pub struct JobConfig { pub top_p: Option, pub stream: Option, pub other_model_params: Option, + pub use_tools: Option, + pub web_search: Option, } \ No newline at end of file