Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/base.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,11 @@ const producer = new Producer({
serializers: stringSerializers,
sasl: {
mechanism: 'PLAIN', // Also SCRAM-SHA-256, SCRAM-SHA-512 and OAUTHBEARER are supported
// username, password or token can also be (async) functions returning a string
// username, password, token and oauthBearerExtensions can also be (async) functions returning a value
username: 'username', // This is used from PLAIN, SCRAM-SHA-256 and SCRAM-SHA-512
password: 'password', // This is used from PLAIN, SCRAM-SHA-256 and SCRAM-SHA-512
token: 'token', // This is used from OAUTHBEARER
oauthBearerExtensions: {}, // This is used from OAUTHBEARER to add extension according to RFC 7628
// This is needed if your Kafka server returns a exitCode 0 when invalid credentials are sent and only stores
// authentication information in auth bytes.
//
Expand Down
3 changes: 3 additions & 0 deletions src/clients/base/options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ export const baseOptionsSchema = {
username: { oneOf: [{ type: 'string' }, { function: true }] },
password: { oneOf: [{ type: 'string' }, { function: true }] },
token: { oneOf: [{ type: 'string' }, { function: true }] },
oauthBearerExtensions: {
oneOf: [{ type: 'object', patternProperties: { '.+': { type: 'string' } } }, { function: true }]
},
authBytesValidator: { function: true }
},
required: ['mechanism'],
Expand Down
7 changes: 4 additions & 3 deletions src/network/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import { defaultCrypto, type ScramAlgorithm } from '../protocol/sasl/scram-sha.t
import { Writer } from '../protocol/writer.ts'
import { loggers } from '../utils.ts'

export type SASLCredentialProvider = () => string | Promise<string>
export type SASLCredentialProvider<T = string> = () => T | Promise<T>
export interface Broker {
host: string
port: number
Expand All @@ -44,6 +44,7 @@ export interface SASLOptions {
password?: string | SASLCredentialProvider
token?: string | SASLCredentialProvider
authBytesValidator?: (authBytes: Buffer, callback: CallbackWithPromise<Buffer>) => void
oauthBearerExtensions?: Record<string, string> | SASLCredentialProvider<Record<string, string>>
}

export interface ConnectionOptions {
Expand Down Expand Up @@ -384,7 +385,7 @@ export class Connection extends EventEmitter {
this.#status = ConnectionStatuses.AUTHENTICATING
}

const { mechanism, username, password, token } = this.#options.sasl!
const { mechanism, username, password, token, oauthBearerExtensions } = this.#options.sasl!

if (!allowedSASLMechanisms.includes(mechanism)) {
this.#onConnectionError(
Expand Down Expand Up @@ -414,7 +415,7 @@ export class Connection extends EventEmitter {
if (mechanism === SASLMechanisms.PLAIN) {
saslPlain.authenticate(saslAuthenticateV2.api, this, username!, password!, callback)
} else if (mechanism === SASLMechanisms.OAUTHBEARER) {
saslOAuthBearer.authenticate(saslAuthenticateV2.api, this, token!, callback)
saslOAuthBearer.authenticate(saslAuthenticateV2.api, this, token!, oauthBearerExtensions!, callback)
} else {
saslScramSha.authenticate(
saslAuthenticateV2.api,
Expand Down
46 changes: 21 additions & 25 deletions src/protocol/sasl/credential-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,58 +2,54 @@ import { type Callback } from '../../apis/index.ts'
import { AuthenticationError } from '../../errors.ts'
import { type SASLCredentialProvider } from '../../network/connection.ts'

export function getCredential (
export function getCredential<T> (
label: string,
credentialOrProvider: string | SASLCredentialProvider,
callback: Callback<string>
credentialOrProvider: T | SASLCredentialProvider<T>,
callback: Callback<T>
): void {
if (typeof credentialOrProvider === 'string') {
callback(null, credentialOrProvider)
if (typeof credentialOrProvider === 'undefined') {
callback(new AuthenticationError(`The ${label} should be a value or a function.`), undefined as unknown as T)
return
} else if (typeof credentialOrProvider !== 'function') {
callback(new AuthenticationError(`The ${label} should be a string or a function.`), undefined as unknown as string)
callback(null, credentialOrProvider)
return
}

try {
const credential = credentialOrProvider()
const credential = (credentialOrProvider as SASLCredentialProvider<T>)()

if (typeof credential === 'string') {
callback(null, credential)
return
} else if (typeof (credential as Promise<string>)?.then !== 'function') {
if (credential == null) {
callback(
new AuthenticationError(`The ${label} provider should return a string or a promise that resolves to a string.`),
undefined as unknown as string
new AuthenticationError(`The ${label} provider should return a string or a promise that resolves to a value.`),
undefined as unknown as T
)

return
} else if (typeof (credential as Promise<string>)?.then !== 'function') {
callback(null, credential as T)
return
}

credential
.then(token => {
if (typeof token !== 'string') {
;(credential as Promise<T>)
.then((result: T) => {
if (result == null) {
process.nextTick(
callback,
new AuthenticationError(`The ${label} provider should resolve to a string.`),
new AuthenticationError(`The ${label} provider should resolve to a value.`),
undefined as unknown as string
)

return
}

process.nextTick(callback, null, token)
process.nextTick(callback, null, result)
})
.catch(error => {
process.nextTick(
callback,
new AuthenticationError(`The ${label} provider threw an error.`, { cause: error as Error })
)
.catch((error: Error) => {
process.nextTick(callback, new AuthenticationError(`The ${label} provider threw an error.`, { cause: error }))
})
} catch (error) {
callback(
new AuthenticationError(`The ${label} provider threw an error.`, { cause: error as Error }),
undefined as unknown as string
undefined as unknown as T
)
}
}
20 changes: 18 additions & 2 deletions src/protocol/sasl/oauth-bearer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,20 @@ export function authenticate (
authenticateAPI: SASLAuthenticationAPI,
connection: Connection,
tokenOrProvider: string | SASLCredentialProvider,
extensions: Record<string, string> | SASLCredentialProvider<Record<string, string>>,
callback: CallbackWithPromise<SaslAuthenticateResponse>
): void
export function authenticate (
authenticateAPI: SASLAuthenticationAPI,
connection: Connection,
tokenOrProvider: string | SASLCredentialProvider
tokenOrProvider: string | SASLCredentialProvider,
extensions: Record<string, string> | SASLCredentialProvider<Record<string, string>>
): Promise<SaslAuthenticateResponse>
export function authenticate (
authenticateAPI: SASLAuthenticationAPI,
connection: Connection,
tokenOrProvider: string | SASLCredentialProvider,
extensionsOrProvider: Record<string, string> | SASLCredentialProvider<Record<string, string>>,
callback?: CallbackWithPromise<SaslAuthenticateResponse>
): void | Promise<SaslAuthenticateResponse> {
if (!callback) {
Expand All @@ -55,7 +58,20 @@ export function authenticate (
return callback!(error, undefined as unknown as SaslAuthenticateResponse)
}

authenticateAPI(connection, Buffer.from(`n,,\x01auth=Bearer ${token}\x01\x01`), callback!)
getCredential('SASL/OAUTHBEARER extensions', extensionsOrProvider ?? {}, (error, extensionsMap) => {
if (error) {
return callback!(error, undefined as unknown as SaslAuthenticateResponse)
}

let extensions = ''
if (extensionsMap) {
for (const [key, value] of Object.entries(extensionsMap)) {
extensions += `\x01${key}=${value}`
}
}

authenticateAPI(connection, Buffer.from(`n,,\x01auth=Bearer ${token}${extensions}\x01\x01`), callback!)
})
})

return callback[kCallbackPromise]
Expand Down
30 changes: 28 additions & 2 deletions test/clients/base/sasl-oauthbearer.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,31 @@ test('should connect to SASL protected broker using SASL/OAUTHBEARER', async t =
deepStrictEqual(metadata.brokers.get(1), saslBroker)
})

test('should connect to SASL protected broker using SASL/OAUTHBEARER and custom extensions', async t => {
const signSync = createSigner({
algorithm: 'none',
iss: 'kafka',
aud: ['users'],
sub: 'admin',
expiresIn: '2h'
})
const token = signSync({ scope: 'test' })

const base = new Base({
clientId: 'clientId',
bootstrapBrokers: kafkaSaslBootstrapServers,
strict: true,
retries: 0,
sasl: { mechanism: SASLMechanisms.OAUTHBEARER, token, oauthBearerExtensions: { aaa: 'bbb', ccc: 'ddd' } }
})

t.after(() => base.close())

const metadata = await base.metadata({ topics: [] })

deepStrictEqual(metadata.brokers.get(1), saslBroker)
})

test('should handle authentication errors', async t => {
const base = new Base({
clientId: 'clientId',
Expand Down Expand Up @@ -178,7 +203,8 @@ test('should handle async credential provider errors', async t => {
retries: 0,
sasl: {
mechanism: 'OAUTHBEARER',
async token () {
token: 'token',
async oauthBearerExtensions () {
throw new Error('Kaboom!')
}
}
Expand All @@ -198,7 +224,7 @@ test('should handle async credential provider errors', async t => {

const authenticationError = networkError.cause
deepStrictEqual(authenticationError instanceof AuthenticationError, true)
deepStrictEqual(authenticationError.message, 'The SASL/OAUTHBEARER token provider threw an error.')
deepStrictEqual(authenticationError.message, 'The SASL/OAUTHBEARER extensions provider threw an error.')
deepStrictEqual(authenticationError.cause.message, 'Kaboom!')
}
})
Expand Down
4 changes: 2 additions & 2 deletions test/clients/consumer/consumer.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3375,7 +3375,7 @@ test('#heartbeat should emit events when it was cancelled while waiting for API

consumer.on('consumer:heartbeat:start', () => {
mockMetadata(consumer, 1, null, null, (original, options, callback) => {
consumer.leaveGroup()
consumer.leaveGroup(false, () => {})
original(options, callback)
})
})
Expand All @@ -3389,7 +3389,7 @@ test('#heartbeat should emit events when it was cancelled while waiting for Hear
const consumer = createConsumer(t)

mockAPI(consumer[kConnections], heartbeatV4.api.key, null, null, (original: Function, ...args: any[]) => {
consumer.leaveGroup()
consumer.leaveGroup(false, () => {})
original(...args)
})

Expand Down
16 changes: 8 additions & 8 deletions test/protocol/sasl/credential-provider.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ test('getCredential with string credential', (_, done) => {
})

test('getCredential with invalid credential type', (_, done) => {
const credential = 123 as any
const credential = undefined as any

getCredential('username', credential, error => {
const authenticationError = error as AuthenticationError
deepStrictEqual(authenticationError instanceof AuthenticationError, true)
deepStrictEqual(authenticationError.message, 'The username should be a string or a function.')
deepStrictEqual(authenticationError.message, 'The username should be a value or a function.')
deepStrictEqual(authenticationError.code, 'PLT_KFK_AUTHENTICATION')
done()
})
Expand Down Expand Up @@ -47,30 +47,30 @@ test('getCredential with function provider returning promise', (_, done) => {
})
})

test('getCredential with function provider returning non-string', (_, done) => {
const provider = () => 123 as any
test('getCredential with function provider returning non-value', (_, done) => {
const provider = () => undefined

getCredential('username', provider, (error: Error | null) => {
const authenticationError = error as AuthenticationError

deepStrictEqual(authenticationError instanceof AuthenticationError, true)
deepStrictEqual(
authenticationError.message,
'The username provider should return a string or a promise that resolves to a string.'
'The username provider should return a string or a promise that resolves to a value.'
)
deepStrictEqual(authenticationError.code, 'PLT_KFK_AUTHENTICATION')
done()
})
})

test('getCredential with promise provider resolving to non-string', (_, done) => {
const provider = () => Promise.resolve(123 as any)
test('getCredential with promise provider resolving to non-value', (_, done) => {
const provider = () => Promise.resolve(null as any)

getCredential('password', provider, (error: Error | null) => {
const authenticationError = error as AuthenticationError

deepStrictEqual(authenticationError instanceof AuthenticationError, true)
deepStrictEqual(authenticationError.message, 'The password provider should resolve to a string.')
deepStrictEqual(authenticationError.message, 'The password provider should resolve to a value.')
deepStrictEqual(authenticationError.code, 'PLT_KFK_AUTHENTICATION')
done()
})
Expand Down
4 changes: 3 additions & 1 deletion test/protocol/sasl/oauthbearer.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ test('authenticate should create proper payload with the token - promise', async
const result = await saslOAuthBearer.authenticate(
api as unknown as saslAuthenticateV2.SASLAuthenticationAPI,
mockConnection as any,
'token'
'token',
{}
)

// Verify the function was called
Expand Down Expand Up @@ -78,6 +79,7 @@ test('authenticate should create proper payload with the token - callback', (_,
api as unknown as saslAuthenticateV2.SASLAuthenticationAPI,
mockConnection as any,
'token',
{},
(error, result) => {
ifError(error)

Expand Down