From f3400892bcc9a185c9654910d60075eb3c48b445 Mon Sep 17 00:00:00 2001 From: WatcherBox Date: Wed, 28 Feb 2024 12:00:03 +0100 Subject: [PATCH] add support for numpy array writes --- setup.cfg | 1 + simple_rpc/io.py | 29 ++++++++++++++++++++++++++-- simple_rpc/protocol.py | 41 ++++++++++++++++++++++++++++++++++++---- simple_rpc/simple_rpc.py | 9 ++++++++- 4 files changed, 73 insertions(+), 7 deletions(-) diff --git a/setup.cfg b/setup.cfg index f4d4663..6b7778b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -21,6 +21,7 @@ packages = find: install_requires = PyYAML>=5.3.1 pyserial>=3.4 + numpy>=1.19.4 [options.entry_points] console_scripts = diff --git a/simple_rpc/io.py b/simple_rpc/io.py index 9718a7b..7c70256 100644 --- a/simple_rpc/io.py +++ b/simple_rpc/io.py @@ -1,3 +1,4 @@ +import numpy as np from typing import Any, BinaryIO from struct import calcsize, pack, unpack @@ -43,9 +44,12 @@ def _write_basic( if basic_type == 's': stream.write(value + b'\0') return + elif isinstance(basic_type, np.ndarray): + print(f"writing array: {basic_type.itemsize}, {value.size}") + stream.write(value.tobytes()) + return full_type = (endianness + basic_type).encode('utf-8') - stream.write(pack(full_type, cast(basic_type)(value))) @@ -77,14 +81,29 @@ def read( :returns: Object of type {obj_type}. """ + + print(f"reading: {obj_type}") + + if isinstance(obj_type, np.ndarray): + # print(f"reading array: {size_t}, {obj_type}, {obj_type.itemsize}") + length = _read_basic(stream, endianness, size_t) + # print(f"length: {length}") + return np.frombuffer( + stream.read(length * obj_type.itemsize), obj_type.dtype) + if isinstance(obj_type, list): + # print(f"reading list: {obj_type}") length = _read_basic(stream, endianness, size_t) + # print(f"length: {length}") + return [ read(stream, endianness, size_t, item) for _ in range(length) for item in obj_type] + if isinstance(obj_type, tuple): return tuple( read(stream, endianness, size_t, item) for item in obj_type) + return _read_basic(stream, endianness, obj_type) @@ -103,9 +122,15 @@ def write( :arg obj_type: Type object. :arg obj: Object of type {obj_type}. """ + if isinstance(obj_type, list): + # print(f" size_t: {size_t}, len:{len(obj) // len(obj_type)}") _write_basic(stream, endianness, size_t, len(obj) // len(obj_type)) - if isinstance(obj_type, list) or isinstance(obj_type, tuple): + if isinstance(obj_type, np.ndarray): + # print(f"writing array: {size_t}, {obj.size}, {obj.dtype}, obj_tpye: {obj_type}") + _write_basic(stream, endianness, size_t, obj.size) + _write_basic(stream, endianness, obj_type, obj) + elif isinstance(obj_type, list) or isinstance(obj_type, tuple): for item_type, item in zip(obj_type * len(obj), obj): write(stream, endianness, size_t, item_type, item) else: diff --git a/simple_rpc/protocol.py b/simple_rpc/protocol.py index 30ed264..24ea9f3 100644 --- a/simple_rpc/protocol.py +++ b/simple_rpc/protocol.py @@ -1,7 +1,25 @@ from typing import Any, BinaryIO - +import numpy as np from .io import cast, read_byte_string - +from itertools import chain + + +dtype_map = { + 'b': np.int8, + 'B': np.uint8, + 'h': np.int16, + 'H': np.uint16, + 'i': np.int32, + 'I': np.uint32, + 'l': np.int32, + 'L': np.uint32, + 'q': np.int64, + 'Q': np.uint64, + 'f': np.float32, + 'd': np.float64, + '?': np.bool_, + 'c': np.byte # Note: 'c' in struct is a single byte; for strings, consider np.bytes_ or np.chararray. +} def _parse_type(type_str: bytes) -> Any: """Parse a type definition string. @@ -15,14 +33,20 @@ def _construct_type(tokens: tuple): for token in tokens: if token == b'[': - obj_type.append(_construct_type(tokens)) + next_token = next(tokens, None) + if next_token not in (b'[', b'(') and next_token is not None: + dtype = _get_dtype(next_token) + obj_type.append(np.array([], dtype=dtype)) + assert next(tokens, None) == b']', "Expected closing bracket" + else: + tokens = chain([next_token], tokens) + obj_type.extend(_construct_type(tokens)) elif token == b'(': obj_type.append(tuple(_construct_type(tokens))) elif token in (b')', b']'): break else: obj_type.append(token.decode()) - return obj_type obj_type = _construct_type((bytes([char]) for char in type_str)) @@ -33,6 +57,15 @@ def _construct_type(tokens: tuple): return '' return obj_type[0] +def _get_dtype(type_str: bytes) -> Any: + """Get the NumPy data type of a type definition string. + + :arg type_str: Type definition string. + + :returns: NumPy data type. + """ + return dtype_map.get(type_str, np.byte) + def _type_name(obj_type: Any) -> str: """Python type name of a C object type. diff --git a/simple_rpc/simple_rpc.py b/simple_rpc/simple_rpc.py index e42a7d3..50d2c11 100644 --- a/simple_rpc/simple_rpc.py +++ b/simple_rpc/simple_rpc.py @@ -1,5 +1,6 @@ from functools import wraps from time import sleep +import numpy as np from types import MethodType from typing import Any, TextIO @@ -86,6 +87,7 @@ def _write(self: object, obj_type: Any, obj: Any) -> None: :arg obj_type: Type of the parameter. :arg obj: Value of the parameter. """ + # print(f"write obj_type: {obj_type}") write( self._connection, self.device['endianness'], self.device['size_t'], obj_type, obj) @@ -100,6 +102,7 @@ def _read(self: object, obj_type: Any) -> Any: :returns: Return value. """ + # print(f"read obj_type: {obj_type}") return read( self._connection, self.device['endianness'], self.device['size_t'], obj_type) @@ -180,11 +183,15 @@ def call_method(self: object, name: str, *args: Any) -> Any: # Provide parameters (if any). if method['parameters']: + # print(f"method['name']: {method['name']}") + # print(f"method['parameters']: {method['parameters']}") for index, parameter in enumerate(method['parameters']): self._write(parameter['fmt'], args[index]) # Read return value (if any). - if method['return']['fmt']: + if method['return']['fmt'] or isinstance(method['return']['fmt'], np.ndarray) : + # print(f"method['name']: {method['name']}") + # print(f"method['return']['fmt']: {method['return']['fmt']}") return self._read(method['return']['fmt']) return None