Skip to content

Commit

Permalink
ISA: handle restricted addresses (#1526)
Browse files Browse the repository at this point in the history
* ISA: handle restricted addresses

* Avoid clone

* Nit

* Add 2 more restricted tests

---------

Co-authored-by: /alex/ <[email protected]>
  • Loading branch information
thibault-martinez and Alex6323 authored Oct 30, 2023
1 parent 86a839a commit 50a5139
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 2 deletions.
7 changes: 6 additions & 1 deletion sdk/src/client/api/block_builder/input_selection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ impl InputSelection {
Address::Account(account_address) => Ok(Some(Requirement::Account(*account_address.account_id()))),
Address::Nft(nft_address) => Ok(Some(Requirement::Nft(*nft_address.nft_id()))),
Address::Anchor(_) => Err(Error::UnsupportedAddressType(AnchorAddress::KIND)),
Address::Restricted(_) => Ok(None),
_ => todo!("What do we do here?"),
}
}
Expand Down Expand Up @@ -234,7 +235,11 @@ impl InputSelection {
.unwrap()
.0;

self.addresses.contains(&required_address)
if let Address::Restricted(restricted_address) = required_address {
self.addresses.contains(restricted_address.address())
} else {
self.addresses.contains(&required_address)
}
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ impl InputSelection {
Err(e) => Err(e),
}
}
Address::Restricted(restricted_address) => {
log::debug!("Forwarding {address:?} sender requirement to inner address");

self.fulfill_sender_requirement(restricted_address.address())
}
_ => Err(Error::UnsupportedAddressType(address.kind())),
}
}
Expand Down
180 changes: 179 additions & 1 deletion sdk/tests/client/input_selection/basic_outputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::str::FromStr;
use iota_sdk::{
client::api::input_selection::{Error, InputSelection, Requirement},
types::block::{
address::{AccountAddress, Address, Bech32Address, NftAddress},
address::{AccountAddress, Address, Bech32Address, NftAddress, RestrictedAddress, ToBech32Ext},
output::{AccountId, NftId},
protocol::protocol_parameters,
},
Expand Down Expand Up @@ -1389,3 +1389,181 @@ fn too_many_outputs_with_remainder() {
iota_sdk::client::api::input_selection::Error::InvalidOutputCount(129)
)
}

#[test]
fn restricted_ed25519() {
let protocol_parameters = protocol_parameters();
let address = Address::try_from_bech32(BECH32_ADDRESS_ED25519_1).unwrap();
let restricted = RestrictedAddress::new(address.clone()).unwrap();
let restricted_bech32 = restricted.to_bech32_unchecked("rms").to_string();

let inputs = build_inputs([
Basic(1_000_000, BECH32_ADDRESS_ED25519_0, None, None, None, None, None, None),
Basic(1_000_000, BECH32_ADDRESS_ED25519_0, None, None, None, None, None, None),
Basic(1_000_000, &restricted_bech32, None, None, None, None, None, None),
Basic(1_000_000, BECH32_ADDRESS_ED25519_0, None, None, None, None, None, None),
Basic(1_000_000, BECH32_ADDRESS_ED25519_0, None, None, None, None, None, None),
]);
let outputs = build_outputs([Basic(
1_000_000,
BECH32_ADDRESS_ED25519_0,
None,
None,
None,
None,
None,
None,
)]);

let selected = InputSelection::new(
inputs.clone(),
outputs.clone(),
addresses([BECH32_ADDRESS_ED25519_1]),
protocol_parameters,
)
.select()
.unwrap();

assert_eq!(selected.inputs.len(), 1);
assert_eq!(selected.inputs, [inputs[2].clone()]);
assert!(unsorted_eq(&selected.outputs, &outputs));
}

