Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pe 7319 fyeo rd 4 reintegration #280

Merged
merged 9 commits into from
Dec 17, 2024
14 changes: 14 additions & 0 deletions spec/utils_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@ describe("utils", function()
end)
end)

describe("isValidUnformattedEthAddress", function()
it("should return true on a valid unformatted ETH address", function()
assert.is_true(utils.isValidUnformattedEthAddress(testEthAddress))
end)

it("should return false on a non-string value", function()
assert.is_false(utils.isValidUnformattedEthAddress(3))
end)

it("should return false on an invalid unformatted ETH address", function()
assert.is_false(utils.isValidUnformattedEthAddress("ZxFCAd0B19bB29D4674531d6f115237E16AfCE377C"))
end)
end)

describe("formatAddress", function()
it("should format ETH address to lowercase", function()
assert.is.equal(testEthAddress, utils.formatAddress(testEthAddress))
Expand Down
100 changes: 41 additions & 59 deletions src/main.lua
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ NameRegistry = NameRegistry or {}
Epochs = Epochs or {}
LastTickedEpochIndex = LastTickedEpochIndex or -1
LastGracePeriodEntryEndTimestamp = LastGracePeriodEntryEndTimestamp or 0
LastKnownMessageTimestamp = LastKnownMessageTimestamp or 0
LastKnownMessageId = LastKnownMessageId or ""

local utils = require("utils")
local json = require("json")
Expand Down Expand Up @@ -383,8 +385,34 @@ local function addPrimaryNameRequestData(ioEvent, primaryNameResult)
addPrimaryNameCounts(ioEvent)
end

local function addEventingHandler(handlerName, pattern, handleFn, critical)
-- Sanitize inputs before every interaction
local function assertAndSanitizeInputs(msg)
assert(
msg.Timestamp and msg.Timestamp >= LastKnownMessageTimestamp,
"Timestamp must be greater than or equal to the last known message timestamp of "
.. LastKnownMessageTimestamp
.. " but was "
.. msg.Timestamp
)
assert(msg.From, "From is required")
assert(msg.Id, "Id is required")
assert(msg.Tags and type(msg.Tags) == "table", "Tags are required")

msg.Tags = utils.validateAndSanitizeInputs(msg.Tags)
msg.From = utils.formatAddress(msg.From)
msg.Timestamp = msg.Timestamp and tonumber(msg.Timestamp) or tonumber(msg.Tags.Timestamp) or nil
end

local function updateLastKnownMessage(msg)
if msg.Timestamp >= LastKnownMessageTimestamp then
LastKnownMessageTimestamp = msg.Timestamp
LastKnownMessageId = msg.Id
end
end

local function addEventingHandler(handlerName, pattern, handleFn, critical, printEvent)
critical = critical or false
printEvent = printEvent == nil and true or printEvent
Handlers.add(handlerName, pattern, function(msg)
-- add an IOEvent to the message if it doesn't exist
msg.ioEvent = msg.ioEvent or ARIOEvent(msg)
Expand All @@ -408,21 +436,24 @@ local function addEventingHandler(handlerName, pattern, handleFn, critical)
local errorWithEvent = tostring(resultOrError) .. "\n" .. errorEvent:toJSON()
error(errorWithEvent, 0) -- 0 ensures not to include this line number in the error message
end
-- isolate out prune handler here when printing
if handlerName ~= "prune" then
if printEvent then
msg.ioEvent:printEvent()
end
end)
end

-- prune state before every interaction
addEventingHandler("sanitize", function()
return "continue"
end, function(msg)
assertAndSanitizeInputs(msg)
updateLastKnownMessage(msg)
end, CRITICAL, false)

-- NOTE: THIS IS A CRITICAL HANDLER AND WILL DISCARD THE MEMORY ON ERROR
addEventingHandler("prune", function()
return "continue" -- continue is a pattern that matches every message and continues to the next handler that matches the tags
end, function(msg)
local msgTimestamp = tonumber(msg.Timestamp or msg.Tags.Timestamp)
assert(msgTimestamp, "Timestamp is required for a tick interaction")
local epochIndex = epochs.getEpochIndexForTimestamp(msgTimestamp)
local epochIndex = epochs.getEpochIndexForTimestamp(msg.Timestamp)
msg.ioEvent:addField("epochIndex", epochIndex)

local previousStateSupplies = {
Expand All @@ -435,54 +466,6 @@ end, function(msg)
lastKnownTotalSupply = token.lastKnownTotalTokenSupply(),
}

msg.From = utils.formatAddress(msg.From)
msg.Timestamp = msg.Timestamp and tonumber(msg.Timestamp) or nil

local knownAddressTags = {
"Recipient",
"Initiator",
"Target",
"Source",
"Address",
"Vault-Id",
"Process-Id",
"Observer-Address",
}

for _, tagName in ipairs(knownAddressTags) do
-- Format all incoming addresses
msg.Tags[tagName] = msg.Tags[tagName] and utils.formatAddress(msg.Tags[tagName]) or nil
end

local knownNumberTags = {
"Quantity",
"Lock-Length",
"Operator-Stake",
"Delegated-Stake",
"Withdraw-Stake",
"Timestamp",
"Years",
"Min-Delegated-Stake",
"Port",
"Extend-Length",
"Delegate-Reward-Share-Ratio",
"Epoch-Index",
"Price-Interval-Ms",
"Block-Height",
}
for _, tagName in ipairs(knownNumberTags) do
-- Format all incoming numbers
msg.Tags[tagName] = msg.Tags[tagName] and tonumber(msg.Tags[tagName]) or nil
end

local knownBooleanTags = {
"Allow-Unsafe-Addresses",
"Force-Prune",
}
for _, tagName in ipairs(knownBooleanTags) do
msg.Tags[tagName] = utils.booleanOrBooleanStringToBoolean(msg.Tags[tagName])
end

if msg.Tags["Force-Prune"] then
gar.scheduleNextGatewaysPruning(0)
gar.scheduleNextRedelegationsPruning(0)
Expand All @@ -492,9 +475,8 @@ end, function(msg)
vaults.scheduleNextVaultsPruning(0)
end

local msgId = msg.Id
print("Pruning state at timestamp: " .. msgTimestamp)
local prunedStateResult = prune.pruneState(msgTimestamp, msgId, LastGracePeriodEntryEndTimestamp)
print("Pruning state at timestamp: " .. msg.Timestamp)
local prunedStateResult = prune.pruneState(msg.Timestamp, msg.Id, LastGracePeriodEntryEndTimestamp)

if prunedStateResult then
local prunedRecordsCount = utils.lengthOfTable(prunedStateResult.prunedRecords or {})
Expand Down Expand Up @@ -573,7 +555,7 @@ end, function(msg)
end

return prunedStateResult
end, CRITICAL)
end, CRITICAL, false)

-- Write handlers
addEventingHandler(ActionMap.Transfer, utils.hasMatchingTag("Action", ActionMap.Transfer), function(msg)
Expand Down
89 changes: 86 additions & 3 deletions src/utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,18 @@ function utils.isValidArweaveAddress(address)
return type(address) == "string" and #address == 43 and string.match(address, "^[%w-_]+$") ~= nil
end

--- Checks if an address is a valid Ethereum address
--- Checks if an address looks like an unformatted Ethereum address
--- @param address string The address to check
--- @return boolean isValidUnformattedEthAddress - whether the address is a valid unformatted Ethereum address
function utils.isValidUnformattedEthAddress(address)
return type(address) == "string" and #address == 42 and string.match(address, "^0x[%x]+$") ~= nil
end

--- Checks if an address is a valid Ethereum address and is in EIP-55 checksum format
--- @param address string The address to check
--- @return boolean isValidEthAddress - whether the address is a valid Ethereum address
function utils.isValidEthAddress(address)
return type(address) == "string" and #address == 42 and string.match(address, "^0x[%x]+$") ~= nil
return utils.isValidUnformattedEthAddress(address) and address == utils.formatEIP55Address(address)
end

function utils.isValidUnsafeAddress(address)
Expand Down Expand Up @@ -318,7 +325,7 @@ end
--- @param address string The address to format
--- @return string formattedAddress - the EIP-55 checksum formatted address
function utils.formatAddress(address)
if utils.isValidEthAddress(address) then
if utils.isValidUnformattedEthAddress(address) then
return utils.formatEIP55Address(address)
end
return address
Expand Down Expand Up @@ -611,6 +618,82 @@ function utils.filterDictionary(tbl, predicate)
return filtered
end

--- Sanitizes inputs to ensure they are valid strings
--- @param table table The table to sanitize
--- @return table sanitizedTable - the sanitized table
function utils.validateAndSanitizeInputs(table)
assert(type(table) == "table", "Table must be a table")
local sanitizedTable = {}
for key, value in pairs(table) do
assert(type(key) == "string", "Key must be a string")
assert(
type(value) == "string" or type(value) == "number" or type(value) == "boolean",
"Value must be a string, integer, or boolean"
)
if type(value) == "string" then
assert(#key > 0, "Key cannot be empty")
assert(#value > 0, "Value cannot be empty")
assert(not string.match(key, "^%s+$"), "Key cannot be only whitespace")
assert(not string.match(value, "^%s+$"), "Value cannot be only whitespace")
end
if type(value) == "boolean" then
assert(value == true or value == false, "Boolean value must be true or false")
end
if type(value) == "number" then
assert(utils.isInteger(value), "Number must be an integer")
end
sanitizedTable[key] = value
end

local knownAddressTags = {
"Recipient",
"Initiator",
"Target",
"Source",
"Address",
"Vault-Id",
"Process-Id",
"Observer-Address",
}

for _, tagName in ipairs(knownAddressTags) do
-- Format all incoming addresses
sanitizedTable[tagName] = sanitizedTable[tagName] and utils.formatAddress(sanitizedTable[tagName]) or nil
end

local knownNumberTags = {
"Quantity",
"Lock-Length",
"Operator-Stake",
"Delegated-Stake",
"Withdraw-Stake",
"Timestamp",
"Years",
"Min-Delegated-Stake",
"Port",
"Extend-Length",
"Delegate-Reward-Share-Ratio",
"Epoch-Index",
"Price-Interval-Ms",
"Block-Height",
}
for _, tagName in ipairs(knownNumberTags) do
-- Format all incoming numbers
sanitizedTable[tagName] = sanitizedTable[tagName] and tonumber(sanitizedTable[tagName]) or nil
end

local knownBooleanTags = {
"Allow-Unsafe-Addresses",
"Force-Prune",
}
for _, tagName in ipairs(knownBooleanTags) do
sanitizedTable[tagName] = sanitizedTable[tagName]
and utils.booleanOrBooleanStringToBoolean(sanitizedTable[tagName])
or nil
end
return sanitizedTable
end

--- @param value string|boolean
--- @return boolean
function utils.booleanOrBooleanStringToBoolean(value)
Expand Down
Loading
Loading