diff --git a/Cargo.lock b/Cargo.lock index f99eba1732..95a2dea104 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6107,6 +6107,7 @@ dependencies = [ "prost", "prost-types", "rand", + "restate-core-derive", "restate-test-util", "restate-types", "schemars", @@ -6131,6 +6132,15 @@ dependencies = [ "xxhash-rust", ] +[[package]] +name = "restate-core-derive" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.85", +] + [[package]] name = "restate-errors" version = "1.1.4" diff --git a/Cargo.toml b/Cargo.toml index 7154a96451..1a5cc1e2fc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ members = [ "cli", "crates/*", + "crates/core/derive", "crates/codederror/derive", "server", "benchmarks", @@ -14,6 +15,7 @@ members = [ default-members = [ "cli", "crates/*", + "crates/core/derive", "crates/codederror/derive", "server", "tools/restatectl", @@ -38,6 +40,7 @@ restate-base64-util = { path = "crates/base64-util" } restate-bifrost = { path = "crates/bifrost" } restate-cli-util = { path = "crates/cli-util" } restate-core = { path = "crates/core" } +restate-core-derive = { path = "crates/core/derive" } restate-errors = { path = "crates/errors" } restate-fs-util = { path = "crates/fs-util" } restate-futures-util = { path = "crates/futures-util" } diff --git a/crates/bifrost/src/bifrost.rs b/crates/bifrost/src/bifrost.rs index beeffa1855..f04cf4acab 100644 --- a/crates/bifrost/src/bifrost.rs +++ b/crates/bifrost/src/bifrost.rs @@ -496,8 +496,8 @@ mod tests { use tracing::info; use tracing_test::traced_test; - use restate_core::TestCoreEnvBuilder; - use restate_core::{TaskCenter, TaskCenterFutureExt, TaskKind, TestCoreEnv}; + use restate_core::TestCoreEnvBuilder2; + use restate_core::{TaskCenter, TaskKind, TestCoreEnv2}; use restate_rocksdb::RocksDbManager; use restate_types::config::CommonOptions; use restate_types::live::Constant; @@ -510,379 +510,359 @@ mod tests { use crate::providers::memory_loglet::{self}; use crate::BifrostAdmin; - #[tokio::test] + #[restate_core::test] #[traced_test] async fn test_append_smoke() -> googletest::Result<()> { let num_partitions = 5; - let node_env = TestCoreEnvBuilder::with_incoming_only_connector() + let _ = TestCoreEnvBuilder2::with_incoming_only_connector() .set_partition_table(PartitionTable::with_equally_sized_partitions( Version::MIN, num_partitions, )) .build() .await; - async { - let bifrost = Bifrost::init_in_memory().await; - - let clean_bifrost_clone = bifrost.clone(); - - let mut appender_0 = bifrost.create_appender(LogId::new(0))?; - let mut appender_3 = bifrost.create_appender(LogId::new(3))?; - let mut max_lsn = Lsn::INVALID; - for i in 1..=5 { - // Append a record to memory - let lsn = appender_0.append("").await?; - info!(%lsn, "Appended record to log"); - assert_eq!(Lsn::from(i), lsn); - max_lsn = lsn; - } - // Append to a log that doesn't exist. - let invalid_log = LogId::from(num_partitions + 1); - let resp = bifrost.create_appender(invalid_log); - - assert_that!(resp, pat!(Err(pat!(Error::UnknownLogId(eq(invalid_log)))))); - - // use a cloned bifrost. - let cloned_bifrost = bifrost.clone(); - let mut second_appender_0 = cloned_bifrost.create_appender(LogId::new(0))?; - for _ in 1..=5 { - // Append a record to memory - let lsn = second_appender_0.append("").await?; - info!(%lsn, "Appended record to log"); - assert_eq!(max_lsn + Lsn::from(1), lsn); - max_lsn = lsn; - } + let bifrost = Bifrost::init_in_memory().await; - // Ensure original clone writes to the same underlying loglet. - let lsn = clean_bifrost_clone - .create_appender(LogId::new(0))? - .append("") - .await?; - assert_eq!(max_lsn + Lsn::from(1), lsn); + let clean_bifrost_clone = bifrost.clone(); + + let mut appender_0 = bifrost.create_appender(LogId::new(0))?; + let mut appender_3 = bifrost.create_appender(LogId::new(3))?; + let mut max_lsn = Lsn::INVALID; + for i in 1..=5 { + // Append a record to memory + let lsn = appender_0.append("").await?; + info!(%lsn, "Appended record to log"); + assert_eq!(Lsn::from(i), lsn); max_lsn = lsn; + } - // Writes to another log don't impact original log - let lsn = appender_3.append("").await?; - assert_eq!(Lsn::from(1), lsn); + // Append to a log that doesn't exist. + let invalid_log = LogId::from(num_partitions + 1); + let resp = bifrost.create_appender(invalid_log); - let lsn = appender_0.append("").await?; + assert_that!(resp, pat!(Err(pat!(Error::UnknownLogId(eq(invalid_log)))))); + + // use a cloned bifrost. + let cloned_bifrost = bifrost.clone(); + let mut second_appender_0 = cloned_bifrost.create_appender(LogId::new(0))?; + for _ in 1..=5 { + // Append a record to memory + let lsn = second_appender_0.append("").await?; + info!(%lsn, "Appended record to log"); assert_eq!(max_lsn + Lsn::from(1), lsn); max_lsn = lsn; - - let tail = bifrost.find_tail(LogId::new(0)).await?; - assert_eq!(max_lsn.next(), tail.offset()); - - // Initiate shutdown - TaskCenter::current().shutdown_node("completed", 0).await; - // appends cannot succeed after shutdown - let res = appender_0.append("").await; - assert!(matches!(res, Err(Error::Shutdown(_)))); - // Validate the watchdog has called the provider::start() function. - assert!(logs_contain("Shutting down in-memory loglet provider")); - assert!(logs_contain("Bifrost watchdog shutdown complete")); - Ok(()) } - .in_tc(&node_env.tc) - .await + + // Ensure original clone writes to the same underlying loglet. + let lsn = clean_bifrost_clone + .create_appender(LogId::new(0))? + .append("") + .await?; + assert_eq!(max_lsn + Lsn::from(1), lsn); + max_lsn = lsn; + + // Writes to another log don't impact original log + let lsn = appender_3.append("").await?; + assert_eq!(Lsn::from(1), lsn); + + let lsn = appender_0.append("").await?; + assert_eq!(max_lsn + Lsn::from(1), lsn); + max_lsn = lsn; + + let tail = bifrost.find_tail(LogId::new(0)).await?; + assert_eq!(max_lsn.next(), tail.offset()); + + // Initiate shutdown + TaskCenter::current().shutdown_node("completed", 0).await; + // appends cannot succeed after shutdown + let res = appender_0.append("").await; + assert!(matches!(res, Err(Error::Shutdown(_)))); + // Validate the watchdog has called the provider::start() function. + assert!(logs_contain("Shutting down in-memory loglet provider")); + assert!(logs_contain("Bifrost watchdog shutdown complete")); + Ok(()) } - #[tokio::test(start_paused = true)] + #[restate_core::test(start_paused = true)] async fn test_lazy_initialization() -> googletest::Result<()> { - let node_env = TestCoreEnv::create_with_single_node(1, 1).await; - async { - let delay = Duration::from_secs(5); - // This memory provider adds a delay to its loglet initialization, we want - // to ensure that appends do not fail while waiting for the loglet; - let factory = memory_loglet::Factory::with_init_delay(delay); - let bifrost = Bifrost::init_with_factory(factory).await; - - let start = tokio::time::Instant::now(); - let lsn = bifrost.create_appender(LogId::new(0))?.append("").await?; - assert_eq!(Lsn::from(1), lsn); - // The append was properly delayed - assert_eq!(delay, start.elapsed()); - Ok(()) - } - .in_tc(&node_env.tc) - .await + let _ = TestCoreEnv2::create_with_single_node(1, 1).await; + let delay = Duration::from_secs(5); + // This memory provider adds a delay to its loglet initialization, we want + // to ensure that appends do not fail while waiting for the loglet; + let factory = memory_loglet::Factory::with_init_delay(delay); + let bifrost = Bifrost::init_with_factory(factory).await; + + let start = tokio::time::Instant::now(); + let lsn = bifrost.create_appender(LogId::new(0))?.append("").await?; + assert_eq!(Lsn::from(1), lsn); + // The append was properly delayed + assert_eq!(delay, start.elapsed()); + Ok(()) } - #[test(tokio::test(flavor = "multi_thread", worker_threads = 2))] + #[test(restate_core::test(flavor = "multi_thread", worker_threads = 2))] async fn trim_log_smoke_test() -> googletest::Result<()> { const LOG_ID: LogId = LogId::new(0); - let node_env = TestCoreEnvBuilder::with_incoming_only_connector() + let node_env = TestCoreEnvBuilder2::with_incoming_only_connector() .set_provider_kind(ProviderKind::Local) .build() .await; - async { - RocksDbManager::init(Constant::new(CommonOptions::default())); - - let bifrost = Bifrost::init_local().await; - let bifrost_admin = BifrostAdmin::new( - &bifrost, - &node_env.metadata_writer, - &node_env.metadata_store_client, - ); + RocksDbManager::init(Constant::new(CommonOptions::default())); - assert_eq!(Lsn::OLDEST, bifrost.find_tail(LOG_ID).await?.offset()); + let bifrost = Bifrost::init_local().await; + let bifrost_admin = BifrostAdmin::new( + &bifrost, + &node_env.metadata_writer, + &node_env.metadata_store_client, + ); - assert_eq!(Lsn::INVALID, bifrost.get_trim_point(LOG_ID).await?); + assert_eq!(Lsn::OLDEST, bifrost.find_tail(LOG_ID).await?.offset()); - let mut appender = bifrost.create_appender(LOG_ID)?; - // append 10 records - for _ in 1..=10 { - appender.append("").await?; - } + assert_eq!(Lsn::INVALID, bifrost.get_trim_point(LOG_ID).await?); - bifrost_admin.trim(LOG_ID, Lsn::from(5)).await?; + let mut appender = bifrost.create_appender(LOG_ID)?; + // append 10 records + for _ in 1..=10 { + appender.append("").await?; + } - let tail = bifrost.find_tail(LOG_ID).await?; - assert_eq!(tail.offset(), Lsn::from(11)); - assert!(!tail.is_sealed()); - assert_eq!(Lsn::from(5), bifrost.get_trim_point(LOG_ID).await?); + bifrost_admin.trim(LOG_ID, Lsn::from(5)).await?; - // 5 itself is trimmed - for lsn in 1..=5 { - let record = bifrost.read(LOG_ID, Lsn::from(lsn)).await?.unwrap(); + let tail = bifrost.find_tail(LOG_ID).await?; + assert_eq!(tail.offset(), Lsn::from(11)); + assert!(!tail.is_sealed()); + assert_eq!(Lsn::from(5), bifrost.get_trim_point(LOG_ID).await?); - assert_that!(record.sequence_number(), eq(Lsn::new(lsn))); - assert_that!(record.trim_gap_to_sequence_number(), eq(Some(Lsn::new(5)))); - } + // 5 itself is trimmed + for lsn in 1..=5 { + let record = bifrost.read(LOG_ID, Lsn::from(lsn)).await?.unwrap(); - for lsn in 6..=10 { - let record = bifrost.read(LOG_ID, Lsn::from(lsn)).await?.unwrap(); - assert_that!(record.sequence_number(), eq(Lsn::new(lsn))); - assert!(record.is_data_record()); - } + assert_that!(record.sequence_number(), eq(Lsn::new(lsn))); + assert_that!(record.trim_gap_to_sequence_number(), eq(Some(Lsn::new(5)))); + } - // trimming beyond the release point will fall back to the release point - bifrost_admin.trim(LOG_ID, Lsn::MAX).await?; + for lsn in 6..=10 { + let record = bifrost.read(LOG_ID, Lsn::from(lsn)).await?.unwrap(); + assert_that!(record.sequence_number(), eq(Lsn::new(lsn))); + assert!(record.is_data_record()); + } - assert_eq!(Lsn::from(11), bifrost.find_tail(LOG_ID).await?.offset()); - let new_trim_point = bifrost.get_trim_point(LOG_ID).await?; - assert_eq!(Lsn::from(10), new_trim_point); + // trimming beyond the release point will fall back to the release point + bifrost_admin.trim(LOG_ID, Lsn::MAX).await?; - let record = bifrost.read(LOG_ID, Lsn::from(10)).await?.unwrap(); - assert!(record.is_trim_gap()); - assert_that!(record.trim_gap_to_sequence_number(), eq(Some(Lsn::new(10)))); + assert_eq!(Lsn::from(11), bifrost.find_tail(LOG_ID).await?.offset()); + let new_trim_point = bifrost.get_trim_point(LOG_ID).await?; + assert_eq!(Lsn::from(10), new_trim_point); - // Add 10 more records - for _ in 0..10 { - appender.append("").await?; - } + let record = bifrost.read(LOG_ID, Lsn::from(10)).await?.unwrap(); + assert!(record.is_trim_gap()); + assert_that!(record.trim_gap_to_sequence_number(), eq(Some(Lsn::new(10)))); - for lsn in 11..20 { - let record = bifrost.read(LOG_ID, Lsn::from(lsn)).await?.unwrap(); - assert_that!(record.sequence_number(), eq(Lsn::new(lsn))); - assert!(record.is_data_record()); - } + // Add 10 more records + for _ in 0..10 { + appender.append("").await?; + } - Ok(()) + for lsn in 11..20 { + let record = bifrost.read(LOG_ID, Lsn::from(lsn)).await?.unwrap(); + assert_that!(record.sequence_number(), eq(Lsn::new(lsn))); + assert!(record.is_data_record()); } - .in_tc(&node_env.tc) - .await + + Ok(()) } - #[tokio::test(start_paused = true)] + #[restate_core::test(start_paused = true)] async fn test_read_across_segments() -> googletest::Result<()> { const LOG_ID: LogId = LogId::new(0); - let node_env = TestCoreEnvBuilder::with_incoming_only_connector() + let node_env = TestCoreEnvBuilder2::with_incoming_only_connector() .set_partition_table(PartitionTable::with_equally_sized_partitions( Version::MIN, 1, )) .build() .await; - async { - let bifrost = Bifrost::init_in_memory().await; - let bifrost_admin = BifrostAdmin::new( - &bifrost, - &node_env.metadata_writer, - &node_env.metadata_store_client, - ); - - let mut appender = bifrost.create_appender(LOG_ID)?; - // Lsns [1..5] - for i in 1..=5 { - // Append a record to memory - let lsn = appender.append(format!("segment-1-{i}")).await?; - assert_eq!(Lsn::from(i), lsn); - } - - // not sealed, tail is what we expect - assert_that!( - bifrost.find_tail(LOG_ID).await?, - pat!(TailState::Open(eq(Lsn::new(6)))) - ); - - let segment_1 = bifrost - .inner - .find_loglet_for_lsn(LOG_ID, Lsn::OLDEST) - .await? - .unwrap(); - - // seal the segment - bifrost_admin - .seal(LOG_ID, segment_1.segment_index()) - .await?; + let bifrost = Bifrost::init_in_memory().await; + let bifrost_admin = BifrostAdmin::new( + &bifrost, + &node_env.metadata_writer, + &node_env.metadata_store_client, + ); + + let mut appender = bifrost.create_appender(LOG_ID)?; + // Lsns [1..5] + for i in 1..=5 { + // Append a record to memory + let lsn = appender.append(format!("segment-1-{i}")).await?; + assert_eq!(Lsn::from(i), lsn); + } - // sealed, tail is what we expect - assert_that!( - bifrost.find_tail(LOG_ID).await?, - pat!(TailState::Sealed(eq(Lsn::new(6)))) - ); + // not sealed, tail is what we expect + assert_that!( + bifrost.find_tail(LOG_ID).await?, + pat!(TailState::Open(eq(Lsn::new(6)))) + ); + + let segment_1 = bifrost + .inner + .find_loglet_for_lsn(LOG_ID, Lsn::OLDEST) + .await? + .unwrap(); + + // seal the segment + bifrost_admin + .seal(LOG_ID, segment_1.segment_index()) + .await?; - println!("attempting to read during reconfiguration"); - // attempting to read from bifrost will result in a timeout since metadata sees this as an open - // segment but the segment itself is sealed. This means reconfiguration is in-progress - // and we can't confidently read records. - assert!(tokio::time::timeout( - Duration::from_secs(5), - bifrost.read(LOG_ID, Lsn::new(2)) + // sealed, tail is what we expect + assert_that!( + bifrost.find_tail(LOG_ID).await?, + pat!(TailState::Sealed(eq(Lsn::new(6)))) + ); + + println!("attempting to read during reconfiguration"); + // attempting to read from bifrost will result in a timeout since metadata sees this as an open + // segment but the segment itself is sealed. This means reconfiguration is in-progress + // and we can't confidently read records. + assert!( + tokio::time::timeout(Duration::from_secs(5), bifrost.read(LOG_ID, Lsn::new(2))) + .await + .is_err() + ); + + let metadata = Metadata::current(); + let old_version = metadata.logs_version(); + + let mut builder = metadata.logs_ref().clone().into_builder(); + let mut chain_builder = builder.chain(LOG_ID).unwrap(); + assert_eq!(1, chain_builder.num_segments()); + let new_segment_params = new_single_node_loglet_params(ProviderKind::InMemory); + // deliberately skips Lsn::from(6) to create a zombie record in segment 1. Segment 1 now has 4 records. + chain_builder.append_segment(Lsn::new(5), ProviderKind::InMemory, new_segment_params)?; + + let new_metadata = builder.build(); + let new_version = new_metadata.version(); + assert_eq!(new_version, old_version.next()); + node_env + .metadata_store_client + .put( + BIFROST_CONFIG_KEY.clone(), + &new_metadata, + restate_metadata_store::Precondition::MatchesVersion(old_version), ) - .await - .is_err()); - - let metadata = Metadata::current(); - let old_version = metadata.logs_version(); - - let mut builder = metadata.logs_ref().clone().into_builder(); - let mut chain_builder = builder.chain(LOG_ID).unwrap(); - assert_eq!(1, chain_builder.num_segments()); - let new_segment_params = new_single_node_loglet_params(ProviderKind::InMemory); - // deliberately skips Lsn::from(6) to create a zombie record in segment 1. Segment 1 now has 4 records. - chain_builder.append_segment( - Lsn::new(5), - ProviderKind::InMemory, - new_segment_params, - )?; - - let new_metadata = builder.build(); - let new_version = new_metadata.version(); - assert_eq!(new_version, old_version.next()); - node_env - .metadata_store_client - .put( - BIFROST_CONFIG_KEY.clone(), - &new_metadata, - restate_metadata_store::Precondition::MatchesVersion(old_version), - ) - .await?; - - // make sure we have updated metadata. - metadata - .sync(MetadataKind::Logs, TargetVersion::Latest) - .await?; - assert_eq!(new_version, metadata.logs_version()); - - { - // validate that the stored metadata matches our expectations. - let new_metadata = metadata.logs_ref().clone(); - let chain_builder = new_metadata.chain(&LOG_ID).unwrap(); - assert_eq!(2, chain_builder.num_segments()); - } - - // find_tail() on the underlying loglet returns (6) but for bifrost it should be (5) after - // the new segment was created at tail of the chain with base_lsn=5 - assert_that!( - bifrost.find_tail(LOG_ID).await?, - pat!(TailState::Open(eq(Lsn::new(5)))) - ); - - // appends should go to the new segment - let mut appender = bifrost.create_appender(LOG_ID)?; - // Lsns [5..7] - for i in 5..=7 { - // Append a record to memory - let lsn = appender.append(format!("segment-2-{i}")).await?; - assert_eq!(Lsn::from(i), lsn); - } - - // tail is now 8 and open. - assert_that!( - bifrost.find_tail(LOG_ID).await?, - pat!(TailState::Open(eq(Lsn::new(8)))) - ); - - // validating that segment 1 is still sealed and has its own tail at Lsn (6) - assert_that!( - segment_1.find_tail().await?, - pat!(TailState::Sealed(eq(Lsn::new(6)))) - ); - - let segment_2 = bifrost - .inner - .find_loglet_for_lsn(LOG_ID, Lsn::new(5)) - .await? - .unwrap(); - - assert_ne!(segment_1, segment_2); - - // segment 2 is open and at 8 as previously validated through bifrost interface - assert_that!( - segment_2.find_tail().await?, - pat!(TailState::Open(eq(Lsn::new(8)))) - ); - - // Reading the log. (OLDEST) - let record = bifrost.read(LOG_ID, Lsn::OLDEST).await?.unwrap(); - assert_that!(record.sequence_number(), eq(Lsn::new(1))); - assert!(record.is_data_record()); - assert_that!( - record.decode_unchecked::(), - eq("segment-1-1".to_owned()) - ); - - let record = bifrost.read(LOG_ID, Lsn::new(2)).await?.unwrap(); - assert_that!(record.sequence_number(), eq(Lsn::new(2))); - assert!(record.is_data_record()); - assert_that!( - record.decode_unchecked::(), - eq("segment-1-2".to_owned()) - ); + .await?; - // border of segment 1 - let record = bifrost.read(LOG_ID, Lsn::new(4)).await?.unwrap(); - assert_that!(record.sequence_number(), eq(Lsn::new(4))); - assert!(record.is_data_record()); - assert_that!( - record.decode_unchecked::(), - eq("segment-1-4".to_owned()) - ); + // make sure we have updated metadata. + metadata + .sync(MetadataKind::Logs, TargetVersion::Latest) + .await?; + assert_eq!(new_version, metadata.logs_version()); - // start of segment 2 - let record = bifrost.read(LOG_ID, Lsn::new(5)).await?.unwrap(); - assert_that!(record.sequence_number(), eq(Lsn::new(5))); - assert!(record.is_data_record()); - assert_that!( - record.decode_unchecked::(), - eq("segment-2-5".to_owned()) - ); + { + // validate that the stored metadata matches our expectations. + let new_metadata = metadata.logs_ref().clone(); + let chain_builder = new_metadata.chain(&LOG_ID).unwrap(); + assert_eq!(2, chain_builder.num_segments()); + } - // last record - let record = bifrost.read(LOG_ID, Lsn::new(7)).await?.unwrap(); - assert_that!(record.sequence_number(), eq(Lsn::new(7))); - assert!(record.is_data_record()); - assert_that!( - record.decode_unchecked::(), - eq("segment-2-7".to_owned()) - ); + // find_tail() on the underlying loglet returns (6) but for bifrost it should be (5) after + // the new segment was created at tail of the chain with base_lsn=5 + assert_that!( + bifrost.find_tail(LOG_ID).await?, + pat!(TailState::Open(eq(Lsn::new(5)))) + ); + + // appends should go to the new segment + let mut appender = bifrost.create_appender(LOG_ID)?; + // Lsns [5..7] + for i in 5..=7 { + // Append a record to memory + let lsn = appender.append(format!("segment-2-{i}")).await?; + assert_eq!(Lsn::from(i), lsn); + } - // 8 doesn't exist yet. - assert!(bifrost.read(LOG_ID, Lsn::new(8)).await?.is_none()); + // tail is now 8 and open. + assert_that!( + bifrost.find_tail(LOG_ID).await?, + pat!(TailState::Open(eq(Lsn::new(8)))) + ); + + // validating that segment 1 is still sealed and has its own tail at Lsn (6) + assert_that!( + segment_1.find_tail().await?, + pat!(TailState::Sealed(eq(Lsn::new(6)))) + ); + + let segment_2 = bifrost + .inner + .find_loglet_for_lsn(LOG_ID, Lsn::new(5)) + .await? + .unwrap(); + + assert_ne!(segment_1, segment_2); + + // segment 2 is open and at 8 as previously validated through bifrost interface + assert_that!( + segment_2.find_tail().await?, + pat!(TailState::Open(eq(Lsn::new(8)))) + ); + + // Reading the log. (OLDEST) + let record = bifrost.read(LOG_ID, Lsn::OLDEST).await?.unwrap(); + assert_that!(record.sequence_number(), eq(Lsn::new(1))); + assert!(record.is_data_record()); + assert_that!( + record.decode_unchecked::(), + eq("segment-1-1".to_owned()) + ); + + let record = bifrost.read(LOG_ID, Lsn::new(2)).await?.unwrap(); + assert_that!(record.sequence_number(), eq(Lsn::new(2))); + assert!(record.is_data_record()); + assert_that!( + record.decode_unchecked::(), + eq("segment-1-2".to_owned()) + ); + + // border of segment 1 + let record = bifrost.read(LOG_ID, Lsn::new(4)).await?.unwrap(); + assert_that!(record.sequence_number(), eq(Lsn::new(4))); + assert!(record.is_data_record()); + assert_that!( + record.decode_unchecked::(), + eq("segment-1-4".to_owned()) + ); + + // start of segment 2 + let record = bifrost.read(LOG_ID, Lsn::new(5)).await?.unwrap(); + assert_that!(record.sequence_number(), eq(Lsn::new(5))); + assert!(record.is_data_record()); + assert_that!( + record.decode_unchecked::(), + eq("segment-2-5".to_owned()) + ); + + // last record + let record = bifrost.read(LOG_ID, Lsn::new(7)).await?.unwrap(); + assert_that!(record.sequence_number(), eq(Lsn::new(7))); + assert!(record.is_data_record()); + assert_that!( + record.decode_unchecked::(), + eq("segment-2-7".to_owned()) + ); + + // 8 doesn't exist yet. + assert!(bifrost.read(LOG_ID, Lsn::new(8)).await?.is_none()); - Ok(()) - } - .in_tc(&node_env.tc) - .await + Ok(()) } - #[tokio::test(start_paused = true)] + #[restate_core::test(start_paused = true)] #[traced_test] async fn test_appends_correctly_handle_reconfiguration() -> googletest::Result<()> { const LOG_ID: LogId = LogId::new(0); - let node_env = TestCoreEnvBuilder::with_incoming_only_connector() + let node_env = TestCoreEnvBuilder2::with_incoming_only_connector() .set_partition_table(PartitionTable::with_equally_sized_partitions( Version::MIN, 1, @@ -890,118 +870,112 @@ mod tests { .set_provider_kind(ProviderKind::Local) .build() .await; - async { - RocksDbManager::init(Constant::new(CommonOptions::default())); - let bifrost = Bifrost::init_local().await; - let bifrost_admin = BifrostAdmin::new( - &bifrost, - &node_env.metadata_writer, - &node_env.metadata_store_client, - ); - - // create an appender - let stop_signal = Arc::new(AtomicBool::default()); - let append_counter = Arc::new(AtomicUsize::new(0)); - let _ = TaskCenter::current().spawn(TaskKind::TestRunner, "append-records", None, { - let append_counter = append_counter.clone(); - let stop_signal = stop_signal.clone(); - let bifrost = bifrost.clone(); - let mut appender = bifrost.create_appender(LOG_ID)?; - async move { - let mut i = 0; - while !stop_signal.load(Ordering::Relaxed) { - i += 1; - if i % 2 == 0 { - // append individual record - let lsn = appender.append(format!("record{}", i)).await?; - println!("Appended {}", lsn); - } else { - // append batch - let mut payloads = Vec::with_capacity(10); - for j in 1..=10 { - payloads.push(format!("record-in-batch{}-{}", i, j)); - } - let lsn = appender.append_batch(payloads).await?; - println!("Appended batch {}", lsn); + RocksDbManager::init(Constant::new(CommonOptions::default())); + let bifrost = Bifrost::init_local().await; + let bifrost_admin = BifrostAdmin::new( + &bifrost, + &node_env.metadata_writer, + &node_env.metadata_store_client, + ); + + // create an appender + let stop_signal = Arc::new(AtomicBool::default()); + let append_counter = Arc::new(AtomicUsize::new(0)); + let _ = TaskCenter::current().spawn(TaskKind::TestRunner, "append-records", None, { + let append_counter = append_counter.clone(); + let stop_signal = stop_signal.clone(); + let bifrost = bifrost.clone(); + let mut appender = bifrost.create_appender(LOG_ID)?; + async move { + let mut i = 0; + while !stop_signal.load(Ordering::Relaxed) { + i += 1; + if i % 2 == 0 { + // append individual record + let lsn = appender.append(format!("record{}", i)).await?; + println!("Appended {}", lsn); + } else { + // append batch + let mut payloads = Vec::with_capacity(10); + for j in 1..=10 { + payloads.push(format!("record-in-batch{}-{}", i, j)); } - append_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - tokio::time::sleep(Duration::from_millis(1)).await; + let lsn = appender.append_batch(payloads).await?; + println!("Appended batch {}", lsn); } - println!("Appender terminated"); - Ok(()) - } - })?; - - let mut append_counter_before_seal; - loop { - append_counter_before_seal = append_counter.load(Ordering::Relaxed); - if append_counter_before_seal < 100 { - tokio::time::sleep(Duration::from_millis(10)).await; - } else { - break; + append_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + tokio::time::sleep(Duration::from_millis(1)).await; } + println!("Appender terminated"); + Ok(()) } - - // seal and don't extend the chain. - let _ = bifrost_admin.seal(LOG_ID, SegmentIndex::from(0)).await?; - - // appends should stall! - tokio::time::sleep(Duration::from_millis(100)).await; - let append_counter_during_seal = append_counter.load(Ordering::Relaxed); - for _ in 0..5 { - tokio::time::sleep(Duration::from_millis(500)).await; - let counter_now = append_counter.load(Ordering::Relaxed); - assert_that!(counter_now, eq(append_counter_during_seal)); - println!("Appends are stalling, counter={}", counter_now); + })?; + + let mut append_counter_before_seal; + loop { + append_counter_before_seal = append_counter.load(Ordering::Relaxed); + if append_counter_before_seal < 100 { + tokio::time::sleep(Duration::from_millis(10)).await; + } else { + break; } + } + + // seal and don't extend the chain. + let _ = bifrost_admin.seal(LOG_ID, SegmentIndex::from(0)).await?; + + // appends should stall! + tokio::time::sleep(Duration::from_millis(100)).await; + let append_counter_during_seal = append_counter.load(Ordering::Relaxed); + for _ in 0..5 { + tokio::time::sleep(Duration::from_millis(500)).await; + let counter_now = append_counter.load(Ordering::Relaxed); + assert_that!(counter_now, eq(append_counter_during_seal)); + println!("Appends are stalling, counter={}", counter_now); + } - for i in 1..=5 { - let last_segment = bifrost + for i in 1..=5 { + let last_segment = bifrost + .inner + .writeable_loglet(LOG_ID) + .await? + .segment_index(); + // allow appender to run a little. + tokio::time::sleep(Duration::from_millis(500)).await; + // seal the loglet and extend with an in-memory one + let new_segment_params = new_single_node_loglet_params(ProviderKind::Local); + bifrost_admin + .seal_and_extend_chain( + LOG_ID, + None, + Version::MIN, + ProviderKind::Local, + new_segment_params, + ) + .await?; + println!("Seal {}", i); + assert_that!( + bifrost .inner .writeable_loglet(LOG_ID) .await? - .segment_index(); - // allow appender to run a little. - tokio::time::sleep(Duration::from_millis(500)).await; - // seal the loglet and extend with an in-memory one - let new_segment_params = new_single_node_loglet_params(ProviderKind::Local); - bifrost_admin - .seal_and_extend_chain( - LOG_ID, - None, - Version::MIN, - ProviderKind::Local, - new_segment_params, - ) - .await?; - println!("Seal {}", i); - assert_that!( - bifrost - .inner - .writeable_loglet(LOG_ID) - .await? - .segment_index(), - gt(last_segment) - ); - } - - // make sure that appends are still happening. - let mut append_counter_after_seal = append_counter.load(Ordering::Relaxed); - tokio::time::sleep(Duration::from_millis(100)).await; - assert_that!(append_counter_after_seal, gt(append_counter_before_seal)); - for _ in 0..5 { - tokio::time::sleep(Duration::from_millis(50)).await; - let counter_now = append_counter.load(Ordering::Relaxed); - assert_that!(counter_now, gt(append_counter_after_seal)); - append_counter_after_seal = counter_now; - } + .segment_index(), + gt(last_segment) + ); + } - googletest::Result::Ok(()) + // make sure that appends are still happening. + let mut append_counter_after_seal = append_counter.load(Ordering::Relaxed); + tokio::time::sleep(Duration::from_millis(100)).await; + assert_that!(append_counter_after_seal, gt(append_counter_before_seal)); + for _ in 0..5 { + tokio::time::sleep(Duration::from_millis(50)).await; + let counter_now = append_counter.load(Ordering::Relaxed); + assert_that!(counter_now, gt(append_counter_after_seal)); + append_counter_after_seal = counter_now; } - .in_tc(&node_env.tc) - .await?; - node_env.tc.shutdown_node("test completed", 0).await; + // questionable. RocksDbManager::get().shutdown().await; Ok(()) } diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 2aaab89b0f..b406faee4b 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -9,11 +9,12 @@ publish = false [features] default = [] -test-util = ["tokio/test-util"] +test-util = ["tokio/test-util", "restate-core-derive"] options_schema = ["dep:schemars"] [dependencies] restate-types = { workspace = true } +restate-core-derive = { workspace = true, optional = true } anyhow = { workspace = true } axum = { workspace = true, default-features = false } @@ -67,6 +68,7 @@ tonic-build = { workspace = true } [dev-dependencies] restate-test-util = { workspace = true } restate-types = { workspace = true, features = ["test-util"] } +restate-core-derive = { workspace = true } googletest = { workspace = true } test-log = { workspace = true } diff --git a/crates/core/derive/Cargo.toml b/crates/core/derive/Cargo.toml new file mode 100644 index 0000000000..355f8beb61 --- /dev/null +++ b/crates/core/derive/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "restate-core-derive" +version = "0.1.0" +authors.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +publish = false + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1.0" +quote = "1" +syn = { version = "2.0", features = ["full"] } diff --git a/crates/core/derive/src/lib.rs b/crates/core/derive/src/lib.rs new file mode 100644 index 0000000000..b4cac8b07a --- /dev/null +++ b/crates/core/derive/src/lib.rs @@ -0,0 +1,39 @@ +// Copyright (c) 2023 - 2025 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +extern crate proc_macro; + +mod tc_test; + +use proc_macro::TokenStream; + +/// Run tests within task-center +/// +/// +/// You can configure the underlying runtime(s) as you would do with tokio +/// ```no_run +/// #[restate_core::test(_args_of_tokio_test)] +/// async fn test_name() { +/// TaskCenter::current(); +/// } +/// ``` +/// +/// A generalised example is +/// ```no_run +/// #[restate_core::test(start_paused = true)]` +/// async fn test_name() { +/// TaskCenter::current(); +/// } +/// ``` +/// +#[proc_macro_attribute] +pub fn test(args: TokenStream, item: TokenStream) -> TokenStream { + tc_test::test(args.into(), item.into(), true).into() +} diff --git a/crates/core/derive/src/tc_test.rs b/crates/core/derive/src/tc_test.rs new file mode 100644 index 0000000000..d7d317a842 --- /dev/null +++ b/crates/core/derive/src/tc_test.rs @@ -0,0 +1,681 @@ +// Copyright (c) 2023 - 2025 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +//! Some parts of this codebase were taken from https://github.com/tokio-rs/tokio/blob/master/tokio-macros/src/entry.rs +//! MIT License + +use proc_macro2::{Span, TokenStream, TokenTree}; +use quote::{quote, quote_spanned, ToTokens}; +use syn::parse::{Parse, ParseStream, Parser}; +use syn::{braced, Attribute, Ident, Path, Signature, Visibility}; + +// syn::AttributeArgs does not implement syn::Parse +type AttributeArgs = syn::punctuated::Punctuated; + +#[derive(Clone, Copy, PartialEq)] +enum RuntimeFlavor { + CurrentThread, + Threaded, +} + +impl RuntimeFlavor { + fn from_str(s: &str) -> Result { + match s { + "current_thread" => Ok(RuntimeFlavor::CurrentThread), + "multi_thread" => Ok(RuntimeFlavor::Threaded), + "single_thread" => Err("The single threaded runtime flavor is called `current_thread`.".to_string()), + "basic_scheduler" => Err("The `basic_scheduler` runtime flavor has been renamed to `current_thread`.".to_string()), + "threaded_scheduler" => Err("The `threaded_scheduler` runtime flavor has been renamed to `multi_thread`.".to_string()), + _ => Err(format!("No such runtime flavor `{}`. The runtime flavors are `current_thread` and `multi_thread`.", s)), + } + } +} + +#[derive(Clone, Copy, PartialEq)] +enum UnhandledPanic { + Ignore, + ShutdownRuntime, +} + +impl UnhandledPanic { + fn from_str(s: &str) -> Result { + match s { + "ignore" => Ok(UnhandledPanic::Ignore), + "shutdown_runtime" => Ok(UnhandledPanic::ShutdownRuntime), + _ => Err(format!("No such unhandled panic behavior `{}`. The unhandled panic behaviors are `ignore` and `shutdown_runtime`.", s)), + } + } + + fn into_tokens(self, crate_path: &TokenStream) -> TokenStream { + match self { + UnhandledPanic::Ignore => quote! { #crate_path::runtime::UnhandledPanic::Ignore }, + UnhandledPanic::ShutdownRuntime => { + quote! { #crate_path::runtime::UnhandledPanic::ShutdownRuntime } + } + } + } +} + +struct FinalConfig { + flavor: RuntimeFlavor, + worker_threads: Option, + start_paused: Option, + crate_name: Option, + unhandled_panic: Option, +} + +/// Config used in case of the attribute not being able to build a valid config +const DEFAULT_ERROR_CONFIG: FinalConfig = FinalConfig { + flavor: RuntimeFlavor::CurrentThread, + worker_threads: None, + start_paused: None, + crate_name: None, + unhandled_panic: None, +}; + +struct Configuration { + rt_multi_thread_available: bool, + default_flavor: RuntimeFlavor, + flavor: Option, + worker_threads: Option<(usize, Span)>, + start_paused: Option<(bool, Span)>, + crate_name: Option, + unhandled_panic: Option<(UnhandledPanic, Span)>, +} + +impl Configuration { + fn new(rt_multi_thread: bool) -> Self { + Configuration { + rt_multi_thread_available: rt_multi_thread, + default_flavor: RuntimeFlavor::CurrentThread, + flavor: None, + worker_threads: None, + start_paused: None, + crate_name: None, + unhandled_panic: None, + } + } + + fn set_flavor(&mut self, runtime: syn::Lit, span: Span) -> Result<(), syn::Error> { + if self.flavor.is_some() { + return Err(syn::Error::new(span, "`flavor` set multiple times.")); + } + + let runtime_str = parse_string(runtime, span, "flavor")?; + let runtime = + RuntimeFlavor::from_str(&runtime_str).map_err(|err| syn::Error::new(span, err))?; + self.flavor = Some(runtime); + Ok(()) + } + + fn set_worker_threads( + &mut self, + worker_threads: syn::Lit, + span: Span, + ) -> Result<(), syn::Error> { + if self.worker_threads.is_some() { + return Err(syn::Error::new( + span, + "`worker_threads` set multiple times.", + )); + } + + let worker_threads = parse_int(worker_threads, span, "worker_threads")?; + if worker_threads == 0 { + return Err(syn::Error::new(span, "`worker_threads` may not be 0.")); + } + self.worker_threads = Some((worker_threads, span)); + Ok(()) + } + + fn set_start_paused(&mut self, start_paused: syn::Lit, span: Span) -> Result<(), syn::Error> { + if self.start_paused.is_some() { + return Err(syn::Error::new(span, "`start_paused` set multiple times.")); + } + + let start_paused = parse_bool(start_paused, span, "start_paused")?; + self.start_paused = Some((start_paused, span)); + Ok(()) + } + + fn set_crate_name(&mut self, name: syn::Lit, span: Span) -> Result<(), syn::Error> { + if self.crate_name.is_some() { + return Err(syn::Error::new(span, "`crate` set multiple times.")); + } + let name_path = parse_path(name, span, "crate")?; + self.crate_name = Some(name_path); + Ok(()) + } + + fn set_unhandled_panic( + &mut self, + unhandled_panic: syn::Lit, + span: Span, + ) -> Result<(), syn::Error> { + if self.unhandled_panic.is_some() { + return Err(syn::Error::new( + span, + "`unhandled_panic` set multiple times.", + )); + } + + let unhandled_panic = parse_string(unhandled_panic, span, "unhandled_panic")?; + let unhandled_panic = + UnhandledPanic::from_str(&unhandled_panic).map_err(|err| syn::Error::new(span, err))?; + self.unhandled_panic = Some((unhandled_panic, span)); + Ok(()) + } + + fn macro_name(&self) -> &'static str { + "restate_core::test" + } + + fn build(&self) -> Result { + use RuntimeFlavor as F; + + let flavor = self.flavor.unwrap_or(self.default_flavor); + let worker_threads = match (flavor, self.worker_threads) { + (F::CurrentThread, Some((_, worker_threads_span))) => { + let msg = format!( + "The `worker_threads` option requires the `multi_thread` runtime flavor. Use `#[{}(flavor = \"multi_thread\")]`", + self.macro_name(), + ); + return Err(syn::Error::new(worker_threads_span, msg)); + } + (F::CurrentThread, None) => None, + (F::Threaded, worker_threads) if self.rt_multi_thread_available => { + worker_threads.map(|(val, _span)| val) + } + (F::Threaded, _) => { + let msg = if self.flavor.is_none() { + "The default runtime flavor is `multi_thread`, but the `rt-multi-thread` feature is disabled." + } else { + "The runtime flavor `multi_thread` requires the `rt-multi-thread` feature." + }; + return Err(syn::Error::new(Span::call_site(), msg)); + } + }; + + let start_paused = match (flavor, self.start_paused) { + (F::Threaded, Some((_, start_paused_span))) => { + let msg = format!( + "The `start_paused` option requires the `current_thread` runtime flavor. Use `#[{}(flavor = \"current_thread\")]`", + self.macro_name(), + ); + return Err(syn::Error::new(start_paused_span, msg)); + } + (F::CurrentThread, Some((start_paused, _))) => Some(start_paused), + (_, None) => None, + }; + + let unhandled_panic = match (flavor, self.unhandled_panic) { + (F::Threaded, Some((_, unhandled_panic_span))) => { + let msg = format!( + "The `unhandled_panic` option requires the `current_thread` runtime flavor. Use `#[{}(flavor = \"current_thread\")]`", + self.macro_name(), + ); + return Err(syn::Error::new(unhandled_panic_span, msg)); + } + (F::CurrentThread, Some((unhandled_panic, _))) => Some(unhandled_panic), + (_, None) => None, + }; + + Ok(FinalConfig { + crate_name: self.crate_name.clone(), + flavor, + worker_threads, + start_paused, + unhandled_panic, + }) + } +} + +fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result { + match int { + syn::Lit::Int(lit) => match lit.base10_parse::() { + Ok(value) => Ok(value), + Err(e) => Err(syn::Error::new( + span, + format!("Failed to parse value of `{}` as integer: {}", field, e), + )), + }, + _ => Err(syn::Error::new( + span, + format!("Failed to parse value of `{}` as integer.", field), + )), + } +} + +fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result { + match int { + syn::Lit::Str(s) => Ok(s.value()), + syn::Lit::Verbatim(s) => Ok(s.to_string()), + _ => Err(syn::Error::new( + span, + format!("Failed to parse value of `{}` as string.", field), + )), + } +} + +fn parse_path(lit: syn::Lit, span: Span, field: &str) -> Result { + match lit { + syn::Lit::Str(s) => { + let err = syn::Error::new( + span, + format!( + "Failed to parse value of `{}` as path: \"{}\"", + field, + s.value() + ), + ); + s.parse::().map_err(|_| err.clone()) + } + _ => Err(syn::Error::new( + span, + format!("Failed to parse value of `{}` as path.", field), + )), + } +} + +fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result { + match bool { + syn::Lit::Bool(b) => Ok(b.value), + _ => Err(syn::Error::new( + span, + format!("Failed to parse value of `{}` as bool.", field), + )), + } +} + +fn build_config( + input: &ItemFn, + args: AttributeArgs, + rt_multi_thread: bool, +) -> Result { + if input.sig.asyncness.is_none() { + let msg = "the `async` keyword is missing from the function declaration"; + return Err(syn::Error::new_spanned(input.sig.fn_token, msg)); + } + + let mut config = Configuration::new(rt_multi_thread); + let macro_name = config.macro_name(); + + for arg in args { + match arg { + syn::Meta::NameValue(namevalue) => { + let ident = namevalue + .path + .get_ident() + .ok_or_else(|| { + syn::Error::new_spanned(&namevalue, "Must have specified ident") + })? + .to_string() + .to_lowercase(); + let lit = match &namevalue.value { + syn::Expr::Lit(syn::ExprLit { lit, .. }) => lit, + expr => return Err(syn::Error::new_spanned(expr, "Must be a literal")), + }; + match ident.as_str() { + "worker_threads" => { + config.set_worker_threads(lit.clone(), syn::spanned::Spanned::span(lit))?; + } + "flavor" => { + config.set_flavor(lit.clone(), syn::spanned::Spanned::span(lit))?; + } + "start_paused" => { + config.set_start_paused(lit.clone(), syn::spanned::Spanned::span(lit))?; + } + "core_threads" => { + let msg = "Attribute `core_threads` is renamed to `worker_threads`"; + return Err(syn::Error::new_spanned(namevalue, msg)); + } + "crate" => { + config.set_crate_name(lit.clone(), syn::spanned::Spanned::span(lit))?; + } + "unhandled_panic" => { + config + .set_unhandled_panic(lit.clone(), syn::spanned::Spanned::span(lit))?; + } + name => { + let msg = format!( + "Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`, `unhandled_panic`", + name, + ); + return Err(syn::Error::new_spanned(namevalue, msg)); + } + } + } + syn::Meta::Path(path) => { + let name = path + .get_ident() + .ok_or_else(|| syn::Error::new_spanned(&path, "Must have specified ident"))? + .to_string() + .to_lowercase(); + let msg = match name.as_str() { + "threaded_scheduler" | "multi_thread" => { + format!( + "Set the runtime flavor with #[{}(flavor = \"multi_thread\")].", + macro_name + ) + } + "basic_scheduler" | "current_thread" | "single_threaded" => { + format!( + "Set the runtime flavor with #[{}(flavor = \"current_thread\")].", + macro_name + ) + } + "flavor" | "worker_threads" | "start_paused" | "crate" | "unhandled_panic" => { + format!("The `{}` attribute requires an argument.", name) + } + name => { + format!("Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`, `unhandled_panic`.", name) + } + }; + return Err(syn::Error::new_spanned(path, msg)); + } + other => { + return Err(syn::Error::new_spanned( + other, + "Unknown attribute inside the macro", + )); + } + } + } + + config.build() +} + +fn parse_knobs(mut input: ItemFn, config: FinalConfig) -> TokenStream { + input.sig.asyncness = None; + + // If type mismatch occurs, the current rustc points to the last statement. + let (last_stmt_start_span, last_stmt_end_span) = { + let mut last_stmt = input.stmts.last().cloned().unwrap_or_default().into_iter(); + + // `Span` on stable Rust has a limitation that only points to the first + // token, not the whole tokens. We can work around this limitation by + // using the first/last span of the tokens like + // `syn::Error::new_spanned` does. + let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span()); + let end = last_stmt.last().map_or(start, |t| t.span()); + (start, end) + }; + + let crate_path = config + .crate_name + .map(ToTokens::into_token_stream) + .unwrap_or_else(|| Ident::new("restate_core", last_stmt_start_span).into_token_stream()); + + let mut tc_builder = quote_spanned! {last_stmt_start_span=> + #crate_path::TaskCenterBuilder::default() + .ingress_runtime_handle(rt.handle().clone()) + .default_runtime_handle(rt.handle().clone()) + }; + let mut rt = match config.flavor { + RuntimeFlavor::CurrentThread => quote_spanned! {last_stmt_start_span=> + ::tokio::runtime::Builder::new_current_thread() + }, + RuntimeFlavor::Threaded => quote_spanned! {last_stmt_start_span=> + ::tokio::runtime::Builder::new_multi_thread() + }, + }; + if let Some(v) = config.worker_threads { + rt = quote_spanned! {last_stmt_start_span=> #rt.worker_threads(#v) }; + } + if let Some(v) = config.start_paused { + rt = quote_spanned! {last_stmt_start_span=> #rt.start_paused(#v) }; + tc_builder = quote_spanned! {last_stmt_start_span=> #tc_builder.pause_time(#v) }; + } + if let Some(v) = config.unhandled_panic { + let unhandled_panic = v.into_tokens(&crate_path); + rt = quote_spanned! {last_stmt_start_span=> #rt.unhandled_panic(#unhandled_panic) }; + } + + let generated_attrs = quote! { + #[::core::prelude::v1::test] + }; + + let body_ident = quote! { body }; + let last_block = quote_spanned! {last_stmt_end_span=> + #[allow(clippy::expect_used, clippy::diverging_sub_expression)] + { + use restate_core::TaskCenterFutureExt as _; + // Make sure that panics exits the process. + let orig_hook = std::panic::take_hook(); + std::panic::set_hook(Box::new(move |panic_info| { + // invoke the default handler and exit the process + orig_hook(panic_info); + std::process::exit(1); + })); + let rt = #rt + .enable_all() + .build() + .expect("Failed building the Runtime"); + + let task_center = #tc_builder + .build() + .expect("Failed building task-center"); + + let ret = rt.block_on(#body_ident.in_tc(&task_center)); + rt.block_on(task_center.shutdown_node("completed", 0)); + ret + } + }; + + let body = input.body(); + + // For test functions pin the body to the stack and use `Pin<&mut dyn + // Future>` to reduce the amount of `Runtime::block_on` (and related + // functions) copies we generate during compilation due to the generic + // parameter `F` (the future to block on). This could have an impact on + // performance, but because it's only for testing it's unlikely to be very + // large. + // + // We don't do this for the main function as it should only be used once so + // there will be no benefit. + let body = { + let output_type = match &input.sig.output { + // For functions with no return value syn doesn't print anything, + // but that doesn't work as `Output` for our boxed `Future`, so + // default to `()` (the same type as the function output). + syn::ReturnType::Default => quote! { () }, + syn::ReturnType::Type(_, ret_type) => quote! { #ret_type }, + }; + quote! { + let body = async #body; + ::tokio::pin!(body); + let body: ::core::pin::Pin<&mut dyn ::core::future::Future> = body; + } + }; + + input.into_tokens(generated_attrs, body, last_block) +} + +fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream { + tokens.extend(error.into_compile_error()); + tokens +} + +pub(crate) fn test(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream { + // If any of the steps for this macro fail, we still want to expand to an item that is as close + // to the expected output as possible. This helps out IDEs such that completions and other + // related features keep working. + let input: ItemFn = match syn::parse2(item.clone()) { + Ok(it) => it, + Err(e) => return token_stream_with_error(item, e), + }; + let config = if let Some(attr) = input.attrs().find(|attr| is_test_attribute(attr)) { + let msg = "second test attribute is supplied, consider removing or changing the order of your test attributes"; + Err(syn::Error::new_spanned(attr, msg)) + } else { + AttributeArgs::parse_terminated + .parse2(args) + .and_then(|args| build_config(&input, args, rt_multi_thread)) + }; + + match config { + Ok(config) => parse_knobs(input, config), + Err(e) => token_stream_with_error(parse_knobs(input, DEFAULT_ERROR_CONFIG), e), + } +} + +// Check whether given attribute is a test attribute of forms: +// * `#[test]` +// * `#[core::prelude::*::test]` or `#[::core::prelude::*::test]` +// * `#[std::prelude::*::test]` or `#[::std::prelude::*::test]` +fn is_test_attribute(attr: &Attribute) -> bool { + let path = match &attr.meta { + syn::Meta::Path(path) => path, + _ => return false, + }; + let candidates = [ + ["core", "prelude", "*", "test"], + ["std", "prelude", "*", "test"], + ]; + if path.leading_colon.is_none() + && path.segments.len() == 1 + && path.segments[0].arguments.is_none() + && path.segments[0].ident == "test" + { + return true; + } else if path.segments.len() != candidates[0].len() { + return false; + } + candidates.into_iter().any(|segments| { + path.segments.iter().zip(segments).all(|(segment, path)| { + segment.arguments.is_none() && (path == "*" || segment.ident == path) + }) + }) +} + +struct ItemFn { + outer_attrs: Vec, + vis: Visibility, + sig: Signature, + brace_token: syn::token::Brace, + inner_attrs: Vec, + stmts: Vec, +} + +impl ItemFn { + /// Access all attributes of the function item. + fn attrs(&self) -> impl Iterator { + self.outer_attrs.iter().chain(self.inner_attrs.iter()) + } + + /// Get the body of the function item in a manner so that it can be + /// conveniently used with the `quote!` macro. + fn body(&self) -> Body<'_> { + Body { + brace_token: self.brace_token, + stmts: &self.stmts, + } + } + + /// Convert our local function item into a token stream. + fn into_tokens( + self, + generated_attrs: proc_macro2::TokenStream, + body: proc_macro2::TokenStream, + last_block: proc_macro2::TokenStream, + ) -> TokenStream { + let mut tokens = proc_macro2::TokenStream::new(); + // Outer attributes are simply streamed as-is. + for attr in self.outer_attrs { + attr.to_tokens(&mut tokens); + } + + // Inner attributes require extra care, since they're not supported on + // blocks (which is what we're expanded into) we instead lift them + // outside of the function. This matches the behavior of `syn`. + for mut attr in self.inner_attrs { + attr.style = syn::AttrStyle::Outer; + attr.to_tokens(&mut tokens); + } + + // Add generated macros at the end, so macros processed later are aware of them. + generated_attrs.to_tokens(&mut tokens); + + self.vis.to_tokens(&mut tokens); + self.sig.to_tokens(&mut tokens); + + self.brace_token.surround(&mut tokens, |tokens| { + body.to_tokens(tokens); + last_block.to_tokens(tokens); + }); + + tokens + } +} + +impl Parse for ItemFn { + #[inline] + fn parse(input: ParseStream<'_>) -> syn::Result { + // This parse implementation has been largely lifted from `syn`, with + // the exception of: + // * We don't have access to the plumbing necessary to parse inner + // attributes in-place. + // * We do our own statements parsing to avoid recursively parsing + // entire statements and only look for the parts we're interested in. + + let outer_attrs = input.call(Attribute::parse_outer)?; + let vis: Visibility = input.parse()?; + let sig: Signature = input.parse()?; + + let content; + let brace_token = braced!(content in input); + let inner_attrs = Attribute::parse_inner(&content)?; + + let mut buf = proc_macro2::TokenStream::new(); + let mut stmts = Vec::new(); + + while !content.is_empty() { + if let Some(semi) = content.parse::>()? { + semi.to_tokens(&mut buf); + stmts.push(buf); + buf = proc_macro2::TokenStream::new(); + continue; + } + + // Parse a single token tree and extend our current buffer with it. + // This avoids parsing the entire content of the sub-tree. + buf.extend([content.parse::()?]); + } + + if !buf.is_empty() { + stmts.push(buf); + } + + Ok(Self { + outer_attrs, + vis, + sig, + brace_token, + inner_attrs, + stmts, + }) + } +} + +struct Body<'a> { + brace_token: syn::token::Brace, + // Statements, with terminating `;`. + stmts: &'a [TokenStream], +} + +impl ToTokens for Body<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + self.brace_token.surround(tokens, |tokens| { + for stmt in self.stmts { + stmt.to_tokens(tokens); + } + }); + } +} diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 93856f6247..6c85c1a32d 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -18,6 +18,10 @@ mod task_center; pub mod worker_api; pub use error::*; +#[cfg(any(test, feature = "test-util"))] +#[doc(inline)] +pub use restate_core_derive::test; + pub use metadata::{ spawn_metadata_manager, Metadata, MetadataBuilder, MetadataKind, MetadataManager, MetadataWriter, SyncError, TargetVersion, @@ -27,5 +31,11 @@ pub use task_center::*; #[cfg(any(test, feature = "test-util"))] mod test_env; +#[cfg(any(test, feature = "test-util"))] +mod test_env2; + #[cfg(any(test, feature = "test-util"))] pub use test_env::{create_mock_nodes_config, NoOpMessageHandler, TestCoreEnv, TestCoreEnvBuilder}; + +#[cfg(any(test, feature = "test-util"))] +pub use test_env2::{TestCoreEnv2, TestCoreEnvBuilder2}; diff --git a/crates/core/src/task_center/builder.rs b/crates/core/src/task_center/builder.rs index dfb3787f62..7325bbfa9c 100644 --- a/crates/core/src/task_center/builder.rs +++ b/crates/core/src/task_center/builder.rs @@ -33,7 +33,6 @@ pub struct TaskCenterBuilder { ingress_runtime_handle: Option, ingress_runtime: Option, options: Option, - #[cfg(any(test, feature = "test-util"))] pause_time: bool, } @@ -67,13 +66,11 @@ impl TaskCenterBuilder { self } - #[cfg(any(test, feature = "test-util"))] pub fn pause_time(mut self, pause_time: bool) -> Self { self.pause_time = pause_time; self } - #[cfg(any(test, feature = "test-util"))] pub fn default_for_tests() -> Self { Self::default() .ingress_runtime_handle(tokio::runtime::Handle::current()) @@ -85,10 +82,6 @@ impl TaskCenterBuilder { let options = self.options.unwrap_or_default(); if self.default_runtime_handle.is_none() { let mut default_runtime_builder = tokio_builder("worker", &options); - #[cfg(any(test, feature = "test-util"))] - if self.pause_time { - default_runtime_builder.start_paused(self.pause_time); - } let default_runtime = default_runtime_builder.build()?; self.default_runtime_handle = Some(default_runtime.handle().clone()); self.default_runtime = Some(default_runtime); @@ -96,10 +89,6 @@ impl TaskCenterBuilder { if self.ingress_runtime_handle.is_none() { let mut ingress_runtime_builder = tokio_builder("ingress", &options); - #[cfg(any(test, feature = "test-util"))] - if self.pause_time { - ingress_runtime_builder.start_paused(self.pause_time); - } let ingress_runtime = ingress_runtime_builder.build()?; self.ingress_runtime_handle = Some(ingress_runtime.handle().clone()); self.ingress_runtime = Some(ingress_runtime); @@ -113,6 +102,7 @@ impl TaskCenterBuilder { self.ingress_runtime_handle.unwrap(), self.default_runtime, self.ingress_runtime, + self.pause_time, )) } } diff --git a/crates/core/src/task_center/mod.rs b/crates/core/src/task_center/mod.rs index 1c48897d00..dd40530d69 100644 --- a/crates/core/src/task_center/mod.rs +++ b/crates/core/src/task_center/mod.rs @@ -72,7 +72,8 @@ pub enum RuntimeError { } /// Task center is used to manage long-running and background tasks and their lifecycle. -#[derive(Clone)] +#[derive(Clone, derive_more::Debug)] +#[debug("TaskCenter({})", inner.id)] pub struct TaskCenter { inner: Arc, } @@ -85,6 +86,9 @@ impl TaskCenter { ingress_runtime_handle: tokio::runtime::Handle, default_runtime: Option, ingress_runtime: Option, + // used in tests to start all runtimes with clock paused. Note that this only impacts + // partition processor runtimes + pause_time: bool, ) -> Self { metric_definitions::describe_metrics(); let root_task_context = TaskContext { @@ -96,6 +100,7 @@ impl TaskCenter { }; Self { inner: Arc::new(TaskCenterInner { + id: rand::random(), start_time: Instant::now(), default_runtime_handle, default_runtime, @@ -108,6 +113,7 @@ impl TaskCenter { global_metadata: OnceLock::new(), managed_runtimes: Mutex::new(HashMap::with_capacity(64)), root_task_context, + pause_time, }), } } @@ -508,6 +514,10 @@ impl TaskCenter { // todo: configure the runtime according to a new runtime kind perhaps? let thread_builder = std::thread::Builder::new().name(format!("rt:{}", runtime_name)); let mut builder = tokio::runtime::Builder::new_current_thread(); + + #[cfg(any(test, feature = "test-util"))] + builder.start_paused(self.inner.pause_time); + let rt = builder .enable_all() .build() @@ -886,6 +896,12 @@ impl TaskCenter { } struct TaskCenterInner { + #[allow(dead_code)] + /// used in Debug impl to distinguish between multiple task-centers + id: u16, + /// Should we start new runtimes with paused clock? + #[allow(dead_code)] + pause_time: bool, default_runtime_handle: tokio::runtime::Handle, ingress_runtime_handle: tokio::runtime::Handle, managed_runtimes: Mutex>>, diff --git a/crates/core/src/test_env.rs b/crates/core/src/test_env.rs index 5ebdd702b9..1d7143aab8 100644 --- a/crates/core/src/test_env.rs +++ b/crates/core/src/test_env.rs @@ -217,7 +217,7 @@ impl TestCoreEnvBuilder { Precondition::None, ) .await - .expect("sot store scheduling plan in metadata store"); + .expect("to store scheduling plan in metadata store"); let _ = self .metadata @@ -227,6 +227,7 @@ impl TestCoreEnvBuilder { ) .await .unwrap(); + self.metadata_writer.set_my_node_id(self.my_node_id); TestCoreEnv { diff --git a/crates/core/src/test_env2.rs b/crates/core/src/test_env2.rs new file mode 100644 index 0000000000..17f9664467 --- /dev/null +++ b/crates/core/src/test_env2.rs @@ -0,0 +1,303 @@ +// Copyright (c) 2023 - 2025 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +use std::marker::PhantomData; +use std::str::FromStr; +use std::sync::Arc; + +use futures::Stream; + +use restate_types::cluster_controller::{ReplicationStrategy, SchedulingPlan}; +use restate_types::config::NetworkingOptions; +use restate_types::logs::metadata::{bootstrap_logs_metadata, ProviderKind}; +use restate_types::metadata_store::keys::{ + BIFROST_CONFIG_KEY, NODES_CONFIG_KEY, PARTITION_TABLE_KEY, SCHEDULING_PLAN_KEY, +}; +use restate_types::net::codec::{Targeted, WireDecode}; +use restate_types::net::metadata::MetadataKind; +use restate_types::net::AdvertisedAddress; +use restate_types::nodes_config::{LogServerConfig, NodeConfig, NodesConfiguration, Role}; +use restate_types::partition_table::PartitionTable; +use restate_types::protobuf::node::Message; +use restate_types::{GenerationalNodeId, Version}; + +use crate::metadata_store::{MetadataStoreClient, Precondition}; +use crate::network::{ + ConnectionManager, FailingConnector, Incoming, MessageHandler, MessageRouterBuilder, + NetworkError, Networking, ProtocolError, TransportConnect, +}; +use crate::TaskCenter; +use crate::{spawn_metadata_manager, MetadataBuilder, TaskId}; +use crate::{Metadata, MetadataManager, MetadataWriter}; + +pub struct TestCoreEnvBuilder2 { + pub my_node_id: GenerationalNodeId, + pub metadata_manager: MetadataManager, + pub metadata_writer: MetadataWriter, + pub metadata: Metadata, + pub networking: Networking, + pub nodes_config: NodesConfiguration, + pub provider_kind: ProviderKind, + pub router_builder: MessageRouterBuilder, + pub partition_table: PartitionTable, + pub scheduling_plan: SchedulingPlan, + pub metadata_store_client: MetadataStoreClient, +} + +impl TestCoreEnvBuilder2 { + pub fn with_incoming_only_connector() -> Self { + let metadata_builder = MetadataBuilder::default(); + let net_opts = NetworkingOptions::default(); + let connection_manager = + ConnectionManager::new_incoming_only(metadata_builder.to_metadata()); + let networking = Networking::with_connection_manager( + metadata_builder.to_metadata(), + net_opts, + connection_manager, + ); + + TestCoreEnvBuilder2::with_networking(networking, metadata_builder) + } +} +impl TestCoreEnvBuilder2 { + pub fn with_transport_connector(connector: Arc) -> TestCoreEnvBuilder2 { + let metadata_builder = MetadataBuilder::default(); + let net_opts = NetworkingOptions::default(); + let connection_manager = + ConnectionManager::new(metadata_builder.to_metadata(), connector, net_opts.clone()); + let networking = Networking::with_connection_manager( + metadata_builder.to_metadata(), + net_opts, + connection_manager, + ); + + TestCoreEnvBuilder2::with_networking(networking, metadata_builder) + } + + pub fn with_networking(networking: Networking, metadata_builder: MetadataBuilder) -> Self { + let my_node_id = GenerationalNodeId::new(1, 1); + let metadata_store_client = MetadataStoreClient::new_in_memory(); + let metadata = metadata_builder.to_metadata(); + let metadata_manager = + MetadataManager::new(metadata_builder, metadata_store_client.clone()); + let metadata_writer = metadata_manager.writer(); + let router_builder = MessageRouterBuilder::default(); + let nodes_config = NodesConfiguration::new(Version::MIN, "test-cluster".to_owned()); + let partition_table = PartitionTable::with_equally_sized_partitions(Version::MIN, 10); + let scheduling_plan = + SchedulingPlan::from(&partition_table, ReplicationStrategy::OnAllNodes); + TaskCenter::try_set_global_metadata(metadata.clone()); + + // Use memory-loglet as a default if in test-mode + #[cfg(any(test, feature = "test-util"))] + let provider_kind = ProviderKind::InMemory; + #[cfg(not(any(test, feature = "test-util")))] + let provider_kind = ProviderKind::Local; + + TestCoreEnvBuilder2 { + my_node_id, + metadata_manager, + metadata_writer, + metadata, + networking, + nodes_config, + router_builder, + partition_table, + scheduling_plan, + metadata_store_client, + provider_kind, + } + } + + pub fn set_nodes_config(mut self, nodes_config: NodesConfiguration) -> Self { + self.nodes_config = nodes_config; + self + } + + pub fn set_partition_table(mut self, partition_table: PartitionTable) -> Self { + self.partition_table = partition_table; + self + } + + pub fn set_scheduling_plan(mut self, scheduling_plan: SchedulingPlan) -> Self { + self.scheduling_plan = scheduling_plan; + self + } + + pub fn set_my_node_id(mut self, my_node_id: GenerationalNodeId) -> Self { + self.my_node_id = my_node_id; + self + } + + pub fn set_provider_kind(mut self, provider_kind: ProviderKind) -> Self { + self.provider_kind = provider_kind; + self + } + + pub fn add_mock_nodes_config(mut self) -> Self { + self.nodes_config = + create_mock_nodes_config(self.my_node_id.raw_id(), self.my_node_id.raw_generation()); + self + } + + pub fn add_message_handler(mut self, handler: H) -> Self + where + H: MessageHandler + Send + Sync + 'static, + { + self.router_builder.add_message_handler(handler); + self + } + + pub async fn build(mut self) -> TestCoreEnv2 { + self.metadata_manager + .register_in_message_router(&mut self.router_builder); + self.networking + .connection_manager() + .set_message_router(self.router_builder.build()); + + let metadata_manager_task = + spawn_metadata_manager(self.metadata_manager).expect("metadata manager should start"); + + self.metadata_store_client + .put( + NODES_CONFIG_KEY.clone(), + &self.nodes_config, + Precondition::None, + ) + .await + .expect("to store nodes config in metadata store"); + self.metadata_writer + .submit(Arc::new(self.nodes_config.clone())); + + let logs = bootstrap_logs_metadata( + self.provider_kind, + None, + self.partition_table.num_partitions(), + ); + self.metadata_store_client + .put(BIFROST_CONFIG_KEY.clone(), &logs, Precondition::None) + .await + .expect("to store bifrost config in metadata store"); + self.metadata_writer.submit(Arc::new(logs)); + + self.metadata_store_client + .put( + PARTITION_TABLE_KEY.clone(), + &self.partition_table, + Precondition::None, + ) + .await + .expect("to store partition table in metadata store"); + self.metadata_writer.submit(Arc::new(self.partition_table)); + + self.metadata_store_client + .put( + SCHEDULING_PLAN_KEY.clone(), + &self.scheduling_plan, + Precondition::None, + ) + .await + .expect("to store scheduling plan in metadata store"); + + let _ = self + .metadata + .wait_for_version( + MetadataKind::NodesConfiguration, + self.nodes_config.version(), + ) + .await + .unwrap(); + + self.metadata_writer.set_my_node_id(self.my_node_id); + + TestCoreEnv2 { + metadata: self.metadata, + metadata_manager_task, + metadata_writer: self.metadata_writer, + networking: self.networking, + metadata_store_client: self.metadata_store_client, + } + } +} + +// This might need to be moved to a better place in the future. +pub struct TestCoreEnv2 { + pub metadata: Metadata, + pub metadata_writer: MetadataWriter, + pub networking: Networking, + pub metadata_manager_task: TaskId, + pub metadata_store_client: MetadataStoreClient, +} + +impl TestCoreEnv2 { + pub async fn create_with_single_node(node_id: u32, generation: u32) -> Self { + TestCoreEnvBuilder2::with_incoming_only_connector() + .set_my_node_id(GenerationalNodeId::new(node_id, generation)) + .add_mock_nodes_config() + .build() + .await + } +} + +impl TestCoreEnv2 { + pub async fn accept_incoming_connection( + &self, + incoming: S, + ) -> Result + Unpin + Send + 'static, NetworkError> + where + S: Stream> + Unpin + Send + 'static, + { + self.networking + .connection_manager() + .accept_incoming_connection(incoming) + .await + } +} + +pub fn create_mock_nodes_config(node_id: u32, generation: u32) -> NodesConfiguration { + let mut nodes_config = NodesConfiguration::new(Version::MIN, "test-cluster".to_owned()); + let address = AdvertisedAddress::from_str("http://127.0.0.1:5122/").unwrap(); + let node_id = GenerationalNodeId::new(node_id, generation); + let roles = Role::Admin | Role::Worker; + let my_node = NodeConfig::new( + format!("MyNode-{}", node_id), + node_id, + address, + roles, + LogServerConfig::default(), + ); + nodes_config.upsert_node(my_node); + nodes_config +} + +/// No-op message handler which simply drops the received messages. Useful if you don't want to +/// react to network messages. +pub struct NoOpMessageHandler { + phantom_data: PhantomData, +} + +impl Default for NoOpMessageHandler { + fn default() -> Self { + NoOpMessageHandler { + phantom_data: PhantomData, + } + } +} + +impl MessageHandler for NoOpMessageHandler +where + M: WireDecode + Targeted + Send + Sync, +{ + type MessageType = M; + + async fn on_message(&self, _msg: Incoming) { + // no-op + } +}