#[test]
fn restricted_nft() {
let protocol_parameters = protocol_parameters();
let nft_id_1 = NftId::from_str(NFT_ID_1).unwrap();
let nft_address = Address::from(nft_id_1);
let restricted = RestrictedAddress::new(nft_address.clone()).unwrap();
let restricted_bech32 = restricted.to_bech32_unchecked("rms").to_string();

let inputs = build_inputs([
Basic(2_000_000, &restricted_bech32, None, None, None, None, None, None),
Nft(
2_000_000,
nft_id_1,
BECH32_ADDRESS_ED25519_0,
None,
None,
None,
None,
None,
None,
),
]);
let outputs = build_outputs([Basic(
3_000_000,
BECH32_ADDRESS_ED25519_0,
None,
None,
None,
None,
None,
None,
)]);

let selected = InputSelection::new(
inputs.clone(),
outputs.clone(),
addresses([BECH32_ADDRESS_ED25519_0]),
protocol_parameters,
)
.select()
.unwrap();

assert!(unsorted_eq(&selected.inputs, &inputs));
assert_eq!(selected.outputs.len(), 2);
assert!(selected.outputs.contains(&outputs[0]));
}

#[test]
fn restricted_account() {
let protocol_parameters = protocol_parameters();
let account_id_1 = AccountId::from_str(ACCOUNT_ID_1).unwrap();
let account_address = Address::from(account_id_1);
let restricted = RestrictedAddress::new(account_address.clone()).unwrap();
let restricted_bech32 = restricted.to_bech32_unchecked("rms").to_string();

let inputs = build_inputs([
Basic(2_000_000, &restricted_bech32, None, None, None, None, None, None),
Account(
2_000_000,
account_id_1,
BECH32_ADDRESS_ED25519_0,
None,
None,
None,
None,
),
]);

let outputs = build_outputs([Basic(
3_000_000,
BECH32_ADDRESS_ED25519_0,
None,
None,
None,
None,
None,
None,
)]);

let selected = InputSelection::new(
inputs.clone(),
outputs.clone(),
addresses([BECH32_ADDRESS_ED25519_0]),
protocol_parameters,
)
.select()
.unwrap();

assert!(unsorted_eq(&selected.inputs, &inputs));
assert_eq!(selected.outputs.len(), 2);
assert!(selected.outputs.contains(&outputs[0]));
}

#[test]
fn restricted_ed25519_sender() {
let protocol_parameters = protocol_parameters();
let sender = Address::try_from_bech32(BECH32_ADDRESS_ED25519_1).unwrap();
let restricted_sender = RestrictedAddress::new(sender.clone()).unwrap();
let restricted_sender_bech32 = restricted_sender.to_bech32_unchecked("rms").to_string();

let inputs = build_inputs([
Basic(2_000_000, BECH32_ADDRESS_ED25519_0, None, None, None, None, None, None),
Basic(2_000_000, BECH32_ADDRESS_ED25519_0, None, None, None, None, None, None),
Basic(1_000_000, BECH32_ADDRESS_ED25519_1, None, None, None, None, None, None),
Basic(2_000_000, BECH32_ADDRESS_ED25519_0, None, None, None, None, None, None),
Basic(2_000_000, BECH32_ADDRESS_ED25519_0, None, None, None, None, None, None),
]);
let outputs = build_outputs([Basic(
2_000_000,
BECH32_ADDRESS_ED25519_0,
None,
Some(&restricted_sender_bech32),
None,
None,
None,
None,
)]);

let selected = InputSelection::new(
inputs.clone(),
outputs.clone(),
addresses([BECH32_ADDRESS_ED25519_0, BECH32_ADDRESS_ED25519_1]),
protocol_parameters,
)
.select()
.unwrap();

// Sender + another for amount
assert_eq!(selected.inputs.len(), 2);
assert!(
selected
.inputs
.iter()
.any(|input| *input.output.as_basic().address() == sender)
);
// Provided output + remainder
assert_eq!(selected.outputs.len(), 2);
}

0 comments on commit 50a5139

Please sign in to comment.