Skip to content

Commit c492329

Browse files
kalzoostylewarning
authored andcommitted
Add ZeroMQ Curve auth to Client and Server
1 parent 2119978 commit c492329

File tree

8 files changed

+519
-10
lines changed

8 files changed

+519
-10
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,6 @@ fabric.properties
179179

180180
### MyPi ###
181181
.mypy_cache
182+
183+
### VSCode ###
184+
.vscode/

rpcq/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from rpcq._client import Client
2-
from rpcq._server import Server
1+
from rpcq._client import Client, ClientAuthConfig
2+
from rpcq._server import Server, ServerAuthConfig
33
from rpcq.version import __version__

rpcq/_client.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import asyncio
1717
import logging
1818
import time
19+
from dataclasses import dataclass
1920
from typing import Dict, Union
2021
from warnings import warn
2122

@@ -29,22 +30,34 @@
2930
_log = logging.getLogger(__name__)
3031

3132

33+
# Required values for ZeroMQ curve authentication, in lieu of a TypedDict
34+
@dataclass
35+
class ClientAuthConfig:
36+
client_secret_key: bytes
37+
client_public_key: bytes
38+
server_public_key: bytes
39+
40+
3241
class Client:
3342
"""
3443
Client that executes methods on a remote server by sending JSON RPC requests to a socket.
3544
"""
36-
def __init__(self, endpoint: str, timeout: float = None):
45+
def __init__(self, endpoint: str, timeout: float = None, auth_config: ClientAuthConfig = None):
3746
"""
3847
Create a client that connects to a server at <endpoint>.
3948
4049
:param str endpoint: Socket endpoint, e.g. "tcp://localhost:1234"
4150
:param float timeout: Timeout in seconds for Server response, set to None to disable the timeout
51+
:param auth_config: The configuration values necessary to enable Curve ZeroMQ authentication.
52+
These must be provided at instantiation, so they are available when the socket is created.
4253
"""
4354
# TODO: leaving self.timeout for backwards compatibility; we should move towards using rpc_timeout only
4455
self.timeout = timeout
4556
self.rpc_timeout = timeout
4657
self.endpoint = endpoint
4758

59+
self._auth_config = auth_config
60+
4861
self._socket = self._connect_to_socket(zmq.Context(), endpoint)
4962
# The async socket can't be created yet because it's possible that the current event loop during Client creation
5063
# is different to the one used later to call a method, so we need to create the socket after the first call and
@@ -59,6 +72,7 @@ def __init__(self, endpoint: str, timeout: float = None):
5972
# Cache of replies so that different tasks can share results with each other
6073
self._replies: Dict[str, Union[RPCReply, RPCError]] = {}
6174

75+
6276
def __setattr__(self, key, value):
6377
"""
6478
Ensure rpc_timeout attribute gets update with timeout. Currently keeping self.timeout and
@@ -199,6 +213,7 @@ def _connect_to_socket(self, context: zmq.Context, endpoint: str):
199213
:return: Connected socket
200214
"""
201215
socket = context.socket(zmq.DEALER)
216+
self.enable_auth(socket)
202217
socket.connect(endpoint)
203218
socket.setsockopt(zmq.LINGER, 0)
204219
_log.debug("Client connected to endpoint %s", self.endpoint)
@@ -213,3 +228,19 @@ def _async_socket(self):
213228
self._async_socket_cache = self._connect_to_socket(zmq.asyncio.Context(), self.endpoint)
214229

215230
return self._async_socket_cache
231+
232+
@property
233+
def auth_configured(self) -> bool:
234+
return self._auth_config is not None
235+
236+
def enable_auth(self, socket=None) -> bool:
237+
"""
238+
Enables Curve ZeroMQ Authentication if the necessary configuration is present
239+
"""
240+
if not self.auth_configured:
241+
return False
242+
socket.curve_secretkey = self._auth_config.client_secret_key
243+
socket.curve_publickey = self._auth_config.client_public_key
244+
socket.curve_serverkey = self._auth_config.server_public_key
245+
return True
246+

