diff --git a/src/AttestationRegistry.sol b/src/AttestationRegistry.sol index f0a922f2..57176b7e 100644 --- a/src/AttestationRegistry.sol +++ b/src/AttestationRegistry.sol @@ -2,6 +2,7 @@ pragma solidity 0.8.21; import { OwnableUpgradeable } from "openzeppelin-contracts-upgradeable/contracts/access/OwnableUpgradeable.sol"; +import { ERC1155Upgradeable } from "openzeppelin-contracts-upgradeable/contracts/token/ERC1155/ERC1155Upgradeable.sol"; import { Attestation, AttestationPayload } from "./types/Structs.sol"; import { PortalRegistry } from "./PortalRegistry.sol"; import { SchemaRegistry } from "./SchemaRegistry.sol"; @@ -12,7 +13,7 @@ import { IRouter } from "./interface/IRouter.sol"; * @author Consensys * @notice This contract stores a registry of all attestations */ -contract AttestationRegistry is OwnableUpgradeable { +contract AttestationRegistry is OwnableUpgradeable, ERC1155Upgradeable { IRouter public router; uint16 private version; @@ -262,4 +263,37 @@ contract AttestationRegistry is OwnableUpgradeable { function getAttestationIdCounter() public view returns (uint32) { return attestationIdCounter; } + + /** + * @notice Checks if an address owns a given attestation + * @param account The address of the token holder + * @param id ID of the attestation + * @return The _owner's balance of the attestations on a given attestation ID + */ + function balanceOf(address account, uint256 id) public view override returns (uint256) { + bytes32 attestationId = bytes32(abi.encode(id)); + Attestation memory attestation = attestations[attestationId]; + if (keccak256(attestation.subject) == keccak256(abi.encode(account))) return 1; + return 0; + } + + /** + * @notice Get the balance of multiple account/attestation pairs + * @param accounts The addresses of the attestation holders + * @param ids ID of the attestations + * @return The _owner's balance of the attestation for a given address (i.e. balance for each (owner, id) pair) + */ + function balanceOfBatch( + address[] memory accounts, + uint256[] memory ids + ) public view override returns (uint256[] memory) { + if (accounts.length != ids.length) revert ArrayLengthMismatch(); + uint256[] memory result = new uint256[](accounts.length); + for (uint256 i = 0; i < accounts.length; i++) { + bytes32 attestationId = bytes32(abi.encode(ids[i])); + Attestation memory attestation = attestations[attestationId]; + if (keccak256(attestation.subject) == keccak256(abi.encode(accounts[i]))) result[i] = 1; + } + return result; + } } diff --git a/test/AttestationRegistry.t.sol b/test/AttestationRegistry.t.sol index f6d3f1b2..88ac3085 100644 --- a/test/AttestationRegistry.t.sol +++ b/test/AttestationRegistry.t.sol @@ -484,25 +484,76 @@ contract AttestationRegistryTest is Test { SchemaRegistryMock schemaRegistryMock = SchemaRegistryMock(router.getSchemaRegistry()); attestationPayload.schemaId = schemaRegistryMock.getIdFromSchemaString("schemaString"); schemaRegistryMock.createSchema("name", "description", "context", "schemaString"); - uint32 version = attestationRegistry.getAttestationIdCounter(); + uint32 attestationIdCounter = attestationRegistry.getAttestationIdCounter(); - assertEq(version, 0); + assertEq(attestationIdCounter, 0); vm.startPrank(portal); attestationRegistry.attest(attestationPayload, attester); - version = attestationRegistry.getAttestationIdCounter(); - assertEq(version, 1); + attestationIdCounter = attestationRegistry.getAttestationIdCounter(); + assertEq(attestationIdCounter, 1); attestationRegistry.attest(attestationPayload, attester); - version = attestationRegistry.getAttestationIdCounter(); - assertEq(version, 2); + attestationIdCounter = attestationRegistry.getAttestationIdCounter(); + assertEq(attestationIdCounter, 2); vm.stopPrank(); } + function test_balanceOf(AttestationPayload memory attestationPayload) public { + vm.assume(attestationPayload.subject.length != 0); + vm.assume(attestationPayload.attestationData.length != 0); + SchemaRegistryMock schemaRegistryMock = SchemaRegistryMock(router.getSchemaRegistry()); + attestationPayload.schemaId = schemaRegistryMock.getIdFromSchemaString("schemaString"); + schemaRegistryMock.createSchema("name", "description", "context", "schemaString"); + + vm.startPrank(portal); + attestationPayload.subject = abi.encode(address(1)); + attestationRegistry.attest(attestationPayload, attester); + + uint256 balance = attestationRegistry.balanceOf(address(1), 1); + assertEq(balance, 1); + } + + function test_balanceOfBatch(AttestationPayload memory attestationPayload) public { + vm.assume(attestationPayload.subject.length != 0); + vm.assume(attestationPayload.attestationData.length != 0); + SchemaRegistryMock schemaRegistryMock = SchemaRegistryMock(router.getSchemaRegistry()); + attestationPayload.schemaId = schemaRegistryMock.getIdFromSchemaString("schemaString"); + schemaRegistryMock.createSchema("name", "description", "context", "schemaString"); + + address[] memory owners = new address[](2); + owners[0] = address(1); + owners[1] = address(2); + + vm.startPrank(portal); + attestationPayload.subject = abi.encode(owners[0]); + attestationRegistry.attest(attestationPayload, attester); + + attestationPayload.subject = abi.encode(owners[1]); + attestationRegistry.attest(attestationPayload, attester); + + uint256[] memory ids = new uint256[](2); + ids[0] = 1; + ids[1] = 2; + uint256[] memory balance = attestationRegistry.balanceOfBatch(owners, ids); + assertEq(balance[0], 1); + assertEq(balance[1], 1); + } + + function test_balanceOfBatch_ArrayLengthMismatch() public { + address[] memory owners = new address[](2); + owners[0] = address(1); + owners[1] = address(2); + uint256[] memory ids = new uint256[](1); + ids[0] = 1; + vm.expectRevert(AttestationRegistry.ArrayLengthMismatch.selector); + attestationRegistry.balanceOfBatch(owners, ids); + } + function _createAttestation( AttestationPayload memory attestationPayload, uint256 id