From 9bb43af07176ddf11a3a07f3d7ef4fe1ff788551 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ivan=20Sch=C3=BCtz?= Date: Wed, 22 Jul 2020 15:14:25 +0200 Subject: [PATCH] Contact duration and distance (#123) * Compile Android specifics on desktop build too * Add duration and distance to DB, add FFI distance parameter * Revert "Compile Android specifics on desktop build too" Causing problems with CI, reason unclear as it's supposed to also use MacOs This reverts commit b48a649625b30908579d57c10db852d8ec3cb32c. * Clear db before each instrumentation test * Exposure grouping * Batched updates * Cleanup * Fix timer not firing * Set actual time * Regenerate iOS headers * Adjust JNI * Set timer period to 10s * Implement TCNs IN query * Separate exposures in DB, merge TCNs in batch * Remove not needed Arc * Add comments * Fix transaction function error doesn't make operation fail * Add average distance * Remove outdated function * Replace iterators with loop Should be better for performance * Add avg distance to Android interface * Release mutex lock when done accessing tcns * Remove unnecessary lock * More idiomatic way to release the lock * Improve logs --- Cargo.toml | 3 +- android/core/core/build.gradle | 6 + .../core/JNIInterfaceBootstrappedTests.kt | 2 +- .../java/org/coepi/core/JNIInterfaceTests.kt | 6 +- .../java/org/coepi/core/domain/model/Alert.kt | 5 +- .../main/java/org/coepi/core/jni/JniApi.kt | 7 +- .../org/coepi/core/services/AlertsFetcher.kt | 18 +- .../core/services/ObservedTcnsRecorder.kt | 6 +- src/android/android_interface.rs | 34 +- src/android/jni_domain_tests.rs | 5 +- src/composition_root.rs | 13 +- src/errors.rs | 13 +- src/ios/c_headers/coepicore.h | 80 +- src/ios/ios_interface.rs | 8 +- src/preferences.rs | 36 +- src/reports_updater.rs | 1333 +++++++++++++++-- 16 files changed, 1418 insertions(+), 157 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 25d9052..97d7502 100755 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,8 @@ rand = "0.7" hex = "0.4.2" serde-big-array = "0.3.0" rayon = "1.1" -rusqlite = {version = "0.23.1", features = ["bundled"]} +rusqlite = {version = "0.23.1", features = ["bundled", "vtab", "array"]} +timer = "0.2.0" [dependencies.reqwest] default-features = false # do not include the default features, and optionally diff --git a/android/core/core/build.gradle b/android/core/core/build.gradle index 3f71a01..dd34000 100644 --- a/android/core/core/build.gradle +++ b/android/core/core/build.gradle @@ -16,9 +16,14 @@ android { versionName "1.0" testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + testInstrumentationRunnerArguments clearPackageData: 'true' consumerProguardFiles 'consumer-rules.pro' } + testOptions { + execution 'androidx_test_orchestrator' + } + buildTypes { release { minifyEnabled false @@ -42,6 +47,7 @@ dependencies { testImplementation 'junit:junit:4.12' androidTestImplementation 'androidx.test.ext:junit:1.1.1' androidTestImplementation 'androidx.test.espresso:espresso-core:3.2.0' + androidTestUtil 'androidx.test:orchestrator:1.2.0' androidTestImplementation 'org.jetbrains.kotlinx:kotlinx-coroutines-test:1.3.0' diff --git a/android/core/core/src/androidTest/java/org/coepi/core/JNIInterfaceBootstrappedTests.kt b/android/core/core/src/androidTest/java/org/coepi/core/JNIInterfaceBootstrappedTests.kt index 0d92f31..936669f 100644 --- a/android/core/core/src/androidTest/java/org/coepi/core/JNIInterfaceBootstrappedTests.kt +++ b/android/core/core/src/androidTest/java/org/coepi/core/JNIInterfaceBootstrappedTests.kt @@ -68,7 +68,7 @@ class JNIInterfaceBootstrappedTests { @Test fun recordTcn() { - val value = JniApi().recordTcn("2485a64b57addcaea3ed1b538d07dbce") + val value = JniApi().recordTcn("2485a64b57addcaea3ed1b538d07dbce", 34.03f) assertEquals(JniVoidResult(1, ""), value) } diff --git a/android/core/core/src/androidTest/java/org/coepi/core/JNIInterfaceTests.kt b/android/core/core/src/androidTest/java/org/coepi/core/JNIInterfaceTests.kt index 386964f..a7bc551 100644 --- a/android/core/core/src/androidTest/java/org/coepi/core/JNIInterfaceTests.kt +++ b/android/core/core/src/androidTest/java/org/coepi/core/JNIInterfaceTests.kt @@ -48,7 +48,7 @@ class JNIInterfaceTests { runnyNose = true, other = false, noSymptoms = true - ), 1592567315 + ), 1592567315, 1592567335, 1.2f, 2.1f ) ), value @@ -75,7 +75,7 @@ class JNIInterfaceTests { runnyNose = true, other = false, noSymptoms = true - ), 1592567315 + ), 1592567315, 1592567335, 1.2f, 2.1f ), JniAlert( "343356", JniPublicReport( @@ -90,7 +90,7 @@ class JNIInterfaceTests { runnyNose = true, other = false, noSymptoms = true - ), 1592567315 + ), 1592567315, 1592567335, 1.2f, 2.1f ) ) ), diff --git a/android/core/core/src/main/java/org/coepi/core/domain/model/Alert.kt b/android/core/core/src/main/java/org/coepi/core/domain/model/Alert.kt index 3a4d712..edd9db1 100644 --- a/android/core/core/src/main/java/org/coepi/core/domain/model/Alert.kt +++ b/android/core/core/src/main/java/org/coepi/core/domain/model/Alert.kt @@ -17,7 +17,10 @@ data class Alert( val runnyNose: Boolean, val other: Boolean, val noSymptoms: Boolean, // https://github.com/Co-Epi/app-ios/issues/268#issuecomment-645583717 - var contactTime: UnixTime + var contactStart: UnixTime, + var contactEnd: UnixTime, + var minDistance: Float, + var avgDistance: Float ) : Parcelable enum class FeverSeverity { diff --git a/android/core/core/src/main/java/org/coepi/core/jni/JniApi.kt b/android/core/core/src/main/java/org/coepi/core/jni/JniApi.kt index 4e7b35c..d05bbfb 100644 --- a/android/core/core/src/main/java/org/coepi/core/jni/JniApi.kt +++ b/android/core/core/src/main/java/org/coepi/core/jni/JniApi.kt @@ -21,7 +21,7 @@ class JniApi { external fun generateTcn(): String - external fun recordTcn(tcn: String): JniVoidResult + external fun recordTcn(tcn: String, distance: Float): JniVoidResult // TODO test: external fun setBreathlessnessCause(cause: String): JniVoidResult @@ -133,7 +133,10 @@ data class JniAlertsArrayResult( data class JniAlert( var id: String, var report: JniPublicReport, - var contactTime: Long + var contactStart: Long, + var contactEnd: Long, + var minDistance: Float, + var avgDistance: Float ) data class JniPublicReport( diff --git a/android/core/core/src/main/java/org/coepi/core/services/AlertsFetcher.kt b/android/core/core/src/main/java/org/coepi/core/services/AlertsFetcher.kt index 0f8e500..e2d9302 100644 --- a/android/core/core/src/main/java/org/coepi/core/services/AlertsFetcher.kt +++ b/android/core/core/src/main/java/org/coepi/core/services/AlertsFetcher.kt @@ -36,9 +36,21 @@ class AlertsFetcherImpl(private val api: JniApi) : private fun JniAlert.toAlert() = Alert( id = id, - contactTime = when { - contactTime < 0 -> error("Invalid contact time: $contactTime") - else -> UnixTime.fromValue(contactTime) + contactStart = when { + contactStart < 0 -> error("Invalid contact start: $contactStart") + else -> UnixTime.fromValue(contactStart) + }, + contactEnd = when { + contactEnd < 0 -> error("Invalid contact end: $contactEnd") + else -> UnixTime.fromValue(contactEnd) + }, + minDistance = when { + minDistance < 0 -> error("Invalid min distance: $minDistance") + else -> minDistance + }, + avgDistance = when { + avgDistance < 0 -> error("Invalid avg distance: $avgDistance") + else -> avgDistance }, reportTime = when { report.reportTime < 0 -> error("Invalid report time: ${report.reportTime}") diff --git a/android/core/core/src/main/java/org/coepi/core/services/ObservedTcnsRecorder.kt b/android/core/core/src/main/java/org/coepi/core/services/ObservedTcnsRecorder.kt index 2524d94..3518ef5 100644 --- a/android/core/core/src/main/java/org/coepi/core/services/ObservedTcnsRecorder.kt +++ b/android/core/core/src/main/java/org/coepi/core/services/ObservedTcnsRecorder.kt @@ -6,11 +6,11 @@ import org.coepi.core.domain.model.Tcn import org.coepi.core.domain.common.Result interface ObservedTcnsRecorder { - fun recordTcn(tcn: Tcn): Result + fun recordTcn(tcn: Tcn, distance: Float): Result } class ObservedTcnsRecorderImpl(private val api: JniApi) : ObservedTcnsRecorder { - override fun recordTcn(tcn: Tcn): Result = - api.recordTcn(tcn.toHex()).asResult() + override fun recordTcn(tcn: Tcn, distance: Float): Result = + api.recordTcn(tcn.toHex(), distance).asResult() } diff --git a/src/android/android_interface.rs b/src/android/android_interface.rs index a4c2be3..0fff1b3 100644 --- a/src/android/android_interface.rs +++ b/src/android/android_interface.rs @@ -80,8 +80,9 @@ pub unsafe extern "C" fn Java_org_coepi_core_jni_JniApi_recordTcn( env: JNIEnv, _: JClass, tcn: JString, + distance: jfloat, ) -> jobject { - recordTcn(&env, tcn).to_void_jni(&env) + recordTcn(&env, tcn, distance).to_void_jni(&env) } // NOTE: Returns directly success string @@ -262,11 +263,13 @@ fn fetch_new_reports(env: &JNIEnv) -> Result { alerts_to_jobject_array(result, &env) } -fn recordTcn(env: &JNIEnv, tcn: JString) -> Result<(), ServicesError> { +fn recordTcn(env: &JNIEnv, tcn: JString, distance: jfloat) -> Result<(), ServicesError> { let tcn_java_str = env.get_string(tcn)?; let tcn_str = tcn_java_str.to_str()?; - let result = dependencies().observed_tcn_processor.save(tcn_str); + let result = dependencies() + .observed_tcn_processor + .save(tcn_str, distance as f32); info!("Recording TCN result {:?}", result); result @@ -485,8 +488,8 @@ impl LogCallbackWrapper for LogCallbackWrapperImpl { // Note that if we panic, LogCat will also not show a message, or location. // TODO consider writing to file. Otherwise it's impossible to notice this. Err(e) => println!( - "Couldn't get env: Can't send log: level: {}, text: {}", - level, text, + "Couldn't get env: Can't send log: level: {}, text: {}, e: {}", + level, text, e ), } } @@ -566,7 +569,10 @@ fn placeholder_alert() -> Alert { Alert { id: "0".to_owned(), report, - contact_time: 0, + contact_start: 0, + contact_end: 0, + min_distance: 0.0, + avg_distance: 0.0, } } @@ -627,16 +633,22 @@ pub fn alert_to_jobject(alert: Alert, env: &JNIEnv) -> Result = env .new_object( jni_alert_class, - "(Ljava/lang/String;Lorg/coepi/core/jni/JniPublicReport;J)V", + "(Ljava/lang/String;Lorg/coepi/core/jni/JniPublicReport;JJFF)V", &[ id_j_value, JValue::from(jni_public_report_obj), - earliest_time_j_value, + contact_start_j_value, + contact_end_j_value, + min_distance_j_value, + avg_distance_j_value, ], ) .map(|o| o.into_inner()); @@ -682,6 +694,10 @@ impl JniErrorMappable for ServicesError { status: 5, message: msg.to_owned(), }, + ServicesError::NotFound => JniError { + status: 6, + message: "Not found".to_owned(), + }, } } } diff --git a/src/android/jni_domain_tests.rs b/src/android/jni_domain_tests.rs index bfb81ef..d3189b2 100644 --- a/src/android/jni_domain_tests.rs +++ b/src/android/jni_domain_tests.rs @@ -81,6 +81,9 @@ fn create_test_alert(id: &str, report_time: u64) -> Alert { Alert { id: id.to_owned(), report, - contact_time: 1592567315, + contact_start: 1592567315, + contact_end: 1592567335, + min_distance: 1.2, + avg_distance: 2.1, } } diff --git a/src/composition_root.rs b/src/composition_root.rs index 6c08661..e0d5901 100644 --- a/src/composition_root.rs +++ b/src/composition_root.rs @@ -1,7 +1,7 @@ use crate::networking::{TcnApi, TcnApiImpl}; use crate::reports_updater::{ - ObservedTcnProcessor, ObservedTcnProcessorImpl, ReportsUpdater, TcnDao, TcnDaoImpl, TcnMatcher, - TcnMatcherRayon, + ExposureGrouper, ObservedTcnProcessor, ObservedTcnProcessorImpl, ReportsUpdater, + TcnBatchesManager, TcnDao, TcnDaoImpl, TcnMatcher, TcnMatcherRayon, }; use crate::{ errors::ServicesError, @@ -152,6 +152,7 @@ fn create_comp_root( }; let tcn_dao = Arc::new(TcnDaoImpl::new(database.clone())); + let exposure_grouper = ExposureGrouper { threshold: 3600 }; CompositionRoot { api, @@ -161,6 +162,7 @@ fn create_comp_root( tcn_matcher: TcnMatcherRayon {}, api, memo_mapper, + exposure_grouper: exposure_grouper.clone(), }, symptom_inputs_processor: SymptomInputsProcessorImpl { inputs_manager: SymptomInputsManagerImpl { @@ -168,9 +170,10 @@ fn create_comp_root( inputs_submitter: symptom_inputs_submitter, }, }, - observed_tcn_processor: ObservedTcnProcessorImpl { - tcn_dao: tcn_dao.clone(), - }, + observed_tcn_processor: ObservedTcnProcessorImpl::new(TcnBatchesManager::new( + tcn_dao.clone(), + exposure_grouper.clone(), + )), tcn_keys: tcn_keys.clone(), } } diff --git a/src/errors.rs b/src/errors.rs index 0c6c4c7..03af877 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,4 +1,5 @@ use crate::networking::NetworkingError; +use rusqlite::Error::QueryReturnedNoRows; use std::{error, fmt, io::Error as StdError, io::ErrorKind}; use tcn::Error as TcnError; pub type Error = Box; @@ -8,6 +9,7 @@ pub enum ServicesError { Networking(NetworkingError), Error(Error), FFIParameters(String), + NotFound, General(String), } @@ -79,10 +81,13 @@ impl From for ServicesError { impl From for ServicesError { fn from(error: rusqlite::Error) -> Self { - ServicesError::Error(Box::new(StdError::new( - ErrorKind::Other, - format!("{}", error), - ))) + match error { + QueryReturnedNoRows => ServicesError::NotFound, + _ => ServicesError::Error(Box::new(StdError::new( + ErrorKind::Other, + format!("{}", error), + ))), + } } } diff --git a/src/ios/c_headers/coepicore.h b/src/ios/c_headers/coepicore.h index 3f0dfa0..57de0ab 100644 --- a/src/ios/c_headers/coepicore.h +++ b/src/ios/c_headers/coepicore.h @@ -3,96 +3,150 @@ #define TCK_SIZE_IN_BYTES 66 enum CoreLogLevel { - Trace, - Debug, - Info, - Warn, - Error, + Trace = 0, + Debug = 1, + Info = 2, + Warn = 3, + Error = 4, }; typedef uint8_t CoreLogLevel; +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) typedef struct { uint8_t my_u8; } FFINestedReturnStruct; +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) typedef struct { int32_t my_int; CFStringRef my_str; FFINestedReturnStruct my_nested; } FFIReturnStruct; +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) typedef struct { uint8_t my_u8; } FFINestedParameterStruct; +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) typedef struct { int32_t my_int; const char *my_str; FFINestedParameterStruct my_nested; } FFIParameterStruct; +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) typedef struct { CoreLogLevel level; CFStringRef text; int64_t time; } CoreLogMessage; +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) CFStringRef bootstrap_core(const char *db_path, CoreLogLevel level, bool coepi_only); +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) int32_t call_callback(void (*callback)(int32_t, bool, CFStringRef)); +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) CFStringRef clear_symptoms(void); +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) CFStringRef fetch_new_reports(void); +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) CFStringRef generate_tcn(void); +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) FFIReturnStruct pass_and_return_struct(const FFIParameterStruct *par); +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) int32_t pass_struct(const FFIParameterStruct *par); +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) CFStringRef post_report(const char *c_report); +#endif -CFStringRef record_tcn(const char *c_tcn); +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) +CFStringRef record_tcn(const char *c_tcn, float distance); +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) int32_t register_callback(void (*callback)(int32_t, bool, CFStringRef)); +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) int32_t register_log_callback(void (*log_callback)(CoreLogMessage)); - -FFIReturnStruct return_struct(void); - -#if defined(TARGET_OS_ANDROID) -char *rust_greeting(const char *to); #endif -#if defined(TARGET_OS_ANDROID) -char *rust_greeting2(const char *to); +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) +FFIReturnStruct return_struct(void); #endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) CFStringRef set_breathlessness_cause(const char *c_cause); +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) CFStringRef set_cough_days(uint8_t c_is_set, uint32_t c_days); +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) CFStringRef set_cough_status(const char *c_status); +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) CFStringRef set_cough_type(const char *c_cough_type); +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) CFStringRef set_earliest_symptom_started_days_ago(uint8_t c_is_set, uint32_t c_days); +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) CFStringRef set_fever_days(uint8_t c_is_set, uint32_t c_days); +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) CFStringRef set_fever_highest_temperature_taken(uint8_t c_is_set, float c_temp); +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) CFStringRef set_fever_taken_temperature_spot(const char *c_cause); +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) CFStringRef set_fever_taken_temperature_today(uint8_t c_is_set, uint8_t c_taken); +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) CFStringRef set_symptom_ids(const char *c_ids); +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) int32_t setup_logger(CoreLogLevel level, bool coepi_only); +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) CFStringRef submit_symptoms(void); +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) int32_t trigger_callback(const char *my_str); +#endif +#if (defined(TARGET_OS_IOS) || defined(TARGET_OS_MACOS)) int32_t trigger_logging_macros(void); +#endif diff --git a/src/ios/ios_interface.rs b/src/ios/ios_interface.rs index daacb62..80c79da 100644 --- a/src/ios/ios_interface.rs +++ b/src/ios/ios_interface.rs @@ -65,9 +65,13 @@ pub unsafe extern "C" fn fetch_new_reports() -> CFStringRef { } #[no_mangle] -pub unsafe extern "C" fn record_tcn(c_tcn: *const c_char) -> CFStringRef { +pub unsafe extern "C" fn record_tcn(c_tcn: *const c_char, distance: f32) -> CFStringRef { let tcn_str = cstring_to_str(&c_tcn); - let result = tcn_str.and_then(|tcn_str| dependencies().observed_tcn_processor.save(tcn_str)); + let result = tcn_str.and_then(|tcn_str| { + dependencies() + .observed_tcn_processor + .save(tcn_str, distance) + }); info!("Recording TCN result {:?}", result); return to_result_str(result); } diff --git a/src/preferences.rs b/src/preferences.rs index 484c6d2..6661b6d 100644 --- a/src/preferences.rs +++ b/src/preferences.rs @@ -1,6 +1,9 @@ -use crate::{byte_vec_to_32_byte_array, expect_log, reports_interval::ReportsInterval}; +use crate::{ + byte_vec_to_32_byte_array, errors::ServicesError, expect_log, reports_interval::ReportsInterval, +}; use log::*; -use rusqlite::{params, Connection, Row, ToSql}; +use rusqlite::{params, Connection, Result}; +use rusqlite::{Row, ToSql, Transaction}; use serde::{Deserialize, Serialize}; use std::fmt; use std::{ @@ -152,7 +155,36 @@ impl Database { conn.query_row(sql, params, f) } + pub fn transaction(&self, f: F) -> Result<(), ServicesError> + where + F: FnOnce(&Transaction) -> Result<(), ServicesError>, + { + let conn_res = self.conn.lock(); + let mut conn = expect_log!(conn_res, "Couldn't lock connection"); + + let t = conn.transaction()?; + match f(&t) { + Ok(_) => t.commit().map_err(ServicesError::from), + Err(commit_error) => { + let rollback_res = t.rollback(); + if rollback_res.is_err() { + // As we're already returning error status, show only a log for rollback error. + error!( + "There was an error committing and rollback failed too with: {:?}", + rollback_res + ); + } + Err(commit_error) + } + } + } + pub fn new(conn: Connection) -> Database { + let load_array_mod_res = rusqlite::vtab::array::load_module(&conn); + expect_log!( + load_array_mod_res, + "Couldn't load array module (needed for IN query)" + ); Database { conn: Mutex::new(conn), } diff --git a/src/reports_updater.rs b/src/reports_updater.rs index 69df60b..f34e1d2 100644 --- a/src/reports_updater.rs +++ b/src/reports_updater.rs @@ -10,15 +10,20 @@ use crate::{ }, reports_interval, }; -use chrono::Utc; +use exposure::Exposure; use log::*; use rayon::prelude::*; use reports_interval::{ReportsInterval, UnixTime}; -use rusqlite::{params, Row, NO_PARAMS}; +use rusqlite::{params, Row, NO_PARAMS, types::Value}; use serde::Serialize; use std::collections::HashMap; -use std::{io::Cursor, sync::Arc, time::Instant}; +use std::{ + io::Cursor, + sync::{Arc, Mutex}, + time::Instant, rc::Rc, +}; use tcn::{SignedReport, TemporaryContactNumber}; +use timer::{Guard, Timer}; pub trait TcnMatcher { fn match_reports( @@ -31,7 +36,7 @@ pub trait TcnMatcher { #[derive(Debug, Clone)] pub struct MatchedReport { report: SignedReport, - contact_time: UnixTime, + tcns: Vec, } pub struct TcnMatcherRayon {} @@ -76,17 +81,20 @@ impl TcnMatcherRayon { let rep = report.clone().verify(); match rep { Ok(rep) => { - let mut out: Option = None; + let mut tcns: Vec = vec![]; for tcn in rep.temporary_contact_numbers() { - if let Some(entry) = observed_tcns_map.get(&tcn.0) { - out = Some(MatchedReport { - report: report.clone(), - contact_time: entry.time.clone(), - }); - break; + if let Some(observed_tcn) = observed_tcns_map.get(&tcn.0) { + tcns.push(observed_tcn.to_owned()); } } - out + if tcns.is_empty() { + None + } else { + Some(MatchedReport { + report: report.clone(), + tcns, + }) + } } Err(error) => { error!("Report can't be matched. Verification failed: {:?}", error); @@ -96,45 +104,244 @@ impl TcnMatcherRayon { } } -#[derive(Debug, Eq, PartialEq, Clone)] +#[derive(Debug, PartialEq, Clone)] pub struct ObservedTcn { tcn: TemporaryContactNumber, - time: UnixTime, + contact_start: UnixTime, + contact_end: UnixTime, + min_distance: f32, + avg_distance: f32, + total_count: usize // Needed to calculate correctly average of averages (= average of single values) } pub trait ObservedTcnProcessor { - fn save(&self, tcn_str: &str) -> Result<(), ServicesError>; + fn save(&self, tcn_str: &str, distance: f32) -> Result<(), ServicesError>; } -pub struct ObservedTcnProcessorImpl +#[derive(Copy, Clone)] +struct Element {} + +pub struct TcnBatchesManager where T: TcnDao, { - pub tcn_dao: Arc, + tcn_dao: Arc, + tcns_batch: Mutex>, + exposure_grouper: ExposureGrouper, +} + +impl TcnBatchesManager +where + T: 'static + TcnDao, +{ + pub fn new(tcn_dao: Arc, exposure_grouper: ExposureGrouper) -> TcnBatchesManager { + TcnBatchesManager { + tcn_dao, + tcns_batch: Mutex::new(HashMap::new()), + exposure_grouper, + } + } + + pub fn flush(&self) -> Result<(), ServicesError> { + + let tcns = { + let res = self.tcns_batch.lock(); + let mut tcns = expect_log!(res, "Couldn't lock tcns batch"); + let clone = tcns.clone(); + tcns.clear(); + clone + }; + + debug!("Flushing TCN batch into database: {:?}", tcns); + + // Do an in-memory merge with the DB TCNs and overwrite stored exposures with result. + let merged = self.merge_with_db(tcns)?; + self.tcn_dao.overwrite(merged)?; + + Ok(()) + } + + pub fn push(&self, tcn: ObservedTcn) { + let res = self.tcns_batch.lock(); + let mut tcns = expect_log!(res, "Couldn't lock tcns batch"); + + // TCNs in batch are merged to save memory and simplify processing / reduce logs. + let merged_tcn = match tcns.get(&tcn.tcn.0) { + Some(existing_tcn) => match Self::merge_tcns(&self.exposure_grouper, existing_tcn.to_owned(), tcn.clone()) { + Some(merged) => merged, + None => tcn + }, + None => tcn + }; + tcns.insert(merged_tcn.tcn.0, merged_tcn); + + // debug!("Updated TCNs batch: {:?}", tcns); + } + + // Used only in tests + #[allow(dead_code)] + pub fn len(&self) -> Result { + match self.tcns_batch.lock() { + Ok(tcns) => Ok(tcns.len()), + Err(_) => Err(ServicesError::General("Couldn't lock tcns batch".to_owned())) + } + } + + // Retrieves possible existing exposures from DB with same TCNs and does an in-memory merge. + fn merge_with_db(&self, tcns: HashMap<[u8; 16], ObservedTcn>) -> Result, ServicesError> { + let tcns_vec: Vec = tcns.iter().map(|(_, observed_tcn)| observed_tcn.clone()).collect(); + + let mut db_tcns = self + .tcn_dao + .find_tcns(tcns_vec.clone().into_iter().map(|tcn| tcn.tcn).collect())?; + db_tcns.sort_by_key(|tcn| tcn.contact_start.value); + + let db_tcns_map: HashMap<[u8; 16], Vec> = Self::to_hash_map(db_tcns); + + Ok(tcns_vec.into_iter().map(|tcn| + // Values in db_tcns_map can't be empty: we built the map based on existing TCNs + Self::determine_tcns_to_write(self.exposure_grouper.clone(), db_tcns_map.clone(), tcn) + ).flatten().collect()) + } + + // Expects: + // - Values in db_tcns_map not empty + // - db_tcns_map sorted by contact_start (ascending) + fn determine_tcns_to_write(exposure_grouper: ExposureGrouper, db_tcns_map: HashMap<[u8; 16], Vec>, tcn: ObservedTcn) -> Vec { + let db_tcns = db_tcns_map.get(&tcn.tcn.0); + + match db_tcns { + // Matching exposures in DB + Some(db_tcns) => { + if let Some(last) = db_tcns.last() { + + // If contiguous to last DB exposure, merge with it, otherwise append. + let tail = match Self::merge_tcns(&exposure_grouper, last.to_owned(), tcn.clone()) { + Some(merged ) => vec![merged], + None => vec![last.to_owned(), tcn] + }; + let mut head: Vec = db_tcns.to_owned().into_iter().take(db_tcns.len() - 1).collect(); + head.extend(tail); + head + + } else { + error!("Illegal state: value in db_tcns_map is empty"); + panic!(); + } + } + // No matching exposures in DB: insert new TCN + None => vec![tcn] + } + } + + fn to_hash_map(tcns: Vec) -> HashMap<[u8; 16], Vec> { + let mut map: HashMap<[u8; 16], Vec> = HashMap::new(); + for tcn in tcns { + match map.get_mut(&tcn.tcn.0) { + Some(tcns) => tcns.push(tcn), + None => { + map.insert(tcn.tcn.0, vec![tcn]); + } + } + } + map + } + + // Returns a merged TCN, if the TCNs are contiguous, None otherwise. + // Assumes: tcn contact_start after db_tcn contact_start + fn merge_tcns( + exposure_grouper: &ExposureGrouper, + db_tcn: ObservedTcn, + tcn: ObservedTcn, + ) -> Option { + if exposure_grouper.is_contiguous(&db_tcn, &tcn) { + // Put db TCN and new TCN in an exposure as convenience to re-calculate measurements. + let mut exposure = Exposure::create(db_tcn); + exposure.push(tcn.clone()); + let measurements = exposure.measurements(); + Some(ObservedTcn { + tcn: tcn.tcn, + contact_start: measurements.contact_start, + contact_end: measurements.contact_end, + min_distance: measurements.min_distance, + avg_distance: measurements.avg_distance, + total_count: measurements.total_count + }) + } else { + None + } + } +} + +pub struct ObservedTcnProcessorImpl +where + T: 'static + TcnDao, +{ + tcn_batches_manager: Arc>, + _timer_data: TimerData +} + +struct TimerData { + _timer: Arc>, + _guard: Guard +} + +impl ObservedTcnProcessorImpl +where + T: 'static + TcnDao, +{ + pub fn new(tcn_batches_manager: TcnBatchesManager) -> ObservedTcnProcessorImpl { + let tcn_batches_manager = Arc::new(tcn_batches_manager); + let instance = ObservedTcnProcessorImpl { + tcn_batches_manager: tcn_batches_manager.clone(), + _timer_data: Self::schedule_process_batches(tcn_batches_manager) + }; + instance + } + + fn schedule_process_batches(tcn_batches_manager: Arc>) -> TimerData { + let timer = Arc::new(Mutex::new(Timer::new())); + TimerData { + _timer: timer.clone(), + _guard: timer.clone().lock().unwrap().schedule_repeating(chrono::Duration::seconds(10), move || { + let flush_res = tcn_batches_manager.flush(); + expect_log!(flush_res, "Couldn't flush TCNs"); + }) + } + } } impl ObservedTcnProcessor for ObservedTcnProcessorImpl where - T: TcnDao, + T: TcnDao + Sync + Send, { - fn save(&self, tcn_str: &str) -> Result<(), ServicesError> { - info!("Recording a TCN {:?}", tcn_str); + fn save(&self, tcn_str: &str, distance: f32) -> Result<(), ServicesError> { + info!("Recording a TCN {:?}, distance: {}", tcn_str, distance); let bytes_vec: Vec = hex::decode(tcn_str)?; let observed_tcn = ObservedTcn { tcn: TemporaryContactNumber(byte_vec_to_16_byte_array(bytes_vec)), - time: UnixTime { - value: Utc::now().timestamp() as u64, - }, + contact_start: UnixTime::now(), + contact_end: UnixTime::now(), + min_distance: distance, + avg_distance: distance, + total_count: 1, }; - self.tcn_dao.save(&observed_tcn) + self.tcn_batches_manager.push(observed_tcn); + + Ok(()) } } -pub trait TcnDao { +pub trait TcnDao: Send + Sync { fn all(&self) -> Result, ServicesError>; - fn save(&self, observed_tcn: &ObservedTcn) -> Result<(), ServicesError>; + fn find_tcns( + &self, + with: Vec, + ) -> Result, ServicesError>; + // Removes all matching TCNs (same TCN bytes) and stores observed_tcns + fn overwrite(&self, observed_tcns: Vec) -> Result<(), ServicesError>; } pub struct TcnDaoImpl { @@ -148,7 +355,11 @@ impl TcnDaoImpl { let res = db.execute_sql( "create table if not exists tcn( tcn text not null, - contact_time integer not null + contact_start integer not null, + contact_end integer not null, + min_distance real not null, + avg_distance real not null, + total_count integer not null )", params![], ); @@ -157,20 +368,46 @@ impl TcnDaoImpl { fn to_tcn(row: &Row) -> ObservedTcn { let tcn: Result = row.get(0); - let contact_time = row.get(1); let tcn_value = expect_log!(tcn, "Invalid row: no TCN"); - let tcn_value_bytes_vec_res = hex::decode(tcn_value); - let tcn_value_bytes_vec = expect_log!(tcn_value_bytes_vec_res, "Invalid stored TCN format"); - let tcn_value_bytes = byte_vec_to_16_byte_array(tcn_value_bytes_vec); - let contact_time_value: i64 = expect_log!(contact_time, "Invalid row: no contact time"); + let tcn = Self::db_tcn_str_to_tcn(tcn_value); + + let contact_start_res = row.get(1); + let contact_start: i64 = expect_log!(contact_start_res, "Invalid row: no contact start"); + + let contact_end_res = row.get(2); + let contact_end: i64 = expect_log!(contact_end_res, "Invalid row: no contact end"); + + let min_distance_res = row.get(3); + let min_distance: f64 = expect_log!(min_distance_res, "Invalid row: no min distance"); + + let avg_distance_res = row.get(4); + let avg_distance: f64 = expect_log!(avg_distance_res, "Invalid row: no avg distance"); + + let total_count_res = row.get(5); + let total_count: i64 = expect_log!(total_count_res, "Invalid row: no total count"); + ObservedTcn { - tcn: TemporaryContactNumber(tcn_value_bytes), - time: UnixTime { - value: contact_time_value as u64, + tcn, + contact_start: UnixTime { + value: contact_start as u64, + }, + contact_end: UnixTime { + value: contact_end as u64, }, + min_distance: min_distance as f32, + avg_distance: avg_distance as f32, + total_count: total_count as usize, } } + // TCN string loaded from DB is assumed to be valid + fn db_tcn_str_to_tcn(str: String) -> TemporaryContactNumber { + let tcn_value_bytes_vec_res = hex::decode(str); + let tcn_value_bytes_vec = expect_log!(tcn_value_bytes_vec_res, "Invalid stored TCN format"); + let tcn_value_bytes = byte_vec_to_16_byte_array(tcn_value_bytes_vec); + TemporaryContactNumber(tcn_value_bytes) + } + pub fn new(db: Arc) -> TcnDaoImpl { Self::create_table_if_not_exists(&db); TcnDaoImpl { db } @@ -180,22 +417,67 @@ impl TcnDaoImpl { impl TcnDao for TcnDaoImpl { fn all(&self) -> Result, ServicesError> { self.db - .query("select tcn, contact_time from tcn", NO_PARAMS, |row| { - Self::to_tcn(row) - }) + .query( + "select tcn, contact_start, contact_end, min_distance, avg_distance, total_count from tcn", + NO_PARAMS, + |row| Self::to_tcn(row), + ) .map_err(ServicesError::from) } - fn save(&self, observed_tcn: &ObservedTcn) -> Result<(), ServicesError> { - let tcn_str = hex::encode(observed_tcn.tcn.0); + fn find_tcns( + &self, + with: Vec, + ) -> Result, ServicesError> { + let tcn_strs: Vec = with.into_iter().map(|tcn| + Value::Text(hex::encode(tcn.0)) + ) + .collect(); - let res = self.db.execute_sql( - "insert or replace into tcn(tcn, contact_time) values(?1, ?2)", - // conversion to signed timestamp is safe, for obvious reasons. - params![tcn_str, observed_tcn.time.value as i64], - ); - expect_log!(res, "Couldn't insert tcn"); - Ok(()) + self.db + .query( + "select tcn, contact_start, contact_end, min_distance, avg_distance, total_count from tcn where tcn in rarray(?);", + params![Rc::new(tcn_strs)], + |row| Self::to_tcn(row), + ) + .map_err(ServicesError::from) + } + + fn overwrite(&self, observed_tcns: Vec) -> Result<(), ServicesError> { + debug!("Overwriting db exposures with same TCNs, with: {:?}", observed_tcns); + + let tcn_strs: Vec = observed_tcns.clone().into_iter().map(|tcn| + Value::Text(hex::encode(tcn.tcn.0)) + ) + .collect(); + + self.db.transaction(|t| { + // Delete all the exposures for TCNs + let delete_res = t.execute("delete from tcn where tcn in rarray(?);", params![Rc::new(tcn_strs)]); + if delete_res.is_err() { + return Err(ServicesError::General("Delete TCNs failed".to_owned())) + } + + // Insert up to date exposures + for tcn in observed_tcns { + let tcn_str = hex::encode(tcn.tcn.0); + let insert_res = t.execute("insert into tcn(tcn, contact_start, contact_end, min_distance, avg_distance, total_count) values(?1, ?2, ?3, ?4, ?5, ?6)", + params![ + tcn_str, + tcn.contact_start.value as i64, + tcn.contact_end.value as i64, + tcn.min_distance as f64, // db requires f64 / real + tcn.avg_distance as f64, // db requires f64 / real + tcn.total_count as i64 + ]); + + if insert_res.is_err() { + return Err(ServicesError::General("Insert TCN failed".to_owned())) + } + } + + Ok(()) + }) } } @@ -220,15 +502,125 @@ impl ByteArrayMappable for u64 { pub struct Alert { pub id: String, pub report: PublicReport, - pub contact_time: u64, + pub contact_start: u64, + pub contact_end: u64, + pub min_distance: f32, + pub avg_distance: f32, } -pub struct ReportsUpdater<'a, T: Preferences, U: TcnDao, V: TcnMatcher, W: TcnApi, X: MemoMapper> { - pub preferences: Arc, - pub tcn_dao: Arc, - pub tcn_matcher: V, - pub api: &'a W, - pub memo_mapper: &'a X, +mod exposure { + use super::ObservedTcn; + use crate::{errors::ServicesError, reports_interval::UnixTime}; + + #[derive(PartialEq, Debug)] + pub struct Exposure { + // Can't be empty + tcns: Vec, + } + + impl Exposure { + pub fn create(tcn: ObservedTcn) -> Exposure { + Exposure { tcns: vec![tcn] } + } + + // Only used in tests + #[allow(dead_code)] + pub fn create_with_tcns(tcns: Vec) -> Result { + if tcns.is_empty() { + Err(ServicesError::General( + "Exposure can't be created without TCNs.".to_owned(), + )) + } else { + Ok(Exposure { tcns }) + } + } + + pub fn push(&mut self, tcn: ObservedTcn) { + self.tcns.push(tcn); + } + + pub fn last(&self) -> ObservedTcn { + // Unwrap: struct guarantees that tcns can't be empty. + self.tcns.last().unwrap().clone() + } + + pub fn measurements(&self) -> ExposureMeasurements { + let mut tcns = self.tcns.clone(); + tcns.sort_by_key(|tcn| tcn.contact_start.value); + + let first_tcn = tcns + .first() + .expect("Invalid state: struct guarantees that tcns can't be empty"); + + let contact_start = first_tcn.contact_start.value; + let contact_end = tcns.last().unwrap_or(first_tcn).contact_end.value; + + let mut min_distance = std::f32::MAX; + let mut total_count: usize = 0; + let mut avg_distance = 0.0; + for tcn in tcns { + min_distance = f32::min(min_distance, tcn.min_distance); + total_count += tcn.total_count; + avg_distance += tcn.avg_distance * tcn.total_count as f32; + } + // Note: this struct (Exposure) guarantees that TCNs can't be empty, + // so don't have to check for 0 division. + avg_distance /= total_count as f32; + + ExposureMeasurements { + contact_start: UnixTime { + value: contact_start, + }, + contact_end: UnixTime { value: contact_end }, + min_distance, + avg_distance, + total_count + } + } + } + pub struct ExposureMeasurements { + pub contact_start: UnixTime, + pub contact_end: UnixTime, + pub min_distance: f32, + pub avg_distance: f32, + pub total_count: usize + } +} + +// Groups TCNs by contiguity. +#[derive(Clone)] +pub struct ExposureGrouper { + pub threshold: u64, +} + +impl ExposureGrouper { + fn group(&self, mut tcns: Vec) -> Vec { + tcns.sort_by_key(|tcn| tcn.contact_start.value); + + let mut exposures: Vec = vec![]; + for tcn in tcns { + match exposures.last_mut() { + Some(last_group) => { + if self.is_contiguous(&last_group.last(), &tcn) { + last_group.push(tcn) + } else { + exposures.push(Exposure::create(tcn)); + } + } + None => exposures.push(Exposure::create(tcn)), + } + } + exposures + } + + // Notes: + // - Expects tcn2.start > tcn1.start. If will return otherwise always true. + // - Overlapping is considered contiguous. + // (Note that depending on the implementation of writes, overlaps may not be possible.) + fn is_contiguous(&self, tcn1: &ObservedTcn, tcn2: &ObservedTcn) -> bool { + // Signed: overlap (start2 < end1) considered contiguous. + (tcn2.contact_start.value as i64 - tcn1.contact_end.value as i64) < self.threshold as i64 + } } trait SignedReportExt { @@ -248,9 +640,17 @@ trait SignedReportExt { .ok() } } - impl SignedReportExt for SignedReport {} +pub struct ReportsUpdater<'a, T: Preferences, U: TcnDao, V: TcnMatcher, W: TcnApi, X: MemoMapper> { + pub preferences: Arc, + pub tcn_dao: Arc, + pub tcn_matcher: V, + pub api: &'a W, + pub memo_mapper: &'a X, + pub exposure_grouper: ExposureGrouper, +} + impl<'a, T, U, V, W, X> ReportsUpdater<'a, T, U, V, W, X> where T: Preferences, @@ -263,24 +663,43 @@ where self.retrieve_and_match_new_reports().map(|signed_reports| { signed_reports .into_iter() - .filter_map(|matched_report| self.to_ffi_alert(matched_report).ok()) + .filter_map(|matched_report| self.to_ffi_alerts(matched_report).ok()) + .flatten() .collect() }) } // Note: For now we will not create an FFI layer to handle JSON conversions, since it may be possible // to use directly the data structures. - fn to_ffi_alert(&self, matched_report: MatchedReport) -> Result { - let report = matched_report.report.clone().verify()?; + fn to_ffi_alerts(&self, matched_report: MatchedReport) -> Result, ServicesError> { + let exposures = self.exposure_grouper.group(matched_report.clone().tcns); + + exposures + .into_iter() + .map(|exposure_tcns| self.to_alert(matched_report.report.clone(), exposure_tcns)) + .collect() + } + + fn to_alert( + &self, + signed_report: SignedReport, + exposure: Exposure, + ) -> Result { + let report = signed_report.clone().verify()?; let public_report = self.memo_mapper.to_report(Memo { bytes: report.memo_data().to_vec(), }); + let measurements = exposure.measurements(); + Ok(Alert { - id: format!("{:?}", matched_report.report.sig), + id: format!("{:?}", signed_report.sig), report: public_report, - contact_time: matched_report.contact_time.value, + contact_start: measurements.contact_start.value, + contact_end: measurements.contact_end.value, + min_distance: measurements.min_distance, + avg_distance: measurements.avg_distance }) } @@ -574,7 +993,6 @@ mod tests { } #[test] - #[ignore] fn saves_and_loads_observed_tcn() { let database = Arc::new(Database::new( Connection::open_in_memory().expect("Couldn't create database!"), @@ -585,10 +1003,14 @@ mod tests { tcn: TemporaryContactNumber([ 24, 229, 125, 245, 98, 86, 219, 221, 172, 25, 232, 150, 206, 66, 164, 173, ]), - time: UnixTime { value: 1590528300 }, + contact_start: UnixTime { value: 1590528300 }, + contact_end: UnixTime { value: 1590528301 }, + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, }; - let save_res = tcn_dao.save(&observed_tcn); + let save_res = tcn_dao.overwrite(vec![observed_tcn.clone()]); assert!(save_res.is_ok()); let loaded_tcns_res = tcn_dao.all(); @@ -601,7 +1023,6 @@ mod tests { } #[test] - #[ignore] fn saves_and_loads_multiple_tcns() { let database = Arc::new(Database::new( Connection::open_in_memory().expect("Couldn't create database!"), @@ -612,24 +1033,36 @@ mod tests { tcn: TemporaryContactNumber([ 24, 229, 125, 245, 98, 86, 219, 221, 172, 25, 232, 150, 206, 66, 164, 173, ]), - time: UnixTime { value: 1590528300 }, + contact_start: UnixTime { value: 1590528300 }, + contact_end: UnixTime { value: 1590528301 }, + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, }; let observed_tcn_2 = ObservedTcn { tcn: TemporaryContactNumber([ 43, 229, 125, 245, 98, 86, 100, 1, 172, 25, 0, 150, 123, 66, 34, 12, ]), - time: UnixTime { value: 1590518190 }, + contact_start: UnixTime { value: 1590518190 }, + contact_end: UnixTime { value: 1590518191 }, + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, }; let observed_tcn_3 = ObservedTcn { tcn: TemporaryContactNumber([ 11, 246, 125, 123, 102, 86, 100, 1, 34, 25, 21, 150, 99, 66, 34, 0, ]), - time: UnixTime { value: 2230522104 }, + contact_start: UnixTime { value: 2230522104 }, + contact_end: UnixTime { value: 2230522105 }, + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, }; - let save_res_1 = tcn_dao.save(&observed_tcn_1); - let save_res_2 = tcn_dao.save(&observed_tcn_2); - let save_res_3 = tcn_dao.save(&observed_tcn_3); + let save_res_1 = tcn_dao.overwrite(vec![observed_tcn_1.clone()]); + let save_res_2 = tcn_dao.overwrite(vec![observed_tcn_2.clone()]); + let save_res_3 = tcn_dao.overwrite(vec![observed_tcn_3.clone()]); assert!(save_res_1.is_ok()); assert!(save_res_2.is_ok()); assert!(save_res_3.is_ok()); @@ -645,39 +1078,6 @@ mod tests { assert_eq!(loaded_tcns[2], observed_tcn_3); } - #[test] - #[ignore] - // Currently there's no unique, as for current use case it doesn't seem necessary/critical and - // it affects negatively performance. - // TODO revisit - fn saves_and_loads_repeated_tcns() { - let database = Arc::new(Database::new( - Connection::open_in_memory().expect("Couldn't create database!"), - )); - let tcn_dao = TcnDaoImpl::new(database.clone()); - - let observed_tcn_1 = ObservedTcn { - tcn: TemporaryContactNumber([ - 24, 229, 125, 245, 98, 86, 219, 221, 172, 25, 232, 150, 206, 66, 164, 173, - ]), - time: UnixTime { value: 1590528300 }, - }; - - let save_res_1 = tcn_dao.save(&observed_tcn_1); - let save_res_2 = tcn_dao.save(&observed_tcn_1); - assert!(save_res_1.is_ok()); - assert!(save_res_2.is_ok()); - - let loaded_tcns_res = tcn_dao.all(); - assert!(loaded_tcns_res.is_ok()); - - let loaded_tcns = loaded_tcns_res.unwrap(); - - assert_eq!(loaded_tcns.len(), 2); - assert_eq!(loaded_tcns[0], observed_tcn_1); - assert_eq!(loaded_tcns[1], observed_tcn_1); - } - // Utility to see quickly all TCNs (hex) for a report #[test] #[ignore] @@ -691,6 +1091,67 @@ mod tests { } #[test] + fn one_report_matches() { + let verification_report_str = "D7Z8XrufMgfsFH3K5COnv17IFG2ahDb4VM/UMK/5y0+/OtUVVTh7sN0DQ5+R+ocecTilR+SIIpPHzujeJdJzugEAECcAFAEAmmq5XgAAAACaarleAAAAACEBo8p1WdGeXb5O5/3kN6x7GSylgiYGIGsABl3NrxhJu9XHwsN3f6yvRwUxs2fhP4oU5E3+JWabBP6v09pGV1xRCw=="; + let verification_report_tcn: [u8; 16] = [ + 24, 229, 125, 245, 98, 86, 219, 221, 172, 25, 232, 150, 206, 66, 164, 173, + ]; // belongs to report + let verification_contact_start = UnixTime { value: 1590528300 }; + let verification_contact_end = UnixTime { value: 1590528301 }; + let verification_min_distance = 2.3; + let verification_avg_distance = 3.0; + let verification_total_count = 3; + let verification_report = SignedReport::with_str(verification_report_str).unwrap(); + + let mut reports: Vec = vec![0; 20] + .into_iter() + .map(|_| create_test_report()) + .collect(); + reports.push(verification_report); + + // let matcher = TcnMatcherStdThreadSpawn {}; // 20 -> 1s, 200 -> 16s, 1000 -> 84s, 10000 -> + let matcher = TcnMatcherRayon {}; // 20 -> 1s, 200 -> 7s, 1000 -> 87s, 10000 -> 927s + + let tcns = vec![ + ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1590528300 }, + contact_end: UnixTime { value: 1590528301 }, + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, + }, + ObservedTcn { + tcn: TemporaryContactNumber(verification_report_tcn), + contact_start: verification_contact_start.clone(), + contact_end: verification_contact_end.clone(), + min_distance: verification_min_distance, + avg_distance: verification_avg_distance, + total_count: verification_total_count, + }, + ObservedTcn { + tcn: TemporaryContactNumber([1; 16]), + contact_start: UnixTime { value: 1590528300 }, + contact_end: UnixTime { value: 1590528301 }, + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, + }, + ]; + + let res = matcher.match_reports(tcns, reports); + let matches = res.unwrap(); + assert_eq!(matches.len(), 1); + + let matched_report_str = base64::encode(signed_report_to_bytes(matches[0].report.clone())); + assert_eq!(matched_report_str, verification_report_str); + assert_eq!(matches[0].tcns[0].contact_start, verification_contact_start); + assert_eq!(matches[0].tcns[0].contact_end, verification_contact_end); + assert_eq!(matches[0].tcns[0].min_distance, verification_min_distance); + } + + #[test] + #[ignore] fn matching_benchmark() { let verification_report_str = "D7Z8XrufMgfsFH3K5COnv17IFG2ahDb4VM/UMK/5y0+/OtUVVTh7sN0DQ5+R+ocecTilR+SIIpPHzujeJdJzugEAECcAFAEAmmq5XgAAAACaarleAAAAACEBo8p1WdGeXb5O5/3kN6x7GSylgiYGIGsABl3NrxhJu9XHwsN3f6yvRwUxs2fhP4oU5E3+JWabBP6v09pGV1xRCw=="; let verification_report_tcn: [u8; 16] = [ @@ -711,15 +1172,27 @@ mod tests { let tcns = vec![ ObservedTcn { tcn: TemporaryContactNumber([0; 16]), - time: UnixTime { value: 1590528300 }, + contact_start: UnixTime { value: 1590528300 }, + contact_end: UnixTime { value: 1590528301 }, + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, }, ObservedTcn { tcn: TemporaryContactNumber(verification_report_tcn), - time: verification_contact_time.clone(), + contact_start: verification_contact_time.clone(), + contact_end: verification_contact_time.clone(), + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, }, ObservedTcn { tcn: TemporaryContactNumber([1; 16]), - time: UnixTime { value: 1590528300 }, + contact_start: UnixTime { value: 1590528300 }, + contact_end: UnixTime { value: 1590528301 }, + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, }, ]; @@ -736,7 +1209,6 @@ mod tests { // Short verification that matching is working let matched_report_str = base64::encode(signed_report_to_bytes(matches[0].report.clone())); assert_eq!(matched_report_str, verification_report_str); - assert_eq!(matches[0].contact_time, verification_contact_time); } #[test] @@ -754,6 +1226,653 @@ mod tests { assert!(SignedReport::with_str("slkdjfslfd").is_none()) } + #[test] + fn test_group_in_exposures_empty() { + let tcns = vec![]; + let groups = ExposureGrouper { threshold: 1000 }.group(tcns.clone()); + assert_eq!(groups.len(), 0); + } + + #[test] + fn test_group_in_exposures_same_group() { + let tcns = vec![ + ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1000 }, + contact_end: UnixTime { value: 1001 }, + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, + }, + ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1500 }, + contact_end: UnixTime { value: 1501 }, + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, + }, + ]; + let groups = ExposureGrouper { threshold: 1000 }.group(tcns.clone()); + assert_eq!(groups.len(), 1); + assert_eq!(groups[0], Exposure::create_with_tcns(tcns).unwrap()); + } + + #[test] + fn test_group_in_exposures_identical_tcns() { + // Passing same TCN 2x (normally will not happen) + let tcns = vec![ + ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1000 }, + contact_end: UnixTime { value: 1001 }, + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, + }, + ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1000 }, + contact_end: UnixTime { value: 1001 }, + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, + }, + ]; + + let groups = ExposureGrouper { threshold: 1000 }.group(tcns.clone()); + assert_eq!(groups.len(), 1); + assert_eq!(groups[0], Exposure::create_with_tcns(tcns).unwrap()); + } + + #[test] + fn test_group_in_exposures_different_groups() { + let tcns = vec![ + ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1000 }, + contact_end: UnixTime { value: 1001 }, + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, + }, + ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 2002 }, + contact_end: UnixTime { value: 2501 }, + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, + }, + ]; + + let groups = ExposureGrouper { threshold: 1000 }.group(tcns.clone()); + assert_eq!(groups.len(), 2); + assert_eq!( + groups[0], + Exposure::create_with_tcns(vec![tcns[0].clone()]).unwrap() + ); + assert_eq!( + groups[1], + Exposure::create_with_tcns(vec![tcns[1].clone()]).unwrap() + ); + } + + #[test] + fn test_group_in_exposures_sort() { + let tcns = vec![ + ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 2002 }, + contact_end: UnixTime { value: 2501 }, + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, + }, + ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1000 }, + contact_end: UnixTime { value: 1001 }, + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, + }, + ]; + + let groups = ExposureGrouper { threshold: 1000 }.group(tcns.clone()); + assert_eq!(groups.len(), 2); + assert_eq!( + groups[0], + Exposure::create_with_tcns(vec![tcns[1].clone()]).unwrap() + ); + assert_eq!( + groups[1], + Exposure::create_with_tcns(vec![tcns[0].clone()]).unwrap() + ); + } + + #[test] + fn test_group_in_exposures_overlap() { + let tcns = vec![ + ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1000 }, + contact_end: UnixTime { value: 2000 }, + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, + }, + // starts before previous TCN ends + ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1600 }, + contact_end: UnixTime { value: 2600 }, + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, + }, + ]; + + let groups = ExposureGrouper { threshold: 1000 }.group(tcns.clone()); + assert_eq!(groups.len(), 1); + assert_eq!(groups[0], Exposure::create_with_tcns(tcns).unwrap()); + } + + #[test] + fn test_group_in_exposures_mixed() { + let tcns = vec![ + ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 3000 }, + contact_end: UnixTime { value: 3001 }, + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, + }, + ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1 }, + contact_end: UnixTime { value: 2 }, + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, + }, + ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 3900 }, + contact_end: UnixTime { value: 4500 }, + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, + }, + ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 500 }, + contact_end: UnixTime { value: 501 }, + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, + }, + ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1589209754 }, + contact_end: UnixTime { value: 1589209755 }, + min_distance: 0.0, + avg_distance: 0.0, + total_count: 1, + }, + ]; + + let groups = ExposureGrouper { threshold: 1000 }.group(tcns.clone()); + + assert_eq!(groups.len(), 3); + assert_eq!( + groups[0], + Exposure::create_with_tcns(vec![tcns[1].clone(), tcns[3].clone()]).unwrap() + ); + assert_eq!( + groups[1], + Exposure::create_with_tcns(vec![tcns[0].clone(), tcns[2].clone()]).unwrap() + ); + assert_eq!( + groups[2], + Exposure::create_with_tcns(vec![tcns[4].clone().clone()]).unwrap() + ); + } + + #[test] + fn test_exposure_measurements_correct() { + let tcns = vec![ + ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1600 }, + contact_end: UnixTime { value: 2600 }, + min_distance: 2.3, + avg_distance: 2.7, // (2.3 + 3.1) / 2 + total_count: 2, + }, + ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 2601 }, + contact_end: UnixTime { value: 3223 }, + min_distance: 0.845, + avg_distance: 0.948333333, // (0.845 + 0.5 + 1.5) / 3 + total_count: 3, + }, + ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1000 }, + contact_end: UnixTime { value: 2000 }, + min_distance: 0.846, + avg_distance: 0.846, + total_count: 1, + }, + ]; + + let measurements = Exposure::create_with_tcns(tcns).unwrap().measurements(); + + assert_eq!(measurements.contact_start.value, 1000); + assert_eq!(measurements.contact_end.value, 3223); + assert_eq!(measurements.min_distance, 0.845); + let avg_rounded = (measurements.avg_distance * 10000.0).floor() / 10000.0; + assert_eq!(avg_rounded, 1.5151); // (2.3 + 3.1 + 0.845 + 0.5 + 1.5 + 0.846) / (2 + 3 + 1) + assert_eq!(measurements.total_count, 6); // 2 + 3 + 1 + } + + #[test] + fn test_push_merges_existing_tcn_in_batch_manager() { + let database = Arc::new(Database::new( + Connection::open_in_memory().expect("Couldn't create database!"), + )); + let tcn_dao = TcnDaoImpl::new(database.clone()); + + let batches_manager = TcnBatchesManager::new(Arc::new(tcn_dao), ExposureGrouper{ threshold: 1000}); + + batches_manager.push(ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1600 }, + contact_end: UnixTime { value: 2600 }, + min_distance: 2.3, + avg_distance: 0.506, // (0.1 + 0.62 + 0.8 + 0.21 + 0.8) / 5 + total_count: 5 + }); + + batches_manager.push(ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 3000 }, + contact_end: UnixTime { value: 5000 }, + min_distance: 2.0, + avg_distance: 0.7, // (1.2 + 0.5 + 0.4) / 3 + total_count: 3, + }); + + let len_res = batches_manager.len(); + assert!(len_res.is_ok()); + assert_eq!(1, len_res.unwrap()); + + let tcns = batches_manager.tcns_batch.lock().unwrap(); + assert_eq!(tcns[&[0; 16]], ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1600 }, + contact_end: UnixTime { value: 5000 }, + min_distance: 2.0, + avg_distance: 0.57875, // (0.1 + 0.62 + 0.8 + 0.21 + 0.8 + 1.2 + 0.5 + 0.4) / (5 + 3) + total_count: 8 // 5 + 3 + }); + } + + #[test] + fn test_flush_clears_tcns() { + let database = Arc::new(Database::new( + Connection::open_in_memory().expect("Couldn't create database!"), + )); + let tcn_dao = TcnDaoImpl::new(database.clone()); + + let batches_manager = TcnBatchesManager::new(Arc::new(tcn_dao), ExposureGrouper{ threshold: 1000}); + + batches_manager.push(ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1600 }, + contact_end: UnixTime { value: 2600 }, + min_distance: 2.3, + avg_distance: 2.3, + total_count: 1, + }); + let flush_res = batches_manager.flush(); + assert!(flush_res.is_ok()); + + let len_res = batches_manager.len(); + assert!(len_res.is_ok()); + assert_eq!(0, len_res.unwrap()) + } + + #[test] + fn test_flush_adds_entries_to_db() { + let database = Arc::new(Database::new( + Connection::open_in_memory().expect("Couldn't create database!"), + )); + let tcn_dao = Arc::new(TcnDaoImpl::new(database.clone())); + + let batches_manager = TcnBatchesManager::new(tcn_dao.clone(), ExposureGrouper{ threshold: 1000}); + + let tcn = ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1600 }, + contact_end: UnixTime { value: 2600 }, + min_distance: 2.3, + avg_distance: 2.3, + total_count: 1, + }; + batches_manager.push(tcn.clone()); + + let flush_res = batches_manager.flush(); + assert!(flush_res.is_ok()); + + let stored_tcns_res = tcn_dao.all(); + assert!(stored_tcns_res.is_ok()); + + let stored_tcns = stored_tcns_res.unwrap(); + assert_eq!(1, stored_tcns.len()); + assert_eq!(tcn, stored_tcns[0]); + } + + #[test] + fn test_flush_updates_correctly_existing_entry() { + let database = Arc::new(Database::new( + Connection::open_in_memory().expect("Couldn't create database!"), + )); + let tcn_dao = Arc::new(TcnDaoImpl::new(database.clone())); + + let batches_manager = TcnBatchesManager::new(tcn_dao.clone(), ExposureGrouper{ threshold: 1000}); + + let stored_tcn = ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1600 }, + contact_end: UnixTime { value: 2600 }, + min_distance: 2.3, + avg_distance: 1.25,// (2.3 + 0.7 + 1 + 1) / 4 + total_count: 4, + }; + let save_res = tcn_dao.overwrite(vec![stored_tcn]); + assert!(save_res.is_ok()); + + let tcn = ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 3000 }, + contact_end: UnixTime { value: 5000 }, + min_distance: 1.12, + avg_distance: 1.0,// (1.12 + 0.88 + 1) / 3 + total_count: 3, + }; + batches_manager.push(tcn.clone()); + + let flush_res = batches_manager.flush(); + assert!(flush_res.is_ok()); + + let loaded_tcns_res = tcn_dao.all(); + assert!(loaded_tcns_res.is_ok()); + + let loaded_tcns = loaded_tcns_res.unwrap(); + assert_eq!(1, loaded_tcns.len()); + + assert_eq!(loaded_tcns[0], ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1600 }, + contact_end: UnixTime { value: 5000 }, + min_distance: 1.12, + avg_distance: 1.14285714, // (2.3 + 0.7 + 1 + 1 + 1.12 + 0.88 + 1) / (4 + 3) + total_count: 7, + }); + } + + #[test] + fn test_flush_does_not_affect_different_stored_tcn() { + let database = Arc::new(Database::new( + Connection::open_in_memory().expect("Couldn't create database!"), + )); + let tcn_dao = Arc::new(TcnDaoImpl::new(database.clone())); + + let batches_manager = TcnBatchesManager::new(tcn_dao.clone(), ExposureGrouper{ threshold: 1000}); + + let stored_tcn = ObservedTcn { + tcn: TemporaryContactNumber([1; 16]), + contact_start: UnixTime { value: 1600 }, + contact_end: UnixTime { value: 2600 }, + min_distance: 2.3, + avg_distance: 2.3, + total_count: 1, + }; + let save_res = tcn_dao.overwrite(vec![stored_tcn]); + assert!(save_res.is_ok()); + + let tcn = ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 3000 }, + contact_end: UnixTime { value: 5000 }, + min_distance: 1.12, + avg_distance: 1.12, + total_count: 1 + }; + batches_manager.push(tcn.clone()); + + let loaded_tcns_res = tcn_dao.all(); + assert!(loaded_tcns_res.is_ok()); + + let flush_res = batches_manager.flush(); + assert!(flush_res.is_ok()); + + let loaded_tcns_res = tcn_dao.all(); + assert!(loaded_tcns_res.is_ok()); + + let loaded_tcns = loaded_tcns_res.unwrap(); + assert_eq!(2, loaded_tcns.len()); + + assert_eq!(loaded_tcns[0], ObservedTcn { + tcn: TemporaryContactNumber([1; 16]), + contact_start: UnixTime { value: 1600 }, + contact_end: UnixTime { value: 2600 }, + min_distance: 2.3, + avg_distance: 2.3, + total_count: 1, + }); + assert_eq!(loaded_tcns[1], ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 3000 }, + contact_end: UnixTime { value: 5000 }, + min_distance: 1.12, + avg_distance: 1.12, + total_count: 1 + }); + } + + #[test] + fn test_flush_updates_correctly_2_stored_1_updated() { + let database = Arc::new(Database::new( + Connection::open_in_memory().expect("Couldn't create database!"), + )); + let tcn_dao = Arc::new(TcnDaoImpl::new(database.clone())); + + let batches_manager = TcnBatchesManager::new(tcn_dao.clone(), ExposureGrouper{ threshold: 1000}); + + let stored_tcn1 = ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1000 }, + contact_end: UnixTime { value: 6000 }, + min_distance: 0.4, + avg_distance: 0.4, + total_count: 1, + }; + + let stored_tcn2 = ObservedTcn { + tcn: TemporaryContactNumber([1; 16]), + contact_start: UnixTime { value: 1600 }, + contact_end: UnixTime { value: 2600 }, + min_distance: 2.3, + avg_distance: 2.3, + total_count: 1, + }; + let save_res = tcn_dao.overwrite(vec![stored_tcn1.clone(), stored_tcn2.clone()]); + assert!(save_res.is_ok()); + + let tcn = ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 3000 }, + contact_end: UnixTime { value: 7000 }, + min_distance: 1.12, + avg_distance: 1.12, + total_count: 1, + }; + batches_manager.push(tcn.clone()); + + let flush_res = batches_manager.flush(); + assert!(flush_res.is_ok()); + + let loaded_tcns_res = tcn_dao.all(); + assert!(loaded_tcns_res.is_ok()); + + let mut loaded_tcns = loaded_tcns_res.unwrap(); + assert_eq!(2, loaded_tcns.len()); + + // Sqlite doesn't guarantee insertion order, so sort + // start value not meaningul here, other than for reproducible sorting + loaded_tcns.sort_by_key(|tcn| tcn.contact_start.value); + + assert_eq!(loaded_tcns[0], ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1000 }, + contact_end: UnixTime { value: 7000 }, + min_distance: 0.4, + avg_distance: 0.76, // (0.4, + 1.12) / (1 + 1) + total_count: 2 // 1 + 1 + }); + assert_eq!(loaded_tcns[1], stored_tcn2); + } + + + #[test] + fn test_finds_tcn() { + let database = Arc::new(Database::new( + Connection::open_in_memory().expect("Couldn't create database!"), + )); + let tcn_dao = Arc::new(TcnDaoImpl::new(database.clone())); + + let stored_tcn1 = ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1000 }, + contact_end: UnixTime { value: 6000 }, + min_distance: 0.4, + avg_distance: 0.4, + total_count: 1, + }; + + let stored_tcn2 = ObservedTcn { + tcn: TemporaryContactNumber([1; 16]), + contact_start: UnixTime { value: 2000 }, + contact_end: UnixTime { value: 3000 }, + min_distance: 1.8, + avg_distance: 1.8, + total_count: 1, + }; + + let stored_tcn3 = ObservedTcn { + tcn: TemporaryContactNumber([2; 16]), + contact_start: UnixTime { value: 1600 }, + contact_end: UnixTime { value: 2600 }, + min_distance: 2.3, + avg_distance: 2.3, + total_count: 1, + }; + + let save_res = tcn_dao.overwrite(vec![stored_tcn1.clone(), stored_tcn2.clone(), stored_tcn3.clone()]); + assert!(save_res.is_ok()); + + let res = tcn_dao.find_tcns(vec![TemporaryContactNumber([0; 16]), TemporaryContactNumber([2; 16])]); + assert!(res.is_ok()); + + let mut tcns = res.unwrap(); + + // Sqlite doesn't guarantee insertion order, so sort + // start value not meaningul here, other than for reproducible sorting + tcns.sort_by_key(|tcn| tcn.contact_start.value); + + assert_eq!(2, tcns.len()); + assert_eq!(stored_tcn1, tcns[0]); + assert_eq!(stored_tcn3, tcns[1]); + } + + #[test] + fn test_multiple_exposures_updated_correctly() { + let database = Arc::new(Database::new( + Connection::open_in_memory().expect("Couldn't create database!"), + )); + let tcn_dao = Arc::new(TcnDaoImpl::new(database.clone())); + + let batches_manager = TcnBatchesManager::new(tcn_dao.clone(), ExposureGrouper{ threshold: 1000}); + + let stored_tcn1 = ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1000 }, + contact_end: UnixTime { value: 3000 }, + min_distance: 0.4, + avg_distance: 0.4, + total_count: 1 + }; + + let stored_tcn2 = ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 5000 }, + contact_end: UnixTime { value: 7000 }, + min_distance: 2.0, + avg_distance: 2.0, + total_count: 1 + }; + let save_res = tcn_dao.overwrite(vec![stored_tcn1.clone(), stored_tcn2.clone()]); + assert!(save_res.is_ok()); + + let tcn = ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 7500 }, + contact_end: UnixTime { value: 9000 }, + min_distance: 1.0, + avg_distance: 1.0, + total_count: 1 + }; + + batches_manager.push(tcn.clone()); + + let flush_res = batches_manager.flush(); + assert!(flush_res.is_ok()); + + let loaded_tcns_res = tcn_dao.all(); + assert!(loaded_tcns_res.is_ok()); + + let mut loaded_tcns = loaded_tcns_res.unwrap(); + assert_eq!(2, loaded_tcns.len()); + + // Sqlite doesn't guarantee insertion order, so sort + // start value not meaningul here, other than for reproducible sorting + loaded_tcns.sort_by_key(|tcn| tcn.contact_start.value); + + assert_eq!(loaded_tcns[0], ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 1000 }, + contact_end: UnixTime { value: 3000 }, + min_distance: 0.4, + avg_distance: 0.4, + total_count: 1 + }); + // The new TCN was merged with stored_tcn2 + assert_eq!(loaded_tcns[1], ObservedTcn { + tcn: TemporaryContactNumber([0; 16]), + contact_start: UnixTime { value: 5000 }, + contact_end: UnixTime { value: 9000 }, + min_distance: 1.0, + avg_distance: 1.5, // (2.0 + 1.0) / (1 + 1), + total_count: 2 // 1 + 1 + }); + } + + fn create_test_report() -> SignedReport { let memo_mapper = MemoMapperImpl {}; let public_report = PublicReport {