Skip to content

Commit

Permalink
Use Bytes Trie to track warm addresses (hyperledger#6069)
Browse files Browse the repository at this point in the history
* Use Bytes Trie to track warm addresses

Move from a java HashSet to a custom Trie based on bytes to store the
warm addresses, creates, and self-destructs.

This avoids needing to calculate java hashes or engage in using custom
Comparators.

Signed-off-by: Danno Ferrin <[email protected]>

* codeql scan

Signed-off-by: Danno Ferrin <[email protected]>

---------

Signed-off-by: Danno Ferrin <[email protected]>
Signed-off-by: Sally MacFarlane <[email protected]>
Co-authored-by: Sally MacFarlane <[email protected]>
  • Loading branch information
2 people authored and daniellehrner committed Oct 25, 2023
1 parent d6c7040 commit 28904e1
Show file tree
Hide file tree
Showing 6 changed files with 499 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static org.hyperledger.besu.ethereum.mainnet.PrivateStateUtils.KEY_TRANSACTION;
import static org.hyperledger.besu.ethereum.mainnet.PrivateStateUtils.KEY_TRANSACTION_HASH;

import org.hyperledger.besu.collections.trie.BytesTrieSet;
import org.hyperledger.besu.datatypes.AccessListEntry;
import org.hyperledger.besu.datatypes.Address;
import org.hyperledger.besu.datatypes.Wei;
Expand All @@ -43,7 +44,6 @@
import org.hyperledger.besu.evm.worldstate.WorldUpdater;

import java.util.Deque;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -318,7 +318,7 @@ public TransactionProcessingResult processTransaction(
final List<AccessListEntry> accessListEntries = transaction.getAccessList().orElse(List.of());
// we need to keep a separate hash set of addresses in case they specify no storage.
// No-storage is a common pattern, especially for Externally Owned Accounts
final Set<Address> addressList = new HashSet<>();
final Set<Address> addressList = new BytesTrieSet<>(Address.SIZE);
final Multimap<Address, Bytes32> storageList = HashMultimap.create();
int accessListStorageCount = 0;
for (final var entry : accessListEntries) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
/*
* Copyright contributors to Hyperledger Besu
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*
* SPDX-License-Identifier: Apache-2.0
*
*/
package org.hyperledger.besu.collections.trie;

import java.util.AbstractSet;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Deque;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.Objects;

import org.apache.tuweni.bytes.Bytes;

/**
* A Bytes optimized set that stores values in a trie by byte
*
* @param <E> Type of trie
*/
public class BytesTrieSet<E extends Bytes> extends AbstractSet<E> {

record Node<E extends Bytes>(byte[] leafArray, E leafObject, Node<E>[] children) {

@Override
public boolean equals(final Object o) {
if (this == o) return true;
if (!(o instanceof Node<?> node)) return false;
return Arrays.equals(leafArray, node.leafArray)
&& Objects.equals(leafObject, node.leafObject)
&& Arrays.equals(children, node.children);
}

@Override
public int hashCode() {
int result = Objects.hash(leafObject);
result = 31 * result + Arrays.hashCode(leafArray);
result = 31 * result + Arrays.hashCode(children);
return result;
}

@Override
public String toString() {
final StringBuilder sb = new StringBuilder("Node{");
sb.append("leaf=");
if (leafObject == null) sb.append("null");
else {
sb.append('[');
System.out.println(leafObject.toHexString());
sb.append(']');
}
sb.append(", children=");
if (children == null) sb.append("null");
else {
sb.append('[');
for (int i = 0; i < children.length; ++i) {
if (children[i] == null) {
continue;
}
sb.append(i == 0 ? "" : ", ").append(i).append("=").append(children[i]);
}
sb.append(']');
}
sb.append('}');
return sb.toString();
}
}

Node<E> root;

int size = 0;
final int byteLength;

/**
* Create a BytesTrieSet with a fixed length
*
* @param byteLength length in bytes of the stored types
*/
public BytesTrieSet(final int byteLength) {
this.byteLength = byteLength;
}

static class NodeWalker<E extends Bytes> {
final Node<E> node;
int lastRead;

NodeWalker(final Node<E> node) {
this.node = node;
this.lastRead = -1;
}

NodeWalker<E> nextNodeWalker() {
if (node.children == null) {
return null;
}
while (lastRead < 255) {
lastRead++;
Node<E> child = node.children[lastRead];
if (child != null) {
return new NodeWalker<>(child);
}
}
return null;
}

E thisNode() {
return node.leafObject;
}
}

@Override
public Iterator<E> iterator() {
var result =
new Iterator<E>() {
final Deque<NodeWalker<E>> stack = new ArrayDeque<>();
E next;
E last;

@Override
public boolean hasNext() {
return next != null;
}

@Override
public E next() {
if (next == null) {
throw new NoSuchElementException();
}
last = next;
advance();
return last;
}

@Override
public void remove() {
BytesTrieSet.this.remove(last);
}

void advance() {
while (!stack.isEmpty()) {
NodeWalker<E> thisStep = stack.peek();
var nextStep = thisStep.nextNodeWalker();
if (nextStep == null) {
stack.pop();
if (thisStep.thisNode() != null) {
next = thisStep.thisNode();
return;
}
} else {
stack.push(nextStep);
}
}
next = null;
}
};
if (root != null) {
result.stack.add(new NodeWalker<>(root));
}
result.advance();
return result;
}

@Override
public int size() {
return size;
}

@Override
public boolean contains(final Object o) {
if (!(o instanceof Bytes bytes)) {
throw new IllegalArgumentException(
"Expected Bytes, got " + (o == null ? "null" : o.getClass().getName()));
}
byte[] array = bytes.toArrayUnsafe();
if (array.length != byteLength) {
throw new IllegalArgumentException(
"Byte array is size " + array.length + " but set is size " + byteLength);
}
if (root == null) {
return false;
}
int level = 0;
Node<E> current = root;
while (current != null) {
if (current.leafObject != null) {
return Arrays.compare(current.leafArray, array) == 0;
}
current = current.children[array[level] & 0xff];
level++;
}
return false;
}

@Override
public boolean remove(final Object o) {
// Two base cases, size==0 and size==1;
if (!(o instanceof Bytes bytes)) {
throw new IllegalArgumentException(
"Expected Bytes, got " + (o == null ? "null" : o.getClass().getName()));
}
byte[] array = bytes.toArrayUnsafe();
if (array.length != byteLength) {
throw new IllegalArgumentException(
"Byte array is size " + array.length + " but set is size " + byteLength);
}
// Two base cases, size==0 and size==1;
if (root == null) {
// size==0 is easy, empty
return false;
}
if (root.leafObject != null) {
// size==1 just check and possibly remove the root
if (Arrays.compare(array, root.leafArray) == 0) {
root = null;
size--;
return true;
} else {
return false;
}
}
int level = 0;
Node<E> current = root;
do {
int index = array[level] & 0xff;
Node<E> next = current.children[index];
if (next == null) {
return false;
}
if (next.leafObject != null) {
if (Arrays.compare(array, next.leafArray) == 0) {
// TODO there is no cleanup of empty branches
current.children[index] = null;
size--;
return true;
} else {
return false;
}
}
current = next;

level++;
} while (true);
}

@SuppressWarnings({"unchecked", "rawtypes"})
@Override
public boolean add(final E bytes) {
byte[] array = bytes.toArrayUnsafe();
if (array.length != byteLength) {
throw new IllegalArgumentException(
"Byte array is size " + array.length + " but set is size " + byteLength);
}
// Two base cases, size==0 and size==1;
if (root == null) {
// size==0 is easy, just add
root = new Node<>(array, bytes, null);
size++;
return true;
}
if (root.leafObject != null) {
// size==1 first check then if no match make it look like n>1
if (Arrays.compare(array, root.leafArray) == 0) {
return false;
}
Node<E> oldRoot = root;
root = new Node<>(null, null, new Node[256]);
root.children[oldRoot.leafArray[0] & 0xff] = oldRoot;
}
int level = 0;
Node<E> current = root;
do {
int index = array[level] & 0xff;
Node<E> next = current.children[index];
if (next == null) {
next = new Node<>(array, bytes, null);
current.children[index] = next;
size++;
return true;
}
if (next.leafObject != null) {
if (Arrays.compare(array, next.leafArray) == 0) {
return false;
}
Node<E> newLeaf = new Node<>(null, null, new Node[256]);
newLeaf.children[next.leafArray[level + 1] & 0xff] = next;
current.children[index] = newLeaf;
next = newLeaf;
}
level++;

current = next;

} while (true);
}

@Override
public void clear() {
root = null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import static com.google.common.base.Preconditions.checkNotNull;

import org.hyperledger.besu.collections.trie.BytesTrieSet;
import org.hyperledger.besu.datatypes.Address;
import org.hyperledger.besu.datatypes.Hash;
import org.hyperledger.besu.datatypes.VersionedHash;
Expand All @@ -41,7 +42,6 @@
import java.math.BigInteger;
import java.util.Collection;
import java.util.Deque;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
Expand Down Expand Up @@ -80,7 +80,7 @@ public class EVMExecutor {
List.of(MaxCodeSizeRule.of(0x6000), PrefixCodeRule.of());
private long initialNonce = 1;
private Collection<Address> forceCommitAddresses = List.of(Address.fromHexString("0x03"));
private Set<Address> accessListWarmAddresses = new HashSet<>();
private Set<Address> accessListWarmAddresses = new BytesTrieSet<>(Address.SIZE);
private Multimap<Address, Bytes32> accessListWarmStorage = HashMultimap.create();
private MessageCallProcessor messageCallProcessor = null;
private ContractCreationProcessor contractCreationProcessor = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import static com.google.common.base.Preconditions.checkState;
import static java.util.Collections.emptySet;

import org.hyperledger.besu.collections.trie.BytesTrieSet;
import org.hyperledger.besu.collections.undo.UndoSet;
import org.hyperledger.besu.collections.undo.UndoTable;
import org.hyperledger.besu.datatypes.Address;
Expand All @@ -38,7 +39,6 @@
import java.util.ArrayList;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -1707,7 +1707,7 @@ public MessageFrame build() {
new TxValues(
blockHashLookup,
maxStackSize,
UndoSet.of(new HashSet<>()),
UndoSet.of(new BytesTrieSet<>(Address.SIZE)),
UndoTable.of(HashBasedTable.create()),
originator,
gasPrice,
Expand All @@ -1717,8 +1717,8 @@ public MessageFrame build() {
miningBeneficiary,
versionedHashes,
UndoTable.of(HashBasedTable.create()),
UndoSet.of(new HashSet<>()),
UndoSet.of(new HashSet<>()));
UndoSet.of(new BytesTrieSet<>(Address.SIZE)),
UndoSet.of(new BytesTrieSet<>(Address.SIZE)));
updater = worldUpdater;
newStatic = isStatic;
} else {
Expand Down
Loading

0 comments on commit 28904e1

Please sign in to comment.