From efe51836155263ea3cec0fe5376871b9dd37bef5 Mon Sep 17 00:00:00 2001 From: Hyunbin Kim Date: Fri, 4 Oct 2024 22:38:57 +0900 Subject: [PATCH] [IN PROGRESS] optimizing query --- Cargo.lock | 1 - Cargo.toml | 1 - src/cli/workflows/query_pdb.rs | 6 +- src/controller/mode.rs | 77 +++++++++ src/controller/oldrank.rs | 298 +++++++++++++++++++++++++++++++++ src/controller/rank.rs | 255 +++++++++++++--------------- src/index/indextable.rs | 4 +- src/utils/combination.rs | 2 + src/utils/convert.rs | 4 - 9 files changed, 499 insertions(+), 149 deletions(-) create mode 100644 src/controller/oldrank.rs diff --git a/Cargo.lock b/Cargo.lock index d9a2d27..291ebf1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -195,7 +195,6 @@ dependencies = [ "cmake", "dashmap", "flate2", - "lazy_static", "libc", "memmap2", "peak_alloc", diff --git a/Cargo.toml b/Cargo.toml index 6e9f62f..d0508e3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,6 @@ toml = "0.8.12" regex = "1.10.4" petgraph = "0.6.4" libc = "0.2.155" -lazy_static = "1.5.0" dashmap = { version = "6.1.0", features = ["rayon"] } diff --git a/src/cli/workflows/query_pdb.rs b/src/cli/workflows/query_pdb.rs index 2428f28..133442c 100644 --- a/src/cli/workflows/query_pdb.rs +++ b/src/cli/workflows/query_pdb.rs @@ -15,7 +15,7 @@ use rayon::prelude::*; use crate::cli::config::{read_index_config_from_file, IndexConfig}; use crate::controller::map::SimpleHashMap; -use crate::controller::mode::{parse_path_by_id_type, IdType, IndexMode}; +use crate::controller::mode::{parse_path_by_id_type, parse_path_by_id_type_with_string, IdType, IndexMode}; use crate::cli::*; use crate::controller::io::{read_compact_structure, read_u16_vector}; use crate::controller::query::{check_and_get_indices, get_offset_value_lookup_type, make_query_map, parse_threshold_string}; @@ -389,8 +389,10 @@ pub fn query_pdb(env: AppArgs) { if header { println!("{}", QUERY_RESULT_HEADER); } + let mut id_container = String::new(); for (_k, v) in queried_from_indices.iter_mut() { - v.id = parse_path_by_id_type(&v.id, &id_type); + parse_path_by_id_type_with_string(v.id, &id_type, &mut id_container); + v.id = Box::leak(id_container.clone().into_boxed_str()); println!("{:?}\t{}\t{}\t{}", v, query_string, pdb_path, index_path.clone().unwrap()); } } diff --git a/src/controller/mode.rs b/src/controller/mode.rs index b267c1b..d53b734 100644 --- a/src/controller/mode.rs +++ b/src/controller/mode.rs @@ -123,6 +123,83 @@ pub fn parse_path_by_id_type(path: &str, id_type: &IdType) -> String { } } + +pub fn parse_path_by_id_type_with_string(path: &str, id_type: &IdType, string: &mut String) { + // TODO: 2024-04-04 15:07:54 Fill in this function to ease benchmarking + string.clear(); + let afdb_regex = regex::Regex::new(r"AF-.+-model_v\d").unwrap(); + match id_type { + IdType::Pdb => { + // Get the basename of the path + let path = Path::new(path); + let file_name = path.file_stem().unwrap(); + // Remove extension + let file_name = file_name.to_str().unwrap(); + // Remove extension, If startswith "pdb" remove "pdb" from the start + if file_name.starts_with("pdb") { + // &file_name[3..] + string.push_str(&file_name[3..]); + } else { + // file_name + string.push_str(file_name); + } + } + IdType::Afdb => { + let path = Path::new(path); + let file_name = path.file_stem().unwrap().to_str().unwrap(); + // Find the matching pattern + let afdb_id = afdb_regex.find(file_name); + if afdb_id.is_none() { + // return file_name; + string.push_str(file_name); + } else { + // &file_name[afdb_id.unwrap().start()..afdb_id.unwrap().end()] + string.push_str(&file_name[afdb_id.unwrap().start()..afdb_id.unwrap().end()]); + } + } + IdType::UniProt => { + let path = Path::new(path); + let file_name = path.file_stem().unwrap().to_str().unwrap(); + // Find the matching pattern + let afdb_id = afdb_regex.find(file_name); + if afdb_id.is_none() { + // return file_name; + string.push_str(file_name); + } + let afdb_id = file_name[afdb_id.unwrap().start()..afdb_id.unwrap().end()].to_string(); + let afdb_id = afdb_id.split("-").collect::>(); + // afdb_id[1] + string.push_str(afdb_id[1]); + } + IdType::BasenameWithoutExt => { + let path = Path::new(path); + let file_name = path.file_stem().unwrap().to_str().unwrap(); + // file_name + string.push_str(file_name); + } + IdType::BasenameWithExt => { + let path = Path::new(path); + let file_name = path.file_name().unwrap().to_str().unwrap(); + // file_name + string.push_str(file_name); + } + IdType::AbsPath => { + let path = fs::canonicalize(path).unwrap(); + // path.to_str().unwrap() + string.push_str(path.to_str().unwrap()); + } + IdType::RelPath => { + // path + string.push_str(path); + } + IdType::Other => { + // path + string.push_str(path); + } + } +} + + pub fn parse_path_vec_by_id_type(path_vec: &Vec, id_type: IdType) -> Vec { let mut parsed_path_vec = Vec::with_capacity(path_vec.len()); for path in path_vec { diff --git a/src/controller/oldrank.rs b/src/controller/oldrank.rs new file mode 100644 index 0000000..2e39c57 --- /dev/null +++ b/src/controller/oldrank.rs @@ -0,0 +1,298 @@ +// Functions for ranking queried results + +use std::collections::HashMap; +use std::fmt; +use crate::index::indextable::FolddiscoIndex; +use crate::prelude::GeometricHash; + + +use super::io::get_values_with_offset_u16; +use super::map::SimpleHashMap; + +#[derive(Clone)] +pub struct QueryResult { + pub id: String, + pub nid: usize, + pub total_match_count: usize, + pub node_count: usize, + pub edge_count: usize, + pub exact_match_count: usize, + pub overflow_count: usize, + pub idf: f32, + pub nres: usize, + pub plddt: f32, + pub node_set: HashMap, + pub edge_set: HashMap<(usize, usize), usize>, + pub pos_set: HashMap<(u16, u16), usize>, + pub matching_residues: Vec<(String, f32)>, + pub matching_residues_processed: Vec<(String, f32)>, +} + +impl QueryResult { + pub fn new( + id: String, nid: usize, total_match_count: usize, node_count: usize, edge_count: usize, + exact_match_count: usize, overflow_count: usize, idf: f32, nres: usize, plddt: f32 + ) -> Self { + Self { + id, + nid, + total_match_count, + node_count, + edge_count, + exact_match_count, + overflow_count, + idf, + nres, + plddt, + node_set: HashMap::new(), + edge_set: HashMap::new(), + pos_set: HashMap::new(), + matching_residues: Vec::new(), + matching_residues_processed: Vec::new(), + } + } +} + +impl fmt::Display for QueryResult { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let matching_residues_with_score = if self.matching_residues.len() == 0 { + "NA".to_string() + } else { + self.matching_residues.iter().map( + // Only print score with 4 decimal places + |(x, y)| format!("{}:{:.4}", x, y) + ).collect::>().join(";") + }; + let matching_residues_processed_with_score = if self.matching_residues_processed.len() == 0 { + "NA".to_string() + } else { + self.matching_residues_processed.iter().map( + // Only print score with 4 decimal places + |(x, y)| format!("{}:{:.4}", x, y) + ).collect::>().join(";") + }; + write!( + f, "{}\t{:.4}\t{}\t{}\t{}\t{}\t{}\t{}\t{:.4}\t{}\t{}", + self.id ,self.idf, self.total_match_count, self.node_count, self.edge_count, + self.exact_match_count, self.overflow_count, + self.nres, self.plddt, matching_residues_with_score, matching_residues_processed_with_score + // self.pos_set.len(), + // self.node_set, self.edge_set, self.grid_set, self.pos_set + ) + } +} + +impl fmt::Debug for QueryResult { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let matching_residues_with_score = if self.matching_residues.len() == 0 { + "NA".to_string() + } else { + self.matching_residues.iter().map( + // Only print score with 4 decimal places + |(x, y)| format!("{}:{:.4}", x, y) + ).collect::>().join(";") + }; + let matching_residues_processed_with_score = if self.matching_residues_processed.len() == 0 { + "NA".to_string() + } else { + self.matching_residues_processed.iter().map( + // Only print score with 4 decimal places + |(x, y)| format!("{}:{:.4}", x, y) + ).collect::>().join(";") + }; + write!( + f, "{}\t{:.4}\t{}\t{}\t{}\t{}\t{}\t{}\t{:.4}\t{}\t{}", + self.id ,self.idf, self.total_match_count, self.node_count, self.edge_count, + self.exact_match_count, self.overflow_count, + self.nres, self.plddt, matching_residues_with_score, matching_residues_processed_with_score + // self.pos_set.len(), + // self.node_set, self.edge_set, self.grid_set, self.pos_set + ) + } +} +// write_fmt +impl QueryResult { + pub fn write_fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let matching_residues_with_score = if self.matching_residues.len() == 0 { + "NA".to_string() + } else { + self.matching_residues.iter().map( + // Only print score with 4 decimal places + |(x, y)| format!("{}:{:.4}", x, y) + ).collect::>().join(";") + }; + let matching_residues_processed_with_score = if self.matching_residues_processed.len() == 0 { + "NA".to_string() + } else { + self.matching_residues_processed.iter().map( + // Only print score with 4 decimal places + |(x, y)| format!("{}:{:.4}", x, y) + ).collect::>().join(";") + }; + write!( + f, "{}\t{:.4}\t{}\t{}\t{}\t{}\t{}\t{}\t{:.4}\t{}\t{}", + self.id ,self.idf, self.total_match_count, self.node_count, self.edge_count, + self.exact_match_count, self.overflow_count, + self.nres, self.plddt, matching_residues_with_score, matching_residues_processed_with_score + // self.pos_set.len(), + // self.node_set, self.edge_set, self.grid_set, self.pos_set + ) + } +} + + +pub fn count_query_idmode( + queries: &Vec, query_map: &HashMap, + offset_table: &SimpleHashMap, + value_vec: &[u16], + lookup: &(Vec, Vec, Vec, Vec) +) -> HashMap { + let mut query_count_map = HashMap::new(); + for (_i, query) in queries.iter().enumerate() { + let offset = offset_table.get(query); + if offset.is_none() { + continue; + } + let offset = offset.unwrap(); + let single_queried_values = get_values_with_offset_u16(value_vec, offset.0, offset.1); + let edge_info = query_map.get(query).unwrap(); + let is_exact = edge_info.1; + let edge = edge_info.0; + let hash_count = offset.1; + for j in 0..single_queried_values.len() { + let id = lookup.0[single_queried_values[j] as usize].clone(); + let nid = lookup.1[single_queried_values[j] as usize]; + let nres = lookup.2[single_queried_values[j] as usize]; + let plddt = lookup.3[single_queried_values[j] as usize]; + + let result = query_count_map.get_mut(&nid); + let idf = (lookup.0.len() as f32 / hash_count as f32).log2(); + let nres_norm = (nres as f32).log2() * -1.0 + 12.0; + + if result.is_none() { + let mut node_set = HashMap::new(); + node_set.insert(edge.0, 1); + node_set.insert(edge.1, 1); + let mut edge_set = HashMap::new(); + edge_set.insert(edge, 1); + let exact_match_count = if is_exact { 1usize } else { 0usize }; + let overflow_count = 0usize; + let total_match_count = 1usize; + let mut query_result = QueryResult::new( + id, nid, total_match_count, 2, 1, exact_match_count, + overflow_count, idf + nres_norm, nres, plddt + ); + query_result.node_set = node_set; + query_result.edge_set = edge_set; + query_count_map.insert(nid, query_result); + } else { + let result = result.unwrap(); + if result.node_set.contains_key(&edge.0) { + let count = result.node_set.get_mut(&edge.0).unwrap(); + *count += 1; + } else { + result.node_set.insert(edge.0, 1); + result.node_count += 1; + } + if result.node_set.contains_key(&edge.1) { + let count = result.node_set.get_mut(&edge.1).unwrap(); + *count += 1; + } else { + result.node_set.insert(edge.1, 1); + result.node_count += 1; + } + let is_overflow = result.edge_set.contains_key(&edge); + if is_overflow { + let count = result.edge_set.get_mut(&edge).unwrap(); + *count += 1; + result.overflow_count += 1; + } else { + result.edge_set.insert(edge, 1); + result.edge_count += 1; + } + result.total_match_count += 1; + result.exact_match_count += if is_exact { 1usize } else { 0usize }; + result.idf += idf + nres_norm; + } + } + } + query_count_map +} + +pub fn count_query_bigmode( + queries: &Vec, query_map: &HashMap, + big_index: &FolddiscoIndex, + lookup: &(Vec, Vec, Vec, Vec) +) -> HashMap { + let mut query_count_map = HashMap::new(); + for (_i, query) in queries.iter().enumerate() { + + let single_queried_values = big_index.get_entries(query.as_u32()); + let edge_info = query_map.get(query).unwrap(); + let is_exact = edge_info.1; + let edge = edge_info.0; + let hash_count = single_queried_values.len(); + for j in 0..single_queried_values.len() { + if single_queried_values[j] >= lookup.0.len() { + println!("Error: {} >= {}", single_queried_values[j], lookup.0.len()); + println!("Error query: {:?}", query); + } + let id = lookup.0[single_queried_values[j]].clone(); + let nid = lookup.1[single_queried_values[j]]; + let nres = lookup.2[single_queried_values[j]]; + let plddt = lookup.3[single_queried_values[j]]; + + let result = query_count_map.get_mut(&nid); + let idf = (lookup.0.len() as f32 / hash_count as f32).log2(); + let nres_norm = (nres as f32).log2() * -1.0 + 12.0; + + if result.is_none() { + let mut node_set = HashMap::new(); + node_set.insert(edge.0, 1); + node_set.insert(edge.1, 1); + let mut edge_set = HashMap::new(); + edge_set.insert(edge, 1); + let exact_match_count = if is_exact { 1usize } else { 0usize }; + let overflow_count = 0usize; + let total_match_count = 1usize; + let mut query_result = QueryResult::new( + id, nid, total_match_count, 2, 1, exact_match_count, + overflow_count, idf + nres_norm, nres, plddt + ); + query_result.node_set = node_set; + query_result.edge_set = edge_set; + query_count_map.insert(nid, query_result); + } else { + let result = result.unwrap(); + if result.node_set.contains_key(&edge.0) { + let count = result.node_set.get_mut(&edge.0).unwrap(); + *count += 1; + } else { + result.node_set.insert(edge.0, 1); + result.node_count += 1; + } + if result.node_set.contains_key(&edge.1) { + let count = result.node_set.get_mut(&edge.1).unwrap(); + *count += 1; + } else { + result.node_set.insert(edge.1, 1); + result.node_count += 1; + } + let is_overflow = result.edge_set.contains_key(&edge); + if is_overflow { + let count = result.edge_set.get_mut(&edge).unwrap(); + *count += 1; + result.overflow_count += 1; + } else { + result.edge_set.insert(edge, 1); + result.edge_count += 1; + } + result.total_match_count += 1; + result.exact_match_count += if is_exact { 1usize } else { 0usize }; + result.idf += idf + nres_norm; + } + } + } + query_count_map +} + diff --git a/src/controller/rank.rs b/src/controller/rank.rs index 2e39c57..2f14499 100644 --- a/src/controller/rank.rs +++ b/src/controller/rank.rs @@ -1,17 +1,18 @@ // Functions for ranking queried results -use std::collections::HashMap; +use dashmap::DashMap; +use rayon::prelude::*; // Import rayon for parallel iterators + +use std::collections::{HashMap, HashSet}; use std::fmt; use crate::index::indextable::FolddiscoIndex; use crate::prelude::GeometricHash; - use super::io::get_values_with_offset_u16; use super::map::SimpleHashMap; -#[derive(Clone)] -pub struct QueryResult { - pub id: String, +pub struct QueryResult<'a> { + pub id: &'a str, pub nid: usize, pub total_match_count: usize, pub node_count: usize, @@ -21,18 +22,23 @@ pub struct QueryResult { pub idf: f32, pub nres: usize, pub plddt: f32, - pub node_set: HashMap, - pub edge_set: HashMap<(usize, usize), usize>, - pub pos_set: HashMap<(u16, u16), usize>, + pub node_set: HashSet, + pub edge_set: HashSet<(usize, usize)>, pub matching_residues: Vec<(String, f32)>, pub matching_residues_processed: Vec<(String, f32)>, } -impl QueryResult { +impl<'a> QueryResult<'a> { pub fn new( - id: String, nid: usize, total_match_count: usize, node_count: usize, edge_count: usize, - exact_match_count: usize, overflow_count: usize, idf: f32, nres: usize, plddt: f32 + id: &'a str, nid: usize, total_match_count: usize, node_count: usize, edge_count: usize, + exact_match_count: usize, overflow_count: usize, idf: f32, nres: usize, plddt: f32, + edge: &(usize, usize), ) -> Self { + let mut node_set = HashSet::new(); + node_set.insert(edge.0); + node_set.insert(edge.1); + let mut edge_set = HashSet::new(); + edge_set.insert(*edge); Self { id, nid, @@ -44,16 +50,15 @@ impl QueryResult { idf, nres, plddt, - node_set: HashMap::new(), - edge_set: HashMap::new(), - pos_set: HashMap::new(), + node_set: node_set, + edge_set: edge_set, matching_residues: Vec::new(), matching_residues_processed: Vec::new(), } } } -impl fmt::Display for QueryResult { +impl<'a> fmt::Display for QueryResult<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let matching_residues_with_score = if self.matching_residues.len() == 0 { "NA".to_string() @@ -82,7 +87,7 @@ impl fmt::Display for QueryResult { } } -impl fmt::Debug for QueryResult { +impl<'a> fmt::Debug for QueryResult<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let matching_residues_with_score = if self.matching_residues.len() == 0 { "NA".to_string() @@ -111,7 +116,7 @@ impl fmt::Debug for QueryResult { } } // write_fmt -impl QueryResult { +impl<'a> QueryResult<'a> { pub fn write_fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let matching_residues_with_score = if self.matching_residues.len() == 0 { "NA".to_string() @@ -140,159 +145,131 @@ impl QueryResult { } } - -pub fn count_query_idmode( +pub fn count_query_idmode<'a>( queries: &Vec, query_map: &HashMap, offset_table: &SimpleHashMap, value_vec: &[u16], - lookup: &(Vec, Vec, Vec, Vec) -) -> HashMap { - let mut query_count_map = HashMap::new(); - for (_i, query) in queries.iter().enumerate() { - let offset = offset_table.get(query); - if offset.is_none() { - continue; - } - let offset = offset.unwrap(); - let single_queried_values = get_values_with_offset_u16(value_vec, offset.0, offset.1); - let edge_info = query_map.get(query).unwrap(); - let is_exact = edge_info.1; - let edge = edge_info.0; - let hash_count = offset.1; - for j in 0..single_queried_values.len() { - let id = lookup.0[single_queried_values[j] as usize].clone(); - let nid = lookup.1[single_queried_values[j] as usize]; - let nres = lookup.2[single_queried_values[j] as usize]; - let plddt = lookup.3[single_queried_values[j] as usize]; - - let result = query_count_map.get_mut(&nid); - let idf = (lookup.0.len() as f32 / hash_count as f32).log2(); - let nres_norm = (nres as f32).log2() * -1.0 + 12.0; - - if result.is_none() { - let mut node_set = HashMap::new(); - node_set.insert(edge.0, 1); - node_set.insert(edge.1, 1); - let mut edge_set = HashMap::new(); - edge_set.insert(edge, 1); - let exact_match_count = if is_exact { 1usize } else { 0usize }; - let overflow_count = 0usize; - let total_match_count = 1usize; - let mut query_result = QueryResult::new( - id, nid, total_match_count, 2, 1, exact_match_count, - overflow_count, idf + nres_norm, nres, plddt - ); - query_result.node_set = node_set; - query_result.edge_set = edge_set; - query_count_map.insert(nid, query_result); - } else { - let result = result.unwrap(); - if result.node_set.contains_key(&edge.0) { - let count = result.node_set.get_mut(&edge.0).unwrap(); - *count += 1; - } else { - result.node_set.insert(edge.0, 1); - result.node_count += 1; - } - if result.node_set.contains_key(&edge.1) { - let count = result.node_set.get_mut(&edge.1).unwrap(); - *count += 1; - } else { - result.node_set.insert(edge.1, 1); - result.node_count += 1; - } - let is_overflow = result.edge_set.contains_key(&edge); - if is_overflow { - let count = result.edge_set.get_mut(&edge).unwrap(); - *count += 1; - result.overflow_count += 1; - } else { - result.edge_set.insert(edge, 1); - result.edge_count += 1; + lookup: &'a (Vec, Vec, Vec, Vec) +) -> DashMap> { + let query_count_map = DashMap::new(); // Use DashMap instead of HashMap + + queries.par_iter().for_each(|query| { // Use parallel iterator + if let Some(offset) = offset_table.get(query) { + let single_queried_values = get_values_with_offset_u16(value_vec, offset.0, offset.1); + let edge_info = query_map.get(query).unwrap(); + let is_exact = edge_info.1; + let edge = edge_info.0; + let hash_count = offset.1; + + for &value in single_queried_values.iter() { + let id = &lookup.0[value as usize]; + let nid = lookup.1[value as usize]; + let nres = lookup.2[value as usize]; + let plddt = lookup.3[value as usize]; + + let idf = (lookup.0.len() as f32 / hash_count as f32).log2(); + let nres_norm = (nres as f32).log2() * -1.0 + 12.0; + let mut is_new: bool = false; + let entry = query_count_map.entry(nid); + // Not consuming the entry, so we can modify it + let mut ref_mut = entry.or_insert_with(|| { + let exact_match_count = if is_exact { 1usize } else { 0usize }; + let overflow_count = 0usize; + let total_match_count = 1usize; + is_new = true; + QueryResult::new( + id, nid, total_match_count, 2, 1, exact_match_count, + overflow_count, idf + nres_norm, nres, plddt, &edge + ) + }); + + // Modify with ref_mut + let result = ref_mut.value_mut(); + if !is_new { + // Now modify the `result` directly + // Check if node_set has edge.0 and edge.1 + result.node_set.insert(edge.0); + result.node_set.insert(edge.1); + result.node_count = result.node_set.len(); + let has_edge = result.edge_set.contains(&edge); + if !has_edge { + result.edge_set.insert(edge); + result.edge_count += 1; + } else { + result.overflow_count += 1; + } + result.total_match_count += 1; + result.exact_match_count += if is_exact { 1usize } else { 0usize }; + result.idf += idf + nres_norm; } - result.total_match_count += 1; - result.exact_match_count += if is_exact { 1usize } else { 0usize }; - result.idf += idf + nres_norm; } } - } + }); + query_count_map } -pub fn count_query_bigmode( +pub fn count_query_bigmode<'a>( queries: &Vec, query_map: &HashMap, big_index: &FolddiscoIndex, - lookup: &(Vec, Vec, Vec, Vec) -) -> HashMap { - let mut query_count_map = HashMap::new(); - for (_i, query) in queries.iter().enumerate() { - + lookup: &'a (Vec, Vec, Vec, Vec) +) -> DashMap> { + let query_count_map = DashMap::new(); // Use DashMap instead of HashMap + + queries.par_iter().for_each(|query| { // Use parallel iterator let single_queried_values = big_index.get_entries(query.as_u32()); let edge_info = query_map.get(query).unwrap(); let is_exact = edge_info.1; let edge = edge_info.0; let hash_count = single_queried_values.len(); - for j in 0..single_queried_values.len() { - if single_queried_values[j] >= lookup.0.len() { - println!("Error: {} >= {}", single_queried_values[j], lookup.0.len()); + + for &value in single_queried_values.iter() { + if value >= lookup.0.len() { + println!("Error: {} >= {}", value, lookup.0.len()); println!("Error query: {:?}", query); + continue; } - let id = lookup.0[single_queried_values[j]].clone(); - let nid = lookup.1[single_queried_values[j]]; - let nres = lookup.2[single_queried_values[j]]; - let plddt = lookup.3[single_queried_values[j]]; - - let result = query_count_map.get_mut(&nid); + let id = &lookup.0[value]; + let nid = lookup.1[value]; + let nres = lookup.2[value]; + let plddt = lookup.3[value]; + let idf = (lookup.0.len() as f32 / hash_count as f32).log2(); let nres_norm = (nres as f32).log2() * -1.0 + 12.0; - - if result.is_none() { - let mut node_set = HashMap::new(); - node_set.insert(edge.0, 1); - node_set.insert(edge.1, 1); - let mut edge_set = HashMap::new(); - edge_set.insert(edge, 1); + let mut is_new: bool = false; + let entry = query_count_map.entry(nid); + // Not consuming the entry, so we can modify it + let mut ref_mut = entry.or_insert_with(|| { let exact_match_count = if is_exact { 1usize } else { 0usize }; let overflow_count = 0usize; let total_match_count = 1usize; - let mut query_result = QueryResult::new( + is_new = true; + QueryResult::new( id, nid, total_match_count, 2, 1, exact_match_count, - overflow_count, idf + nres_norm, nres, plddt - ); - query_result.node_set = node_set; - query_result.edge_set = edge_set; - query_count_map.insert(nid, query_result); - } else { - let result = result.unwrap(); - if result.node_set.contains_key(&edge.0) { - let count = result.node_set.get_mut(&edge.0).unwrap(); - *count += 1; - } else { - result.node_set.insert(edge.0, 1); - result.node_count += 1; - } - if result.node_set.contains_key(&edge.1) { - let count = result.node_set.get_mut(&edge.1).unwrap(); - *count += 1; + overflow_count, idf + nres_norm, nres, plddt, &edge + ) + }); + + // Modify with ref_mut + let result = ref_mut.value_mut(); + if !is_new { + // Now modify the `result` directly + // Check if node_set has edge.0 and edge.1 + result.node_set.insert(edge.0); + result.node_set.insert(edge.1); + result.node_count = result.node_set.len(); + let has_edge = result.edge_set.contains(&edge); + if !has_edge { + result.edge_set.insert(edge); + result.edge_count += 1; } else { - result.node_set.insert(edge.1, 1); - result.node_count += 1; - } - let is_overflow = result.edge_set.contains_key(&edge); - if is_overflow { - let count = result.edge_set.get_mut(&edge).unwrap(); - *count += 1; result.overflow_count += 1; - } else { - result.edge_set.insert(edge, 1); - result.edge_count += 1; } result.total_match_count += 1; result.exact_match_count += if is_exact { 1usize } else { 0usize }; result.idf += idf + nres_norm; } } - } + }); query_count_map } - diff --git a/src/index/indextable.rs b/src/index/indextable.rs index c557eda..b058655 100644 --- a/src/index/indextable.rs +++ b/src/index/indextable.rs @@ -37,7 +37,7 @@ impl FolddiscoIndex { } } - #[inline(always)] + pub fn count_single_entry(&self, hash: u32, id: usize) { let last_id = unsafe { &mut *self.last_id.get() }; // let atomic_offsets = unsafe { &mut *self.atomic_offsets.get() }; @@ -121,7 +121,7 @@ impl FolddiscoIndex { } } - #[inline(always)] + pub fn add_single_entry(&self, hash: u32, id: usize, bit_container: &mut Vec) { let last_id = unsafe { &mut *self.last_id.get() }; // let atomic_offsets = unsafe { &mut *self.atomic_offsets.get() }; diff --git a/src/utils/combination.rs b/src/utils/combination.rs index 5e44046..51152de 100644 --- a/src/utils/combination.rs +++ b/src/utils/combination.rs @@ -1,3 +1,4 @@ + pub struct CombinationIterator { n: usize, i: usize, @@ -65,6 +66,7 @@ mod tests { println!("{} {}", i, j); }); } + } #[derive(Hash, PartialEq, Eq)] diff --git a/src/utils/convert.rs b/src/utils/convert.rs index ed8acd5..8571a28 100644 --- a/src/utils/convert.rs +++ b/src/utils/convert.rs @@ -1,7 +1,3 @@ - -use std::collections::HashMap; -use lazy_static::lazy_static; - // Constants // 1. for cb_dist pub const MIN_DIST: f32 = 2.0;