Skip to content

Commit

Permalink
fix: prevent context from leaking with ClsGuard (#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
Papooch authored Mar 14, 2024
1 parent cec3015 commit 7026fdf
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 21 deletions.
1 change: 1 addition & 0 deletions packages/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export * from './lib/cls-initializers/cls.middleware';
export * from './lib/cls-initializers/cls.interceptor';
export * from './lib/cls-initializers/cls.guard';
export * from './lib/cls-initializers/use-cls.decorator';
export * from './lib/cls-initializers/utils/context-cls-store-map';
export * from './lib/cls.module';
export * from './lib/cls.service';
export * from './lib/cls.decorators';
Expand Down
9 changes: 8 additions & 1 deletion packages/core/src/lib/cls-initializers/cls.guard.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
import { ClsServiceManager } from '../cls-service-manager';
import { CLS_GUARD_OPTIONS, CLS_ID } from '../cls.constants';
import { ClsGuardOptions } from '../cls.options';
import { ContextClsStoreMap } from './utils/context-cls-store-map';

@Injectable()
export class ClsGuard implements CanActivate {
Expand All @@ -21,7 +22,13 @@ export class ClsGuard implements CanActivate {

async canActivate(context: ExecutionContext): Promise<boolean> {
const cls = ClsServiceManager.getClsService();
cls.enter({ ifNested: 'reuse' });
const existingStore = ContextClsStoreMap.get(context);
if (existingStore) {
cls.enter({ ifNested: 'reuse' });
} else {
cls.enterWith({});
ContextClsStoreMap.set(context, cls.get());
}
if (this.options.generateId) {
const id = await this.options.idGenerator?.(context);
cls.setIfUndefined<any>(CLS_ID, id);
Expand Down
23 changes: 3 additions & 20 deletions packages/core/src/lib/cls-initializers/cls.interceptor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { Observable } from 'rxjs';
import { ClsServiceManager } from '../cls-service-manager';
import { CLS_ID, CLS_INTERCEPTOR_OPTIONS } from '../cls.constants';
import { ClsInterceptorOptions } from '../cls.options';
import { ClsService } from '../cls.service';
import { ContextClsStoreMap } from './utils/context-cls-store-map';

@Injectable()
export class ClsInterceptor implements NestInterceptor {
Expand All @@ -24,7 +24,8 @@ export class ClsInterceptor implements NestInterceptor {

intercept(context: ExecutionContext, next: CallHandler): Observable<any> {
const cls = ClsServiceManager.getClsService<any>();
const clsStore = this.createOrReuseStore(context, cls);
const clsStore = ContextClsStoreMap.get(context) ?? {};
ContextClsStoreMap.set(context, clsStore);
return new Observable((subscriber) => {
cls.runWith(clsStore, async () => {
if (this.options.generateId) {
Expand Down Expand Up @@ -54,22 +55,4 @@ export class ClsInterceptor implements NestInterceptor {
});
});
}

createOrReuseStore(context: ExecutionContext, cls: ClsService) {
let store = (cls.isActive() && cls.get()) || {};
// NestJS triggers the interceptor for all queries within the same
// call individually, so each query would be wrapped in a different
// CLS context.
// The solution is to store the CLS store in the GQL context and re-use
// it each time the interceptor is triggered within the same request.
if ((context.getType() as string) == 'graphql') {
const gqlContext = context.getArgByIndex(2);
if (!gqlContext.__CLS_STORE__) {
gqlContext.__CLS_STORE__ = store;
} else {
store = gqlContext.__CLS_STORE__;
}
}
return store;
}
}
2 changes: 2 additions & 0 deletions packages/core/src/lib/cls-initializers/cls.middleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
CLS_RES,
} from '../cls.constants';
import { ClsMiddlewareOptions } from '../cls.options';
import { ContextClsStoreMap } from './utils/context-cls-store-map';

@Injectable()
export class ClsMiddleware implements NestMiddleware {
Expand All @@ -23,6 +24,7 @@ export class ClsMiddleware implements NestMiddleware {
const callback = async () => {
try {
this.options.useEnterWith && cls.enter();
ContextClsStoreMap.setByRaw(req, cls.get());
if (this.options.generateId) {
const id = await this.options.idGenerator?.(req);
cls.setIfUndefined<any>(CLS_ID, id);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import { ContextType, ExecutionContext } from '@nestjs/common';
import { ClsStore } from '../../cls.options';

/**
* This static class can be used to save the CLS store in a WeakMap based on the ExecutionContext
* or any object that is passed to the `setByRawContext` method.
*
* It is used internally by the `ClsMiddleware`, `ClsInterceptor` and `ClsGuard` to prevent
* instantiating the context multiple times for the same request.
*
* It can also be used as an escape hatch to retrieve the CLS store based on the ExecutionContext
* or the "raw context" when the ExecutionContext is not available.
* * For HTTP, it is the Request (@Req) object
* * For WS, it is the data object
* * For RPC (microservices), it is the RpcContext (@Ctx) object
* * For GraphQL, it is the GqlContext object
*/
export class ContextClsStoreMap {
private static readonly contextMap = new WeakMap<any, ClsStore>();
// eslint-disable-next-line @typescript-eslint/no-empty-function
private constructor() {}
static set(context: ExecutionContext, value: ClsStore): void {
const ctx = this.getContextByType(context);
this.contextMap.set(ctx, value);
}
static get(context: ExecutionContext): ClsStore | undefined {
const ctx = this.getContextByType(context);
return this.contextMap.get(ctx);
}
static setByRaw(ctx: any, value: ClsStore): void {
this.contextMap.set(ctx, value);
}
static getByRaw(ctx: any): ClsStore | undefined {
return this.contextMap.get(ctx);
}

private static getContextByType(context: ExecutionContext): any {
switch (context.getType() as ContextType | 'graphql') {
case 'http':
return context.switchToHttp().getRequest();
case 'ws':
return context.switchToWs().getData();
case 'rpc':
return context.switchToRpc().getContext();
case 'graphql':
// As per the GqlExecutionContext, the context is the second argument
return context.getArgByIndex(2);
default:
return {};
}
}
}

0 comments on commit 7026fdf

Please sign in to comment.