diff --git a/docs/base.md b/docs/base.md index 4f5b212d..e395c448 100644 --- a/docs/base.md +++ b/docs/base.md @@ -119,10 +119,11 @@ const producer = new Producer({ serializers: stringSerializers, sasl: { mechanism: 'PLAIN', // Also SCRAM-SHA-256, SCRAM-SHA-512 and OAUTHBEARER are supported - // username, password or token can also be (async) functions returning a string + // username, password, token and oauthBearerExtensions can also be (async) functions returning a value username: 'username', // This is used from PLAIN, SCRAM-SHA-256 and SCRAM-SHA-512 password: 'password', // This is used from PLAIN, SCRAM-SHA-256 and SCRAM-SHA-512 token: 'token', // This is used from OAUTHBEARER + oauthBearerExtensions: {}, // This is used from OAUTHBEARER to add extension according to RFC 7628 // This is needed if your Kafka server returns a exitCode 0 when invalid credentials are sent and only stores // authentication information in auth bytes. // diff --git a/src/clients/base/options.ts b/src/clients/base/options.ts index c6665be9..b9ae484f 100644 --- a/src/clients/base/options.ts +++ b/src/clients/base/options.ts @@ -46,6 +46,9 @@ export const baseOptionsSchema = { username: { oneOf: [{ type: 'string' }, { function: true }] }, password: { oneOf: [{ type: 'string' }, { function: true }] }, token: { oneOf: [{ type: 'string' }, { function: true }] }, + oauthBearerExtensions: { + oneOf: [{ type: 'object', patternProperties: { '.+': { type: 'string' } } }, { function: true }] + }, authBytesValidator: { function: true } }, required: ['mechanism'], diff --git a/src/network/connection.ts b/src/network/connection.ts index 06e062fb..89e3834b 100644 --- a/src/network/connection.ts +++ b/src/network/connection.ts @@ -32,7 +32,7 @@ import { defaultCrypto, type ScramAlgorithm } from '../protocol/sasl/scram-sha.t import { Writer } from '../protocol/writer.ts' import { loggers } from '../utils.ts' -export type SASLCredentialProvider = () => string | Promise +export type SASLCredentialProvider = () => T | Promise export interface Broker { host: string port: number @@ -44,6 +44,7 @@ export interface SASLOptions { password?: string | SASLCredentialProvider token?: string | SASLCredentialProvider authBytesValidator?: (authBytes: Buffer, callback: CallbackWithPromise) => void + oauthBearerExtensions?: Record | SASLCredentialProvider> } export interface ConnectionOptions { @@ -384,7 +385,7 @@ export class Connection extends EventEmitter { this.#status = ConnectionStatuses.AUTHENTICATING } - const { mechanism, username, password, token } = this.#options.sasl! + const { mechanism, username, password, token, oauthBearerExtensions } = this.#options.sasl! if (!allowedSASLMechanisms.includes(mechanism)) { this.#onConnectionError( @@ -414,7 +415,7 @@ export class Connection extends EventEmitter { if (mechanism === SASLMechanisms.PLAIN) { saslPlain.authenticate(saslAuthenticateV2.api, this, username!, password!, callback) } else if (mechanism === SASLMechanisms.OAUTHBEARER) { - saslOAuthBearer.authenticate(saslAuthenticateV2.api, this, token!, callback) + saslOAuthBearer.authenticate(saslAuthenticateV2.api, this, token!, oauthBearerExtensions!, callback) } else { saslScramSha.authenticate( saslAuthenticateV2.api, diff --git a/src/protocol/sasl/credential-provider.ts b/src/protocol/sasl/credential-provider.ts index f8505783..7555983b 100644 --- a/src/protocol/sasl/credential-provider.ts +++ b/src/protocol/sasl/credential-provider.ts @@ -2,58 +2,54 @@ import { type Callback } from '../../apis/index.ts' import { AuthenticationError } from '../../errors.ts' import { type SASLCredentialProvider } from '../../network/connection.ts' -export function getCredential ( +export function getCredential ( label: string, - credentialOrProvider: string | SASLCredentialProvider, - callback: Callback + credentialOrProvider: T | SASLCredentialProvider, + callback: Callback ): void { - if (typeof credentialOrProvider === 'string') { - callback(null, credentialOrProvider) + if (typeof credentialOrProvider === 'undefined') { + callback(new AuthenticationError(`The ${label} should be a value or a function.`), undefined as unknown as T) return } else if (typeof credentialOrProvider !== 'function') { - callback(new AuthenticationError(`The ${label} should be a string or a function.`), undefined as unknown as string) + callback(null, credentialOrProvider) return } try { - const credential = credentialOrProvider() + const credential = (credentialOrProvider as SASLCredentialProvider)() - if (typeof credential === 'string') { - callback(null, credential) - return - } else if (typeof (credential as Promise)?.then !== 'function') { + if (credential == null) { callback( - new AuthenticationError(`The ${label} provider should return a string or a promise that resolves to a string.`), - undefined as unknown as string + new AuthenticationError(`The ${label} provider should return a string or a promise that resolves to a value.`), + undefined as unknown as T ) - + return + } else if (typeof (credential as Promise)?.then !== 'function') { + callback(null, credential as T) return } - credential - .then(token => { - if (typeof token !== 'string') { + ;(credential as Promise) + .then((result: T) => { + if (result == null) { process.nextTick( callback, - new AuthenticationError(`The ${label} provider should resolve to a string.`), + new AuthenticationError(`The ${label} provider should resolve to a value.`), undefined as unknown as string ) return } - process.nextTick(callback, null, token) + process.nextTick(callback, null, result) }) - .catch(error => { - process.nextTick( - callback, - new AuthenticationError(`The ${label} provider threw an error.`, { cause: error as Error }) - ) + .catch((error: Error) => { + process.nextTick(callback, new AuthenticationError(`The ${label} provider threw an error.`, { cause: error })) }) } catch (error) { callback( new AuthenticationError(`The ${label} provider threw an error.`, { cause: error as Error }), - undefined as unknown as string + undefined as unknown as T ) } } diff --git a/src/protocol/sasl/oauth-bearer.ts b/src/protocol/sasl/oauth-bearer.ts index 02e252fb..3139956c 100644 --- a/src/protocol/sasl/oauth-bearer.ts +++ b/src/protocol/sasl/oauth-bearer.ts @@ -33,17 +33,20 @@ export function authenticate ( authenticateAPI: SASLAuthenticationAPI, connection: Connection, tokenOrProvider: string | SASLCredentialProvider, + extensions: Record | SASLCredentialProvider>, callback: CallbackWithPromise ): void export function authenticate ( authenticateAPI: SASLAuthenticationAPI, connection: Connection, - tokenOrProvider: string | SASLCredentialProvider + tokenOrProvider: string | SASLCredentialProvider, + extensions: Record | SASLCredentialProvider> ): Promise export function authenticate ( authenticateAPI: SASLAuthenticationAPI, connection: Connection, tokenOrProvider: string | SASLCredentialProvider, + extensionsOrProvider: Record | SASLCredentialProvider>, callback?: CallbackWithPromise ): void | Promise { if (!callback) { @@ -55,7 +58,20 @@ export function authenticate ( return callback!(error, undefined as unknown as SaslAuthenticateResponse) } - authenticateAPI(connection, Buffer.from(`n,,\x01auth=Bearer ${token}\x01\x01`), callback!) + getCredential('SASL/OAUTHBEARER extensions', extensionsOrProvider ?? {}, (error, extensionsMap) => { + if (error) { + return callback!(error, undefined as unknown as SaslAuthenticateResponse) + } + + let extensions = '' + if (extensionsMap) { + for (const [key, value] of Object.entries(extensionsMap)) { + extensions += `\x01${key}=${value}` + } + } + + authenticateAPI(connection, Buffer.from(`n,,\x01auth=Bearer ${token}${extensions}\x01\x01`), callback!) + }) }) return callback[kCallbackPromise] diff --git a/test/clients/base/sasl-oauthbearer.test.ts b/test/clients/base/sasl-oauthbearer.test.ts index 7514261a..9016e1a6 100644 --- a/test/clients/base/sasl-oauthbearer.test.ts +++ b/test/clients/base/sasl-oauthbearer.test.ts @@ -53,6 +53,31 @@ test('should connect to SASL protected broker using SASL/OAUTHBEARER', async t = deepStrictEqual(metadata.brokers.get(1), saslBroker) }) +test('should connect to SASL protected broker using SASL/OAUTHBEARER and custom extensions', async t => { + const signSync = createSigner({ + algorithm: 'none', + iss: 'kafka', + aud: ['users'], + sub: 'admin', + expiresIn: '2h' + }) + const token = signSync({ scope: 'test' }) + + const base = new Base({ + clientId: 'clientId', + bootstrapBrokers: kafkaSaslBootstrapServers, + strict: true, + retries: 0, + sasl: { mechanism: SASLMechanisms.OAUTHBEARER, token, oauthBearerExtensions: { aaa: 'bbb', ccc: 'ddd' } } + }) + + t.after(() => base.close()) + + const metadata = await base.metadata({ topics: [] }) + + deepStrictEqual(metadata.brokers.get(1), saslBroker) +}) + test('should handle authentication errors', async t => { const base = new Base({ clientId: 'clientId', @@ -178,7 +203,8 @@ test('should handle async credential provider errors', async t => { retries: 0, sasl: { mechanism: 'OAUTHBEARER', - async token () { + token: 'token', + async oauthBearerExtensions () { throw new Error('Kaboom!') } } @@ -198,7 +224,7 @@ test('should handle async credential provider errors', async t => { const authenticationError = networkError.cause deepStrictEqual(authenticationError instanceof AuthenticationError, true) - deepStrictEqual(authenticationError.message, 'The SASL/OAUTHBEARER token provider threw an error.') + deepStrictEqual(authenticationError.message, 'The SASL/OAUTHBEARER extensions provider threw an error.') deepStrictEqual(authenticationError.cause.message, 'Kaboom!') } }) diff --git a/test/clients/consumer/consumer.test.ts b/test/clients/consumer/consumer.test.ts index d71bca37..fc77f981 100644 --- a/test/clients/consumer/consumer.test.ts +++ b/test/clients/consumer/consumer.test.ts @@ -3375,7 +3375,7 @@ test('#heartbeat should emit events when it was cancelled while waiting for API consumer.on('consumer:heartbeat:start', () => { mockMetadata(consumer, 1, null, null, (original, options, callback) => { - consumer.leaveGroup() + consumer.leaveGroup(false, () => {}) original(options, callback) }) }) @@ -3389,7 +3389,7 @@ test('#heartbeat should emit events when it was cancelled while waiting for Hear const consumer = createConsumer(t) mockAPI(consumer[kConnections], heartbeatV4.api.key, null, null, (original: Function, ...args: any[]) => { - consumer.leaveGroup() + consumer.leaveGroup(false, () => {}) original(...args) }) diff --git a/test/protocol/sasl/credential-provider.test.ts b/test/protocol/sasl/credential-provider.test.ts index 243f787b..b39918bd 100644 --- a/test/protocol/sasl/credential-provider.test.ts +++ b/test/protocol/sasl/credential-provider.test.ts @@ -14,12 +14,12 @@ test('getCredential with string credential', (_, done) => { }) test('getCredential with invalid credential type', (_, done) => { - const credential = 123 as any + const credential = undefined as any getCredential('username', credential, error => { const authenticationError = error as AuthenticationError deepStrictEqual(authenticationError instanceof AuthenticationError, true) - deepStrictEqual(authenticationError.message, 'The username should be a string or a function.') + deepStrictEqual(authenticationError.message, 'The username should be a value or a function.') deepStrictEqual(authenticationError.code, 'PLT_KFK_AUTHENTICATION') done() }) @@ -47,8 +47,8 @@ test('getCredential with function provider returning promise', (_, done) => { }) }) -test('getCredential with function provider returning non-string', (_, done) => { - const provider = () => 123 as any +test('getCredential with function provider returning non-value', (_, done) => { + const provider = () => undefined getCredential('username', provider, (error: Error | null) => { const authenticationError = error as AuthenticationError @@ -56,21 +56,21 @@ test('getCredential with function provider returning non-string', (_, done) => { deepStrictEqual(authenticationError instanceof AuthenticationError, true) deepStrictEqual( authenticationError.message, - 'The username provider should return a string or a promise that resolves to a string.' + 'The username provider should return a string or a promise that resolves to a value.' ) deepStrictEqual(authenticationError.code, 'PLT_KFK_AUTHENTICATION') done() }) }) -test('getCredential with promise provider resolving to non-string', (_, done) => { - const provider = () => Promise.resolve(123 as any) +test('getCredential with promise provider resolving to non-value', (_, done) => { + const provider = () => Promise.resolve(null as any) getCredential('password', provider, (error: Error | null) => { const authenticationError = error as AuthenticationError deepStrictEqual(authenticationError instanceof AuthenticationError, true) - deepStrictEqual(authenticationError.message, 'The password provider should resolve to a string.') + deepStrictEqual(authenticationError.message, 'The password provider should resolve to a value.') deepStrictEqual(authenticationError.code, 'PLT_KFK_AUTHENTICATION') done() }) diff --git a/test/protocol/sasl/oauthbearer.test.ts b/test/protocol/sasl/oauthbearer.test.ts index 151d04b5..6c2a8b3f 100644 --- a/test/protocol/sasl/oauthbearer.test.ts +++ b/test/protocol/sasl/oauthbearer.test.ts @@ -36,7 +36,8 @@ test('authenticate should create proper payload with the token - promise', async const result = await saslOAuthBearer.authenticate( api as unknown as saslAuthenticateV2.SASLAuthenticationAPI, mockConnection as any, - 'token' + 'token', + {} ) // Verify the function was called @@ -78,6 +79,7 @@ test('authenticate should create proper payload with the token - callback', (_, api as unknown as saslAuthenticateV2.SASLAuthenticationAPI, mockConnection as any, 'token', + {}, (error, result) => { ifError(error)