Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added workspace oauth source for UC #3152

Merged
merged 2 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions crates/catalog-unity/src/credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::client::token::{TemporaryToken, TokenCache};
// https://learn.microsoft.com/en-us/azure/databricks/dev-tools/api/latest/authentication

const DATABRICKS_RESOURCE_SCOPE: &str = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d";
const DATABRICKS_WORKSPACE_SCOPE: &str = "all-apis";
const CONTENT_TYPE_JSON: &str = "application/json";
const MSI_SECRET_ENV_KEY: &str = "IDENTITY_HEADER";
const MSI_API_VERSION: &str = "2019-08-01";
Expand Down Expand Up @@ -56,6 +57,58 @@ struct TokenResponse {
expires_in: u64,
}

/// The same thing as the azure oauth provider, but uses the databricks api to
/// get tokens directly from the workspace.
#[derive(Debug, Clone)]
pub struct WorkspaceOAuthProvider {
token_url: String,
client_id: String,
client_secret: String,
}

impl WorkspaceOAuthProvider {
pub fn new(
client_id: impl Into<String>,
client_secret: impl Into<String>,
workspace_host: impl Into<String>,
) -> Self {
Self {
token_url: format!("{}/oidc/v1/token", workspace_host.into()),
client_id: client_id.into(),
client_secret: client_secret.into(),
}
}
}

#[async_trait::async_trait]
impl TokenCredential for WorkspaceOAuthProvider {
async fn fetch_token(
&self,
client: &ClientWithMiddleware,
) -> Result<TemporaryToken<String>, UnityCatalogError> {
let response: TokenResponse = client
.request(Method::POST, &self.token_url)
.header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON))
.form(&[
("client_id", self.client_id.as_str()),
("client_secret", self.client_secret.as_str()),
("scope", DATABRICKS_WORKSPACE_SCOPE),
("grant_type", "client_credentials"),
])
.send()
.await
.map_err(UnityCatalogError::from)?
.json()
.await
.map_err(UnityCatalogError::from)?;

Ok(TemporaryToken {
token: response.access_token,
expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)),
})
}
}

