diff --git a/Cargo.toml b/Cargo.toml index 44dd349..0719566 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,3 +31,8 @@ default = ["python-support"] [[bench]] name = "all_windows" harness = false + + +[[bench]] +name = "expansions" +harness = false diff --git a/benches/expansions.rs b/benches/expansions.rs new file mode 100644 index 0000000..e010ebb --- /dev/null +++ b/benches/expansions.rs @@ -0,0 +1,153 @@ +// Copyright 2021-2023 SecureDNA Stiftung (SecureDNA Foundation) +// SPDX-License-Identifier: MIT OR Apache-2.0 + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use quickdna::{ + expansions::Expansions, BaseSequence, DnaSequenceStrict, Nucleotide, NucleotideAmbiguous, +}; +use rand::{rngs::OsRng, seq::SliceRandom}; + +// Vec-based expansions code taken from DNA screening tools + +const MAX_EXPANSIONS_FOR_WINDOW: usize = 16; + +fn vec_based_expansions( + wildcards: [NucleotideAmbiguous; WINDOW_SIZE], +) -> Vec<[Nucleotide; WINDOW_SIZE]> { + if !are_expansions_feasible(wildcards) { + panic!( + "Too many possible expansions of window {:?}; ignoring window", + wildcards + ); + } + + let mut expanded_sequence_buffers: Vec = + vec![DnaSequenceStrict::new(vec![])]; + + for wildcard in wildcards { + let expansions = wildcard.possibilities(); + if expansions.len() == 1 { + for sequence_buffer in expanded_sequence_buffers.iter_mut() { + sequence_buffer.push(expansions[0]) + } + } else { + let mut new_buffers = Vec::new(); + for nucleotide_from_expansion in expansions.iter() { + for existing_sequence_buffer in expanded_sequence_buffers.iter() { + let mut new_sequence_bufer = existing_sequence_buffer.clone(); + new_sequence_bufer.push(*nucleotide_from_expansion); + new_buffers.push(new_sequence_bufer); + } + } + expanded_sequence_buffers = new_buffers; + } + } + + expanded_sequence_buffers + .into_iter() + .map(|s| s.as_slice().try_into().unwrap()) + .collect() +} + +/// each expansion creates n variants +/// for example AAR becomes [AAA, AAG] +/// these expansions are exponential +/// AARR becomes [AAAA, AAAG, AAGA, AAGG] and so on +/// we do not want to generate more than a limited number of these +fn are_expansions_feasible(ws: impl IntoIterator) -> bool { + let mut acc = 1; + for w in ws { + acc *= w.possibilities().len(); + if acc > MAX_EXPANSIONS_FOR_WINDOW { + return false; + } + } + acc <= MAX_EXPANSIONS_FOR_WINDOW +} + +fn semi_ambiguous_dna(dna_len: usize, num_ambiguities: usize) -> Vec { + let nucleotides = Nucleotide::ALL.map(|nuc| nuc.into()); + let ambiguous_nucleotides = [ + NucleotideAmbiguous::W, + NucleotideAmbiguous::M, + NucleotideAmbiguous::Y, + NucleotideAmbiguous::H, + NucleotideAmbiguous::R, + NucleotideAmbiguous::K, + NucleotideAmbiguous::D, + NucleotideAmbiguous::S, + NucleotideAmbiguous::V, + NucleotideAmbiguous::B, + NucleotideAmbiguous::N, + ]; + let mut dna: Vec<_> = (0..dna_len) + .map(|_| *nucleotides.choose(&mut OsRng).unwrap()) + .collect(); + for i in rand::seq::index::sample(&mut OsRng, dna_len, num_ambiguities) { + dna[i] = *ambiguous_nucleotides.choose(&mut OsRng).unwrap(); + } + dna +} + +pub fn criterion_benchmark(c: &mut Criterion) { + let num_windows = 1000; + let num_ambiguities = 2; + const WINDOW_LEN: usize = 42; + let windows: Vec<[_; WINDOW_LEN]> = (0..num_windows) + .map(|_| { + semi_ambiguous_dna(WINDOW_LEN, num_ambiguities) + .try_into() + .unwrap() + }) + .collect(); + + let num_windows_desc = format!("{num_windows} windows"); + + let mut group = c.benchmark_group("Expansion window content"); + group.throughput(Throughput::Elements(num_windows)); + group.bench_with_input( + BenchmarkId::new("vec-based", &num_windows_desc), + &windows, + |b, windows| { + b.iter(|| { + for window in windows { + for expansion in vec_based_expansions(*window) { + black_box(expansion); + } + } + }) + }, + ); + group.bench_with_input( + BenchmarkId::new("iter-based", &num_windows_desc), + &windows, + |b, windows| { + b.iter(|| { + for window in windows { + for expansion in Expansions::new(window) { + black_box(expansion); + } + } + }) + }, + ); + group.finish(); + + let mut group = c.benchmark_group("Expansion window size-hint"); + group.throughput(Throughput::Elements(num_windows)); + group.bench_with_input( + BenchmarkId::new("iter-based", &num_windows_desc), + &windows, + |b, windows| { + b.iter(|| { + for window in windows { + black_box(Expansions::new(window).size_hint()); + } + }) + }, + ); + group.finish(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/src/expansions.rs b/src/expansions.rs new file mode 100644 index 0000000..27f76e3 --- /dev/null +++ b/src/expansions.rs @@ -0,0 +1,288 @@ +// Copyright 2021-2023 SecureDNA Stiftung (SecureDNA Foundation) +// SPDX-License-Identifier: MIT OR Apache-2.0 + +use std::sync::Arc; + +use smallvec::SmallVec; + +use crate::{ + BaseSequence, DnaSequenceAmbiguous, DnaSequenceStrict, Nucleotide, NucleotideAmbiguous, + NucleotideLike, +}; + +/// Iterator of all unambiguous expansions of ambiguous DNA. +/// +/// Expansions are returned in lexicographic order based on the ordering of [`Nucleotide`] +/// (not currently alphabetical). +#[derive(Clone)] +pub struct Expansions { + ambiguities: SmallVec<[Ambiguity; 8]>, + buf: Arc<[Nucleotide]>, + began: bool, +} + +#[derive(Clone)] +struct Ambiguity { + index: usize, + digit: u8, + nucleotide: NucleotideAmbiguous, +} + +/// Unambiguous DNA expansion produced by [`Expansions`] iterator. +#[derive(Clone, PartialEq, Eq)] +pub struct Expansion(Arc<[Nucleotide]>); + +impl Expansions { + // Construct new [`Expansions`] iterator + pub fn new(dna: &[NucleotideAmbiguous]) -> Self { + let ambiguities = dna + .iter() + .enumerate() + .filter(|(_, nuc)| nuc.bits().count_ones() > 1) + .map(|(index, &nucleotide)| Ambiguity { + index, + digit: 0, + nucleotide, + }) + .collect(); + let buf: SmallVec<[_; 64]> = dna + .iter() + .map(|nuc| *nuc.possibilities().first().unwrap()) + .collect(); + let buf = Arc::from(buf.as_slice()); + Expansions { + ambiguities, + buf, + began: false, + } + } +} + +impl Iterator for Expansions { + type Item = Expansion; + + fn next(&mut self) -> Option { + if !self.began { + self.began = true; + return Some(Expansion(self.buf.clone())); + } + + // Can't use Arc::make_mut because [T] is unsized + let buf = match Arc::get_mut(&mut self.buf) { + Some(buf) => buf, + None => { + self.buf = Arc::from(&*self.buf); + Arc::get_mut(&mut self.buf).unwrap() + } + }; + + for amb in self.ambiguities.iter_mut().rev() { + amb.digit = (amb.digit + 1) % (amb.nucleotide.bits().count_ones() as u8); + buf[amb.index] = amb.nucleotide.possibilities()[amb.digit as usize]; + if amb.digit > 0 { + return Some(Expansion(self.buf.clone())); + } + } + None + } + + fn size_hint(&self) -> (usize, Option) { + let size = (|| { + let mut size: usize = 0; + for amb in &self.ambiguities { + let num_digit_states = amb.nucleotide.bits().count_ones() as usize; + let remaining = num_digit_states - (amb.digit as usize) - 1; + size = size.checked_mul(num_digit_states)?.checked_add(remaining)?; + } + size.checked_add(!self.began as usize) + })(); + (size.unwrap_or(usize::MAX), size) + } +} + +impl std::fmt::Debug for Expansions { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let dna = self.buf.iter().map(|&nuc| nuc.into()).collect(); + let mut dna = DnaSequenceAmbiguous::new(dna); + for amb in &self.ambiguities { + dna[amb.index] = amb.nucleotide; + } + f.debug_tuple("Expansions").field(&dna.to_string()).finish() + } +} + +impl Expansion { + /// Produces a [`DnaSequenceStrict`] from [`Expansion`]. + /// + /// This takes *O*(*N*) time. + pub fn to_dna(&self) -> DnaSequenceStrict { + DnaSequenceStrict::new(self.to_vec()) + } +} + +impl From for DnaSequenceStrict { + fn from(expansion: Expansion) -> Self { + expansion.to_dna() + } +} + +impl std::ops::Deref for Expansion { + type Target = [Nucleotide]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl AsRef<[Nucleotide]> for Expansion { + fn as_ref(&self) -> &[Nucleotide] { + self + } +} + +impl std::cmp::PartialEq<&[Nucleotide]> for Expansion { + fn eq(&self, other: &&[Nucleotide]) -> bool { + self.as_ref() == *other + } +} + +impl std::cmp::PartialEq for Expansion { + fn eq(&self, other: &DnaSequenceStrict) -> bool { + self.as_ref() == other.as_slice() + } +} + +impl std::fmt::Debug for Expansion { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_tuple("Expansion") + .field(&self.to_dna().to_string()) + .finish() + } +} + +#[cfg(test)] +mod test { + use super::*; + + use crate::{BaseSequence, DnaSequenceAmbiguous}; + + fn amb_dna(dna: &str) -> DnaSequenceAmbiguous { + dna.parse().unwrap() + } + + fn dna(dna: &str) -> DnaSequenceStrict { + dna.parse().unwrap() + } + + #[test] + fn basic_end_to_end() { + let src_dna = amb_dna("ATBCGYAC"); // AT{TCG}CG{TC}AC + let mut expansions = src_dna.expansions(); + assert_eq!(expansions.size_hint(), (6, Some(6))); + assert_eq!(expansions.next().unwrap(), dna("ATTCGTAC")); + assert_eq!(expansions.size_hint(), (5, Some(5))); + assert_eq!(expansions.next().unwrap(), dna("ATTCGCAC")); + assert_eq!(expansions.size_hint(), (4, Some(4))); + assert_eq!(expansions.next().unwrap(), dna("ATCCGTAC")); + assert_eq!(expansions.size_hint(), (3, Some(3))); + assert_eq!(expansions.next().unwrap(), dna("ATCCGCAC")); + assert_eq!(expansions.size_hint(), (2, Some(2))); + assert_eq!(expansions.next().unwrap(), dna("ATGCGTAC")); + assert_eq!(expansions.size_hint(), (1, Some(1))); + assert_eq!(expansions.next().unwrap(), dna("ATGCGCAC")); + assert_eq!(expansions.size_hint(), (0, Some(0))); + assert!(expansions.next().is_none()); + } + + #[test] + fn verify_basic_counting() { + let actual: Vec<_> = amb_dna("WWW").expansions().collect(); + let expected = ["AAA", "AAT", "ATA", "ATT", "TAA", "TAT", "TTA", "TTT"].map(dna); + assert_eq!(actual, expected); + } + + #[test] + fn verify_size_hints() { + let testcases = [ + ("", 1), + ("A", 1), + ("ATCG", 1), + ("W", 2), + ("B", 3), + ("N", 4), + ("MR", 4), + ("YV", 6), + ("DS", 6), + ("KN", 8), + ("NW", 8), + ("MRY", 8), + ("HB", 9), + ("ACGTYTAGCNCGATVGCTA", 24), + ]; + for (src_dna, mut expected_len) in testcases { + let mut iter = amb_dna(src_dna).expansions(); + assert_eq!( + iter.size_hint(), + (expected_len, Some(expected_len)), + "Wrong initial size_hint for {src_dna:?}" + ); + while iter.next().is_some() { + assert!(expected_len > 0, "Iter too long for {src_dna:?}"); + expected_len -= 1; + assert_eq!( + iter.size_hint(), + (expected_len, Some(expected_len)), + "Wrong size_hint during iteration of {src_dna:?}" + ); + } + assert_eq!(expected_len, 0, "Iterator too short for {src_dna:?}"); + } + } + + #[test] + fn unambiguous_sequences_have_single_element() { + let src_dna = amb_dna("ATCGATATCGCGAATTCCGG"); + let mut expansions = src_dna.expansions(); + assert_eq!(expansions.size_hint(), (1, Some(1))); + assert_eq!(expansions.next().unwrap(), dna("ATCGATATCGCGAATTCCGG")); + assert_eq!(expansions.size_hint(), (0, Some(0))); + assert!(expansions.next().is_none()); + } + + #[test] + fn size_hint_handles_overflow() { + let src_dna = amb_dna("NNNNNNNNNNNNNNNNNNBBBBBBBBBBBBBBBBBY"); + assert_eq!( + src_dna.expansions().size_hint(), + (17748888853923495936, Some(17748888853923495936)) + ); + let src_dna = amb_dna("NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN"); + assert_eq!(src_dna.expansions().size_hint(), (usize::MAX, None)); + } + + #[test] + fn works_even_if_expansions_are_held() { + let src_dna = amb_dna("ATCGNGCTA"); + let expansions: Vec<_> = src_dna.expansions().collect(); + let expected = [ + dna("ATCGAGCTA"), + dna("ATCGTGCTA"), + dna("ATCGCGCTA"), + dna("ATCGGGCTA"), + ]; + assert_eq!(expansions, expected); + } + + #[test] + fn debug_expansion() { + let expansion = Expansion(Arc::from(dna("ATCG").as_slice())); + assert_eq!(format!("{expansion:?}"), "Expansion(\"ATCG\")"); + } + + #[test] + fn debug_expansions() { + let src_dna = amb_dna("ATNCG"); + let expansions = Expansions::new(src_dna.as_slice()); + assert_eq!(format!("{expansions:?}"), "Expansions(\"ATNCG\")"); + } +} diff --git a/src/lib.rs b/src/lib.rs index e39c70a..cacf280 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,8 @@ pub mod trans_table; // needs to be public for bin/gen_table mod extendable; pub use extendable::*; +pub mod expansions; + mod fasta; pub use fasta::*; diff --git a/src/rust_api.rs b/src/rust_api.rs index 92386ec..44b1472 100644 --- a/src/rust_api.rs +++ b/src/rust_api.rs @@ -14,6 +14,7 @@ pub use crate::nucleotide::{ pub use crate::trans_table::TranslationTable; use crate::Extendable; +use crate::expansions::Expansions; use crate::trans_table::reverse_complement; #[cfg(feature = "serde")] @@ -337,6 +338,16 @@ impl FromStr for DnaSequence { } } +impl DnaSequence { + /// Return all unambiguous expansions. + /// + /// Expansions are returned in lexicographic order based on the ordering of [`Nucleotide`] + /// (not currently alphabetical). + pub fn expansions(&self) -> Expansions { + Expansions::new(self.as_slice()) + } +} + #[cfg(test)] mod tests { use super::*;