diff --git a/python/fusion_engine_client/parsers/decoder.py b/python/fusion_engine_client/parsers/decoder.py index 83bf0bff..b1e588e4 100644 --- a/python/fusion_engine_client/parsers/decoder.py +++ b/python/fusion_engine_client/parsers/decoder.py @@ -261,16 +261,12 @@ def on_data(self, data: Union[bytes, int]) -> List[Union[MessageTuple, MessageWi contents = cls() try: contents.unpack(buffer=self._buffer, offset=MessageHeader.calcsize()) + _logger.debug('Decoded FusionEngine message %s.', repr(contents)) except Exception as e: # unpack() may fail if the payload length in the header differs from the length expected by the # class, the payload contains an illegal value, etc. _logger.error('Error deserializing message %s payload: %s', self._header.get_type_string(), e) - self._header = None - self._msg_len = 0 - self._buffer.pop(0) - self._bytes_processed += 1 - continue - _logger.debug('Decoded FusionEngine message %s.', repr(contents)) + contents = bytes(self._buffer[MessageHeader.calcsize():self._msg_len]) # If cls is None, we don't have a class for the message type. Return a copy of the payload bytes. else: contents = bytes(self._buffer[MessageHeader.calcsize():self._msg_len]) diff --git a/python/fusion_engine_client/utils/construct_utils.py b/python/fusion_engine_client/utils/construct_utils.py index f3b326d6..4242b3d4 100644 --- a/python/fusion_engine_client/utils/construct_utils.py +++ b/python/fusion_engine_client/utils/construct_utils.py @@ -3,7 +3,7 @@ import re from typing import Optional -from construct import Adapter, Array, Enum, Float64l, Float32l, FormatField, Struct +from construct import Adapter, Array, Container, Enum, Float64l, Float32l, FormatField, Struct import numpy as np from .enum_utils import IntEnum @@ -117,8 +117,12 @@ def make_default(self): return self.tuple_cls() def _decode(self, obj, context, path): - # skip _io member - return self.tuple_cls(*list(obj.values())[1:]) + if isinstance(obj, Container): + # skip _io member + values = list(obj.values())[1:] + else: + values = [obj] + return self.tuple_cls(*values) def _encode(self, obj, context, path): return obj._asdict() diff --git a/python/fusion_engine_client/utils/transport_utils.py b/python/fusion_engine_client/utils/transport_utils.py index 2495421a..deace9df 100644 --- a/python/fusion_engine_client/utils/transport_utils.py +++ b/python/fusion_engine_client/utils/transport_utils.py @@ -3,7 +3,19 @@ from typing import Callable, Union try: - # pySerial is optional. + # WebSocket support is optional. To use, install with: + # pip install websockets + import websockets.sync.client as ws + ws_supported = True +except ImportError: + ws_supported = False + # Dummy stand-ins for type hinting if websockets is not installed. + class ws: + class ClientConnection: pass + +try: + # Serial port support is optional. To use, install with: + # pip install pyserial import serial serial_supported = True @@ -21,7 +33,7 @@ def __send(self, data, flags=None): serial.Serial.send = __send except ImportError: serial_supported = False - # Dummy stand-in if pySerial is not installed. + # Dummy stand-in for type hinting if pySerial is not installed. class serial: class Serial: pass class SerialException(Exception): pass @@ -34,6 +46,8 @@ class SerialException(Exception): pass udp://:12345) Note: When using UDP, you must configure the device to send data to your machine. +- ws://HOSTNAME:PORT - Connect to the specified hostname (or IP address) and + port over WebSocket (e.g., ws://192.168.0.3:30300) - unix://FILENAME - Connect to the specified UNIX domain socket file - [(serial|tty)://]DEVICE:BAUD - Connect to a serial device with the specified baud rate (e.g., tty:///dev/ttyUSB0:460800 or /dev/ttyUSB0:460800) @@ -41,7 +55,7 @@ class SerialException(Exception): pass def create_transport(descriptor: str, timeout_sec: float = None, print_func: Callable = None) -> \ - Union[socket.socket, serial.Serial]: + Union[socket.socket, serial.Serial, ws.ClientConnection]: m = re.match(r'^tcp://([a-zA-Z0-9-_.]+)?(?::([0-9]+))?$', descriptor) if m: hostname = m.group(1) @@ -72,6 +86,27 @@ def create_transport(descriptor: str, timeout_sec: float = None, print_func: Cal transport.bind(('', port)) return transport + m = re.match(r'^ws://([a-zA-Z0-9-_.]+):([0-9]+)$', descriptor) + if m: + hostname = m.group(1) + ip_address = socket.gethostbyname(hostname) + port = int(m.group(2)) + + url = f'ws://{ip_address}:{port}' + + if not ws_supported: + raise RuntimeError(f'Websocket support not found. Cannot connect to {url}. ' + f'Please install (pip install websockets) and run again.') + + if print_func is not None: + print_func(f'Connecting to {url}.') + + try: + transport = ws.connect(url, open_timeout=timeout_sec) + except TimeoutError: + raise TimeoutError(f'Timed out connecting to {url}.') + return transport + m = re.match(r'^unix://([a-zA-Z0-9-_./]+)$', descriptor) if m: path = m.group(1) @@ -84,21 +119,21 @@ def create_transport(descriptor: str, timeout_sec: float = None, print_func: Cal transport.connect(path) return transport - m = re.match(r'^(?:(?:serial|tty)://)?([^:]+)(:([0-9]+))?$', descriptor) + m = re.match(r'^(?:(?:serial|tty)://)?([^:]+)(?::([0-9]+))?$', descriptor) if m: - if serial_supported: - path = m.group(1) - if m.group(2) is None: - raise ValueError('Serial baud rate not specified.') - else: - baud_rate = int(m.group(2)) - if print_func is not None: - print_func(f'Connecting to tty://{path}:{baud_rate}.') - - transport = serial.Serial(port=path, baudrate=baud_rate, timeout=timeout_sec) - return transport + path = m.group(1) + if m.group(2) is None: + raise ValueError('Serial baud rate not specified.') else: - raise RuntimeError( - "This application requires pyserial. Please install (pip install pyserial) and run again.") + baud_rate = int(m.group(2)) + + if not serial_supported: + raise RuntimeError(f'Serial port support not found. Cannot connect to tty://{path}:{baud_rate}. ' + f'Please install (pip install pyserial) and run again.') + if print_func is not None: + print_func(f'Connecting to tty://{path}:{baud_rate}.') + + transport = serial.Serial(port=path, baudrate=baud_rate, timeout=timeout_sec) + return transport raise ValueError('Unsupported transport descriptor.')