1717Server that accepts JSON RPC requests and returns JSON RPC replies/errors.
1818"""
1919import asyncio
20+ from dataclasses import dataclass
2021import logging
2122from asyncio import AbstractEventLoop
2223from typing import Callable
2324from datetime import datetime
2425
2526import zmq .asyncio
27+ from zmq .auth .asyncio import AsyncioAuthenticator
2628
2729from rpcq ._base import to_msgpack , from_msgpack
2830from rpcq ._spec import RPCSpec
2931from rpcq .messages import RPCRequest
3032
3133_log = logging .getLogger (__name__ )
3234
35+ # Required values for ZeroMQ curve authentication, in lieu of a TypedDict
36+ @dataclass
37+ class ServerAuthConfig :
38+ server_secret_key : bytes
39+ server_public_key : bytes
40+ client_keys_directory : str = ""
41+
3342
3443class Server :
3544 """
3645 Server that accepts JSON RPC calls through a socket.
3746 """
3847 def __init__ (self , rpc_spec : RPCSpec = None , announce_timing : bool = False ,
39- serialize_exceptions : bool = True ):
48+ serialize_exceptions : bool = True , auth_config : ServerAuthConfig = None ):
4049 """
4150 Create a server that will be linked to a socket
4251
@@ -50,6 +59,9 @@ def __init__(self, rpc_spec: RPCSpec = None, announce_timing: bool = False,
5059
5160 IMPORTANT NOTE: When set to False, this *almost definitely* means an unrecoverable
5261 crash, and the Server should then be _shutdown().
62+ :param auth_config: The configuration values necessary to enable Curve ZeroMQ authentication.
63+ These must be provided at instantiation, so they are available between the creation of the
64+ context and socket.
5365 """
5466 self .announce_timing = announce_timing
5567 self .serialize_exceptions = serialize_exceptions
@@ -58,6 +70,9 @@ def __init__(self, rpc_spec: RPCSpec = None, announce_timing: bool = False,
5870 self ._exit_handlers = []
5971
6072 self ._socket = None
73+ self ._auth_config = auth_config
74+ self ._authenticator = None
75+ self ._preloaded_keys = None
6176
6277 def rpc_handler (self , f : Callable ):
6378 """
@@ -77,6 +92,56 @@ def exit_handler(self, f: Callable):
7792 """
7893 self ._exit_handlers .append (f )
7994
95+ async def recv_multipart (self ):
96+ if self .auth_enabled :
97+ return await self .recv_multipart_with_auth ()
98+ else :
99+ # If auth is not enabled, then the client "User-Id" will not be retrieved from
100+ # the frames received, and we return None for that value.
101+ return (* await self ._socket .recv_multipart (), None )
102+
103+ async def recv_multipart_with_auth (self ) -> (bytes , list , bytes , ):
104+ """
105+ Code taken from pyzmq itself: https://github.com/zeromq/pyzmq/blob/master/zmq/sugar/socket.py#L449
106+ and then adapted to allow us to access the information in the frames.
107+
108+ When copy=True, only the contents of the messages are returned, rather than the messages themselves.
109+ The message is necessary to be able to fetch the "User-Id", which is the public key the client used
110+ to connect to this socket.
111+
112+ When using auth, knowing which client sent which message is important for authentication, and so
113+ we reimplement recv_multipart here, and return the client key as the final member of a tuple
114+ """
115+
116+ copy = False
117+ # Given a ROUTER socket, the first frame will be the sender's identity.
118+ # While, per the docs, this _should_ be retrievable from any frame with
119+ # frame.get('Identity'), in practice this value was always blank.
120+ # If we don't record the identity value, messages cannot be returned to
121+ # the correct client.
122+ identity_frame = await self ._socket .recv (0 , copy = copy , track = False )
123+ identity = identity_frame .bytes
124+
125+ # The client_id is the public key the client used to establish this connection
126+ # It can be retrieved from all frames after the first. Here, we assume it
127+ # is the same among all frames, and set it to the value from the first frame
128+ client_key = None
129+
130+ # After the identity frame, we assemble all further frames in a single buffer.
131+ parts = bytearray (b'' )
132+ while self ._socket .getsockopt (zmq .RCVMORE ):
133+ part = await self ._socket .recv (0 , copy = copy , track = False )
134+ data = part .bytes
135+ if client_key is None :
136+ client_key = part .get ('User-Id' )
137+ if not isinstance (client_key , bytes ) and client_key is not None :
138+ client_key = client_key .encode ('utf-8' )
139+ parts += data
140+
141+ _log .debug (f'Received authenticated request from client_key { client_key } ' )
142+
143+ return (identity , parts , client_key )
144+
80145 async def run_async (self , endpoint : str ):
81146 """
82147 Run server main task (asynchronously).
@@ -86,7 +151,7 @@ async def run_async(self, endpoint: str):
86151 self ._connect (endpoint )
87152
88153 # spawn an initial listen task
89- listen_task = asyncio .ensure_future (self ._socket . recv_multipart ())
154+ listen_task = asyncio .ensure_future (self .recv_multipart ())
90155 task_list = [listen_task ]
91156
92157 while True :
@@ -102,8 +167,12 @@ async def run_async(self, endpoint: str):
102167 # empty_frame may either be:
103168 # 1. a single null frame if the client is a REQ socket
104169 # 2. an empty list (ie. no frames) if the client is a DEALER socket
105- identity , * empty_frame , msg = done .result ()
170+ identity , * empty_frame , msg , client_key = done .result ()
106171 request = from_msgpack (msg )
172+ try :
173+ request .params ['client_key' ] = client_key
174+ except Exception as e :
175+ _log .error (f'Failed to attach client_key to request: { e } ' )
107176
108177 # spawn a processing task
109178 task_list .append (asyncio .ensure_future (
@@ -116,7 +185,7 @@ async def run_async(self, endpoint: str):
116185 raise e
117186 finally :
118187 # spawn a new listen task
119- listen_task = asyncio .ensure_future (self ._socket . recv_multipart ())
188+ listen_task = asyncio .ensure_future (self .recv_multipart ())
120189 task_list .append (listen_task )
121190 else :
122191 # if there's been an exception during processing, consider reraising it
@@ -172,6 +241,7 @@ def _connect(self, endpoint: str):
172241
173242 context = zmq .asyncio .Context ()
174243 self ._socket = context .socket (zmq .ROUTER )
244+ self .start_auth (context )
175245 self ._socket .bind (endpoint )
176246
177247 _log .info ("Starting server, listening on endpoint {}" .format (endpoint ))
@@ -200,3 +270,69 @@ async def _process_request(self, identity: bytes, empty_frame: list, request: RP
200270 else :
201271 raise e
202272
273+ @property
274+ def auth_configured (self ) -> bool :
275+ return (self ._auth_config is not None ) and isinstance (self ._auth_config .server_secret_key , bytes ) and isinstance (self ._auth_config .server_public_key , bytes )
276+
277+ @property
278+ def auth_enabled (self ) -> bool :
279+ return bool (self ._socket and self ._socket .curve_server )
280+
281+ def start_auth (self , context : zmq .Context ) -> bool :
282+ """
283+ Starts the ZMQ auth service thread, enabling authorization on all sockets within this context
284+ """
285+ if not self .auth_configured :
286+ return False
287+ self ._socket .curve_secretkey = self ._auth_config .server_secret_key
288+ self ._socket .curve_publickey = self ._auth_config .server_public_key
289+ self ._socket .curve_server = True
290+ self ._authenticator = AsyncioAuthenticator (context )
291+ if self ._preloaded_keys :
292+ self .set_client_keys (self ._preloaded_keys )
293+ else :
294+ self .load_client_keys_from_directory ()
295+ self ._authenticator .start ()
296+ return True
297+
298+ def stop_auth (self ) -> bool :
299+ """
300+ Stops the ZMQ auth service thread, allowing NULL authenticated clients (only) to connect to
301+ all threads within its context
302+ """
303+ if self ._authenticator :
304+ self ._socket .curve_server = False
305+ self ._authenticator .stop ()
306+ return True
307+ else :
308+ return False
309+
310+ def load_client_keys_from_directory (self , directory : str = None ) -> bool :
311+ """
312+ Reset authorized public key list to those present in the specified directory
313+ """
314+
315+ # The directory must either be specified at class creation or on each method call
316+ if directory is None :
317+ if not self ._auth_config .client_keys_directory :
318+ raise Exception ("Server Auth: Client keys directory required" )
319+ else :
320+ directory = self ._auth_config .client_keys_directory
321+ if not self .auth_configured :
322+ return False
323+ self ._authenticator .configure_curve (domain = '*' , location = self ._auth_config .client_keys_directory )
324+
325+ def set_client_keys (self , client_keys : [bytes ]):
326+ """
327+ Reset authorized public key list to this set. Avoids the disk read required by configure_curve,
328+ and allows keys to be managed externally.
329+
330+ In some cases, keys may be preloaded before the authenticator is started. In this case, we
331+ cache those preloaded keys
332+ """
333+ if self ._authenticator :
334+ _log .debug (f"Authorizer: Setting client keys to { client_keys } " )
335+ self ._authenticator .certs ['*' ] = {key : True for key in client_keys }
336+ else :
337+ _log .debug (f"Authorizer: Preloading client keys to { client_keys } " )
338+ self ._preloaded_keys = client_keys
0 commit comments