Skip to content

Commit

Permalink
Replace AtlasRegister with Sparta registers
Browse files Browse the repository at this point in the history
  • Loading branch information
cnyce committed Dec 10, 2024
1 parent c61cce5 commit c6544d4
Show file tree
Hide file tree
Showing 13 changed files with 275 additions and 344 deletions.
10 changes: 10 additions & 0 deletions arch/RegisterDefnsJSON.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ namespace atlas
cached_initial_values_.emplace_back(item["initial_value"].GetString());
initial_value = cached_initial_values_.back().raw();
}
else
{
cached_initial_values_.emplace_back(bytes);
initial_value = cached_initial_values_.back().raw();
}

constexpr sparta::RegisterBase::Definition::HintsT hints = 0;
constexpr sparta::RegisterBase::Definition::RegDomainT regdomain = 0;
Expand Down Expand Up @@ -193,6 +198,11 @@ namespace atlas
}
}

InitialValueRef(size_t num_bytes)
{
hex_bytes_ = std::vector<unsigned char>(num_bytes, 0);
}

const unsigned char* raw() const { return hex_bytes_.data(); }

private:
Expand Down
25 changes: 25 additions & 0 deletions arch/RegisterSet.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,31 @@

namespace atlas
{
inline RegId getRegId(const sparta::Register* reg)
{
RegId reg_id;

switch (reg->getGroupNum()) {
case 0:
reg_id.reg_type = RegType::INTEGER;
break;
case 1:
reg_id.reg_type = RegType::FLOATING_POINT;
break;
case 2:
reg_id.reg_type = RegType::VECTOR;
break;
case 3:
reg_id.reg_type = RegType::CSR;
break;
default:
sparta_assert(false, "Invalid register group number");
}

reg_id.reg_num = reg->getID();
reg_id.reg_name = reg->getName();
return reg_id;
}

class RegisterSet : public sparta::RegisterSet
{
Expand Down
2 changes: 2 additions & 0 deletions arch/register_macros.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include "core/inst_handlers/inst_helpers.hpp"

#define READ_INT_REG(reg_name) \
(atlas::INT::reg_name::reg_num == 0) \
? 0 \
Expand Down
27 changes: 10 additions & 17 deletions core/AtlasInst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
namespace atlas
{
template <mavis::InstMetaData::OperandFieldID OperandFieldId>
AtlasRegisterPtr getAtlasReg(const AtlasState* state,
const mavis::OperandInfo::ElementList & operand_list)
sparta::Register* getSpartaReg(AtlasState* state,
const mavis::OperandInfo::ElementList & operand_list)
{
const auto operand = std::find_if(operand_list.begin(), operand_list.end(),
[](const mavis::OperandInfo::Element & operand)
Expand All @@ -17,29 +17,22 @@ namespace atlas
return nullptr;
}

// Determine Atlas register type
RegType reg_type = RegType::INVALID;
switch (operand->operand_type)
{
case mavis::InstMetaData::OperandTypes::WORD:
case mavis::InstMetaData::OperandTypes::LONG:
reg_type = RegType::INTEGER;
break;
return state->getIntRegister(operand->field_value);
case mavis::InstMetaData::OperandTypes::SINGLE:
case mavis::InstMetaData::OperandTypes::DOUBLE:
case mavis::InstMetaData::OperandTypes::QUAD:
reg_type = RegType::FLOATING_POINT;
break;
return state->getFpRegister(operand->field_value);
case mavis::InstMetaData::OperandTypes::VECTOR:
reg_type = RegType::VECTOR;
break;
return state->getVecRegister(operand->field_value);
case mavis::InstMetaData::OperandTypes::NONE:
sparta_assert(false, "Invalid Mavis Operand Type!");
}

// TODO AtlasRegister: Update this function to use new AtlasState methods for accessing
// registers
return state->getAtlasRegister(reg_type, operand->field_value);
return nullptr;
}

AtlasInst::AtlasInst(const mavis::OpcodeInfo::PtrType & opcode_info,
Expand All @@ -48,12 +41,12 @@ namespace atlas
opcode_info_(opcode_info),
extractor_info_(extractor_info),
opcode_size_(((getOpcode() & 0x3) != 0x3) ? 2 : 4),
rs1_(getAtlasReg<mavis::InstMetaData::OperandFieldID::RS1>(
rs1_(getSpartaReg<mavis::InstMetaData::OperandFieldID::RS1>(
state, opcode_info->getSourceOpInfoList())),
rs2_(getAtlasReg<mavis::InstMetaData::OperandFieldID::RS2>(
rs2_(getSpartaReg<mavis::InstMetaData::OperandFieldID::RS2>(
state, opcode_info->getSourceOpInfoList())),
rd_(getAtlasReg<mavis::InstMetaData::OperandFieldID::RD>(state,
opcode_info->getDestOpInfoList())),
rd_(getSpartaReg<mavis::InstMetaData::OperandFieldID::RD>(
state, opcode_info->getDestOpInfoList())),
inst_action_group_(extractor_info_->inst_action_group_)
{
}
Expand Down
20 changes: 11 additions & 9 deletions core/AtlasInst.hpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#pragma once

#include "core/AtlasExtractor.hpp"
#include "core/AtlasRegister.hpp"

#include "mavis/OpcodeInfo.h"

#include "sparta/utils/SpartaSharedPointerAllocator.hpp"

namespace sparta
{
class Register;
}

namespace atlas
{
class AtlasState;
Expand Down Expand Up @@ -52,19 +54,19 @@ namespace atlas

uint32_t getOpcodeSize() const { return opcode_size_; }

AtlasRegisterPtr & getRs1()
sparta::Register* getRs1()
{
sparta_assert(rs1_, "Operand RS1 is a nullptr! " << *this);
return rs1_;
}

AtlasRegisterPtr & getRs2()
sparta::Register* getRs2()
{
sparta_assert(rs2_, "Operand RS2 is a nullptr! " << *this);
return rs2_;
}

AtlasRegisterPtr & getRd()
sparta::Register* getRd()
{
sparta_assert(rd_, "Operand RD is a nullptr! " << *this);
return rd_;
Expand Down Expand Up @@ -94,9 +96,9 @@ namespace atlas
Addr next_pc_;

// Registers
AtlasRegisterPtr rs1_;
AtlasRegisterPtr rs2_;
AtlasRegisterPtr rd_;
sparta::Register *rs1_;
sparta::Register *rs2_;
sparta::Register *rd_;

ActionGroup inst_action_group_;

Expand Down
83 changes: 0 additions & 83 deletions core/AtlasRegister.hpp

This file was deleted.

24 changes: 0 additions & 24 deletions core/AtlasState.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,30 +31,6 @@ namespace atlas
csr_rset_ =
RegisterSet::create(core_node, json_dir + std::string("/reg_csr.json"), "csr_regs");

// Initialize integer registers
for (uint32_t reg_num = 0; reg_num < p->num_int_regs; ++reg_num)
{
const std::string reg_name = "x" + std::to_string(reg_num);
const RegId reg_id{RegType::INTEGER, reg_num, reg_name};
int_regs_.emplace_back(new AtlasRegister(reg_id));
}

// Initialize floating point registers
for (uint32_t reg_num = 0; reg_num < p->num_fp_regs; ++reg_num)
{
const std::string reg_name = "f" + std::to_string(reg_num);
const RegId reg_id{RegType::FLOATING_POINT, reg_num, reg_name};
fp_regs_.emplace_back(new AtlasRegister(reg_id));
}

// Initialize vector registers
for (uint32_t reg_num = 0; reg_num < p->num_vec_regs; ++reg_num)
{
const std::string reg_name = "v" + std::to_string(reg_num);
const RegId reg_id{RegType::VECTOR, reg_num, reg_name};
vec_regs_.emplace_back(new AtlasRegister(reg_id));
}

// Increment PC Action
increment_pc_action_ =
atlas::Action::createAction<&AtlasState::incrementPc_>(this, "increment pc");
Expand Down
30 changes: 0 additions & 30 deletions core/AtlasState.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#pragma once

#include "core/ActionGroup.hpp"
#include "core/AtlasRegister.hpp"
#include "core/AtlasTranslationState.hpp"
#include "core/observers/InstructionLogger.hpp"
#include "arch/RegisterSet.hpp"
Expand Down Expand Up @@ -43,9 +42,6 @@ namespace atlas
AtlasStateParameters(sparta::TreeNode* node) : sparta::ParameterSet(node) {}

PARAMETER(uint32_t, hart_id, 0, "Hart ID")
PARAMETER(uint32_t, num_int_regs, 32, "Number of integer registers")
PARAMETER(uint32_t, num_fp_regs, 32, "Number of floating point registers")
PARAMETER(uint32_t, num_vec_regs, 32, "Number of vector registers")
PARAMETER(bool, stop_sim_on_wfi, false, "Executing a WFI instruction stops simulation")
};

Expand Down Expand Up @@ -109,27 +105,6 @@ namespace atlas

Translate* getTranslateUnit() const { return translate_unit_; }

// TODO AtlasRegister: Replace with methods to access registers in the Sparta RegisterSet
// Probably best to have multiple methods for each register type (e.g.
// getIntegerRegister(uint32_t reg_num))
AtlasRegisterPtr getAtlasRegister(RegType reg_type, uint32_t reg_num) const
{
switch (reg_type)
{
case RegType::INTEGER:
return int_regs_.at(reg_num);
case RegType::FLOATING_POINT:
return fp_regs_.at(reg_num);
case RegType::VECTOR:
return vec_regs_.at(reg_num);
case RegType::CSR:
case RegType::INVALID:
sparta_assert(false, "Invalid Atlas Register Type!");
}

return nullptr;
}

atlas::RegisterSet* getIntRegisterSet() { return int_rset_.get(); }

atlas::RegisterSet* getFpRegisterSet() { return fp_rset_.get(); }
Expand Down Expand Up @@ -215,11 +190,6 @@ namespace atlas
// Translate Unit
Translate* translate_unit_ = nullptr;

// TODO AtlasRegister: Replace with Sparta RegisterSet (int, fp, vec and csrs)
std::vector<AtlasRegisterPtr> int_regs_;
std::vector<AtlasRegisterPtr> fp_regs_;
std::vector<AtlasRegisterPtr> vec_regs_;

// Register set holding all Sparta registers from all generated JSON files
std::unique_ptr<RegisterSet> int_rset_;
std::unique_ptr<RegisterSet> fp_rset_;
Expand Down
36 changes: 36 additions & 0 deletions core/inst_handlers/inst_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,39 @@ inline int64_t mulh(int64_t a, int64_t b)
uint64_t res = mulhu(a < 0 ? -a : a, b < 0 ? -b : b);
return negate ? ~res + ((uint64_t)a * (uint64_t)b == 0) : res;
}

namespace atlas
{
template <uint64_t Mask> struct RegisterBitMask
{
static uint64_t mask(const uint64_t old_val, const uint64_t new_val)
{
// The 'Mask' template parameter is a bit mask that specifies which bits are writable.
// We need to preserve the 'old_val' bits that are not writable, and replace the
// writable bits with 'new_val'.
return (old_val & ~Mask) | (new_val & Mask);
}
};

template <> struct RegisterBitMask<0>;

template <> struct RegisterBitMask<0xffffffffffffffff>
{
static uint64_t mask(const uint64_t old_val, const uint64_t new_val)
{
(void)old_val;
return new_val;
}
};

template <typename FieldT, typename Enable = void> struct CSRFields
{
static uint64_t readField(const uint64_t reg_val)
{
constexpr uint64_t num_field_bits = FieldT::high_bit - FieldT::low_bit + 1;
constexpr uint64_t mask = (1 << num_field_bits) - 1;
return (reg_val >> FieldT::low_bit) & mask;
}
};

} // namespace atlas
Loading

0 comments on commit c6544d4

Please sign in to comment.