From f46050552a90cbf6f8bc6c384b6c83f3d2a08970 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Thu, 22 Feb 2024 02:39:07 -0800 Subject: [PATCH 1/2] samples -> num_samples. --- simulation/core/sim_dataset.py | 2 +- streaming/dataset.py | 2 +- streaming/format/base/shard/base.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/simulation/core/sim_dataset.py b/simulation/core/sim_dataset.py index b782c45bf..338c4df2b 100644 --- a/simulation/core/sim_dataset.py +++ b/simulation/core/sim_dataset.py @@ -363,7 +363,7 @@ def __init__( shutil.rmtree(local_foldernames[stream_idx]) # Build the shard index (for partitioning and mapping samples to shards). - self.samples_per_shard = np.array([shard.samples for shard in self.shards], np.int64) + self.samples_per_shard = np.array([shard.num_samples for shard in self.shards], np.int64) self.sample_offset_per_shard = self.samples_per_shard.cumsum() - self.samples_per_shard self.spanner = SimulationSpanner(self.samples_per_shard) diff --git a/streaming/dataset.py b/streaming/dataset.py index 467824bd0..cde26b1e9 100644 --- a/streaming/dataset.py +++ b/streaming/dataset.py @@ -516,7 +516,7 @@ def __init__( self.cache_limit = None # Build the shard index (for partitioning and mapping samples to shards). - self.samples_per_shard = np.array([shard.samples for shard in self.shards], np.int64) + self.samples_per_shard = np.array([shard.num_samples for shard in self.shards], np.int64) self.sample_offset_per_shard = self.samples_per_shard.cumsum() - self.samples_per_shard self.spanner = Spanner(self.samples_per_shard) diff --git a/streaming/format/base/shard/base.py b/streaming/format/base/shard/base.py index 03ddf5e43..c4366fae2 100644 --- a/streaming/format/base/shard/base.py +++ b/streaming/format/base/shard/base.py @@ -37,7 +37,7 @@ def __init__( ) -> None: self.conf = conf self.stream = stream - self.num_samples = self.samples = num_samples + self.num_samples = num_samples self.files = files @classmethod From 47d5d88280bc4597897c46903f9019da41fb11ec Mon Sep 17 00:00:00 2001 From: James Knighton Date: Thu, 22 Feb 2024 10:52:12 -0800 Subject: [PATCH 2/2] Do logical column/field types -> implement allow_schema_mismatch. --- streaming/dataset.py | 17 +++ streaming/format/base/shard/base.py | 15 +++ streaming/format/base/shard/dual_row.py | 5 +- streaming/format/base/shard/mono_row.py | 5 +- streaming/format/base/type.py | 167 ++++++++++++++++++++++++ streaming/format/jsonl/encodings.py | 32 ++++- streaming/format/jsonl/shard.py | 22 +++- streaming/format/mds/encodings.py | 97 +++++++++++++- streaming/format/mds/shard.py | 7 +- streaming/format/xsv/encodings.py | 32 ++++- streaming/format/xsv/shard.py | 6 +- 11 files changed, 390 insertions(+), 15 deletions(-) create mode 100644 streaming/format/base/type.py diff --git a/streaming/dataset.py b/streaming/dataset.py index cde26b1e9..90495761f 100644 --- a/streaming/dataset.py +++ b/streaming/dataset.py @@ -8,6 +8,7 @@ import os import sys import warnings +from collections import Counter, defaultdict from concurrent.futures import ThreadPoolExecutor, wait from concurrent.futures._base import Future from enum import IntEnum @@ -479,6 +480,22 @@ def __init__( self.stream_per_shard = np.array(stream_per_shard, np.int64) self.num_shards = len(self.shards) + # Maybe check all schemas match. + if not allow_schema_mismatch: + sigs = [shard.get_logical_type_signature() for shard in self.shards] + sig2count = Counter(sigs) + if len(sig2count) != 1: + count2sigs = defaultdict(list) + for sig, count in sorted(sig2count.items()): + count2sigs[count].append(sig) + parts = [] + for count, sigs in sorted(count2sigs.items()): + for sig in sorted(sigs): + part = f'{count} {sig}' + parts.append(part) + text = ', '.join(parts) + raise ValueError(f'Expected the columns of every shard to match, but got: {text}.') + # Wait for the pool workers (stream index download processes) to finish. if pool is not None: pool.join() diff --git a/streaming/format/base/shard/base.py b/streaming/format/base/shard/base.py index c4366fae2..3f767e0cf 100644 --- a/streaming/format/base/shard/base.py +++ b/streaming/format/base/shard/base.py @@ -10,6 +10,7 @@ from streaming.array import Array from streaming.format.base.file import ShardFile +from streaming.format.base.type import Type as LogicalType from streaming.stream.dir_conf import StreamDirConf __all__ = ['Shard'] @@ -33,11 +34,13 @@ def __init__( conf: Optional[Any] = None, stream: StreamDirConf, num_samples: int, + logical_columns: Dict[str, LogicalType], files: List[ShardFile], ) -> None: self.conf = conf self.stream = stream self.num_samples = num_samples + self.logical_columns = logical_columns self.files = files @classmethod @@ -77,6 +80,18 @@ def set_stream(self, stream: StreamDirConf) -> None: for file in self.files: file.set_stream(stream) + def get_logical_type_signature(self) -> str: + """Get a string encoding our logical column info. + + Returns: + str: Logical type signature. + """ + parts = [] + for name, logical_type in sorted(self.logical_columns.items()): + sig = logical_type.get_signature() + parts += f'{name}:{sig}', + return ','.join(parts) + def inventory_local(self, listing: Set[str]) -> Optional[int]: """Normalize what files/phases of files are present to a coherent state. diff --git a/streaming/format/base/shard/dual_row.py b/streaming/format/base/shard/dual_row.py index 9efeabc51..e6bdb2fbb 100644 --- a/streaming/format/base/shard/dual_row.py +++ b/streaming/format/base/shard/dual_row.py @@ -3,10 +3,11 @@ """Streaming shard abstract base classes.""" -from typing import Any, Optional +from typing import Any, Dict, Optional from streaming.format.base.file import ShardFile from streaming.format.base.shard.row import RowShard +from streaming.format.base.type import Type as LogicalType from streaming.stream.dir_conf import StreamDirConf __all__ = ['DualRowShard'] @@ -31,6 +32,7 @@ def __init__( conf: Optional[Any] = None, stream: StreamDirConf, num_samples: int, + logical_columns: Dict[str, LogicalType], data_file: ShardFile, meta_file: ShardFile, ) -> None: @@ -38,6 +40,7 @@ def __init__( conf=conf, stream=stream, num_samples=num_samples, + logical_columns=logical_columns, files=[data_file, meta_file], ) self.data_file = data_file diff --git a/streaming/format/base/shard/mono_row.py b/streaming/format/base/shard/mono_row.py index eebb3c2b4..20663a863 100644 --- a/streaming/format/base/shard/mono_row.py +++ b/streaming/format/base/shard/mono_row.py @@ -3,10 +3,11 @@ """Streaming shard abstract base classes.""" -from typing import Any, Optional +from typing import Any, Dict, Optional from streaming.format.base.file import ShardFile from streaming.format.base.shard.row import RowShard +from streaming.format.base.type import Type as LogicalType from streaming.stream.dir_conf import StreamDirConf __all__ = ['MonoRowShard'] @@ -30,12 +31,14 @@ def __init__( conf: Optional[Any] = None, stream: StreamDirConf, num_samples: int, + logical_columns: Dict[str, LogicalType], file: ShardFile, ) -> None: super().__init__( conf=conf, stream=stream, num_samples=num_samples, + logical_columns=logical_columns, files=[file], ) self.file = file diff --git a/streaming/format/base/type.py b/streaming/format/base/type.py new file mode 100644 index 000000000..53d39af33 --- /dev/null +++ b/streaming/format/base/type.py @@ -0,0 +1,167 @@ +# Copyright 2022-2024 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""The Streaming logical type hierarchy. + +This is a common language of types which the type systems of all Streaming shard formats are mapped +to. A field is stored as its shard format-specific physical type, and loaded and returned as its +logical type. +""" + +from typing import Optional, Tuple + +import numpy as np +from numpy.typing import DTypeLike + +__all__ = [ + 'Type', 'Bytes', 'Str', 'Number', 'Decimal', 'Float', 'Float64', 'Float32', 'Float16', 'Int', + 'Int64', 'Int32', 'Int16', 'Int8', 'UInt64', 'UInt32', 'UInt16', 'UInt8', 'Bool', 'NDArray', + 'Image', 'JSON', 'Pickle' +] + + +class Type: + """Logical type.""" + + def get_signature(self) -> str: + """Get a string representation of this logical type. + + Returns: + str: String representation. + """ + return self.__class__.__name__ + + +class Bytes(Type): + """Bytes logical type.""" + pass + + +class Str(Type): + """UTF-8 string logical type.""" + pass + + +class Number(Type): + """Number logical type.""" + pass + + +class Decimal(Number): + """Decimal logical type.""" + pass + + +class Float(Number): + """Native floating point logical type. + + This logical type refers to your programming language's default floating point type. Presumably + the value will have been serialized at that precision or higher. + + For example, in Python/CPython, the language has its own ``float`` type, which is internally + backed by a ``double`` in the implementation. + """ + pass + + +class Float64(Float): + """Float64 logical type.""" + pass + + +class Float32(Float64): + """Float32 logical type.""" + pass + + +class Float16(Float32): + """Float16 logical type.""" + pass + + +class Int(Number): + """Arbitrary-precision integer logical type.""" + pass + + +class Int64(Int): + """``int64`` logical type.""" + pass + + +class Int32(Int64): + """``int32`` logical type.""" + pass + + +class Int16(Int32): + """``int16`` logical type.""" + pass + + +class Int8(Int16): + """``int8`` logical type.""" + pass + + +class UInt64(Int): + """``uint64`` logical type.""" + pass + + +class UInt32(UInt64): + """``uint32`` logical type.""" + pass + + +class UInt16(UInt32): + """``uint16`` logical type.""" + pass + + +class UInt8(UInt16): + """``uint8`` logical type.""" + pass + + +class Bool(UInt8): + """``bool`` logical type.""" + pass + + +class NDArray(Type): + """Numpy ndarray logical type. + + Args: + shape (Tuple[int], optional): Optional shape requirement. + dtype (DTypeLike, optional): Optional dtype requirement. + """ + + def __init__( + self, + shape: Optional[Tuple[int]] = None, + dtype: Optional[DTypeLike] = None, + ) -> None: + self.shape = shape + self.dtype = np.dtype(dtype) if dtype else None + + def get_signature(self) -> str: + logical_type = self.__class__.__name__ + shape = ','.join(map(str, self.shape)) if self.shape else '' + dtype = self.dtype.name if self.dtype else '' + return ':'.join([logical_type, shape, dtype]) + + +class Image(Type): + """PIL Image logical type.""" + pass + + +class JSON(Type): + """JSON logical type.""" + pass + + +class Pickle(Type): + """Pickle logical type.""" + pass diff --git a/streaming/format/jsonl/encodings.py b/streaming/format/jsonl/encodings.py index c16f3ef52..903c01a92 100644 --- a/streaming/format/jsonl/encodings.py +++ b/streaming/format/jsonl/encodings.py @@ -6,7 +6,12 @@ from abc import ABC, abstractmethod from typing import Any -__all__ = ['is_jsonl_encoded', 'is_jsonl_encoding'] +from streaming.format.base.type import Float as LogicalFloat +from streaming.format.base.type import Int as LogicalInt +from streaming.format.base.type import Str as LogicalStr +from streaming.format.base.type import Type as LogicalType + +__all__ = ['is_jsonl_encoded', 'is_jsonl_encoding', 'jsonl_encoding_to_logical_type'] class Encoding(ABC): @@ -36,6 +41,8 @@ def _validate(data: Any, expected_type: Any) -> bool: class Str(Encoding): """Store str.""" + logical_type = LogicalStr + @classmethod def is_encoded(cls, obj: Any) -> bool: return cls._validate(obj, str) @@ -44,6 +51,8 @@ def is_encoded(cls, obj: Any) -> bool: class Int(Encoding): """Store int.""" + logical_type = LogicalInt + @classmethod def is_encoded(cls, obj: Any) -> bool: return cls._validate(obj, int) @@ -52,12 +61,18 @@ def is_encoded(cls, obj: Any) -> bool: class Float(Encoding): """Store float.""" + logical_type = LogicalFloat + @classmethod def is_encoded(cls, obj: Any) -> bool: return cls._validate(obj, float) -_encodings = {'str': Str, 'int': Int, 'float': Float} +_encodings = { + 'str': Str, + 'int': Int, + 'float': Float, +} def is_jsonl_encoded(encoding: str, value: Any) -> bool: @@ -84,3 +99,16 @@ def is_jsonl_encoding(encoding: str) -> bool: bool: Whether encoding is supported. """ return encoding in _encodings + + +def jsonl_encoding_to_logical_type(encoding: str) -> LogicalType: + """Get the logical type of the given encoding. + + Args: + encoding (str): Encoding. + + Returns: + LogicalType: Its logical type. + """ + cls = _encodings[encoding] + return cls.logical_type() diff --git a/streaming/format/jsonl/shard.py b/streaming/format/jsonl/shard.py index 2e35daeba..af163b2ff 100644 --- a/streaming/format/jsonl/shard.py +++ b/streaming/format/jsonl/shard.py @@ -12,6 +12,7 @@ from streaming.format.base.file import ShardFile from streaming.format.base.phase import ShardFilePhase from streaming.format.base.shard.dual_row import DualRowShard +from streaming.format.jsonl.encodings import jsonl_encoding_to_logical_type from streaming.stream.dir_conf import StreamDirConf __all__ = ['JSONLShard'] @@ -43,20 +44,29 @@ def __init__( columns: Dict[str, str], newline: str, ) -> None: + col_names = [] + col_encodings = [] + col_logical_types = [] + for col_name, col_encoding in sorted(columns.items()): + col_names.append(col_name) + col_encodings.append(col_encoding) + col_logical_type = jsonl_encoding_to_logical_type(col_encoding) + col_logical_types.append(col_logical_type) + logical_columns = dict(zip(col_names, col_logical_types)) + super().__init__( conf=conf, stream=stream, num_samples=num_samples, + logical_columns=logical_columns, data_file=data_file, meta_file=meta_file, ) + self.columns = columns - self.column_names = [] - self.column_encodings = [] - for col_name in sorted(self.columns): - self.column_names.append(col_name) - col_encoding = columns[col_name] - self.column_encodings.append(col_encoding) + self.column_names = col_names + self.column_encodings = col_encodings + self.column_logical_types = col_logical_types self.newline = newline @classmethod diff --git a/streaming/format/mds/encodings.py b/streaming/format/mds/encodings.py index 2af5048a0..de86a23d7 100644 --- a/streaming/format/mds/encodings.py +++ b/streaming/format/mds/encodings.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from decimal import Decimal from io import BytesIO -from typing import Any, Optional, Set, Tuple +from typing import Any, Callable, Optional, Set, Tuple import numpy as np from numpy import typing as npt @@ -16,6 +16,28 @@ from PIL.JpegImagePlugin import JpegImageFile from typing_extensions import Self +from streaming.format.base.type import JSON as LogicalJSON +from streaming.format.base.type import Bytes as LogicalBytes +from streaming.format.base.type import Decimal as LogicalDecimal +from streaming.format.base.type import Float as LogicalFloat +from streaming.format.base.type import Float16 as LogicalFloat16 +from streaming.format.base.type import Float32 as LogicalFloat32 +from streaming.format.base.type import Float64 as LogicalFloat64 +from streaming.format.base.type import Image as LogicalImage +from streaming.format.base.type import Int as LogicalInt +from streaming.format.base.type import Int8 as LogicalInt8 +from streaming.format.base.type import Int16 as LogicalInt16 +from streaming.format.base.type import Int32 as LogicalInt32 +from streaming.format.base.type import Int64 as LogicalInt64 +from streaming.format.base.type import NDArray as LogicalNDArray +from streaming.format.base.type import Pickle as LogicalPickle +from streaming.format.base.type import Str as LogicalStr +from streaming.format.base.type import Type as LogicalType +from streaming.format.base.type import UInt8 as LogicalUInt8 +from streaming.format.base.type import UInt16 as LogicalUInt16 +from streaming.format.base.type import UInt32 as LogicalUInt32 +from streaming.format.base.type import UInt64 as LogicalUInt64 + __all__ = [ 'get_mds_encoded_size', 'get_mds_encodings', @@ -23,6 +45,7 @@ 'mds_decode', 'mds_encode', 'is_mds_encoding_safe', + 'mds_encoding_to_logical_type', ] @@ -31,6 +54,18 @@ class Encoding(ABC): size: Optional[int] = None # Fixed size in bytes of encoded data (None if variable size). + def __init__(self) -> None: + self.get_logical_type: Callable[[], LogicalType] # Logical type of this MDS type. + + @property + def logical_type(self) -> LogicalType: + """Access the logical type corresponding to this MDS type. + + Returns: + LogicalType: Its logical type. + """ + return self.get_logical_type() + @abstractmethod def encode(self, obj: Any) -> bytes: """Encode the given data from the original object to bytes. @@ -65,6 +100,8 @@ def _validate(data: Any, expected_type: Any) -> None: class Bytes(Encoding): """Store bytes (no-op encoding).""" + get_logical_type = LogicalBytes + def encode(self, obj: bytes) -> bytes: self._validate(obj, bytes) return obj @@ -76,6 +113,8 @@ def decode(self, data: bytes) -> bytes: class Str(Encoding): """Store UTF-8.""" + get_logical_type = LogicalStr + def encode(self, obj: str) -> bytes: self._validate(obj, str) return obj.encode('utf-8') @@ -87,6 +126,8 @@ def decode(self, data: bytes) -> str: class Int(Encoding): """Store int64.""" + get_logical_type = LogicalInt + size = 8 def encode(self, obj: int) -> bytes: @@ -172,6 +213,7 @@ def __init__(self, dtype: Optional[str] = None, shape: Optional[Tuple[int]] = No self.dtype = dtype self.shape = shape self.size = self._get_static_size(dtype, shape) + self.get_logical_type = lambda: LogicalNDArray(shape, dtype) @classmethod def from_str(cls, text: str) -> Self: @@ -317,6 +359,8 @@ def decode(self, data: bytes) -> Any: class UInt8(Scalar): """Store uint8.""" + get_logical_type = LogicalUInt8 + def __init__(self): super().__init__(np.uint8) @@ -324,6 +368,8 @@ def __init__(self): class UInt16(Scalar): """Store uint16.""" + get_logical_type = LogicalUInt16 + def __init__(self): super().__init__(np.uint16) @@ -331,6 +377,8 @@ def __init__(self): class UInt32(Scalar): """Store uint32.""" + get_logical_type = LogicalUInt32 + def __init__(self): super().__init__(np.uint32) @@ -338,6 +386,8 @@ def __init__(self): class UInt64(Scalar): """Store uint64.""" + get_logical_type = LogicalUInt64 + def __init__(self): super().__init__(np.uint64) @@ -345,6 +395,8 @@ def __init__(self): class Int8(Scalar): """Store int8.""" + get_logical_type = LogicalInt8 + def __init__(self): super().__init__(np.int8) @@ -352,6 +404,8 @@ def __init__(self): class Int16(Scalar): """Store int16.""" + get_logical_type = LogicalInt16 + def __init__(self): super().__init__(np.int16) @@ -359,6 +413,8 @@ def __init__(self): class Int32(Scalar): """Store int32.""" + get_logical_type = LogicalInt32 + def __init__(self): super().__init__(np.int32) @@ -366,6 +422,8 @@ def __init__(self): class Int64(Scalar): """Store int64.""" + get_logical_type = LogicalInt64 + def __init__(self): super().__init__(np.int64) @@ -373,6 +431,8 @@ def __init__(self): class Float16(Scalar): """Store float16.""" + get_logical_type = LogicalFloat16 + def __init__(self): super().__init__(np.float16) @@ -380,6 +440,8 @@ def __init__(self): class Float32(Scalar): """Store float32.""" + get_logical_type = LogicalFloat32 + def __init__(self): super().__init__(np.float32) @@ -387,6 +449,8 @@ def __init__(self): class Float64(Scalar): """Store float64.""" + get_logical_type = LogicalFloat64 + def __init__(self): super().__init__(np.float64) @@ -405,6 +469,8 @@ class StrEncoding(Encoding): class StrInt(StrEncoding): """Store int as variable-length digits str.""" + get_logical_type = LogicalInt + def encode(self, obj: int) -> bytes: self._validate(obj, int) return str(obj).encode('utf-8') @@ -416,6 +482,8 @@ def decode(self, data: bytes) -> int: class StrFloat(Encoding): """Store float as variable-length digits str.""" + get_logical_type = LogicalFloat + def encode(self, obj: float) -> bytes: self._validate(obj, float) return str(obj).encode('utf-8') @@ -427,6 +495,8 @@ def decode(self, data: bytes) -> float: class StrDecimal(Encoding): """Store decimal as variable-length digits str.""" + get_logical_type = LogicalDecimal + def encode(self, obj: Decimal) -> bytes: self._validate(obj, Decimal) return str(obj).encode('utf-8') @@ -441,6 +511,8 @@ class PIL(Encoding): Format: [width: 4] [height: 4] [mode size: 4] [mode] [raw image]. """ + get_logical_type = LogicalImage + def encode(self, obj: Image.Image) -> bytes: self._validate(obj, Image.Image) mode = obj.mode.encode('utf-8') @@ -462,6 +534,8 @@ def decode(self, data: bytes) -> Image.Image: class JPEG(Encoding): """Store PIL image as JPEG.""" + get_logical_type = LogicalImage + def encode(self, obj: Image.Image) -> bytes: self._validate(obj, Image.Image) if isinstance(obj, JpegImageFile) and hasattr(obj, 'filename'): @@ -481,6 +555,8 @@ def decode(self, data: bytes) -> Image.Image: class PNG(Encoding): """Store PIL image as PNG.""" + get_logical_type = LogicalImage + def encode(self, obj: Image.Image) -> bytes: self._validate(obj, Image.Image) out = BytesIO() @@ -495,6 +571,8 @@ def decode(self, data: bytes) -> Image.Image: class Pickle(Encoding): """Store arbitrary data as pickle.""" + get_logical_type = LogicalPickle + def encode(self, obj: Any) -> bytes: return pickle.dumps(obj) @@ -505,6 +583,8 @@ def decode(self, data: bytes) -> Any: class JSON(Encoding): """Store arbitrary data as JSON.""" + get_logical_type = LogicalJSON + def encode(self, obj: Any) -> bytes: data = json.dumps(obj) self._is_valid(obj, data) @@ -652,3 +732,18 @@ def get_mds_encoded_size(encoding: str) -> Optional[int]: if coder is None: raise ValueError(f'Unsupported encoding: {encoding}.') return coder.size + + +def mds_encoding_to_logical_type(encoding: str) -> LogicalType: + """Get the logical type for the given MDS encoding. + + Args: + encoding (str): Encoding. + + Returns: + LogicalType: Its logical type. + """ + coder = _get_coder(encoding) + if coder is None: + raise ValueError(f'Unsupported encoding: {encoding}.') + return coder.get_logical_type() diff --git a/streaming/format/mds/shard.py b/streaming/format/mds/shard.py index 2497809d6..9d050ba9f 100644 --- a/streaming/format/mds/shard.py +++ b/streaming/format/mds/shard.py @@ -14,7 +14,8 @@ from streaming.format.base.file import ShardFile from streaming.format.base.phase import ShardFilePhase from streaming.format.base.shard.mono_row import MonoRowShard -from streaming.format.mds.encodings import is_mds_encoding_safe, mds_decode +from streaming.format.mds.encodings import (is_mds_encoding_safe, mds_decode, + mds_encoding_to_logical_type) from streaming.stream.dir_conf import StreamDirConf __all__ = ['MDSShard'] @@ -51,10 +52,14 @@ def __init__( file: ShardFile, columns: List[MDSColumn], ) -> None: + col_names = [col.name for col in columns] + col_logical_types = [mds_encoding_to_logical_type(col.encoding) for col in columns] + logical_columns = dict(zip(col_names, col_logical_types)) super().__init__( conf=conf, stream=stream, num_samples=num_samples, + logical_columns=logical_columns, file=file, ) self.columns = columns diff --git a/streaming/format/xsv/encodings.py b/streaming/format/xsv/encodings.py index b05cb0dfa..23fa6fc51 100644 --- a/streaming/format/xsv/encodings.py +++ b/streaming/format/xsv/encodings.py @@ -6,7 +6,12 @@ from abc import ABC, abstractmethod from typing import Any -__all__ = ['is_xsv_encoding', 'xsv_decode', 'xsv_encode'] +from streaming.format.base.type import Float as LogicalFloat +from streaming.format.base.type import Int as LogicalInt +from streaming.format.base.type import Str as LogicalStr +from streaming.format.base.type import Type as LogicalType + +__all__ = ['is_xsv_encoding', 'xsv_encoding_to_logical_type', 'xsv_decode', 'xsv_encode'] class Encoding(ABC): @@ -48,6 +53,8 @@ def _validate(data: Any, expected_type: Any) -> None: class Str(Encoding): """Store str.""" + logical_type = LogicalStr + @classmethod def encode(cls, obj: Any) -> str: cls._validate(obj, str) @@ -61,6 +68,8 @@ def decode(cls, obj: str) -> Any: class Int(Encoding): """Store int.""" + logical_type = LogicalInt + @classmethod def encode(cls, obj: Any) -> str: cls._validate(obj, int) @@ -74,6 +83,8 @@ def decode(cls, obj: str) -> Any: class Float(Encoding): """Store float.""" + logical_type = LogicalFloat + @classmethod def encode(cls, obj: Any) -> str: cls._validate(obj, float) @@ -84,7 +95,11 @@ def decode(cls, obj: str) -> Any: return float(obj) -_encodings = {'str': Str, 'int': Int, 'float': Float} +_encodings = { + 'str': Str, + 'int': Int, + 'float': Float, +} def is_xsv_encoding(encoding: str) -> bool: @@ -99,6 +114,19 @@ def is_xsv_encoding(encoding: str) -> bool: return encoding in _encodings +def xsv_encoding_to_logical_type(encoding: str) -> LogicalType: + """Get the logical type of the given encoding. + + Args: + encoding (str): Encoding. + + Returns: + LogicalType: Its logical type. + """ + cls = _encodings[encoding] + return cls.logical_type() + + def xsv_encode(encoding: str, value: Any) -> str: """Encode the given data from the original object to string. diff --git a/streaming/format/xsv/shard.py b/streaming/format/xsv/shard.py index fa8db2d06..a5f7d4824 100644 --- a/streaming/format/xsv/shard.py +++ b/streaming/format/xsv/shard.py @@ -11,7 +11,7 @@ from streaming.format.base.file import ShardFile from streaming.format.base.phase import ShardFilePhase from streaming.format.base.shard.dual_row import DualRowShard -from streaming.format.xsv.encodings import xsv_decode +from streaming.format.xsv.encodings import xsv_decode, xsv_encoding_to_logical_type from streaming.stream.dir_conf import StreamDirConf __all__ = ['XSVShard'] @@ -47,15 +47,19 @@ def __init__( newline: str, separator: Optional[str], ) -> None: + column_logical_types = list(map(xsv_encoding_to_logical_type, column_encodings)) + logical_columns = dict(zip(column_names, column_logical_types)) super().__init__( conf=conf, stream=stream, num_samples=num_samples, + logical_columns=logical_columns, data_file=data_file, meta_file=meta_file, ) self.column_names = column_names self.column_encodings = column_encodings + self.column_logical_types = column_logical_types self.newline = newline self.separator = separator