From b541f0507aa18af146fbdde1c04fda9c8abae584 Mon Sep 17 00:00:00 2001 From: Parker Timmerman Date: Thu, 3 Apr 2025 14:01:58 -0400 Subject: [PATCH] start, update the mz_sql_server_util::Client to support tunnelling * introduces a new Config type that wraps tiberius::Config * adds a 'resources' field to mz_sql_server_util::Client that can be used to keep tunnels open * updates the 'cdc' example to use the new Client config --- Cargo.lock | 1 + src/sql-server-util/BUILD.bazel | 3 + src/sql-server-util/Cargo.toml | 10 +-- src/sql-server-util/examples/cdc.rs | 7 +- src/sql-server-util/src/config.rs | 122 ++++++++++++++++++++++++++++ src/sql-server-util/src/lib.rs | 87 ++++++++++++++++---- src/testdrive/Cargo.toml | 2 +- 7 files changed, 210 insertions(+), 22 deletions(-) create mode 100644 src/sql-server-util/src/config.rs diff --git a/Cargo.lock b/Cargo.lock index 85d1bfd62065c..61b7bcc5dad78 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7074,6 +7074,7 @@ dependencies = [ "mz-ore", "mz-proto", "mz-repr", + "mz-ssh-util", "ordered-float 5.0.0", "proptest", "proptest-derive", diff --git a/src/sql-server-util/BUILD.bazel b/src/sql-server-util/BUILD.bazel index a2500bc9fa1dd..17784e85baaab 100644 --- a/src/sql-server-util/BUILD.bazel +++ b/src/sql-server-util/BUILD.bazel @@ -35,6 +35,7 @@ rust_library( "//src/ore:mz_ore", "//src/proto:mz_proto", "//src/repr:mz_repr", + "//src/ssh-util:mz_ssh_util", ] + all_crate_deps(normal = True), ) @@ -69,6 +70,7 @@ rust_test( "//src/ore:mz_ore", "//src/proto:mz_proto", "//src/repr:mz_repr", + "//src/ssh-util:mz_ssh_util", ] + all_crate_deps( normal = True, normal_dev = True, @@ -82,6 +84,7 @@ rust_doc_test( "//src/ore:mz_ore", "//src/proto:mz_proto", "//src/repr:mz_repr", + "//src/ssh-util:mz_ssh_util", ] + all_crate_deps( normal = True, normal_dev = True, diff --git a/src/sql-server-util/Cargo.toml b/src/sql-server-util/Cargo.toml index 03616dc0cc7f9..dcbbf4091905f 100644 --- a/src/sql-server-util/Cargo.toml +++ b/src/sql-server-util/Cargo.toml @@ -9,6 +9,9 @@ publish = false [lints] workspace = true +[[example]] +name = "cdc" + [dependencies] anyhow = "1.0.66" async-stream = "0.3.3" @@ -22,6 +25,7 @@ itertools = "0.12.1" mz-ore = { path = "../ore", features = ["async"] } mz-proto = { path = "../proto" } mz-repr = { path = "../repr" } +mz-ssh-util = { path = "../ssh-util" } ordered-float = { version = "5.0.0", features = ["serde"] } proptest = { version = "1.6.0", default-features = false, features = ["std"] } proptest-derive = { version = "0.5.1", features = ["boxed_union"] } @@ -30,11 +34,7 @@ serde = { version = "1.0.218", features = ["derive"] } smallvec = { version = "1.14.0", features = ["union"] } static_assertions = "1.1" thiserror = "2.0.11" -tiberius = { version = "0.12", features = [ - "chrono", - "sql-browser-tokio", - "tds73", -], default-features = false } +tiberius = { version = "0.12", features = [ "chrono", "sql-browser-tokio", "tds73"], default-features = false } timely = "0.20.0" tokio = { version = "1.44.1", features = ["net"] } tokio-stream = "0.1.17" diff --git a/src/sql-server-util/examples/cdc.rs b/src/sql-server-util/examples/cdc.rs index 5ae2077721ce2..02291c681b660 100644 --- a/src/sql-server-util/examples/cdc.rs +++ b/src/sql-server-util/examples/cdc.rs @@ -36,7 +36,9 @@ //! 4. Watch CDC events come streaming in as you change the table! use futures::StreamExt; -use mz_sql_server_util::Client; +use mz_ore::future::InTask; +use mz_sql_server_util::config::TunnelConfig; +use mz_sql_server_util::{Client, Config}; use tracing_subscriber::EnvFilter; #[tokio::main] @@ -53,6 +55,7 @@ async fn main() -> Result<(), anyhow::Error> { config.authentication(tiberius::AuthMethod::sql_server("SA", "password123?")); config.trust_cert(); + let config = Config::new(config, TunnelConfig::Direct, InTask::No); let (mut client, connection) = Client::connect(config).await?; mz_ore::task::spawn(|| "sql-server connection", async move { connection.await }); tracing::info!("connection successful!"); @@ -74,7 +77,7 @@ async fn main() -> Result<(), anyhow::Error> { for instance in capture_instances { cdc_handle = cdc_handle.start_lsn(instance, lsn); } - // Get a stream of changes from the table with the provided LSN. + // Get a stream of changes from the table. let changes = cdc_handle.into_stream(); let mut changes = std::pin::pin!(changes); while let Some(change) = changes.next().await { diff --git a/src/sql-server-util/src/config.rs b/src/sql-server-util/src/config.rs new file mode 100644 index 0000000000000..18de7a6167711 --- /dev/null +++ b/src/sql-server-util/src/config.rs @@ -0,0 +1,122 @@ +// Copyright Materialize, Inc. and contributors. All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +use anyhow::Context; +use mz_ore::future::InTask; +use mz_repr::CatalogItemId; +use mz_ssh_util::tunnel::{SshTimeoutConfig, SshTunnelConfig}; +use mz_ssh_util::tunnel_manager::SshTunnelManager; +use proptest_derive::Arbitrary; +use serde::{Deserialize, Serialize}; + +/// Materialize specific configuration for SQL Server connections. +/// +/// This wraps a [`tiberius::Config`] so we can configure a tunnel over SSH, AWS +/// PrivateLink, or various other techniques, via a [`TunnelConfig`] +#[derive(Clone, Debug)] +pub struct Config { + /// SQL Server specific configuration. + pub(crate) inner: tiberius::Config, + /// Details of how we'll connect to the upstream SQL Server instance. + pub(crate) tunnel: TunnelConfig, + /// If all of the I/O for this connection will be done in a separate task. + /// + /// Note: This is used to prevent accidentally doing I/O in timely threads. + pub(crate) in_task: InTask, +} + +impl Config { + pub fn new(inner: tiberius::Config, tunnel: TunnelConfig, in_task: InTask) -> Self { + Config { + inner, + tunnel, + in_task, + } + } + + /// Create a new [`Config`] from an ActiveX Data Object. + /// + /// Generally this is only used in test environments, see [`Config::new`] + /// for regular/production use cases. + pub fn from_ado_string(s: &str) -> Result { + let inner = tiberius::Config::from_ado_string(s).context("tiberius config")?; + Ok(Config { + inner, + tunnel: TunnelConfig::Direct, + in_task: InTask::No, + }) + } +} + +/// Configures an optional tunnel for use when connecting to a SQL Server database. +/// +/// TODO(sql_server2): De-duplicate this with MySQL and Postgres sources. +#[derive(Debug, Clone)] +pub enum TunnelConfig { + /// No tunnelling. + Direct, + /// Establish a TCP connection to the database via an SSH tunnel. + Ssh { + /// Config for opening the SSH tunnel. + config: SshTunnelConfig, + /// Global manager of SSH tunnels. + manager: SshTunnelManager, + /// Timeout config for the SSH tunnel. + timeout: SshTimeoutConfig, + // TODO(sql_server1): Remove these fields by forking the `tiberius` + // crate and expose the `get_host` and `get_port` methods. + // + // See: + host: String, + port: u16, + }, + /// Establish a TCP connection to the database via an AWS PrivateLink service. + AwsPrivatelink { + /// The ID of the AWS PrivateLink service. + connection_id: CatalogItemId, + }, +} + +/// Level of encryption to use with a SQL Server connection. +/// +/// Mirror of [`tiberius::EncryptionLevel`] but we define our own so we can +/// implement traits like [`Serialize`]. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Arbitrary, Serialize, Deserialize)] +pub enum EncryptionLevel { + /// Do not use encryption at all. + None, + /// Only use encryption for the login procedure. + Login, + /// Use encryption for everything, if possible. + Preferred, + /// Require encryption, failing if not possible. + Required, +} + +impl From for EncryptionLevel { + fn from(value: tiberius::EncryptionLevel) -> Self { + match value { + tiberius::EncryptionLevel::NotSupported => EncryptionLevel::None, + tiberius::EncryptionLevel::Off => EncryptionLevel::Login, + tiberius::EncryptionLevel::On => EncryptionLevel::Preferred, + tiberius::EncryptionLevel::Required => EncryptionLevel::Required, + } + } +} + +impl From for tiberius::EncryptionLevel { + fn from(value: EncryptionLevel) -> Self { + match value { + EncryptionLevel::None => tiberius::EncryptionLevel::NotSupported, + EncryptionLevel::Login => tiberius::EncryptionLevel::Off, + EncryptionLevel::Preferred => tiberius::EncryptionLevel::On, + EncryptionLevel::Required => tiberius::EncryptionLevel::Required, + } + } +} diff --git a/src/sql-server-util/src/lib.rs b/src/sql-server-util/src/lib.rs index 870a96c6b5ab6..e44c0882a0a79 100644 --- a/src/sql-server-util/src/lib.rs +++ b/src/sql-server-util/src/lib.rs @@ -7,6 +7,7 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. +use std::any::Any; use std::borrow::Cow; use std::future::IntoFuture; use std::pin::Pin; @@ -25,11 +26,14 @@ use tokio::sync::oneshot; use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt}; pub mod cdc; +pub mod config; pub mod desc; pub mod inspect; -// Re-export tiberius' Config type since it's needed by our Client wrapper. -pub use tiberius::Config; +pub use config::Config; +pub use desc::{ProtoSqlServerColumnDesc, ProtoSqlServerTableDesc}; + +use crate::config::TunnelConfig; /// Higher level wrapper around a [`tiberius::Client`] that models transaction /// management like other database clients. @@ -48,22 +52,73 @@ pub struct Client { static_assertions::assert_not_impl_all!(Client: Clone); impl Client { - pub async fn connect(config: tiberius::Config) -> Result<(Self, Connection), SqlServerError> { - let tcp = TcpStream::connect(config.get_addr()).await?; + /// Connect to the specified SQL Server instance, returning a [`Client`] + /// that can be used to query it and a [`Connection`] that must be polled + /// to send and receive results. + /// + /// TODO(sql_server2): Maybe return a `ClientBuilder` here that implements + /// IntoFuture and does the default good thing of moving the `Connection` + /// into a tokio task? And a `.raw()` option that will instead return both + /// the Client and Connection for manual polling. + pub async fn connect(config: Config) -> Result<(Self, Connection), SqlServerError> { + // Setup our tunnelling and return any resources that need to be kept + // alive for the duration of the connection. + let (tcp, resources): (_, Option>) = match &config.tunnel { + TunnelConfig::Direct => { + let tcp = TcpStream::connect(config.inner.get_addr()) + .await + .context("direct")?; + (tcp, None) + } + TunnelConfig::Ssh { + config: ssh_config, + manager, + timeout, + host, + port, + } => { + // N.B. If this tunnel is dropped it will close so we need to + // keep it alive for the duration of the connection. + let tunnel = manager + .connect(ssh_config.clone(), host, *port, *timeout, config.in_task) + .await?; + let tcp = TcpStream::connect(tunnel.local_addr()) + .await + .context("ssh tunnel")?; + + (tcp, Some(Box::new(tunnel))) + } + TunnelConfig::AwsPrivatelink { connection_id: _ } => { + // TODO(sql_server1): Getting this right is tricky because + // there is some subtle logic with hostname validation. + return Err(SqlServerError::Generic(anyhow::anyhow!( + "TODO(sql_server1): Support PrivateLink connections" + ))); + } + }; + tcp.set_nodelay(true)?; - Self::connect_raw(config, tcp).await + Self::connect_raw(config, tcp, resources).await } pub async fn connect_raw( - config: tiberius::Config, + config: Config, tcp: tokio::net::TcpStream, + resources: Option>, ) -> Result<(Self, Connection), SqlServerError> { - let client = tiberius::Client::connect(config, tcp.compat_write()) - .await - .context("connecting to SQL Server")?; + let client = tiberius::Client::connect(config.inner, tcp.compat_write()).await?; let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - Ok((Client { tx }, Connection { rx, client })) + // TODO(sql_server2): Add a lot more logging here like the Postgres and MySQL clients have. + + Ok(( + Client { tx }, + Connection { + rx, + client, + _resources: resources, + }, + )) } /// Executes SQL statements in SQL Server, returning the number of rows effected. @@ -93,7 +148,7 @@ impl Client { .send(Request { tx, kind }) .context("sending request")?; - let response = rx.await.context("channel")?.context("execute")?; + let response = rx.await.context("channel")??; match response { Response::Execute { rows_affected } => Ok(rows_affected), other @ Response::Rows(_) | other @ Response::RowStream { .. } => { @@ -131,7 +186,7 @@ impl Client { .send(Request { tx, kind }) .context("sending request")?; - let response = rx.await.context("channel")?.context("query")?; + let response = rx.await.context("channel")??; match response { Response::Rows(rows) => Ok(rows), other @ Response::Execute { .. } | other @ Response::RowStream { .. } => Err( @@ -203,7 +258,7 @@ impl Client { .send(Request { tx, kind }) .context("sending request")?; - let response = rx.await.context("channel")?.context("simple_query")?; + let response = rx.await.context("channel")??; match response { Response::Rows(rows) => Ok(rows), other @ Response::Execute { .. } | other @ Response::RowStream { .. } => Err( @@ -443,14 +498,18 @@ enum RequestKind { } pub struct Connection { + /// Other end of the channel that [`Client`] holds. rx: UnboundedReceiver, + /// Actual client that we use to send requests. client: tiberius::Client>, + /// Resources (e.g. SSH tunnel) that need to be held open for the life of this connection. + _resources: Option>, } impl Connection { async fn run(mut self) { while let Some(Request { tx, kind }) = self.rx.recv().await { - tracing::debug!(?kind, "processing SQL Server query"); + tracing::trace!(?kind, "processing SQL Server query"); let result = Connection::handle_request(&mut self.client, kind).await; let (response, maybe_extra_work) = match result { Ok((response, work)) => (Ok(response), work), diff --git a/src/testdrive/Cargo.toml b/src/testdrive/Cargo.toml index 237049d13ce0b..dd946fd161ce8 100644 --- a/src/testdrive/Cargo.toml +++ b/src/testdrive/Cargo.toml @@ -68,7 +68,7 @@ serde_json = { version = "1.0.125", features = ["raw_value"] } similar = "2.7.0" tempfile = "3.19.0" termcolor = "1.4.1" -tiberius = { version = "0.12", default-features = false } +tiberius = { version = "0.12", features = ["sql-browser-tokio", "tds73"], default-features = false } time = "0.3.17" tracing = "0.1.37" tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }