Skip to content

Commit

Permalink
Update to PR to reduce RAM usage
Browse files Browse the repository at this point in the history
  • Loading branch information
MatinF committed Nov 4, 2024
1 parent bf59725 commit e2958c7
Showing 1 changed file with 149 additions and 132 deletions.
281 changes: 149 additions & 132 deletions can/io/mf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
from hashlib import md5
from io import BufferedIOBase, BytesIO
from pathlib import Path
from typing import Any, BinaryIO, Iterator, Optional, Union, cast
from typing import Any, BinaryIO, Generator, Iterable, Optional, Union, cast

import can
from ..message import Message
from ..typechecking import StringPathLike
from ..util import channel2int, len2dlc
Expand Down Expand Up @@ -276,129 +275,153 @@ class MF4Reader(BinaryIOMessageReader):

# NOTE: Readout based on the bus logging code from asammdf GUI

def _extract_can_data_frame(self, data: Signal):
num_records = len(data)
names = set(data.samples[0].dtype.names)
required_names = {"CAN_DataFrame.ID", "CAN_DataFrame.DataBytes", "CAN_DataFrame.DataLength"}
if not required_names & names:
raise ValueError("Missing required columns")

column_timestamps = data.timestamps
column_id = data["CAN_DataFrame.ID"].tolist()
column_data = data["CAN_DataFrame.DataBytes"]
column_data_length = data["CAN_DataFrame.DataLength"].tolist()

default_object = can.Message()
class FrameIterator(object):
"""
Iterator helper class for common handling among CAN DataFrames, ErrorFrames and RemoteFrames.
"""

column_channel = data["CAN_DataFrame.BusChannel"].tolist() if "CAN_DataFrame.BusChannel" in names else [default_object.channel for _ in range(num_records)]
column_ide = data["CAN_DataFrame.IDE"].astype(bool).tolist() if "CAN_DataFrame.IDE" in names else [default_object.is_extended_id for _ in range(num_records)]
column_dir = data["CAN_DataFrame.Dir"].astype(bool).tolist() if "CAN_DataFrame.Dir" in names else [default_object.is_rx for _ in range(num_records)]
column_edl = data["CAN_DataFrame.EDL"].astype(bool).tolist() if "CAN_DataFrame.EDL" in names else [default_object.is_fd for _ in range(num_records)]
column_brs = data["CAN_DataFrame.BRS"].astype(bool).tolist() if "CAN_DataFrame.BRS" in names else [default_object.bitrate_switch for _ in range(num_records)]
column_esi = data["CAN_DataFrame.ESI"].astype(bool).tolist() if "CAN_DataFrame.ESI" in names else [default_object.error_state_indicator for _ in range(num_records)]
# Number of records to request for each asammdf call
_chunk_size = 1000

if "CAN_DataFrame.Dir" in names:
for i in range(num_records):
column_dir[i] = not column_dir[i]
def __init__(self, mdf: MDF, group_index: int, start_timestamp: float, name: str):
self._mdf = mdf
self._group_index = group_index
self._start_timestamp = start_timestamp
self._name = name

return

