Skip to content

Commit

Permalink
added missing python bindings and fixed bug in netlist preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonKlx committed Nov 8, 2024
1 parent 7274a74 commit 8a2485e
Show file tree
Hide file tree
Showing 4 changed files with 337 additions and 102 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
/**
* @file gate_pair_label.h
* @brief This file contains classes and functions for labeling pairs of gates within a machine learning context.
*/

#pragma once

#include "hal_core/defines.h"
Expand All @@ -9,63 +14,153 @@

namespace hal
{
/* Forward declaration */
/* Forward declarations */
class Gate;
class Netlist;

namespace machine_learning
{
namespace gate_pair_label
{
/**
* @struct MultiBitInformation
* @brief Holds mappings between word labels and gates, and gates and word labels.
*
* This struct provides a bi-directional mapping between specific word pairs and their corresponding gates,
* as well as between gates and associated word pairs.
*/
struct MultiBitInformation
{
/**
* @brief Maps word pairs to corresponding gates.
*/
std::map<const std::pair<const std::string, const std::string>, std::vector<const Gate*>> word_to_gates;

/**
* @brief Maps gates to associated word pairs.
*/
std::map<const Gate*, std::vector<std::pair<const std::string, const std::string>>> gate_to_words;
};

/**
* @struct LabelContext
* @brief Provides context for gate-pair labeling within a netlist.
*
* This struct is initialized with a reference to the netlist and the gates involved in the labeling.
* It also provides access to multi-bit information for use in labeling calculations.
*/
struct LabelContext
{
/**
* @brief Deleted default constructor to enforce initialization with parameters.
*/
LabelContext() = delete;

/**
* @brief Constructs a `LabelContext` with the specified netlist and gates.
* @param[in] netlist - The netlist to which the gates belong.
* @param[in] gates - The gates to be labeled.
*/
LabelContext(const Netlist* netlist, const std::vector<Gate*>& gates) : nl(netlist), gates{gates} {};

/**
* @brief Retrieves the multi-bit information, initializing it if not already done.
* @returns A constant reference to the `MultiBitInformation` object.
*/
const MultiBitInformation& get_multi_bit_information();

/**
* @brief The netlist to which the gates belong.
*/
const Netlist* nl;

/**
* @brief The gates that are part of this labeling context.
*/
const std::vector<Gate*> gates;

/**
* @brief Optional storage for multi-bit information, initialized on demand.
*/
std::optional<MultiBitInformation> mbi;
};

/**
* @class GatePairLabel
* @brief Base class for calculating gate pairs and labels for machine learning models.
*
* This abstract class provides methods for calculating gate pairs and labels based on various criteria.
*/
class GatePairLabel
{
public:
/**
* @brief Calculate gate pairs based on the provided labeling context and netlist.
* @param[in] lc - The labeling context.
* @param[in] nl - The netlist to operate on.
* @param[in] gates - The gates to be paired.
* @returns A vector of gate pairs on success, an error otherwise.
*/
virtual std::vector<std::pair<const Gate*, const Gate*>> calculate_gate_pairs(LabelContext& lc, const Netlist* nl, const std::vector<Gate*>& gates) const = 0;
virtual std::vector<u32> calculate_label(LabelContext& lc, const Gate* g_a, const Gate* g_b) const = 0;
virtual std::vector<std::vector<u32>> calculate_labels(LabelContext& lc, const std::vector<std::pair<Gate*, Gate*>>& gate_pairs) const = 0;

/**
* @brief Calculate labels for a given gate pair.
* @param[in] lc - The labeling context.
* @param[in] g_a - The first gate in the pair.
* @param[in] g_b - The second gate in the pair.
* @returns A vector of labels on success, an error otherwise.
*/
virtual std::vector<u32> calculate_label(LabelContext& lc, const Gate* g_a, const Gate* g_b) const = 0;

/**
* @brief Calculate labels for multiple gate pairs.
* @param[in] lc - The labeling context.
* @param[in] gate_pairs - The gate pairs to label.
* @returns A vector of label vectors for each pair on success, an error otherwise.
*/
virtual std::vector<std::vector<u32>> calculate_labels(LabelContext& lc, const std::vector<std::pair<Gate*, Gate*>>& gate_pairs) const = 0;

/**
* @brief Calculate both gate pairs and their labels within the labeling context.
* @param[in] lc - The labeling context.
* @returns A pair containing gate pairs and corresponding labels on success, an error otherwise.
*/
virtual std::pair<std::vector<std::pair<const Gate*, const Gate*>>, std::vector<std::vector<u32>>> calculate_labels(LabelContext& lc) const = 0;
};

/**
* @class SharedSignalGroup
* @brief Labels gate pairs based on shared signal groups.
*/
class SharedSignalGroup : public GatePairLabel
{
public:
/**
* @brief Default constructor.
*/
SharedSignalGroup() {};

std::vector<std::pair<const Gate*, const Gate*>> calculate_gate_pairs(LabelContext& lc, const Netlist* nl, const std::vector<Gate*>& gates) const override;
std::vector<u32> calculate_label(LabelContext& lc, const Gate* g_a, const Gate* g_b) const override;
std::vector<std::vector<u32>> calculate_labels(LabelContext& lc, const std::vector<std::pair<Gate*, Gate*>>& gate_pairs) const override;

std::pair<std::vector<std::pair<const Gate*, const Gate*>>, std::vector<std::vector<u32>>> calculate_labels(LabelContext& lc) const override;
};

/**
* @class SharedConnection
* @brief Labels gate pairs based on shared connections.
*/
class SharedConnection : public GatePairLabel
{
public:
/**
* @brief Default constructor.
*/
SharedConnection() {};

std::vector<std::pair<const Gate*, const Gate*>> calculate_gate_pairs(LabelContext& lc, const Netlist* nl, const std::vector<Gate*>& gates) const override;
std::vector<u32> calculate_label(LabelContext& lc, const Gate* g_a, const Gate* g_b) const override;
std::vector<std::vector<u32>> calculate_labels(LabelContext& lc, const std::vector<std::pair<Gate*, Gate*>>& gate_pairs) const override;

std::pair<std::vector<std::pair<const Gate*, const Gate*>>, std::vector<std::vector<u32>>> calculate_labels(LabelContext& lc) const override;
};
} // namespace gate_pair_label
Expand Down
Loading

0 comments on commit 8a2485e

Please sign in to comment.