Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(core): support primitive values in websocket payload #172

Merged
merged 3 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/docs/05_considerations/02_compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,13 @@ Below are listed transports with which it is confirmed to work:
### Websockets

_Websocket Gateways_ don't respect globally bound enhancers, therefore it is required to bind the `ClsGuard` or `ClsInterceptor` manually on the `WebsocketGateway`. Special care is also needed for the `handleConnection` method (See [#8](https://github.com/Papooch/nestjs-cls/issues/8))

```ts
@WebSocketGateway()
// highlight-start
@UseInterceptors(ClsInterceptor)
// highlight-end
export class Gateway {
// ...
}
```
9 changes: 9 additions & 0 deletions docs/docusaurus.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ const config = {
footer: {
style: 'dark',
links: [
{
title: 'Docs',
items: [
{
label: 'NestJS Documentation',
href: 'https://docs.nestjs.com/',
},
],
},
{
title: 'Community',
items: [
Expand Down
7 changes: 6 additions & 1 deletion packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,28 @@
"@nestjs/mercurius": "^12.1.1",
"@nestjs/platform-express": "^10.3.7",
"@nestjs/platform-fastify": "^10.3.7",
"@nestjs/platform-ws": "^10.3.10",
"@nestjs/schematics": "^10.0.1",
"@nestjs/testing": "^10.3.7",
"@nestjs/websockets": "^10.3.10",
"@types/express": "^4.17.13",
"@types/jest": "^28.1.2",
"@types/node": "^18.0.0",
"@types/supertest": "^2.0.12",
"@types/ws": "^8",
"graphql": "^16.5.0",
"jest": "^29.7.0",
"mercurius": "^13.0.0",
"reflect-metadata": "^0.1.13",
"rimraf": "^3.0.2",
"rxjs": "^7.5.5",
"supertest": "^6.2.3",
"superwstest": "^2.0.4",
"ts-jest": "^29.1.2",
"ts-loader": "^9.3.0",
"ts-node": "^10.8.1",
"tsconfig-paths": "^4.0.0",
"typescript": "5.0"
"typescript": "5.0",
"ws": "^8.18.0"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ export class ContextClsStoreMap {
case 'http':
return context.switchToHttp().getRequest();
case 'ws':
return context.switchToWs().getData();
return context.switchToWs();
case 'rpc':
return context.switchToRpc().getContext();
case 'graphql':
Expand Down
39 changes: 39 additions & 0 deletions packages/core/test/websockets/expect-ids-websockets.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import { INestApplication } from '@nestjs/common';
import request from 'superwstest';

export const expectOkIdsWs =
(path = '', event = 'hello', data = {}) =>
async (app: INestApplication) =>
request(await app.getUrl())
.ws(path)
.sendJson({
event,
data,
})
.expectJson((body) => {
const id = body.fromGuard ?? body.fromInterceptor;
expect(body.fromInterceptor).toEqual(id);
expect(body.fromInterceptorAfter).toEqual(id);
expect(body.fromGateway).toEqual(id);
expect(body.fromService).toEqual(id);
expect(body.data).toEqual(data);
})
.close();

export const expectErrorIdsWs =
(path = '', event = 'error', data = {}) =>
(app: INestApplication) =>
request(app.getHttpServer())
.ws(path)
.sendJson({
event,
data,
})
.expectJson((body) => {
const id = body.fromGuard ?? body.fromInterceptor;
expect(body.fromInterceptor).toEqual(id);
expect(body.fromGateway).toEqual(id);
expect(body.fromService).toEqual(id);
expect(body.fromFilter).toEqual(id);
})
.close();
18 changes: 18 additions & 0 deletions packages/core/test/websockets/test-ws.filter.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { ArgumentsHost, Catch, ExceptionFilter } from '@nestjs/common';
import { WebSocket } from 'ws';
import { ClsService } from '../../src';
import { TestException } from '../common/test.exception';

@Catch(TestException)
export class TestWsExceptionFilter implements ExceptionFilter {
constructor(private readonly cls: ClsService) {}

catch(exception: TestException, host: ArgumentsHost) {
const client = host.switchToWs().getClient<WebSocket>();
const response = {
...exception.response,
fromFilter: this.cls.getId(),
};
client.send(JSON.stringify(response));
}
}
55 changes: 55 additions & 0 deletions packages/core/test/websockets/websockets.app.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import {
Injectable,
UseFilters,
UseGuards,
UseInterceptors,
} from '@nestjs/common';
import { MessageBody, SubscribeMessage } from '@nestjs/websockets';

import { ClsService } from '../../src';
import { TestException } from '../common/test.exception';
import { TestGuard } from '../common/test.guard';
import { TestInterceptor } from '../common/test.interceptor';
import { TestWsExceptionFilter } from './test-ws.filter';

@Injectable()
export class TestWebsocketService {
constructor(private readonly cls: ClsService) {}

async hello(data?: unknown) {
return {
fromGuard: this.cls.get('FROM_GUARD'),
fromInterceptor: this.cls.get('FROM_INTERCEPTOR'),
fromInterceptorAfter: this.cls.get('FROM_INTERCEPTOR'),
fromGateway: this.cls.get('FROM_GATEWAY'),
fromService: this.cls.getId(),
data,
};
}
}

@Injectable()
@UseFilters(TestWsExceptionFilter)
export class TestWebsocketGateway {
constructor(
private readonly service: TestWebsocketService,
private readonly cls: ClsService,
) {}

@SubscribeMessage('hello')
@UseGuards(TestGuard)
@UseInterceptors(TestInterceptor)
hello(@MessageBody() data: unknown) {
this.cls.set('FROM_GATEWAY', this.cls.getId());
return this.service.hello(data);
}

@UseInterceptors(TestInterceptor)
@UseGuards(TestGuard)
@SubscribeMessage('error')
async error(@MessageBody() data: unknown) {
this.cls.set('FROM_GATEWAY', this.cls.getId());
const response = await this.service.hello(data);
throw new TestException(response);
}
}
76 changes: 76 additions & 0 deletions packages/core/test/websockets/ws.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import {
INestApplication,
Module,
UseGuards,
UseInterceptors,
} from '@nestjs/common';
import { WsAdapter } from '@nestjs/platform-ws';
import { Test } from '@nestjs/testing';
import { WebSocketGateway } from '@nestjs/websockets';
import { ClsGuard, ClsInterceptor, ClsModule } from '../../src';
import { expectErrorIdsWs, expectOkIdsWs } from './expect-ids-websockets';
import { TestWebsocketGateway, TestWebsocketService } from './websockets.app';

describe('Websockets - WS', () => {
let app: INestApplication;

@WebSocketGateway({ path: 'interceptor' })
@UseInterceptors(ClsInterceptor)
class WebsocketGatewayWithClsInterceptor extends TestWebsocketGateway {}

@WebSocketGateway({ path: 'guard' })
@UseGuards(ClsGuard)
class WebsocketGatewayWithClsGuard extends TestWebsocketGateway {}

@Module({
imports: [
ClsModule.forRoot({
interceptor: { mount: false, generateId: true },
guard: { mount: false, generateId: true },
}),
],
providers: [
TestWebsocketService,
WebsocketGatewayWithClsInterceptor,
WebsocketGatewayWithClsGuard,
],
})
class TestWebsocketModule {}

beforeAll(async () => {
const moduleFixture = await Test.createTestingModule({
imports: [TestWebsocketModule],
}).compile();
app = moduleFixture.createNestApplication();
app.useWebSocketAdapter(new WsAdapter(app));
await app.listen(3125);
});

afterAll(async () => {
await app?.close();
});

describe.each(['guard', 'interceptor'])(
'when using an %s to initialize the context',
(name) => {
const path = '/' + name;

it.each([
['ok', 'object', expectOkIdsWs(path, 'hello', { value: 12 })],
['ok', 'primitive', expectOkIdsWs(path, 'hello', 'primitive')],
[
'error',
'object',
expectErrorIdsWs(path, 'error', { value: 12 }),
],
[
'error',
'primitive',
expectErrorIdsWs(path, 'error', 'primitive'),
],
])('works with %s response and %s payload', async (_, __, func) => {
await func(app);
});
},
);
});
Loading
Loading