Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src/sql-server-util/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ rust_library(
"//src/ore:mz_ore",
"//src/proto:mz_proto",
"//src/repr:mz_repr",
"//src/ssh-util:mz_ssh_util",
] + all_crate_deps(normal = True),
)

Expand Down Expand Up @@ -69,6 +70,7 @@ rust_test(
"//src/ore:mz_ore",
"//src/proto:mz_proto",
"//src/repr:mz_repr",
"//src/ssh-util:mz_ssh_util",
] + all_crate_deps(
normal = True,
normal_dev = True,
Expand All @@ -82,6 +84,7 @@ rust_doc_test(
"//src/ore:mz_ore",
"//src/proto:mz_proto",
"//src/repr:mz_repr",
"//src/ssh-util:mz_ssh_util",
] + all_crate_deps(
normal = True,
normal_dev = True,
Expand Down
10 changes: 5 additions & 5 deletions src/sql-server-util/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ publish = false
[lints]
workspace = true

[[example]]
name = "cdc"

[dependencies]
anyhow = "1.0.66"
async-stream = "0.3.3"
Expand All @@ -22,6 +25,7 @@ itertools = "0.12.1"
mz-ore = { path = "../ore", features = ["async"] }
mz-proto = { path = "../proto" }
mz-repr = { path = "../repr" }
mz-ssh-util = { path = "../ssh-util" }
ordered-float = { version = "5.0.0", features = ["serde"] }
proptest = { version = "1.6.0", default-features = false, features = ["std"] }
proptest-derive = { version = "0.5.1", features = ["boxed_union"] }
Expand All @@ -30,11 +34,7 @@ serde = { version = "1.0.218", features = ["derive"] }
smallvec = { version = "1.14.0", features = ["union"] }
static_assertions = "1.1"
thiserror = "2.0.11"
tiberius = { version = "0.12", features = [
"chrono",
"sql-browser-tokio",
"tds73",
], default-features = false }
tiberius = { version = "0.12", features = [ "chrono", "sql-browser-tokio", "tds73"], default-features = false }
timely = "0.20.0"
tokio = { version = "1.44.1", features = ["net"] }
tokio-stream = "0.1.17"
Expand Down
7 changes: 5 additions & 2 deletions src/sql-server-util/examples/cdc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
//! 4. Watch CDC events come streaming in as you change the table!

use futures::StreamExt;
use mz_sql_server_util::Client;
use mz_ore::future::InTask;
use mz_sql_server_util::config::TunnelConfig;
use mz_sql_server_util::{Client, Config};
use tracing_subscriber::EnvFilter;

#[tokio::main]
Expand All @@ -53,6 +55,7 @@ async fn main() -> Result<(), anyhow::Error> {
config.authentication(tiberius::AuthMethod::sql_server("SA", "password123?"));
config.trust_cert();

let config = Config::new(config, TunnelConfig::Direct, InTask::No);
let (mut client, connection) = Client::connect(config).await?;
mz_ore::task::spawn(|| "sql-server connection", async move { connection.await });
tracing::info!("connection successful!");
Expand All @@ -74,7 +77,7 @@ async fn main() -> Result<(), anyhow::Error> {
for instance in capture_instances {
cdc_handle = cdc_handle.start_lsn(instance, lsn);
}
// Get a stream of changes from the table with the provided LSN.
// Get a stream of changes from the table.
let changes = cdc_handle.into_stream();
let mut changes = std::pin::pin!(changes);
while let Some(change) = changes.next().await {
Expand Down
122 changes: 122 additions & 0 deletions src/sql-server-util/src/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright Materialize, Inc. and contributors. All rights reserved.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.

use anyhow::Context;
use mz_ore::future::InTask;
use mz_repr::CatalogItemId;
use mz_ssh_util::tunnel::{SshTimeoutConfig, SshTunnelConfig};
use mz_ssh_util::tunnel_manager::SshTunnelManager;
use proptest_derive::Arbitrary;
use serde::{Deserialize, Serialize};

/// Materialize specific configuration for SQL Server connections.
///
/// This wraps a [`tiberius::Config`] so we can configure a tunnel over SSH, AWS
/// PrivateLink, or various other techniques, via a [`TunnelConfig`]
#[derive(Clone, Debug)]
pub struct Config {
/// SQL Server specific configuration.
pub(crate) inner: tiberius::Config,
/// Details of how we'll connect to the upstream SQL Server instance.
pub(crate) tunnel: TunnelConfig,
/// If all of the I/O for this connection will be done in a separate task.
///
/// Note: This is used to prevent accidentally doing I/O in timely threads.
pub(crate) in_task: InTask,
}

impl Config {
pub fn new(inner: tiberius::Config, tunnel: TunnelConfig, in_task: InTask) -> Self {
Config {
inner,
tunnel,
in_task,
}
}

/// Create a new [`Config`] from an ActiveX Data Object.
///
/// Generally this is only used in test environments, see [`Config::new`]
/// for regular/production use cases.
pub fn from_ado_string(s: &str) -> Result<Self, anyhow::Error> {
let inner = tiberius::Config::from_ado_string(s).context("tiberius config")?;
Ok(Config {
inner,
tunnel: TunnelConfig::Direct,
in_task: InTask::No,
})
}
}

/// Configures an optional tunnel for use when connecting to a SQL Server database.
///
/// TODO(sql_server2): De-duplicate this with MySQL and Postgres sources.
#[derive(Debug, Clone)]
pub enum TunnelConfig {
/// No tunnelling.
Direct,
/// Establish a TCP connection to the database via an SSH tunnel.
Ssh {
/// Config for opening the SSH tunnel.
config: SshTunnelConfig,
/// Global manager of SSH tunnels.
manager: SshTunnelManager,
/// Timeout config for the SSH tunnel.
timeout: SshTimeoutConfig,
// TODO(sql_server1): Remove these fields by forking the `tiberius`
// crate and expose the `get_host` and `get_port` methods.
//
// See: <https://github.com/MaterializeInc/tiberius/blob/406ad2780d206617bd41689b1b638bddf4538f89/src/client/config.rs#L174-L191>
host: String,
port: u16,
},
/// Establish a TCP connection to the database via an AWS PrivateLink service.
AwsPrivatelink {
/// The ID of the AWS PrivateLink service.
connection_id: CatalogItemId,
},
}

/// Level of encryption to use with a SQL Server connection.
///
/// Mirror of [`tiberius::EncryptionLevel`] but we define our own so we can
/// implement traits like [`Serialize`].
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Arbitrary, Serialize, Deserialize)]
pub enum EncryptionLevel {
/// Do not use encryption at all.
None,
/// Only use encryption for the login procedure.
Login,
/// Use encryption for everything, if possible.
Preferred,
/// Require encryption, failing if not possible.
Required,
}

impl From<tiberius::EncryptionLevel> for EncryptionLevel {
fn from(value: tiberius::EncryptionLevel) -> Self {
match value {
tiberius::EncryptionLevel::NotSupported => EncryptionLevel::None,
tiberius::EncryptionLevel::Off => EncryptionLevel::Login,
tiberius::EncryptionLevel::On => EncryptionLevel::Preferred,
tiberius::EncryptionLevel::Required => EncryptionLevel::Required,
}
}
}

impl From<EncryptionLevel> for tiberius::EncryptionLevel {
fn from(value: EncryptionLevel) -> Self {
match value {
EncryptionLevel::None => tiberius::EncryptionLevel::NotSupported,
EncryptionLevel::Login => tiberius::EncryptionLevel::Off,
EncryptionLevel::Preferred => tiberius::EncryptionLevel::On,
EncryptionLevel::Required => tiberius::EncryptionLevel::Required,
}
}
}
87 changes: 73 additions & 14 deletions src/sql-server-util/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.

use std::any::Any;
use std::borrow::Cow;
use std::future::IntoFuture;
use std::pin::Pin;
Expand All @@ -25,11 +26,14 @@ use tokio::sync::oneshot;
use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};

pub mod cdc;
pub mod config;
pub mod desc;
pub mod inspect;

// Re-export tiberius' Config type since it's needed by our Client wrapper.
pub use tiberius::Config;
pub use config::Config;
pub use desc::{ProtoSqlServerColumnDesc, ProtoSqlServerTableDesc};

use crate::config::TunnelConfig;

/// Higher level wrapper around a [`tiberius::Client`] that models transaction
/// management like other database clients.
Expand All @@ -48,22 +52,73 @@ pub struct Client {
static_assertions::assert_not_impl_all!(Client: Clone);

impl Client {
pub async fn connect(config: tiberius::Config) -> Result<(Self, Connection), SqlServerError> {
let tcp = TcpStream::connect(config.get_addr()).await?;
/// Connect to the specified SQL Server instance, returning a [`Client`]
/// that can be used to query it and a [`Connection`] that must be polled
/// to send and receive results.
///
/// TODO(sql_server2): Maybe return a `ClientBuilder` here that implements
/// IntoFuture and does the default good thing of moving the `Connection`
/// into a tokio task? And a `.raw()` option that will instead return both
/// the Client and Connection for manual polling.
pub async fn connect(config: Config) -> Result<(Self, Connection), SqlServerError> {
// Setup our tunnelling and return any resources that need to be kept
// alive for the duration of the connection.
let (tcp, resources): (_, Option<Box<dyn Any + Send + Sync>>) = match &config.tunnel {
TunnelConfig::Direct => {
let tcp = TcpStream::connect(config.inner.get_addr())
.await
.context("direct")?;
(tcp, None)
}
TunnelConfig::Ssh {
config: ssh_config,
manager,
timeout,
host,
port,
} => {
// N.B. If this tunnel is dropped it will close so we need to
// keep it alive for the duration of the connection.
let tunnel = manager
.connect(ssh_config.clone(), host, *port, *timeout, config.in_task)
.await?;
let tcp = TcpStream::connect(tunnel.local_addr())
.await
.context("ssh tunnel")?;

(tcp, Some(Box::new(tunnel)))
}
TunnelConfig::AwsPrivatelink { connection_id: _ } => {
// TODO(sql_server1): Getting this right is tricky because
// there is some subtle logic with hostname validation.
return Err(SqlServerError::Generic(anyhow::anyhow!(
"TODO(sql_server1): Support PrivateLink connections"
)));
}
};

tcp.set_nodelay(true)?;
Self::connect_raw(config, tcp).await
Self::connect_raw(config, tcp, resources).await
}

pub async fn connect_raw(
config: tiberius::Config,
config: Config,
tcp: tokio::net::TcpStream,
resources: Option<Box<dyn Any + Send + Sync>>,
) -> Result<(Self, Connection), SqlServerError> {
let client = tiberius::Client::connect(config, tcp.compat_write())
.await
.context("connecting to SQL Server")?;
let client = tiberius::Client::connect(config.inner, tcp.compat_write()).await?;
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

Ok((Client { tx }, Connection { rx, client }))
// TODO(sql_server2): Add a lot more logging here like the Postgres and MySQL clients have.

Ok((
Client { tx },
Connection {
rx,
client,
_resources: resources,
},
))
}

/// Executes SQL statements in SQL Server, returning the number of rows effected.
Expand Down Expand Up @@ -93,7 +148,7 @@ impl Client {
.send(Request { tx, kind })
.context("sending request")?;

let response = rx.await.context("channel")?.context("execute")?;
let response = rx.await.context("channel")??;
match response {
Response::Execute { rows_affected } => Ok(rows_affected),
other @ Response::Rows(_) | other @ Response::RowStream { .. } => {
Expand Down Expand Up @@ -131,7 +186,7 @@ impl Client {
.send(Request { tx, kind })
.context("sending request")?;

let response = rx.await.context("channel")?.context("query")?;
let response = rx.await.context("channel")??;
match response {
Response::Rows(rows) => Ok(rows),
other @ Response::Execute { .. } | other @ Response::RowStream { .. } => Err(
Expand Down Expand Up @@ -203,7 +258,7 @@ impl Client {
.send(Request { tx, kind })
.context("sending request")?;

let response = rx.await.context("channel")?.context("simple_query")?;
let response = rx.await.context("channel")??;
match response {
Response::Rows(rows) => Ok(rows),
other @ Response::Execute { .. } | other @ Response::RowStream { .. } => Err(
Expand Down Expand Up @@ -443,14 +498,18 @@ enum RequestKind {
}

pub struct Connection {
/// Other end of the channel that [`Client`] holds.
rx: UnboundedReceiver<Request>,
/// Actual client that we use to send requests.
client: tiberius::Client<Compat<TcpStream>>,
/// Resources (e.g. SSH tunnel) that need to be held open for the life of this connection.
_resources: Option<Box<dyn Any + Send + Sync>>,
}

impl Connection {
async fn run(mut self) {
while let Some(Request { tx, kind }) = self.rx.recv().await {
tracing::debug!(?kind, "processing SQL Server query");
tracing::trace!(?kind, "processing SQL Server query");
let result = Connection::handle_request(&mut self.client, kind).await;
let (response, maybe_extra_work) = match result {
Ok((response, work)) => (Ok(response), work),
Expand Down
2 changes: 1 addition & 1 deletion src/testdrive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ serde_json = { version = "1.0.125", features = ["raw_value"] }
similar = "2.7.0"
tempfile = "3.19.0"
termcolor = "1.4.1"
tiberius = { version = "0.12", default-features = false }
tiberius = { version = "0.12", features = ["sql-browser-tokio", "tds73"], default-features = false }
time = "0.3.17"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
Expand Down