Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/fusion_engine_client/analysis/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2282,7 +2282,7 @@ def plot_events(self):
rows = []
system_t0_ns = self.reader.get_system_t0_ns()
max_bytes = 128
for message, message_bytes in zip(data.messages, data.messages_bytes):
for message, message_bytes in zip(data.messages, data.message_bytes):
system_time_ns = message.get_system_time_ns()
if isinstance(message, EventNotificationMessage):
event_type = message.event_type
Expand Down
39 changes: 29 additions & 10 deletions python/fusion_engine_client/analysis/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,20 @@ def __init__(self, message_type, params):
self.message_class = message_type_to_class.get(self.message_type, None)
self.params = params
self.messages = []
self.messages_bytes = []
self.message_bytes = []
self.message_index = []
self.num_messages = 0

def add_message(self, payload: MessagePayload, message_bytes: int = None, message_index: int = None):
self.messages.append(payload)
self.num_messages += 1
if message_bytes is not None:
self.messages_bytes.append(message_bytes)
self.message_bytes.append(message_bytes)
if message_index is not None:
self.message_index.append(message_index)

def to_numpy(self, remove_nan_times: bool = True, keep_messages: bool = True):
def to_numpy(self, remove_nan_times: bool = True,
keep_messages: bool = True, keep_message_bytes: bool = True, keep_message_index: bool = True):
"""!
@brief Convert the raw FusionEngine message data into numpy arrays that can be used for data analysis.

Expand All @@ -59,6 +60,13 @@ def to_numpy(self, remove_nan_times: bool = True, keep_messages: bool = True):

@param remove_nan_times If `True`, remove entries whose P1 timestamps are `NaN` (if P1 time is available for
this message type).
@param keep_messages If `True`, return the @ref MessagePayload elements in the @ref self.messages after
successful conversion if numpy conversion is supported. Otherwise, clear @ref self.messages to free
memory.
@param keep_message_bytes Similar to `keep_messages`, if `False`, free the contents of the @ref
self.message_bytes.
@param keep_message_index Similar to `keep_messages`, if `False`, free the contents of the @ref
self.message_index.
@param keep_messages If `True`, keep the original @ref MessagePayload class instances in @ref self.messages in
addition to the populated numpy arrays. Otherwise, clear @ref self.messages.
"""
Expand All @@ -78,7 +86,7 @@ def to_numpy(self, remove_nan_times: bool = True, keep_messages: bool = True):

if do_conversion:
self.__dict__.update(self.message_class.to_numpy(self.messages))
self.messages_bytes = np.array(self.messages_bytes, dtype=np.uint64)
self.message_bytes = np.array(self.message_bytes, dtype=np.uint64)
self.message_index = np.array(self.message_index, dtype=int)

if remove_nan_times and 'p1_time' in self.__dict__:
Expand Down Expand Up @@ -122,7 +130,9 @@ def to_numpy(self, remove_nan_times: bool = True, keep_messages: bool = True):

if not keep_messages:
self.messages = []
self.messages_bytes = []
if not keep_message_bytes:
self.message_bytes = []
if not keep_message_index:
self.message_index = []
else:
raise ValueError('Message type %s does not support numpy conversion.' %
Expand Down Expand Up @@ -614,7 +624,9 @@ def _read(self,

# Convert the resulting message data to numpy (if supported).
if return_numpy:
DataLoader.to_numpy(result, keep_messages=keep_messages, remove_nan_times=remove_nan_times)
DataLoader.to_numpy(result, remove_nan_times=remove_nan_times,
keep_messages=keep_messages, keep_message_bytes=return_bytes,
keep_message_index=return_message_index)

# Done.
return result
Expand Down Expand Up @@ -912,20 +924,27 @@ def _get_value(i):
return data

@classmethod
def to_numpy(cls, data: dict, keep_messages: bool = True, remove_nan_times: bool = True):
def to_numpy(cls, data: dict, remove_nan_times: bool = True,
keep_messages: bool = True, keep_message_bytes: bool = True, keep_message_index: bool = True):
"""!
@brief Convert all (supported) messages in a data dictionary to numpy for analysis.

See @ref MessageData.to_numpy().

@param data A data `dict` as returned by @ref read().
@param keep_messages If `False`, the raw data in the `messages` field will be cleared for each @ref
MessagePayload object for which numpy conversion is supported.
@param keep_messages If `True`, return the @ref MessagePayload elements in the `messages` field for each message
type object for which numpy conversion is supported. Otherwise, delete the elements to free memory after
successful numpy conversion.
@param keep_message_bytes Similar to `keep_messages`, if `False`, free the contents of the `message_bytes`
field after successful conversion of message types that support numpy conversion.
@param keep_message_index Similar to `keep_messages`, if `False`, free the contents of the `message_index`
field after successful conversion of message types that support numpy conversion.
@param remove_nan_times If `True`, remove entries whose P1 timestamps are `NaN` (if P1 time is available for
this message type).
"""
for entry in data.values():
try:
entry.to_numpy(remove_nan_times=remove_nan_times, keep_messages=keep_messages)
entry.to_numpy(remove_nan_times=remove_nan_times, keep_messages=keep_messages,
keep_message_bytes=keep_message_bytes, keep_message_index=keep_message_index)
except ValueError:
pass
Loading