Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align method serialization in Python #1824

Merged
merged 11 commits into from
Jan 12, 2024
4 changes: 3 additions & 1 deletion bindings/python/iota_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from .external import *

from .client.client import Client, NodeIndexerAPI, ClientError
from .common import custom_encoder
from .client.client import Client, NodeIndexerAPI
from .client.common import ClientError
from .client._high_level_api import GenerateAddressesOptions, GenerateAddressOptions
from .utils import Utils
from .wallet.wallet import Wallet, WalletOptions
Expand Down
2 changes: 1 addition & 1 deletion bindings/python/iota_sdk/client/_node_core_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def post_block(self, block: Block) -> HexStr:
The block id of the posted block.
"""
return self._call_method('postBlock', {
'block': block.to_dict()
'block': block
})

def get_block(self, block_id: HexStr) -> Block:
Expand Down
28 changes: 7 additions & 21 deletions bindings/python/iota_sdk/client/_node_indexer_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,8 @@ def output_ids(
The corresponding output IDs of the outputs.
"""

query_parameters_camelized = query_parameters.to_dict()

response = self._call_method('outputIds', {
'queryParameters': query_parameters_camelized,
'queryParameters': query_parameters,
})
return OutputIdsResponse(response)

Expand All @@ -276,10 +274,8 @@ def basic_output_ids(
The corresponding output IDs of the basic outputs.
"""

query_parameters_camelized = query_parameters.to_dict()

response = self._call_method('basicOutputIds', {
'queryParameters': query_parameters_camelized,
'queryParameters': query_parameters,
})
return OutputIdsResponse(response)

Expand All @@ -291,10 +287,8 @@ def account_output_ids(
The corresponding output IDs of the account outputs.
"""

query_parameters_camelized = query_parameters.to_dict()

response = self._call_method('accountOutputIds', {
'queryParameters': query_parameters_camelized,
'queryParameters': query_parameters,
})
return OutputIdsResponse(response)

Expand All @@ -316,10 +310,8 @@ def anchor_output_ids(
The corresponding output IDs of the anchor outputs.
"""

query_parameters_camelized = query_parameters.to_dict()

response = self._call_method('anchorOutputIds', {
'queryParameters': query_parameters_camelized,
'queryParameters': query_parameters,
})
return OutputIdsResponse(response)

Expand All @@ -341,10 +333,8 @@ def delegation_output_ids(
The corresponding output IDs of the delegation outputs.
"""

query_parameters_camelized = query_parameters.to_dict()

response = self._call_method('delegationOutputIds', {
'queryParameters': query_parameters_camelized,
'queryParameters': query_parameters,
})
return OutputIdsResponse(response)

Expand All @@ -366,10 +356,8 @@ def foundry_output_ids(
The corresponding output IDs of the foundry outputs.
"""

query_parameters_camelized = query_parameters.to_dict()

response = self._call_method('foundryOutputIds', {
'queryParameters': query_parameters_camelized,
'queryParameters': query_parameters,
})
return OutputIdsResponse(response)

Expand All @@ -391,10 +379,8 @@ def nft_output_ids(
The corresponding output IDs of the NFT outputs.
"""

query_parameters_camelized = query_parameters.to_dict()

response = self._call_method('nftOutputIds', {
'queryParameters': query_parameters_camelized,
'queryParameters': query_parameters,
})
return OutputIdsResponse(response)

Expand Down
97 changes: 13 additions & 84 deletions bindings/python/iota_sdk/client/client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# Copyright 2023 IOTA Stiftung
# SPDX-License-Identifier: Apache-2.0

from json import dumps, loads
from json import dumps
from datetime import timedelta
from typing import Any, Dict, List, Optional, Union
import humps

from iota_sdk.external import create_client, call_client_method, listen_mqtt
from iota_sdk.external import create_client, listen_mqtt
from iota_sdk.client._node_core_api import NodeCoreAPI
from iota_sdk.client._node_indexer_api import NodeIndexerAPI
from iota_sdk.client._high_level_api import HighLevelAPI
from iota_sdk.client._utils import ClientUtils
from iota_sdk.client.common import _call_client_method_routine
from iota_sdk.types.block.block import UnsignedBlock
from iota_sdk.types.common import HexStr, Node
from iota_sdk.types.feature import Feature
Expand All @@ -21,10 +22,6 @@
from iota_sdk.types.unlock_condition import UnlockCondition


class ClientError(Exception):
"""Represents a client error."""


class Client(NodeCoreAPI, NodeIndexerAPI, HighLevelAPI, ClientUtils):
"""Represents an IOTA client.

Expand Down Expand Up @@ -117,6 +114,7 @@ def get_remaining_nano_seconds(duration: timedelta):
else:
self.handle = client_handle

@_call_client_method_routine
def _call_method(self, name, data=None):
"""Dumps json string and calls `call_client_method()`
"""
Expand All @@ -125,20 +123,7 @@ def _call_method(self, name, data=None):
}
if data:
message['data'] = data
message = dumps(message)

# Send message to the Rust library
response = call_client_method(self.handle, message)

json_response = loads(response)

if "type" in json_response:
if json_response["type"] == "error":
raise ClientError(json_response['payload'])

if "payload" in json_response:
return json_response['payload']
return response
return message

def get_handle(self):
"""Get the client handle.
Expand Down Expand Up @@ -171,26 +156,11 @@ def build_account_output(self,
The account output as dict.
"""

unlock_conditions = [unlock_condition.to_dict()
for unlock_condition in unlock_conditions]

if features:
features = [feature.to_dict() for feature in features]
if immutable_features:
immutable_features = [immutable_feature.to_dict()
for immutable_feature in immutable_features]

if amount:
amount = str(amount)

if mana:
mana = str(mana)

return deserialize_output(self._call_method('buildAccountOutput', {
'accountId': account_id,
'unlockConditions': unlock_conditions,
'amount': amount,
'mana': mana,
'amount': None if amount is None else str(amount),
'mana': None if mana is None else str(mana),
'foundryCounter': foundry_counter,
'features': features,
'immutableFeatures': immutable_features
Expand All @@ -213,22 +183,10 @@ def build_basic_output(self,
The basic output as dict.
"""

unlock_conditions = [unlock_condition.to_dict()
for unlock_condition in unlock_conditions]

if features:
features = [feature.to_dict() for feature in features]

if amount:
amount = str(amount)

if mana:
mana = str(mana)

return deserialize_output(self._call_method('buildBasicOutput', {
'unlockConditions': unlock_conditions,
'amount': amount,
'mana': mana,
'amount': None if amount is None else str(amount),
'mana': None if mana is None else str(mana),
'features': features,
}))

Expand All @@ -253,23 +211,11 @@ def build_foundry_output(self,
The foundry output as dict.
"""

unlock_conditions = [unlock_condition.to_dict()
for unlock_condition in unlock_conditions]

if features:
features = [feature.to_dict() for feature in features]
if immutable_features:
immutable_features = [immutable_feature.to_dict()
for immutable_feature in immutable_features]

if amount:
amount = str(amount)

return deserialize_output(self._call_method('buildFoundryOutput', {
'serialNumber': serial_number,
'tokenScheme': token_scheme.to_dict(),
'tokenScheme': token_scheme,
'unlockConditions': unlock_conditions,
'amount': amount,
'amount': None if amount is None else str(amount),
'features': features,
'immutableFeatures': immutable_features
}))
Expand All @@ -295,26 +241,11 @@ def build_nft_output(self,
The NFT output as dict.
"""

unlock_conditions = [unlock_condition.to_dict()
for unlock_condition in unlock_conditions]

if features:
features = [feature.to_dict() for feature in features]
if immutable_features:
immutable_features = [immutable_feature.to_dict()
for immutable_feature in immutable_features]

if amount:
amount = str(amount)

if mana:
mana = str(mana)

return deserialize_output(self._call_method('buildNftOutput', {
'nftId': nft_id,
'unlockConditions': unlock_conditions,
'amount': amount,
'mana': mana,
'amount': None if amount is None else str(amount),
'mana': None if mana is None else str(mana),
'features': features,
'immutableFeatures': immutable_features
}))
Expand Down Expand Up @@ -358,8 +289,6 @@ def build_basic_block(
Returns:
An unsigned block.
"""
if payload is not None:
payload = payload.to_dict()
result = self._call_method('buildBasicBlock', {
'issuerId': issuer_id,
'payload': payload,
Expand Down
30 changes: 30 additions & 0 deletions bindings/python/iota_sdk/client/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2024 IOTA Stiftung
# SPDX-License-Identifier: Apache-2.0

import json
from iota_sdk import call_client_method
from iota_sdk.common import custom_encoder


def _call_client_method_routine(func):
"""The routine of dump json string and call call_client_method().
"""
def wrapper(*args, **kwargs):
message = custom_encoder(func, *args, **kwargs)
# Send message to the Rust library
response = call_client_method(args[0].handle, message)

json_response = json.loads(response)

if "type" in json_response:
if json_response["type"] == "error" or json_response["type"] == "panic":
raise ClientError(json_response['payload'])

if "payload" in json_response:
return json_response['payload']
return response
return wrapper


class ClientError(Exception):
"""A client error."""
51 changes: 51 additions & 0 deletions bindings/python/iota_sdk/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2024 IOTA Stiftung
# SPDX-License-Identifier: Apache-2.0

import json
from json import dumps, JSONEncoder
from enum import Enum
import humps


def custom_encoder(func, *args, **kwargs):
"""Converts the parameters to a JSON string and removes None values.
"""
class MyEncoder(JSONEncoder):
"""Custom encoder
"""
# pylint: disable=too-many-return-statements
Alex6323 marked this conversation as resolved.
Show resolved Hide resolved

def default(self, o):
to_dict_method = getattr(o, "to_dict", None)
if callable(to_dict_method):
return o.to_dict()
if isinstance(o, str):
return o
if isinstance(o, Enum):
return o.__dict__
if isinstance(o, dict):
return o
if hasattr(o, "__dict__"):
obj_dict = o.__dict__
items_method = getattr(self, "items", None)
if callable(items_method):
for k, v in obj_dict.items():
obj_dict[k] = dumps(v, cls=MyEncoder)
return obj_dict
return o
message = func(*args, **kwargs)
for k, v in message.items():
if not isinstance(v, str):
message[k] = json.loads(dumps(v, cls=MyEncoder))

def remove_none(obj):
if isinstance(obj, (list, tuple, set)):
return type(obj)(remove_none(x) for x in obj if x is not None)
if isinstance(obj, dict):
return type(obj)((remove_none(k), remove_none(v))
for k, v in obj.items() if k is not None and v is not None)
Alex6323 marked this conversation as resolved.
Show resolved Hide resolved
return obj

message_null_filtered = remove_none(message)
Thoralf-M marked this conversation as resolved.
Show resolved Hide resolved
message = dumps(humps.camelize(message_null_filtered))
return message
2 changes: 1 addition & 1 deletion bindings/python/iota_sdk/types/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def deserialize_output(d: Dict[str, Any]) -> Output:
Arguments:
* `d`: A dictionary that is expected to have a key called 'type' which specifies the type of the returned value.
"""
output_type = dict['type']
output_type = d['type']
if output_type == OutputType.Basic:
return BasicOutput.from_dict(d)
if output_type == OutputType.Account:
Expand Down
8 changes: 3 additions & 5 deletions bindings/python/iota_sdk/types/output_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ def __init__(self, transaction_id: HexStr, output_index: int):
raise ValueError('transaction_id must start with 0x')
# Validate that it has only valid hex characters
int(transaction_id[2:], 16)
if output_index not in range(0, 129):
raise ValueError('output_index must be a value from 0 to 128')
output_index_hex = (output_index).to_bytes(2, "little").hex()
output_index_hex = (output_index).to_bytes(4, "little").hex()
self.output_id = transaction_id + output_index_hex
self.transaction_id = transaction_id
self.output_index = output_index
Expand All @@ -43,9 +41,9 @@ def from_string(cls, output_id: HexStr):
"""
obj = cls.__new__(cls)
super(OutputId, obj).__init__()
if len(output_id) != 70:
if len(output_id) != 78:
raise ValueError(
'output_id length must be 70 characters with 0x prefix')
'output_id length must be 78 characters with 0x prefix')
if not output_id.startswith('0x'):
raise ValueError('transaction_id must start with 0x')
# Validate that it has only valid hex characters
Expand Down
Loading