diff --git a/tests/test_api.py b/tests/test_api.py index 4374b9d..400304e 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -187,6 +187,12 @@ def test_integer(): assert isinstance(i, Integer) +def test_decimal(): + d = tomlkit.decimal("34.56") + + assert isinstance(d, String) + + def test_float(): i = tomlkit.float_("34.56") diff --git a/tests/test_items.py b/tests/test_items.py index 45aea25..2fd0566 100644 --- a/tests/test_items.py +++ b/tests/test_items.py @@ -2,6 +2,7 @@ import math import pickle +from decimal import Decimal from datetime import date from datetime import datetime from datetime import time @@ -90,6 +91,16 @@ def test_integer_unwrap(): elementary_test(item(666), int) +def test_decimal_unwrap(): + """Ensure a decimal unwraps as a string + after TOML encode. + """ + elementary_test( + item(Decimal("0.001")), + str, + ) + + def test_float_unwrap(): elementary_test(item(2.78), float) diff --git a/tomlkit/__init__.py b/tomlkit/__init__.py index acc7046..0075b97 100644 --- a/tomlkit/__init__.py +++ b/tomlkit/__init__.py @@ -5,6 +5,7 @@ from tomlkit.api import comment from tomlkit.api import date from tomlkit.api import datetime +from tomlkit.api import decimal from tomlkit.api import document from tomlkit.api import dump from tomlkit.api import dumps @@ -33,6 +34,7 @@ "comment", "date", "datetime", + "decimal", "document", "dump", "dumps", diff --git a/tomlkit/api.py b/tomlkit/api.py index 8ec5653..b655d98 100644 --- a/tomlkit/api.py +++ b/tomlkit/api.py @@ -3,6 +3,7 @@ import datetime as _datetime from collections.abc import Mapping +from decimal import Decimal from typing import IO from typing import Iterable @@ -99,6 +100,10 @@ def float_(raw: str | float) -> Float: """Create an float item from a number or string.""" return item(float(raw)) +def decimal(raw: str | Decimal | float) -> String: + """Create an string item from a ``decimal.Decimal``.""" + return item(Decimal(raw)) + def boolean(raw: str) -> Bool: """Turn `true` or `false` into a boolean item.""" diff --git a/tomlkit/items.py b/tomlkit/items.py index 683c189..933e408 100644 --- a/tomlkit/items.py +++ b/tomlkit/items.py @@ -10,12 +10,14 @@ from datetime import datetime from datetime import time from datetime import tzinfo +from decimal import Decimal from enum import Enum from typing import TYPE_CHECKING from typing import Any from typing import Collection from typing import Iterable from typing import Iterator +from typing import Optional from typing import Sequence from typing import TypeVar from typing import cast @@ -103,6 +105,20 @@ def item( ... +@overload +def item( + value: Decimal, _parent: Optional["Item"] = ..., _sort_keys: bool = ... +) -> "String": + ... + + +@overload +def item( + value: float, _parent: Optional["Item"] = ..., _sort_keys: bool = ... +) -> "Float": + ... + + @overload def item(value: Sequence, _parent: Item | None = ..., _sort_keys: bool = ...) -> Array: ... @@ -146,16 +162,21 @@ def item(value: Any, _parent: Item | None = None, _sort_keys: bool = False) -> I return Bool(value, Trivia()) elif isinstance(value, int): return Integer(value, Trivia(), str(value)) + elif isinstance(value, Decimal): + return String.from_raw(str(value)) elif isinstance(value, float): return Float(value, Trivia(), str(value)) - elif isinstance(value, dict): + elif isinstance(value, MutableMapping): table_constructor = ( InlineTable if isinstance(_parent, (Array, InlineTable)) else Table ) val = table_constructor(Container(), Trivia(), False) for k, v in sorted( value.items(), - key=lambda i: (isinstance(i[1], dict), i[0]) if _sort_keys else 1, + key=lambda i: ( + isinstance(i[1], MutableMapping), i[0]) + if _sort_keys else 1 + , ): val[k] = item(v, _parent=val, _sort_keys=_sort_keys) @@ -163,7 +184,7 @@ def item(value: Any, _parent: Item | None = None, _sort_keys: bool = False) -> I elif isinstance(value, (list, tuple)): if ( value - and all(isinstance(v, dict) for v in value) + and all(isinstance(v, MutableMapping) for v in value) and (_parent is None or isinstance(_parent, Table)) ): a = AoT([]) @@ -173,12 +194,15 @@ def item(value: Any, _parent: Item | None = None, _sort_keys: bool = False) -> I table_constructor = InlineTable for v in value: - if isinstance(v, dict): + if isinstance(v, MutableMapping): table = table_constructor(Container(), Trivia(), True) for k, _v in sorted( v.items(), - key=lambda i: (isinstance(i[1], dict), i[0] if _sort_keys else 1), + key=lambda i: ( + isinstance(i[1], MutableMapping), + i[0] if _sort_keys else 1 + ), ): i = item(_v, _parent=table, _sort_keys=_sort_keys) if isinstance(table, InlineTable): @@ -1104,7 +1128,8 @@ def __init__( ) self._index_map: dict[int, int] = {} self._value = self._group_values(value) - self._multiline = multiline + self._multiline: bool = multiline + self._multiline_indent: str = " "*4 self._reindex() def _group_values(self, value: list[Item]) -> list[_ArrayItemGroup]: @@ -1154,7 +1179,12 @@ def _iter_items(self) -> Iterator[Item]: for v in self._value: yield from v - def multiline(self, multiline: bool) -> Array: + def multiline( + self, + multiline: bool, + indent: str = " "*4, + + ) -> Array: """Change the array to display in multiline or not. :Example: @@ -1170,6 +1200,7 @@ def multiline(self, multiline: bool) -> Array: ] """ self._multiline = multiline + self._multiline_indent = indent return self @@ -1180,7 +1211,7 @@ def as_string(self) -> str: s = "[\n" s += "".join( self.trivia.indent - + " " * 4 + + self._multiline_indent + v.value.as_string() + ("," if not isinstance(v.value, Null) else "") + (v.comment.as_string() if v.comment is not None else "")