/// Encapsulates the logic to perform an OAuth token challenge
#[derive(Debug, Clone)]
pub struct ClientSecretOAuthProvider {
Expand Down
25 changes: 21 additions & 4 deletions crates/catalog-unity/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ compile_error!(
use reqwest::header::{HeaderValue, InvalidHeaderValue, AUTHORIZATION};
use std::str::FromStr;

use crate::credential::{AzureCliCredential, ClientSecretOAuthProvider, CredentialProvider};
use crate::credential::{
AzureCliCredential, ClientSecretOAuthProvider, CredentialProvider, WorkspaceOAuthProvider,
};
use crate::models::{
ErrorResponse, GetSchemaResponse, GetTableResponse, ListCatalogsResponse, ListSchemasResponse,
ListTableSummariesResponse, TableTempCredentialsResponse, TemporaryTableCredentialsRequest,
Expand Down Expand Up @@ -240,9 +242,10 @@ impl FromStr for UnityCatalogConfigKey {
"use_azure_cli" | "unity_use_azure_cli" | "databricks_use_azure_cli" => {
Ok(UnityCatalogConfigKey::UseAzureCli)
}
"workspace_url" | "unity_workspace_url" | "databricks_workspace_url" => {
Ok(UnityCatalogConfigKey::WorkspaceUrl)
}
"workspace_url"
| "unity_workspace_url"
| "databricks_workspace_url"
| "databricks_host" => Ok(UnityCatalogConfigKey::WorkspaceUrl),
_ => Err(DataCatalogError::UnknownConfigKey {
catalog: "unity",
key: s.to_string(),
Expand Down Expand Up @@ -371,6 +374,7 @@ impl UnityCatalogBuilder {
if let Ok(config_key) =
UnityCatalogConfigKey::from_str(&key.to_ascii_lowercase())
{
tracing::debug!("Trying: {} with {}", key, value);
builder = builder.try_with_option(config_key, value).unwrap();
}
}
Expand Down Expand Up @@ -432,6 +436,19 @@ impl UnityCatalogBuilder {
return Some(CredentialProvider::BearerToken(token.clone()));
}

if let (Some(client_id), Some(client_secret), Some(workspace_host)) =
(&self.client_id, &self.client_secret, &self.workspace_url)
{
return Some(CredentialProvider::TokenCredential(
Default::default(),
Box::new(WorkspaceOAuthProvider::new(
client_id,
client_secret,
workspace_host,
)),
));
}

if let (Some(client_id), Some(client_secret), Some(authority_id)) = (
self.client_id.as_ref(),
self.client_secret.as_ref(),
Expand Down
5 changes: 2 additions & 3 deletions crates/core/src/delta_datafusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ use datafusion::execution::context::{SessionConfig, SessionContext, SessionState
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::execution::FunctionRegistry;
use datafusion::physical_optimizer::pruning::PruningPredicate;
use datafusion_common::project_schema;
use datafusion_common::scalar::ScalarValue;
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
use datafusion_common::{
Expand All @@ -59,7 +58,7 @@ use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::logical_plan::CreateExternalTable;
use datafusion_expr::utils::conjunction;
use datafusion_expr::{col, Expr, Extension, LogicalPlan, TableProviderFilterPushDown, Volatility};
use datafusion_physical_expr::{create_physical_expr, create_physical_exprs, PhysicalExpr};
use datafusion_physical_expr::{create_physical_expr, PhysicalExpr};
use datafusion_physical_plan::filter::FilterExec;
use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec};
Expand Down Expand Up @@ -886,7 +885,7 @@ impl TableProvider for LazyTableProvider {

async fn scan(
&self,
session: &dyn Session,
_session: &dyn Session,
projection: Option<&Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
Expand Down
1 change: 0 additions & 1 deletion crates/core/src/operations/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1274,7 +1274,6 @@ mod tests {
use crate::TableProperty;
use arrow_array::{Int32Array, StringArray, TimestampMicrosecondArray};
use arrow_schema::{DataType, Field, Schema as ArrowSchema, TimeUnit};
use datafusion::prelude::*;
use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
use itertools::Itertools;
use serde_json::{json, Value};
Expand Down
11 changes: 3 additions & 8 deletions crates/core/tests/integration_datafusion.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#![cfg(feature = "datafusion")]
use std::collections::{HashMap, HashSet};
use std::collections::HashSet;
use std::error::Error;
use std::path::PathBuf;
use std::sync::Arc;
Expand All @@ -26,7 +26,6 @@ use datafusion_proto::bytes::{
};
use deltalake_core::delta_datafusion::DeltaScan;
use deltalake_core::kernel::{DataType, MapType, PrimitiveType, StructField, StructType};
use deltalake_core::logstore::logstore_for;
use deltalake_core::operations::create::CreateBuilder;
use deltalake_core::protocol::SaveMode;
use deltalake_core::writer::{DeltaWriter, RecordBatchWriter};
Expand All @@ -41,14 +40,10 @@ use serial_test::serial;
use url::Url;

mod local {
use datafusion::{
common::stats::Precision, datasource::provider_as_source, prelude::DataFrame,
};
use datafusion::{common::stats::Precision, datasource::provider_as_source};
use datafusion_expr::LogicalPlanBuilder;
use deltalake_core::{
delta_datafusion::{DeltaLogicalCodec, DeltaScanConfigBuilder, DeltaTableProvider},
logstore::default_logstore,
writer::JsonWriter,
delta_datafusion::DeltaLogicalCodec, logstore::default_logstore, writer::JsonWriter,
};
use itertools::Itertools;
use object_store::local::LocalFileSystem;
Expand Down
Loading