-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Column logical (not physical) type and allow_schema_mismatch #606
base: delta
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
import os | ||
import sys | ||
import warnings | ||
from collections import Counter, defaultdict | ||
from concurrent.futures import ThreadPoolExecutor, wait | ||
from concurrent.futures._base import Future | ||
from enum import IntEnum | ||
|
@@ -479,6 +480,22 @@ def __init__( | |
self.stream_per_shard = np.array(stream_per_shard, np.int64) | ||
self.num_shards = len(self.shards) | ||
|
||
# Maybe check all schemas match. | ||
if not allow_schema_mismatch: | ||
sigs = [shard.get_logical_type_signature() for shard in self.shards] | ||
sig2count = Counter(sigs) | ||
if len(sig2count) != 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are mixed logical types of the same kind (like int32 + int64) not allowed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not allowed at this time for MVP. Maybe in the future, we would go with the wider type for all shards? But that would be kind of magical and might interact with custom getitem work? |
||
count2sigs = defaultdict(list) | ||
for sig, count in sorted(sig2count.items()): | ||
count2sigs[count].append(sig) | ||
parts = [] | ||
for count, sigs in sorted(count2sigs.items()): | ||
for sig in sorted(sigs): | ||
part = f'{count} {sig}' | ||
parts.append(part) | ||
text = ', '.join(parts) | ||
raise ValueError(f'Expected the columns of every shard to match, but got: {text}.') | ||
|
||
# Wait for the pool workers (stream index download processes) to finish. | ||
if pool is not None: | ||
pool.join() | ||
|
@@ -516,7 +533,7 @@ def __init__( | |
self.cache_limit = None | ||
|
||
# Build the shard index (for partitioning and mapping samples to shards). | ||
self.samples_per_shard = np.array([shard.samples for shard in self.shards], np.int64) | ||
self.samples_per_shard = np.array([shard.num_samples for shard in self.shards], np.int64) | ||
self.sample_offset_per_shard = self.samples_per_shard.cumsum() - self.samples_per_shard | ||
self.spanner = Spanner(self.samples_per_shard) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
# Copyright 2022-2024 MosaicML Streaming authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""The Streaming logical type hierarchy. | ||
|
||
This is a common language of types which the type systems of all Streaming shard formats are mapped | ||
to. A field is stored as its shard format-specific physical type, and loaded and returned as its | ||
logical type. | ||
""" | ||
|
||
from typing import Optional, Tuple | ||
|
||
import numpy as np | ||
from numpy.typing import DTypeLike | ||
|
||
__all__ = [ | ||
'Type', 'Bytes', 'Str', 'Number', 'Decimal', 'Float', 'Float64', 'Float32', 'Float16', 'Int', | ||
'Int64', 'Int32', 'Int16', 'Int8', 'UInt64', 'UInt32', 'UInt16', 'UInt8', 'Bool', 'NDArray', | ||
'Image', 'JSON', 'Pickle' | ||
] | ||
|
||
|
||
class Type: | ||
"""Logical type.""" | ||
|
||
def get_signature(self) -> str: | ||
"""Get a string representation of this logical type. | ||
|
||
Returns: | ||
str: String representation. | ||
""" | ||
return self.__class__.__name__ | ||
|
||
|
||
class Bytes(Type): | ||
"""Bytes logical type.""" | ||
pass | ||
|
||
|
||
class Str(Type): | ||
"""UTF-8 string logical type.""" | ||
pass | ||
|
||
|
||
class Number(Type): | ||
"""Number logical type.""" | ||
pass | ||
|
||
|
||
class Decimal(Number): | ||
"""Decimal logical type.""" | ||
pass | ||
|
||
|
||
class Float(Number): | ||
"""Native floating point logical type. | ||
|
||
This logical type refers to your programming language's default floating point type. Presumably | ||
the value will have been serialized at that precision or higher. | ||
|
||
For example, in Python/CPython, the language has its own ``float`` type, which is internally | ||
backed by a ``double`` in the implementation. | ||
""" | ||
pass | ||
|
||
|
||
class Float64(Float): | ||
"""Float64 logical type.""" | ||
pass | ||
|
||
|
||
class Float32(Float64): | ||
"""Float32 logical type.""" | ||
pass | ||
|
||
|
||
class Float16(Float32): | ||
"""Float16 logical type.""" | ||
pass | ||
|
||
|
||
class Int(Number): | ||
"""Arbitrary-precision integer logical type.""" | ||
pass | ||
|
||
|
||
class Int64(Int): | ||
"""``int64`` logical type.""" | ||
pass | ||
|
||
|
||
class Int32(Int64): | ||
"""``int32`` logical type.""" | ||
pass | ||
|
||
|
||
class Int16(Int32): | ||
"""``int16`` logical type.""" | ||
pass | ||
|
||
|
||
class Int8(Int16): | ||
"""``int8`` logical type.""" | ||
pass | ||
|
||
|
||
class UInt64(Int): | ||
"""``uint64`` logical type.""" | ||
pass | ||
|
||
|
||
class UInt32(UInt64): | ||
"""``uint32`` logical type.""" | ||
pass | ||
|
||
|
||
class UInt16(UInt32): | ||
"""``uint16`` logical type.""" | ||
pass | ||
|
||
|
||
class UInt8(UInt16): | ||
"""``uint8`` logical type.""" | ||
pass | ||
|
||
|
||
class Bool(UInt8): | ||
"""``bool`` logical type.""" | ||
pass | ||
|
||
|
||
class NDArray(Type): | ||
"""Numpy ndarray logical type. | ||
|
||
Args: | ||
shape (Tuple[int], optional): Optional shape requirement. | ||
dtype (DTypeLike, optional): Optional dtype requirement. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
shape: Optional[Tuple[int]] = None, | ||
dtype: Optional[DTypeLike] = None, | ||
) -> None: | ||
self.shape = shape | ||
self.dtype = np.dtype(dtype) if dtype else None | ||
|
||
def get_signature(self) -> str: | ||
logical_type = self.__class__.__name__ | ||
shape = ','.join(map(str, self.shape)) if self.shape else '' | ||
dtype = self.dtype.name if self.dtype else '' | ||
return ':'.join([logical_type, shape, dtype]) | ||
|
||
|
||
class Image(Type): | ||
"""PIL Image logical type.""" | ||
pass | ||
|
||
|
||
class JSON(Type): | ||
"""JSON logical type.""" | ||
pass | ||
|
||
|
||
class Pickle(Type): | ||
"""Pickle logical type.""" | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,12 @@ | |
from abc import ABC, abstractmethod | ||
from typing import Any | ||
|
||
__all__ = ['is_jsonl_encoded', 'is_jsonl_encoding'] | ||
from streaming.format.base.type import Float as LogicalFloat | ||
from streaming.format.base.type import Int as LogicalInt | ||
from streaming.format.base.type import Str as LogicalStr | ||
from streaming.format.base.type import Type as LogicalType | ||
|
||
__all__ = ['is_jsonl_encoded', 'is_jsonl_encoding', 'jsonl_encoding_to_logical_type'] | ||
|
||
|
||
class Encoding(ABC): | ||
|
@@ -36,6 +41,8 @@ def _validate(data: Any, expected_type: Any) -> bool: | |
class Str(Encoding): | ||
"""Store str.""" | ||
|
||
logical_type = LogicalStr | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wondering, where is this getting used? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
@classmethod | ||
def is_encoded(cls, obj: Any) -> bool: | ||
return cls._validate(obj, str) | ||
|
@@ -44,6 +51,8 @@ def is_encoded(cls, obj: Any) -> bool: | |
class Int(Encoding): | ||
"""Store int.""" | ||
|
||
logical_type = LogicalInt | ||
|
||
@classmethod | ||
def is_encoded(cls, obj: Any) -> bool: | ||
return cls._validate(obj, int) | ||
|
@@ -52,12 +61,18 @@ def is_encoded(cls, obj: Any) -> bool: | |
class Float(Encoding): | ||
"""Store float.""" | ||
|
||
logical_type = LogicalFloat | ||
|
||
@classmethod | ||
def is_encoded(cls, obj: Any) -> bool: | ||
return cls._validate(obj, float) | ||
|
||
|
||
_encodings = {'str': Str, 'int': Int, 'float': Float} | ||
_encodings = { | ||
'str': Str, | ||
'int': Int, | ||
'float': Float, | ||
} | ||
|
||
|
||
def is_jsonl_encoded(encoding: str, value: Any) -> bool: | ||
|
@@ -84,3 +99,16 @@ def is_jsonl_encoding(encoding: str) -> bool: | |
bool: Whether encoding is supported. | ||
""" | ||
return encoding in _encodings | ||
|
||
|
||
def jsonl_encoding_to_logical_type(encoding: str) -> LogicalType: | ||
"""Get the logical type of the given encoding. | ||
|
||
Args: | ||
encoding (str): Encoding. | ||
|
||
Returns: | ||
LogicalType: Its logical type. | ||
""" | ||
cls = _encodings[encoding] | ||
return cls.logical_type() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why the word
maybe
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because the check may happen or it may not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, but it depends on whether the flag
allow_schema_mismatch
is true or not. So, maybe,check all schemas match
?