Skip to content

Commit

Permalink
Align method serialization in Python (#1824)
Browse files Browse the repository at this point in the history
* Align method serialization in Python;
Fix TimelockUnlockConditionDto

* Fix test_output_id

* Don't use dangerous default value []

* Remove early return, update comment

* Update bindings/python/iota_sdk/common.py

Co-authored-by: /alex/ <[email protected]>

---------

Co-authored-by: /alex/ <[email protected]>
Co-authored-by: Thibault Martinez <[email protected]>
  • Loading branch information
3 people authored Jan 12, 2024
1 parent 032f179 commit 9c8956a
Show file tree
Hide file tree
Showing 12 changed files with 121 additions and 169 deletions.
4 changes: 3 additions & 1 deletion bindings/python/iota_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from .external import *

from .client.client import Client, NodeIndexerAPI, ClientError
from .common import custom_encoder
from .client.client import Client, NodeIndexerAPI
from .client.common import ClientError
from .client._high_level_api import GenerateAddressesOptions, GenerateAddressOptions
from .utils import Utils
from .wallet.wallet import Wallet, WalletOptions
Expand Down
2 changes: 1 addition & 1 deletion bindings/python/iota_sdk/client/_node_core_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def post_block(self, block: Block) -> HexStr:
The block id of the posted block.
"""
return self._call_method('postBlock', {
'block': block.to_dict()
'block': block
})

def get_block(self, block_id: HexStr) -> Block:
Expand Down
28 changes: 7 additions & 21 deletions bindings/python/iota_sdk/client/_node_indexer_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,8 @@ def output_ids(
The corresponding output IDs of the outputs.
"""

query_parameters_camelized = query_parameters.to_dict()

response = self._call_method('outputIds', {
'queryParameters': query_parameters_camelized,
'queryParameters': query_parameters,
})
return OutputIdsResponse(response)

Expand All @@ -276,10 +274,8 @@ def basic_output_ids(
The corresponding output IDs of the basic outputs.
"""

query_parameters_camelized = query_parameters.to_dict()

response = self._call_method('basicOutputIds', {
'queryParameters': query_parameters_camelized,
'queryParameters': query_parameters,
})
return OutputIdsResponse(response)

Expand All @@ -291,10 +287,8 @@ def account_output_ids(
The corresponding output IDs of the account outputs.
"""

query_parameters_camelized = query_parameters.to_dict()

response = self._call_method('accountOutputIds', {
'queryParameters': query_parameters_camelized,
'queryParameters': query_parameters,
})
return OutputIdsResponse(response)

Expand All @@ -316,10 +310,8 @@ def anchor_output_ids(
The corresponding output IDs of the anchor outputs.
"""

query_parameters_camelized = query_parameters.to_dict()

response = self._call_method('anchorOutputIds', {
'queryParameters': query_parameters_camelized,
'queryParameters': query_parameters,
})
return OutputIdsResponse(response)

Expand All @@ -341,10 +333,8 @@ def delegation_output_ids(
The corresponding output IDs of the delegation outputs.
"""

query_parameters_camelized = query_parameters.to_dict()

response = self._call_method('delegationOutputIds', {
'queryParameters': query_parameters_camelized,
'queryParameters': query_parameters,
})
return OutputIdsResponse(response)

Expand All @@ -366,10 +356,8 @@ def foundry_output_ids(
The corresponding output IDs of the foundry outputs.
"""

query_parameters_camelized = query_parameters.to_dict()

response = self._call_method('foundryOutputIds', {
'queryParameters': query_parameters_camelized,
'queryParameters': query_parameters,
})
return OutputIdsResponse(response)

Expand All @@ -391,10 +379,8 @@ def nft_output_ids(
The corresponding output IDs of the NFT outputs.
"""

query_parameters_camelized = query_parameters.to_dict()

response = self._call_method('nftOutputIds', {
'queryParameters': query_parameters_camelized,
'queryParameters': query_parameters,
})
return OutputIdsResponse(response)

Expand Down
97 changes: 13 additions & 84 deletions bindings/python/iota_sdk/client/client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# Copyright 2023 IOTA Stiftung
# SPDX-License-Identifier: Apache-2.0

from json import dumps, loads
from json import dumps
from datetime import timedelta
from typing import Any, Dict, List, Optional, Union
import humps

from iota_sdk.external import create_client, call_client_method, listen_mqtt
from iota_sdk.external import create_client, listen_mqtt
from iota_sdk.client._node_core_api import NodeCoreAPI
from iota_sdk.client._node_indexer_api import NodeIndexerAPI
from iota_sdk.client._high_level_api import HighLevelAPI
from iota_sdk.client._utils import ClientUtils
from iota_sdk.client.common import _call_client_method_routine
from iota_sdk.types.block.block import UnsignedBlock
from iota_sdk.types.common import HexStr, Node
from iota_sdk.types.feature import Feature
Expand All @@ -21,10 +22,6 @@
from iota_sdk.types.unlock_condition import UnlockCondition


class ClientError(Exception):
"""Represents a client error."""


class Client(NodeCoreAPI, NodeIndexerAPI, HighLevelAPI, ClientUtils):
"""Represents an IOTA client.
Expand Down Expand Up @@ -117,6 +114,7 @@ def get_remaining_nano_seconds(duration: timedelta):
else:
self.handle = client_handle

@_call_client_method_routine
def _call_method(self, name, data=None):
"""Dumps json string and calls `call_client_method()`
"""
Expand All @@ -125,20 +123,7 @@ def _call_method(self, name, data=None):
}
if data:
message['data'] = data
message = dumps(message)

# Send message to the Rust library
response = call_client_method(self.handle, message)

json_response = loads(response)

if "type" in json_response:
if json_response["type"] == "error":
raise ClientError(json_response['payload'])

if "payload" in json_response:
return json_response['payload']
return response
return message

def get_handle(self):
"""Get the client handle.
Expand Down Expand Up @@ -171,26 +156,11 @@ def build_account_output(self,
The account output as dict.
"""

unlock_conditions = [unlock_condition.to_dict()
for unlock_condition in unlock_conditions]

if features:
features = [feature.to_dict() for feature in features]
if immutable_features:
immutable_features = [immutable_feature.to_dict()
for immutable_feature in immutable_features]

if amount:
amount = str(amount)

if mana:
mana = str(mana)

return deserialize_output(self._call_method('buildAccountOutput', {
'accountId': account_id,
'unlockConditions': unlock_conditions,
'amount': amount,
'mana': mana,
'amount': None if amount is None else str(amount),
'mana': None if mana is None else str(mana),
'foundryCounter': foundry_counter,
'features': features,
'immutableFeatures': immutable_features
Expand All @@ -213,22 +183,10 @@ def build_basic_output(self,
The basic output as dict.
"""

unlock_conditions = [unlock_condition.to_dict()
for unlock_condition in unlock_conditions]

if features:
features = [feature.to_dict() for feature in features]

if amount:
amount = str(amount)

if mana:
mana = str(mana)

return deserialize_output(self._call_method('buildBasicOutput', {
'unlockConditions': unlock_conditions,
'amount': amount,
'mana': mana,
'amount': None if amount is None else str(amount),
'mana': None if mana is None else str(mana),
'features': features,
}))

Expand All @@ -253,23 +211,11 @@ def build_foundry_output(self,
The foundry output as dict.
"""

unlock_conditions = [unlock_condition.to_dict()
for unlock_condition in unlock_conditions]

if features:
features = [feature.to_dict() for feature in features]
if immutable_features:
immutable_features = [immutable_feature.to_dict()
for immutable_feature in immutable_features]

if amount:
amount = str(amount)

return deserialize_output(self._call_method('buildFoundryOutput', {
'serialNumber': serial_number,
'tokenScheme': token_scheme.to_dict(),
'tokenScheme': token_scheme,
'unlockConditions': unlock_conditions,
'amount': amount,
'amount': None if amount is None else str(amount),
'features': features,
'immutableFeatures': immutable_features
}))
Expand All @@ -295,26 +241,11 @@ def build_nft_output(self,
The NFT output as dict.
"""

unlock_conditions = [unlock_condition.to_dict()
for unlock_condition in unlock_conditions]

if features:
features = [feature.to_dict() for feature in features]
if immutable_features:
immutable_features = [immutable_feature.to_dict()
for immutable_feature in immutable_features]

if amount:
amount = str(amount)

if mana:
mana = str(mana)

return deserialize_output(self._call_method('buildNftOutput', {
'nftId': nft_id,
'unlockConditions': unlock_conditions,
'amount': amount,
'mana': mana,
'amount': None if amount is None else str(amount),
'mana': None if mana is None else str(mana),
'features': features,
'immutableFeatures': immutable_features
}))
Expand Down Expand Up @@ -358,8 +289,6 @@ def build_basic_block(
Returns:
An unsigned block.
"""
if payload is not None:
payload = payload.to_dict()
result = self._call_method('buildBasicBlock', {
'issuerId': issuer_id,
'payload': payload,
Expand Down
30 changes: 30 additions & 0 deletions bindings/python/iota_sdk/client/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2024 IOTA Stiftung
# SPDX-License-Identifier: Apache-2.0

import json
from iota_sdk import call_client_method
from iota_sdk.common import custom_encoder


def _call_client_method_routine(func):
"""The routine of dump json string and call call_client_method().
"""
def wrapper(*args, **kwargs):
message = custom_encoder(func, *args, **kwargs)
# Send message to the Rust library
response = call_client_method(args[0].handle, message)

json_response = json.loads(response)

if "type" in json_response:
if json_response["type"] == "error" or json_response["type"] == "panic":
raise ClientError(json_response['payload'])

if "payload" in json_response:
return json_response['payload']
return response
return wrapper


class ClientError(Exception):
"""A client error."""
51 changes: 51 additions & 0 deletions bindings/python/iota_sdk/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2024 IOTA Stiftung
# SPDX-License-Identifier: Apache-2.0

import json
from json import dumps, JSONEncoder
from enum import Enum
import humps


def custom_encoder(func, *args, **kwargs):
"""Converts the parameters to a JSON string and removes None values.
"""
class MyEncoder(JSONEncoder):
"""Custom encoder
"""
# pylint: disable=too-many-return-statements

def default(self, o):
to_dict_method = getattr(o, "to_dict", None)
if callable(to_dict_method):
return o.to_dict()
if isinstance(o, str):
return o
if isinstance(o, Enum):
return o.__dict__
if isinstance(o, dict):
return o
if hasattr(o, "__dict__"):
obj_dict = o.__dict__
items_method = getattr(self, "items", None)
if callable(items_method):
for k, v in obj_dict.items():
obj_dict[k] = dumps(v, cls=MyEncoder)
return obj_dict
return o
message = func(*args, **kwargs)
for k, v in message.items():
if not isinstance(v, str):
message[k] = json.loads(dumps(v, cls=MyEncoder))

def remove_none(obj):
if isinstance(obj, (list, tuple, set)):
return type(obj)(remove_none(x) for x in obj if x is not None)
if isinstance(obj, dict):
return type(obj)((remove_none(k), remove_none(v))
for k, v in obj.items() if k is not None and v is not None)
return obj

message_null_filtered = remove_none(message)
message = dumps(humps.camelize(message_null_filtered))
return message
2 changes: 1 addition & 1 deletion bindings/python/iota_sdk/types/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def deserialize_output(d: Dict[str, Any]) -> Output:
Arguments:
* `d`: A dictionary that is expected to have a key called 'type' which specifies the type of the returned value.
"""
output_type = dict['type']
output_type = d['type']
if output_type == OutputType.Basic:
return BasicOutput.from_dict(d)
if output_type == OutputType.Account:
Expand Down
8 changes: 3 additions & 5 deletions bindings/python/iota_sdk/types/output_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ def __init__(self, transaction_id: HexStr, output_index: int):
raise ValueError('transaction_id must start with 0x')
# Validate that it has only valid hex characters
int(transaction_id[2:], 16)
if output_index not in range(0, 129):
raise ValueError('output_index must be a value from 0 to 128')
output_index_hex = (output_index).to_bytes(2, "little").hex()
output_index_hex = (output_index).to_bytes(4, "little").hex()
self.output_id = transaction_id + output_index_hex
self.transaction_id = transaction_id
self.output_index = output_index
Expand All @@ -43,9 +41,9 @@ def from_string(cls, output_id: HexStr):
"""
obj = cls.__new__(cls)
super(OutputId, obj).__init__()
if len(output_id) != 70:
if len(output_id) != 78:
raise ValueError(
'output_id length must be 70 characters with 0x prefix')
'output_id length must be 78 characters with 0x prefix')
if not output_id.startswith('0x'):
raise ValueError('transaction_id must start with 0x')
# Validate that it has only valid hex characters
Expand Down
Loading

0 comments on commit 9c8956a

Please sign in to comment.