From 8404833f975287f039b88183e7e41ada7ef3a500 Mon Sep 17 00:00:00 2001 From: Blake Freer <59676067+BlakeFreer@users.noreply.github.com> Date: Mon, 15 Apr 2024 17:36:35 -0400 Subject: [PATCH] refactor(can-gen): :recycle: Improve CAN generation and build support. (#109) * refactor(can-gen): :recycle: Improve CAN generation and build support. Convert ETL from an include path to a proper target in CMake. Refactor cangen Python code to use a functional paradigm and be more "Pythonic" * undoes unnecessary indent change * adds requiremens.txt for cangen --- firmware/CMakeLists.txt | 6 - firmware/mcal/raspi/periph/can.h | 3 +- firmware/mcal/windows/periph/can.h | 4 +- firmware/projects/DemoCan/CMakeLists.txt | 5 + .../DemoCan/platforms/windows/bindings.cc | 2 +- scripts/cangen/can_generator.py | 553 ++++++++---------- scripts/cangen/main.py | 80 +-- scripts/cangen/requirements.txt | 15 + 8 files changed, 315 insertions(+), 353 deletions(-) create mode 100644 scripts/cangen/requirements.txt diff --git a/firmware/CMakeLists.txt b/firmware/CMakeLists.txt index b67c7599..eb9b5a09 100644 --- a/firmware/CMakeLists.txt +++ b/firmware/CMakeLists.txt @@ -111,12 +111,6 @@ add_subdirectory(${DIR_MCAL}) # provides "mcal-" (library) add_subdirectory(${DIR_PROJECT}) # modifies "main" (executable) add_subdirectory(${DIR_PLATFORM}) # modifies "bindings" (library) -set(DIR_ETL "${CMAKE_CURRENT_SOURCE_DIR}/third-party/etl/include") -target_include_directories(main - PUBLIC - ${DIR_ETL} -) - target_link_libraries(main PRIVATE shared) target_link_libraries(main PRIVATE bindings) diff --git a/firmware/mcal/raspi/periph/can.h b/firmware/mcal/raspi/periph/can.h index 28c05e33..85e494bd 100644 --- a/firmware/mcal/raspi/periph/can.h +++ b/firmware/mcal/raspi/periph/can.h @@ -30,7 +30,8 @@ class CanBase : public shared::periph::CanBase { void Setup() { // Create a socket - sock_ = socket(PF_CAN, SOCK_RAW, CAN_RAW) if (sock_ < 0) { + sock_ = socket(PF_CAN, SOCK_RAW, CAN_RAW); + if (sock_ < 0) { perror("Error creating socket"); return; } diff --git a/firmware/mcal/windows/periph/can.h b/firmware/mcal/windows/periph/can.h index 057b86ad..de03b6bd 100644 --- a/firmware/mcal/windows/periph/can.h +++ b/firmware/mcal/windows/periph/can.h @@ -15,7 +15,7 @@ #include "shared/comms/can/raw_can_msg.h" #include "shared/periph/can.h" -namespace mcal::periph { +namespace mcal::windows::periph { class CanBase : public shared::periph::CanBase { public: @@ -48,4 +48,4 @@ class CanBase : public shared::periph::CanBase { std::string iface_; }; -} // namespace mcal::periph +} // namespace mcal::windows::periph diff --git a/firmware/projects/DemoCan/CMakeLists.txt b/firmware/projects/DemoCan/CMakeLists.txt index 7a8acf14..0884b2f9 100644 --- a/firmware/projects/DemoCan/CMakeLists.txt +++ b/firmware/projects/DemoCan/CMakeLists.txt @@ -29,10 +29,15 @@ target_sources(main ) add_dependencies(main generated_can) + target_include_directories(main PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/inc ) +# Link ETL which is needed for generated/can/msg_registry.h +add_subdirectory(${CMAKE_SOURCE_DIR}/third-party/etl ${CMAKE_BINARY_DIR}/third-party) +target_link_libraries(main PRIVATE etl) + # Notice that we don't include any mcal/ subdirectory in this CMake file. # The master CMakeLists handles platform selection and library linking. \ No newline at end of file diff --git a/firmware/projects/DemoCan/platforms/windows/bindings.cc b/firmware/projects/DemoCan/platforms/windows/bindings.cc index 19f52f42..56749015 100644 --- a/firmware/projects/DemoCan/platforms/windows/bindings.cc +++ b/firmware/projects/DemoCan/platforms/windows/bindings.cc @@ -5,7 +5,7 @@ #include "shared/periph/can.h" namespace mcal { -using namespace periph::windows; +using namespace windows::periph; CanBase veh_can_base{"vcan0"}; } // namespace mcal diff --git a/scripts/cangen/can_generator.py b/scripts/cangen/can_generator.py index 06eb7b92..d552ab5f 100644 --- a/scripts/cangen/can_generator.py +++ b/scripts/cangen/can_generator.py @@ -3,330 +3,275 @@ Date: 2024-04-13 """ -import os import logging import math +import os import re -import cantools -import datetime -from typing import List, Tuple, Dict +import time +from typing import Dict, List, Tuple + +import numpy as np +from cantools.database import Database, Message, Signal from jinja2 import Environment logger = logging.getLogger(__name__) +EIGHT_BITS = 8 +EIGHT_BYTES = 8 +TOTAL_BITS = EIGHT_BITS * EIGHT_BYTES +MSG_REGISTRY_FILE_NAME = "msg_registry.h" +CAN_MESSAGES_FILE_NAME = "can_messages.h" -class CanGenerator: - """ - This class provides functionalities for CAN code generation based on DBC files. - """ - def __init__( - self, - msg_registry_template_path: str, - can_messages_template_path: str, - ): - """ - Initializes the CanGen object. - - Args: - dbc_files (list): List of paths to DBC files. - our_node (str): The CAN node for which to generate code. - log_level (int, optional): Logging level for the class. - Defaults to logging.DEBUG. - """ - - self.msg_registry_template_path = msg_registry_template_path - self.can_messages_template_path = can_messages_template_path - - # Constants - self.EIGHT_BITS = 8 - self.EIGHT_BYTES = 8 - self.TOTAL_BITS = self.EIGHT_BITS * self.EIGHT_BYTES - self.MSG_REGISTRY_FILE_NAME = "msg_registry.h" - self.CAN_MESSAGES_FILE_NAME = "can_messages.h" - - def __parse_dbc_files(self, dbc_files: List[str]) -> cantools.database.Database: - logger.info(f"Parsing DBC files: {dbc_files}") - can_db = cantools.database.Database() - - for dbc_file in dbc_files: - if not str(dbc_file).endswith(".dbc"): - logger.error(f"File provided is not a .dbc: {dbc_file}") - continue - if not os.path.isfile(dbc_file): - logger.error(f"File provided not found: {dbc_file}") - continue - with open(dbc_file, "r") as f: - can_db.add_dbc(f) - logger.info(f"Successfully added DBC: {dbc_file}") - - return can_db - - def __filter_messages_by_node( - self, messages: List[cantools.database.Message], node: str - ) -> Tuple[List[cantools.database.Message], List[cantools.database.Message]]: - rx_msgs, tx_msgs = [], [] - for msg in messages: - if node in msg.senders: - tx_msgs.append(msg) - if node in msg.receivers: - rx_msgs.append(msg) - - logger.info( - f"Filtered messages by node: {node}. Num msgs: rx = {len(rx_msgs)}, tx = " - f"{len(tx_msgs)}" - ) - - return rx_msgs, tx_msgs - - def __get_masks_shifts( - self, msgs: List[cantools.database.Message] - ) -> Dict[str, Dict[str, Tuple[List[int], List[int]]]]: - masks_shifts_dict = {} - - for msg in msgs: - masks_shifts_dict[msg.name] = {} - for sig in msg.signals: - if sig.byte_order == "little_endian": - logger.debug( - f"Using little-endian byte order " - f"(msg: {msg.name}, sig: {sig.name})" - ) - mask, num_trailing_zeros = self.__little_endian_mask( - sig.length, sig.start - ) - shift_amounts = self.__little_endian_shift_amounts( - num_trailing_zeros - ) - elif sig.byte_order == "big_endian": - logger.debug( - f"Using big-endian byte order " - f"(msg: {msg.name}, sig: {sig.name})" - ) - mask, num_trailing_zeros = self.__big_endian_mask( - sig.length, sig.start - ) - shift_amounts = self.__big_endian_shift_amounts(num_trailing_zeros) - else: - logger.error(f"Invalid byte order ({sig.byte_order})") - exit(1) - masks_shifts_dict[msg.name][sig.name] = ( - [int(byte) for byte in mask], - shift_amounts, - ) - - return masks_shifts_dict - - def __big_endian_mask(self, length: int, start: int) -> Tuple[bytearray, int]: - start_flipped = (start // self.EIGHT_BITS) * self.EIGHT_BITS + ( - self.EIGHT_BITS - 1 - (start % self.EIGHT_BITS) - ) # Flip bit direction (less confusing) - pos = start_flipped - end = pos + length - 1 - - mask = bytearray(self.EIGHT_BYTES) - while pos < self.TOTAL_BITS: - if pos <= end: - mask[pos // self.EIGHT_BYTES] |= 1 << ( - self.EIGHT_BITS - 1 - pos % self.EIGHT_BITS - ) - pos += 1 - else: - break +def _assert_valid_dbc(filename: str): + """Raise an error if filename is not a valid and existant dbc file.""" + + if not os.path.isfile(filename): + raise FileNotFoundError(f"Could not find a file at {filename}.") + + _, extension = os.path.splitext(filename) + if extension != ".dbc": + raise ValueError(f"{filename} is not a .dbc file.") + + logger.debug(f"{filename} is a valid dbc file.") + + +def _parse_dbc_files(dbc_files: List[str]) -> Database: + logger.info(f"Parsing DBC files: {dbc_files}") + can_db = Database() + + for dbc_file in dbc_files: + _assert_valid_dbc(dbc_file) + with open(dbc_file, "r") as f: + can_db.add_dbc(f) + logger.info(f"Successfully added DBC: {dbc_file}") + + return can_db + + +def _filter_messages_by_node( + messages: List[Message], node: str +) -> Tuple[List[Message], List[Message]]: + tx_msgs = [msg for msg in messages if node in msg.senders] + rx_msgs = [msg for msg in messages if node in msg.receivers] + + logger.info( + f"Filtered messages by node: {node}. " + f"Num msgs: rx = {len(rx_msgs)}, tx = {len(tx_msgs)}" + ) + + return rx_msgs, tx_msgs + + +def _get_mask_shift_big( + length: int, start: int +) -> Tuple[np.ndarray[int], np.ndarray[int]]: + q, r = divmod(start, EIGHT_BITS) + start_flipped = q * EIGHT_BITS + EIGHT_BITS - r - 1 + end = start_flipped + length + + idx = np.arange(64) + mask_bool = (idx >= start_flipped) & (idx < end) + mask_bytes = np.packbits(mask_bool, bitorder="big") + logger.info("Big endian mask generated.") + + num_zeros = TOTAL_BITS - end + shift_amounts = np.arange(0, 64, 8, dtype=int)[::-1] - num_zeros + + logger.info("Big endian shift amounts calculated.") + + return mask_bytes, shift_amounts + + +def _get_mask_shift_little( + length: int, start: int +) -> Tuple[np.ndarray[int], np.ndarray[int]]: + + idx = np.arange(64) + mask_bool = (idx >= start) & (idx < start + length) + mask_bytes = np.packbits(mask_bool, bitorder="little") + logger.info("Little endian mask generated.") + + num_zeros = start + shift_amounts = np.arange(0, 64, 8, dtype=int) - num_zeros + logger.info("Little endian shift amounts calculated.") + + return mask_bytes, shift_amounts + + +def _get_masks_shifts( + msgs: List[Message], +) -> Dict[str, Dict[str, Tuple[List[int], List[int]]]]: - logger.info("Big endian mask generated") + # Create a dictionary of empty dictionaries, indexed by message names + masks_shifts_dict = {msg.name: {} for msg in msgs} - return mask, (self.TOTAL_BITS - 1 - end) + for msg in msgs: + for sig in msg.signals: + logger.debug(f"Processing (msg: {msg.name}, sig: {sig.name})") - def __little_endian_mask(self, length: int, start: int) -> Tuple[bytearray, int]: - mask = bytearray(self.EIGHT_BYTES) - pos = start - end = start + length - while pos < self.TOTAL_BITS: - if pos < end: - mask[pos // self.EIGHT_BYTES] |= 1 << (pos % self.EIGHT_BITS) - pos += 1 + if sig.byte_order == "little_endian": + mask, shift = _get_mask_shift_little(sig.length, sig.start) + elif sig.byte_order == "big_endian": + mask, shift = _get_mask_shift_big(sig.length, sig.start) else: - break - - logger.info("Little endian mask generated") - - return mask, start + raise ValueError(f"Invalid byteorder {sig.byte_order}.") - def __big_endian_shift_amounts(self, num_zeroes: int) -> List[int]: - shift_amounts = [0] * self.EIGHT_BYTES - for i in range(0, self.EIGHT_BYTES): - shift_amounts[self.EIGHT_BYTES - 1 - i] = (i * self.EIGHT_BITS) - num_zeroes - - logger.info("Big endian shift amounts calculated") - - return shift_amounts + # Evaluate that function on the signal start and length + masks_shifts_dict[msg.name][sig.name] = mask, shift - def __little_endian_shift_amounts(self, num_zeroes: int) -> List[int]: - shift_amounts = [0] * self.EIGHT_BYTES - for i in range(0, self.EIGHT_BYTES): - shift_amounts[i] = i * self.EIGHT_BITS - num_zeroes - - logger.info("Little endian shift amounts calculated") + return masks_shifts_dict - return shift_amounts - - def __get_signal_types(self, can_db, allow_floating_point=True): - sig_types = {} - for message in can_db.messages: - sig_types[message.name] = {} - for signal in message.signals: - num_bits = signal.length - sign = "" +def _get_signal_datatype(signal: Signal, allow_floating_point: bool = True) -> str: + """Get the datatype of a signal.""" + num_bits = signal.length + if signal.scale > 1: + num_bits += math.ceil(math.log2(signal.scale)) - if signal.scale > 1: - num_bits = signal.length + math.ceil(math.log2(signal.scale)) - - if not signal.is_signed: - sign = "u" + is_float = signal.is_float or isinstance(signal.scale, float) + if is_float and allow_floating_point: + return "float" if num_bits <= 32 else "double" - if ( - isinstance(signal.scale, float) or signal.is_float - ) and allow_floating_point: - if num_bits <= 32: - sig_types[message.name][signal.name] = "float" - continue - else: - sig_types[message.name][signal.name] = "double" - continue - if num_bits == 1: - sig_types[message.name][signal.name] = "bool" - elif num_bits <= 8: - sig_types[message.name][signal.name] = sign + "int8_t" - continue - elif num_bits <= 16: - sig_types[message.name][signal.name] = sign + "int16_t" - continue - elif num_bits <= 32: - sig_types[message.name][signal.name] = sign + "int32_t" - continue - else: - sig_types[message.name][signal.name] = sign + "int64_t" - continue + if num_bits == 1: + return "bool" - logger.info("Signal types retrieved") - - return sig_types - - def __camel_to_snake(self, text): - """ - Converts UpperCamelCase to snake_case. - """ - - s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", text) - - return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() - - def __decimal_to_hex(self, decimal_value): - """ - Converts a non-negative decimal integer to a lowercase hexadecimal string. - - Args: - decimal_value: The non-negative decimal integer to convert. - - Returns: - The hexadecimal representation of the decimal value as a lowercase string. - - Raises: - TypeError: If the input is not an integer. - ValueError: If the input is negative. - """ - - if not isinstance(decimal_value, int): - raise TypeError("Input must be an integer") - if decimal_value < 0: - raise ValueError("Input must be a non-negative integer") - - hex_digits = "0123456789ABCDEF" - hex_string = "" - while decimal_value > 0: - remainder = decimal_value % 16 - hex_string = hex_digits[remainder] + hex_string - decimal_value //= 16 - - return "0x" + hex_string.lower() - - def __generate_from_jinja2_template( - self, template_path: str, output_path: str, context_dict: dict - ): - # Read the template string from a file - with open(template_path, "r") as file: - template_str = file.read() - - # Create the environment with trim_blocks and lstrip_blocks settings - env = Environment(trim_blocks=True, lstrip_blocks=True) - - # Register the camel_to_snake filter - env.filters["camel_to_snake"] = self.__camel_to_snake - env.filters["decimal_to_hex"] = self.__decimal_to_hex - - # Load the template from the string content - template = env.from_string(template_str) - - # Render the template with the context - rendered_code = template.render(**context_dict) - - # Write the rendered code to a file - os.makedirs(os.path.dirname(output_path), exist_ok=True) - with open(output_path, "w") as output_file: - output_file.write(rendered_code) - - logger.info(f"Rendered code written to '{output_path}'") - - def generate_code( - self, dbc_files: List[str], our_node: str, bus_name: str, output_dir: str - ): - """ - Parses DBC files, extracts information, and generates code using Jinja2 - templates. - - This method performs all the necessary steps to generate CAN code based on the - provided DBC files and node information. It utilizes helper functions and Jinja2 - templates (not included here) to create the final code. - """ - - logger.info("Generating code") - - can_db = self.__parse_dbc_files(dbc_files) - - signal_types = self.__get_signal_types(can_db) - temp_signal_types = self.__get_signal_types(can_db, allow_floating_point=False) - - rx_msgs, tx_msgs = self.__filter_messages_by_node(can_db.messages, our_node) - - unpack_masks_shifts = self.__get_masks_shifts(rx_msgs) - pack_masks_shifts = self.__get_masks_shifts(tx_msgs) + sign = "" if signal.is_signed else "u" + # Check all possible integer sizes + for size in [8, 16, 32, 64]: + if num_bits <= size: + return "{}int{}_t".format(sign, size) - context = { - "date": datetime.date.today().strftime("%Y-%m-%d"), - "rx_msgs": rx_msgs, - "tx_msgs": tx_msgs, - "signal_types": signal_types, - "temp_signal_types": temp_signal_types, - "unpack_info": unpack_masks_shifts, - "pack_info": pack_masks_shifts, - "bus_name": bus_name, + raise ValueError("Signal does not have a valid datatype.") + + +def _get_signal_types(can_db: Database, allow_floating_point=True): + + # Create a dictionary (indexed by message name) of dictionaries (indexed by signal + # name) corresponding to the datatype of each signal within each message. + sig_types = { + message.name: { + signal.name: _get_signal_datatype(signal, allow_floating_point) + for signal in message.signals } + for message in can_db.messages + } + + logger.info("Signal types retrieved") + + return sig_types + + +def _camel_to_snake(text): + """Converts UpperCamelCase to snake_case.""" + + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", text) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + + +def _decimal_to_hex(decimal_value): + """ + Converts a non-negative decimal integer to a lowercase hexadecimal string. + + Raises: + ValueError: If the input is negative. + """ + + if decimal_value < 0: + raise ValueError("Input must be a non-negative integer") + + return hex(decimal_value) + + +def _generate_from_jinja2_template( + template_path: str, output_path: str, context_dict: dict +): + # Read the template string from a file + with open(template_path, "r") as file: + template_str = file.read() + + # Create the environment with trim_blocks and lstrip_blocks settings + env = Environment(trim_blocks=True, lstrip_blocks=True) + + # Register the camel_to_snake filter + env.filters["camel_to_snake"] = _camel_to_snake + env.filters["decimal_to_hex"] = _decimal_to_hex + + # Load the template from the string content + template = env.from_string(template_str) + + # Render the template with the context + rendered_code = template.render(**context_dict) + + # Write the rendered code to a file + output_dir = os.path.dirname(output_path) + os.makedirs(output_dir, exist_ok=True) + + # Create a git ignore for everything in the generated path. Ignore everything. + gitignore_path = os.path.join(output_dir, ".gitignore") + if not os.path.exists(gitignore_path): + with open(gitignore_path, "w") as f: + f.write("*") + + with open(output_path, "w") as output_file: + output_file.write(rendered_code) + + logger.info(f"Rendered code written to '{output_path}'") + + +def generate_code( + dbc_files: List[str], + our_node: str, + bus_name: str, + output_dir: str, + can_messages_template_path: str, + msg_registry_template_path: str, +): + """ + Parses DBC files, extracts information, and generates code using Jinja2 + templates. + + This method performs all the necessary steps to generate CAN code based on the + provided DBC files and node information. It utilizes helper functions and Jinja2 + templates (not included here) to create the final code. + """ - # Replace these lines with your Jinja2 template logic - logger.info("Generating code for can messages") - self.__generate_from_jinja2_template( - self.can_messages_template_path, - output_dir + "/" + self.CAN_MESSAGES_FILE_NAME, - context, - ) - - logger.info("Generating code for msg registry") - self.__generate_from_jinja2_template( - self.msg_registry_template_path, - output_dir + "/" + self.MSG_REGISTRY_FILE_NAME, - context, - ) - - logger.info("Code generation complete") + logger.info("Generating code") + + can_db = _parse_dbc_files(dbc_files) + + signal_types = _get_signal_types(can_db) + temp_signal_types = _get_signal_types(can_db, allow_floating_point=False) + + rx_msgs, tx_msgs = _filter_messages_by_node(can_db.messages, our_node) + + unpack_masks_shifts = _get_masks_shifts(rx_msgs) + pack_masks_shifts = _get_masks_shifts(tx_msgs) + + context = { + "date": time.strftime("%Y-%m-%d"), + "rx_msgs": rx_msgs, + "tx_msgs": tx_msgs, + "signal_types": signal_types, + "temp_signal_types": temp_signal_types, + "unpack_info": unpack_masks_shifts, + "pack_info": pack_masks_shifts, + "bus_name": bus_name, + } + + # Replace these lines with your Jinja2 template logic + logger.info("Generating code for can messages") + _generate_from_jinja2_template( + can_messages_template_path, + os.path.join(output_dir, CAN_MESSAGES_FILE_NAME), + context, + ) + + logger.info("Generating code for msg registry") + _generate_from_jinja2_template( + msg_registry_template_path, + os.path.join(output_dir, MSG_REGISTRY_FILE_NAME), + context, + ) + + logger.info("Code generation complete") diff --git a/scripts/cangen/main.py b/scripts/cangen/main.py index 684e6b07..6c0c846b 100644 --- a/scripts/cangen/main.py +++ b/scripts/cangen/main.py @@ -3,69 +3,71 @@ Date: 2024-04-13 """ -from can_generator import CanGenerator -import yaml import argparse import logging import os -SCRIPTS_TO_PROJECTS_DIR = "../../firmware/projects" -SCRIPTS_TO_DBCS = "../../firmware/dbcs" -SCRIPTS_TO_TEMPLATES_DIR = "templates" -CONFIG_FILE_PATH = "config.yaml" -CAN_MESSAGES_TEMPLATE_FILENAME = "can_messages.h.jinja2" -MSG_REGISTRY_FILENAME = "msg_registry.h.jinja2" +import can_generator +import yaml + +# Generate a set of directory paths, all based on this file's location +DIR_THIS_FILE = os.path.abspath(os.path.dirname(__file__)) + +DIR_FIRMWARE = os.path.join(DIR_THIS_FILE, os.pardir, os.pardir, "firmware") +DIR_PROJECTS = os.path.join(DIR_FIRMWARE, "projects") +DIR_DBCS = os.path.join(DIR_FIRMWARE, "dbcs") + +CONFIG_FILE_NAME = "config.yaml" OUTPUT_DIR = "generated/can" -logging.basicConfig(level="INFO", format="%(levelname)-8s| (%(name)s) %(message)s") +DIR_TEMPLATES = os.path.join(DIR_THIS_FILE, "templates") +CAN_MESSAGES_TEMPLATE_FILENAME = "can_messages.h.jinja2" +MSG_REGISTRY_TEMPLATE_FILENAME = "msg_registry.h.jinja2" -def read_yaml_file(yaml_file): - with open(yaml_file, "r") as file: - return yaml.safe_load(file) +logging.basicConfig(level="INFO", format="%(levelname)-8s| (%(name)s) %(message)s") -if __name__ == "__main__": + +def parse(): parser = argparse.ArgumentParser(description="DBC to C code generator") - parser.add_argument("--project", type=str, required=True, help="Name of the project") + parser.add_argument( + "--project", type=str, required=True, help="Name of the project" + ) - try: - args = parser.parse_args() - except argparse.ArgumentError: - parser.print_help() - exit(1) + # If parsing fails (ex incorrect or no arguments provided) then this exits with + # code 2. + return parser.parse_args() - project_dir = args.project - this_script_dir = os.path.dirname(os.path.abspath(__file__)) + +if __name__ == "__main__": # Change directory to the project folder - os.chdir(os.path.join(this_script_dir, SCRIPTS_TO_PROJECTS_DIR, project_dir)) + args = parse() + project_folder_name = args.project + os.chdir(os.path.join(DIR_PROJECTS, project_folder_name)) - config = read_yaml_file(CONFIG_FILE_PATH) + # Read & Parse the config file + with open(CONFIG_FILE_NAME, "r") as file: + config = yaml.safe_load(file) our_node = config["canGen"]["ourNode"].upper() bus_name = config["canGen"]["busName"].capitalize() dbc_files = config["canGen"]["dbcFiles"] - output_dir = OUTPUT_DIR - dbc_file_paths = [ - os.path.join(this_script_dir, SCRIPTS_TO_DBCS, dbc) for dbc in dbc_files - ] + dbc_file_paths = [os.path.join(DIR_DBCS, dbc) for dbc in dbc_files] - can_msg_template_path = os.path.join( - this_script_dir, SCRIPTS_TO_TEMPLATES_DIR, CAN_MESSAGES_TEMPLATE_FILENAME + can_messages_template_path = os.path.join( + DIR_TEMPLATES, CAN_MESSAGES_TEMPLATE_FILENAME ) msg_registry_template_path = os.path.join( - this_script_dir, SCRIPTS_TO_TEMPLATES_DIR, MSG_REGISTRY_FILENAME - ) - - can_generator = CanGenerator( - can_messages_template_path=can_msg_template_path, - msg_registry_template_path=msg_registry_template_path, + DIR_TEMPLATES, MSG_REGISTRY_TEMPLATE_FILENAME ) can_generator.generate_code( - dbc_files=dbc_file_paths, - our_node=our_node, - bus_name=bus_name, - output_dir=output_dir, + dbc_file_paths, + our_node, + bus_name, + OUTPUT_DIR, + can_messages_template_path, + msg_registry_template_path, ) diff --git a/scripts/cangen/requirements.txt b/scripts/cangen/requirements.txt new file mode 100644 index 00000000..15f21053 --- /dev/null +++ b/scripts/cangen/requirements.txt @@ -0,0 +1,15 @@ +argparse-addons==0.12.0 +bitstruct==8.19.0 +cantools==39.4.5 +crccheck==1.3.0 +diskcache==5.6.3 +Jinja2==3.1.3 +MarkupSafe==2.1.5 +numpy==1.26.4 +packaging==24.0 +python-can==4.3.1 +pywin32==306 +PyYAML==6.0.1 +textparser==0.24.0 +typing_extensions==4.11.0 +wrapt==1.16.0