rpcq/_server.py

Lines changed: 140 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,35 @@
1717
Server that accepts JSON RPC requests and returns JSON RPC replies/errors.
1818
"""
1919
import asyncio
20+
from dataclasses import dataclass
2021
import logging
2122
from asyncio import AbstractEventLoop
2223
from typing import Callable
2324
from datetime import datetime
2425

2526
import zmq.asyncio
27+
from zmq.auth.asyncio import AsyncioAuthenticator
2628

2729
from rpcq._base import to_msgpack, from_msgpack
2830
from rpcq._spec import RPCSpec
2931
from rpcq.messages import RPCRequest
3032

3133
_log = logging.getLogger(__name__)
3234

35+
# Required values for ZeroMQ curve authentication, in lieu of a TypedDict
36+
@dataclass
37+
class ServerAuthConfig:
38+
server_secret_key: bytes
39+
server_public_key: bytes
40+
client_keys_directory: str = ""
41+
3342

3443
class Server:
3544
"""
3645
Server that accepts JSON RPC calls through a socket.
3746
"""
3847
def __init__(self, rpc_spec: RPCSpec = None, announce_timing: bool = False,
39-
serialize_exceptions: bool = True):
48+
serialize_exceptions: bool = True, auth_config: ServerAuthConfig = None):
4049
"""
4150
Create a server that will be linked to a socket
4251
@@ -50,6 +59,9 @@ def __init__(self, rpc_spec: RPCSpec = None, announce_timing: bool = False,
5059
5160
IMPORTANT NOTE: When set to False, this *almost definitely* means an unrecoverable
5261
crash, and the Server should then be _shutdown().
62+
:param auth_config: The configuration values necessary to enable Curve ZeroMQ authentication.
63+
These must be provided at instantiation, so they are available between the creation of the
64+
context and socket.
5365
"""
5466
self.announce_timing = announce_timing
5567
self.serialize_exceptions = serialize_exceptions
@@ -58,6 +70,9 @@ def __init__(self, rpc_spec: RPCSpec = None, announce_timing: bool = False,
5870
self._exit_handlers = []
5971

6072
self._socket = None
73+
self._auth_config = auth_config
74+
self._authenticator = None
75+
self._preloaded_keys = None
6176

6277
def rpc_handler(self, f: Callable):
6378
"""
@@ -77,6 +92,56 @@ def exit_handler(self, f: Callable):
7792
"""
7893
self._exit_handlers.append(f)
7994

95+
async def recv_multipart(self):
96+
if self.auth_enabled:
97+
return await self.recv_multipart_with_auth()
98+
else:
99+
# If auth is not enabled, then the client "User-Id" will not be retrieved from
100+
# the frames received, and we return None for that value.
101+
return (*await self._socket.recv_multipart(), None)
102+
103+
async def recv_multipart_with_auth(self) -> (bytes, list, bytes, ):
104+
"""
105+
Code taken from pyzmq itself: https://github.com/zeromq/pyzmq/blob/master/zmq/sugar/socket.py#L449
106+
and then adapted to allow us to access the information in the frames.
107+
108+
When copy=True, only the contents of the messages are returned, rather than the messages themselves.
109+
The message is necessary to be able to fetch the "User-Id", which is the public key the client used
110+
to connect to this socket.
111+
112+
When using auth, knowing which client sent which message is important for authentication, and so
113+
we reimplement recv_multipart here, and return the client key as the final member of a tuple
114+
"""
115+
116+
copy = False
117+
# Given a ROUTER socket, the first frame will be the sender's identity.
118+
# While, per the docs, this _should_ be retrievable from any frame with
119+
# frame.get('Identity'), in practice this value was always blank.
120+
# If we don't record the identity value, messages cannot be returned to
121+
# the correct client.
122+
identity_frame = await self._socket.recv(0, copy=copy, track=False)
123+
identity = identity_frame.bytes
124+
125+
# The client_id is the public key the client used to establish this connection
126+
# It can be retrieved from all frames after the first. Here, we assume it
127+
# is the same among all frames, and set it to the value from the first frame
128+
client_key = None
129+
130+
# After the identity frame, we assemble all further frames in a single buffer.
131+
parts = bytearray(b'')
132+
while self._socket.getsockopt(zmq.RCVMORE):
133+
part = await self._socket.recv(0, copy=copy, track=False)
134+
data = part.bytes
135+
if client_key is None:
136+
client_key = part.get('User-Id')
137+
if not isinstance(client_key, bytes) and client_key is not None:
138+
client_key = client_key.encode('utf-8')
139+
parts += data
140+
141+
_log.debug(f'Received authenticated request from client_key {client_key}')
142+
143+
return (identity, parts, client_key)
144+
80145
async def run_async(self, endpoint: str):
81146
"""
82147
Run server main task (asynchronously).
@@ -86,7 +151,7 @@ async def run_async(self, endpoint: str):
86151
self._connect(endpoint)
87152

88153
# spawn an initial listen task
89-
listen_task = asyncio.ensure_future(self._socket.recv_multipart())
154+
listen_task = asyncio.ensure_future(self.recv_multipart())
90155
task_list = [listen_task]
91156

92157
while True:
@@ -102,8 +167,12 @@ async def run_async(self, endpoint: str):
102167
# empty_frame may either be:
103168
# 1. a single null frame if the client is a REQ socket
104169
# 2. an empty list (ie. no frames) if the client is a DEALER socket
105-
identity, *empty_frame, msg = done.result()
170+
identity, *empty_frame, msg, client_key = done.result()
106171
request = from_msgpack(msg)
172+
try:
173+
request.params['client_key'] = client_key
174+
except Exception as e:
175+
_log.error(f'Failed to attach client_key to request: {e}')
107176

108177
# spawn a processing task
109178
task_list.append(asyncio.ensure_future(
@@ -116,7 +185,7 @@ async def run_async(self, endpoint: str):
116185
raise e
117186
finally:
118187
# spawn a new listen task
119-
listen_task = asyncio.ensure_future(self._socket.recv_multipart())
188+
listen_task = asyncio.ensure_future(self.recv_multipart())
120189
task_list.append(listen_task)
121190
else:
122191
# if there's been an exception during processing, consider reraising it
@@ -172,6 +241,7 @@ def _connect(self, endpoint: str):
172241

173242
context = zmq.asyncio.Context()
174243
self._socket = context.socket(zmq.ROUTER)
244+
self.start_auth(context)
175245
self._socket.bind(endpoint)
176246

177247
_log.info("Starting server, listening on endpoint {}".format(endpoint))
@@ -200,3 +270,69 @@ async def _process_request(self, identity: bytes, empty_frame: list, request: RP
200270
else:
201271
raise e
202272

273+
@property
274+
def auth_configured(self) -> bool:
275+
return (self._auth_config is not None) and isinstance(self._auth_config.server_secret_key, bytes) and isinstance(self._auth_config.server_public_key, bytes)
276+
277+
@property
278+
def auth_enabled(self) -> bool:
279+
return bool(self._socket and self._socket.curve_server)
280+
281+
def start_auth(self, context: zmq.Context) -> bool:
282+
"""
283+
Starts the ZMQ auth service thread, enabling authorization on all sockets within this context
284+
"""
285+
if not self.auth_configured:
286+
return False
287+
self._socket.curve_secretkey = self._auth_config.server_secret_key
288+
self._socket.curve_publickey = self._auth_config.server_public_key
289+
self._socket.curve_server = True
290+
self._authenticator = AsyncioAuthenticator(context)
291+
if self._preloaded_keys:
292+
self.set_client_keys(self._preloaded_keys)
293+
else:
294+
self.load_client_keys_from_directory()
295+
self._authenticator.start()
296+
return True
297+
298+
def stop_auth(self) -> bool:
299+
"""
300+
Stops the ZMQ auth service thread, allowing NULL authenticated clients (only) to connect to
301+
all threads within its context
302+
"""
303+
if self._authenticator:
304+
self._socket.curve_server = False
305+
self._authenticator.stop()
306+
return True
307+
else:
308+
return False
309+
310+
def load_client_keys_from_directory(self, directory: str = None) -> bool:
311+
"""
312+
Reset authorized public key list to those present in the specified directory
313+
"""
314+
315+
# The directory must either be specified at class creation or on each method call
316+
if directory is None:
317+
if not self._auth_config.client_keys_directory:
318+
raise Exception("Server Auth: Client keys directory required")
319+
else:
320+
directory = self._auth_config.client_keys_directory
321+
if not self.auth_configured:
322+
return False
323+
self._authenticator.configure_curve(domain='*', location=self._auth_config.client_keys_directory)
324+
325+
def set_client_keys(self, client_keys: [bytes]):
326+
"""
327+
Reset authorized public key list to this set. Avoids the disk read required by configure_curve,
328+
and allows keys to be managed externally.
329+
330+
In some cases, keys may be preloaded before the authenticator is started. In this case, we
331+
cache those preloaded keys
332+
"""
333+
if self._authenticator:
334+
_log.debug(f"Authorizer: Setting client keys to {client_keys}")
335+
self._authenticator.certs['*'] = {key: True for key in client_keys}
336+
else:
337+
_log.debug(f"Authorizer: Preloading client keys to {client_keys}")
338+
self._preloaded_keys = client_keys

rpcq/_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import traceback
2222
from typing import Union
2323

24-
from rpcq._utils import rpc_reply, rpc_error, RPCMethodError, get_input, \
24+
from rpcq._utils import rpc_reply, rpc_error, RPCMethodError, get_input, get_safe_input, \
2525
catch_warnings
2626
from rpcq.messages import RPCRequest, RPCReply, RPCError
2727

@@ -114,7 +114,7 @@ async def run_handler(self, request: RPCRequest) -> Union[RPCReply, RPCError]:
114114

115115
try:
116116
# Run RPC and get result
117-
args, kwargs = get_input(request.params)
117+
args, kwargs = get_safe_input(request.params, rpc_handler)
118118
result = rpc_handler(*args, **kwargs)
119119

120120
if asyncio.iscoroutine(result):

rpcq/_utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
"""Utils for message passing"""
1717
import uuid
1818
import warnings
19-
from typing import Optional, Tuple, Union, List, Any
19+
from inspect import signature
20+
from typing import Callable, Optional, Tuple, Union, List, Any
2021

2122
import rpcq.messages
2223

@@ -103,6 +104,23 @@ def get_input(params: Union[dict, list]) -> Tuple[list, dict]:
103104
return args, kwargs
104105

105106

107+
def get_safe_input(params: Union[dict, list], handler: Callable) -> Tuple[list, dict]:
108+
"""
109+
Get positional or keyword arguments from JSON RPC params,
110+
filtering out kwargs that aren't in the function signature
111+
112+
:param params: Parameters passed through JSON RPC
113+
:param handler: RPC handler function
114+
:return: args, kwargs
115+
"""
116+
args, kwargs = get_input(params)
117+
118+
handler_signature = signature(handler)
119+
kwargs = { k: v for k, v in kwargs.items() if k in handler_signature.parameters }
120+
121+
return args, kwargs
122+
123+
106124
class RPCErrorError(IOError):
107125
"""JSON RPC error that is raised by a Client when it receives an RPCError message"""
108126

rpcq/messages.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ class RPCRequest(Message):
7373
jsonrpc: str = "2.0"
7474
"""The JSONRPC version."""
7575

76+
client_key: str = ""
77+
"""The ZeroMQ CURVE public key used to make the request. Blank if no key is used"""
78+
7679

7780
@dataclass(eq=False, repr=False)
7881
class RPCWarning(Message):

0 commit comments

Comments
 (0)