# Transform to python-can Messages
for i in range(num_records):
self._samples.append(
Message(
timestamp=float(column_timestamps[i]) + self._start_timestamp,
arbitration_id=column_id[i] & 0x1FFFFFFF,
is_extended_id=column_ide[i],
channel=column_channel[i],
is_rx=column_dir[i],
is_fd=column_edl[i],
bitrate_switch=column_brs[i],
error_state_indicator=column_esi[i],
data=column_data[i][:column_data_length[i]].tobytes(),
)
def _get_data(self, current_offset: int) -> asammdf.Signal:
# NOTE: asammdf suggests using select instead of get. Select seem to miss converting some channels which
# get does convert as expected.
data_raw = self._mdf.get(
self._name,
self._group_index,
record_offset=current_offset,
record_count=self._chunk_size,
raw=False
)

return data_raw

return
pass

def _extract_can_error_frame(self, data: Signal):
num_records = len(data)
names = set(data.samples[0].dtype.names)
column_timestamps = data.timestamps
default_object = can.Message()

column_id = data["CAN_ErrorFrame.ID"].tolist() if "CAN_ErrorFrame.ID" in names else [default_object.arbitration_id for _ in range(num_records)]
column_data = data["CAN_ErrorFrame.DataBytes"] if "CAN_ErrorFrame.DataBytes" in names else [default_object.data for _ in range(num_records)]
column_data_length = data["CAN_ErrorFrame.DataLength"].tolist() if "CAN_ErrorFrame.DataLength" in names else [default_object.dlc for _ in range(num_records)]
column_channel = data["CAN_ErrorFrame.BusChannel"].tolist() if "CAN_ErrorFrame.BusChannel" in names else [default_object.channel for _ in range(num_records)]
column_ide = data["CAN_ErrorFrame.IDE"].astype(bool).tolist() if "CAN_ErrorFrame.IDE" in names else [default_object.is_extended_id for _ in range(num_records)]
column_dir = data["CAN_ErrorFrame.Dir"].astype(bool).tolist() if "CAN_ErrorFrame.Dir" in names else [default_object.is_rx for _ in range(num_records)]
column_rtr = data["CAN_ErrorFrame.RTR"].astype(bool).tolist() if "CAN_ErrorFrame.RTR" in names else [default_object.is_remote_frame for _ in range(num_records)]
column_edl = data["CAN_ErrorFrame.EDL"].astype(bool).tolist() if "CAN_ErrorFrame.EDL" in names else [default_object.is_fd for _ in range(num_records)]
column_brs = data["CAN_ErrorFrame.BRS"].astype(bool).tolist() if "CAN_ErrorFrame.BRS" in names else [default_object.bitrate_switch for _ in range(num_records)]
column_esi = data["CAN_ErrorFrame.ESI"].astype(bool).tolist() if "CAN_ErrorFrame.ESI" in names else [default_object.error_state_indicator for _ in range(num_records)]

if "CAN_ErrorFrame.Dir" in names:
for i in range(num_records):
column_dir[i] = not column_dir[i]
class CANDataFrameIterator(FrameIterator):

# Transform to python-can Messages
for i in range(num_records):
message = Message(
timestamp=float(column_timestamps[i]) + self._start_timestamp,
arbitration_id=column_id[i] & 0x1FFFFFFF,
is_extended_id=column_ide[i],
is_error_frame=True,
is_remote_frame=column_rtr[i],
channel=column_channel[i],
is_rx=column_dir[i],
is_fd=column_edl[i],
bitrate_switch=column_brs[i],
error_state_indicator=column_esi[i],
dlc=column_data_length[i]
)
def __init__(self, mdf: MDF, group_index: int, start_timestamp: float):
super().__init__(mdf, group_index, start_timestamp, "CAN_DataFrame")

if column_data[i] is not None:
message.data = column_data[i][:column_data_length[i]].tobytes()
return

def __iter__(self) -> Generator[Message, None, None]:
for current_offset in range(0, self._mdf.groups[self._group_index].channel_group.cycles_nr, self._chunk_size):
data = self._get_data(current_offset)
names = data.samples[0].dtype.names

for i in range(len(data)):
data_length = int(data["CAN_DataFrame.DataLength"][i])

kv = {
"timestamp": float(data.timestamps[i]) + self._start_timestamp,
"arbitration_id": int(data["CAN_DataFrame.ID"][i]) & 0x1FFFFFFF,
"data": data["CAN_DataFrame.DataBytes"][i][:data_length].tobytes(),
}

if "CAN_DataFrame.BusChannel" in names:
kv["channel"] = int(data["CAN_DataFrame.BusChannel"][i])
if "CAN_DataFrame.Dir" in names:
kv["is_rx"] = int(data["CAN_DataFrame.Dir"][i]) == 0
if "CAN_DataFrame.IDE" in names:
kv["is_extended_id"] = bool(data["CAN_DataFrame.IDE"][i])
if "CAN_DataFrame.EDL" in names:
kv["is_fd"] = bool(data["CAN_DataFrame.EDL"][i])
if "CAN_DataFrame.BRS" in names:
kv["bitrate_switch"] = bool(data["CAN_DataFrame.BRS"][i])
if "CAN_DataFrame.ESI" in names:
kv["error_state_indicator"] = bool(data["CAN_DataFrame.ESI"][i])

yield Message(**kv)

self._samples.append(message)
return None

return
pass

def _extract_can_remote_frame(self, data: Signal):
num_records = len(data)
names = set(data.samples[0].dtype.names)
required_names = {"CAN_RemoteFrame.ID", "CAN_RemoteFrame.DLC"}
if not required_names & names:
raise ValueError("Missing required columns")
class CANErrorFrameIterator(FrameIterator):

column_timestamps = data.timestamps
column_id = data["CAN_RemoteFrame.ID"].tolist()
column_dlc = data["CAN_RemoteFrame.DataLength"].tolist()
def __init__(self, mdf: MDF, group_index: int, start_timestamp: float):
super().__init__(mdf, group_index, start_timestamp, "CAN_ErrorFrame")

return

default_object = can.Message()
def __iter__(self) -> Generator[Message, None, None]:
for current_offset in range(0, self._mdf.groups[self._group_index].channel_group.cycles_nr, self._chunk_size):
data = self._get_data(current_offset)
names = data.samples[0].dtype.names

for i in range(len(data)):
kv = {
"timestamp": float(data.timestamps[i]) + self._start_timestamp,
"is_error_frame": True,
}

if "CAN_ErrorFrame.BusChannel" in names:
kv["channel"] = int(data["CAN_ErrorFrame.BusChannel"][i])
if "CAN_ErrorFrame.Dir" in names:
kv["is_rx"] = int(data["CAN_ErrorFrame.Dir"][i]) == 0
if "CAN_ErrorFrame.ID" in names:
kv["arbitration_id"] = int(data["CAN_ErrorFrame.ID"][i]) & 0x1FFFFFFF
if "CAN_ErrorFrame.IDE" in names:
kv["is_extended_id"] = bool(data["CAN_ErrorFrame.IDE"][i])
if "CAN_ErrorFrame.EDL" in names:
kv["is_fd"] = bool(data["CAN_ErrorFrame.EDL"][i])
if "CAN_ErrorFrame.BRS" in names:
kv["bitrate_switch"] = bool(data["CAN_ErrorFrame.BRS"][i])
if "CAN_ErrorFrame.ESI" in names:
kv["error_state_indicator"] = bool(data["CAN_ErrorFrame.ESI"][i])
if "CAN_ErrorFrame.RTR" in names:
kv["is_remote_frame"] = bool(data["CAN_ErrorFrame.RTR"][i])
if "CAN_ErrorFrame.DataLength" in names and "CAN_ErrorFrame.DataBytes" in names:
data_length = int(data["CAN_ErrorFrame.DataLength"][i])
kv["data"] = data["CAN_ErrorFrame.DataBytes"][i][:data_length].tobytes()

yield Message(**kv)

return None

column_channel = data["CAN_RemoteFrame.BusChannel"].tolist() if "CAN_RemoteFrame.BusChannel" in names else [default_object.channel for _ in range(num_records)]
column_ide = data["CAN_RemoteFrame.IDE"].astype(bool).tolist() if "CAN_RemoteFrame.IDE" in names else [default_object.is_extended_id for _ in range(num_records)]
column_dir = data["CAN_RemoteFrame.Dir"].astype(bool).tolist() if "CAN_RemoteFrame.Dir" in names else [default_object.is_rx for _ in range(num_records)]
pass

class CANRemoteFrameIterator(FrameIterator):

if "CAN_RemoteFrame.Dir" in names:
for i in range(num_records):
column_dir[i] = not column_dir[i]
def __init__(self, mdf: MDF, group_index: int, start_timestamp: float):
super().__init__(mdf, group_index, start_timestamp, "CAN_RemoteFrame")

return

# Transform to python-can Messages
for i in range(num_records):
self._samples.append(
Message(
timestamp=float(column_timestamps[i]) + self._start_timestamp,
arbitration_id=column_id[i] & 0x1FFFFFFF,
is_extended_id=column_ide[i],
is_remote_frame=True,
channel=column_channel[i],
is_rx=column_dir[i],
dlc=int(column_dlc[i]),
)
)
def __iter__(self) -> Generator[Message, None, None]:
for current_offset in range(0, self._mdf.groups[self._group_index].channel_group.cycles_nr, self._chunk_size):
data = self._get_data(current_offset)
names = data.samples[0].dtype.names

for i in range(len(data)):
kv = {
"timestamp": float(data.timestamps[i]) + self._start_timestamp,
"arbitration_id": int(data["CAN_RemoteFrame.ID"][i]) & 0x1FFFFFFF,
"dlc": int(data["CAN_RemoteFrame.DLC"][i]),
"is_remote_frame": True,
}

if "CAN_RemoteFrame.BusChannel" in names:
kv["channel"] = int(data["CAN_RemoteFrame.BusChannel"][i])
if "CAN_RemoteFrame.Dir" in names:
kv["is_rx"] = int(data["CAN_RemoteFrame.Dir"][i]) == 0
if "CAN_RemoteFrame.IDE" in names:
kv["is_extended_id"] = bool(data["CAN_RemoteFrame.IDE"][i])

yield Message(**kv)

return None

return
pass

def __init__(
self,
Expand All @@ -418,17 +441,21 @@ def __init__(

super().__init__(file, mode="rb")

m: MDF4
self._mdf: MDF
if isinstance(file, BufferedIOBase):
m = MDF(BytesIO(file.read()))
self._mdf = MDF(BytesIO(file.read()))
else:
m = MDF(file)
self._mdf = MDF(file)

self._start_timestamp = self._mdf.header.start_time.timestamp()

self._start_timestamp = m.header.start_time.timestamp()
self._samples = []
def __iter__(self) -> Iterable[Message]:
import heapq

# Extract all data to a common list
for i, group in enumerate(m.groups):
# To handle messages split over multiple channel groups, create a single iterator per channel group and merge
# these iterators into a single iterator using heapq.
iterators = []
for group_index, group in enumerate(self._mdf.groups):
channel_group: ChannelGroup = group.channel_group

if not channel_group.flags & FLAG_CG_BUS_EVENT:
Expand All @@ -439,7 +466,6 @@ def __init__(
# No data, skip
continue

# Get a handle to the acquisition source
acquisition_source: Optional[Source] = channel_group.acq_source

if acquisition_source is None:
Expand All @@ -453,28 +479,19 @@ def __init__(

if acquisition_source.bus_type == Source.BUS_TYPE_CAN:
if "CAN_DataFrame" in channel_names:
# Ensure all required fields are present
data = m.get("CAN_DataFrame", group=i, raw=False)
self._extract_can_data_frame(data)
iterators.append(self.CANDataFrameIterator(self._mdf, group_index, self._start_timestamp))
elif "CAN_ErrorFrame" in channel_names:
data = m.get("CAN_ErrorFrame", group=i, raw=False)
self._extract_can_error_frame(data)
iterators.append(self.CANErrorFrameIterator(self._mdf, group_index, self._start_timestamp))
elif "CAN_RemoteFrame" in channel_names:
data = m.get("CAN_RemoteFrame", group=i, raw=False)
self._extract_can_remote_frame(data)
iterators.append(self.CANRemoteFrameIterator(self._mdf, group_index, self._start_timestamp))
else:
# Unknown bus type, skip
continue

pass

# Ensure the samples are sorted according to timestamp
self._samples.sort(key=lambda x: x.timestamp)

m.close()

def __iter__(self) -> Iterator[Message]:
return iter(self._samples)
# Create merged iterator over all the groups, using the timestamps as comparison key
return heapq.merge(*iterators, key=lambda x: x.timestamp)

def stop(self) -> None:
self._mdf.close()
self._mdf = None
super().stop()

0 comments on commit e2958c7

Please sign in to comment.