Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Column logical (not physical) type and allow_schema_mismatch #606

Open
wants to merge 2 commits into
base: delta
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion simulation/core/sim_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def __init__(
shutil.rmtree(local_foldernames[stream_idx])

# Build the shard index (for partitioning and mapping samples to shards).
self.samples_per_shard = np.array([shard.samples for shard in self.shards], np.int64)
self.samples_per_shard = np.array([shard.num_samples for shard in self.shards], np.int64)
self.sample_offset_per_shard = self.samples_per_shard.cumsum() - self.samples_per_shard
self.spanner = SimulationSpanner(self.samples_per_shard)

Expand Down
19 changes: 18 additions & 1 deletion streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import sys
import warnings
from collections import Counter, defaultdict
from concurrent.futures import ThreadPoolExecutor, wait
from concurrent.futures._base import Future
from enum import IntEnum
Expand Down Expand Up @@ -479,6 +480,22 @@ def __init__(
self.stream_per_shard = np.array(stream_per_shard, np.int64)
self.num_shards = len(self.shards)

# Maybe check all schemas match.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why the word maybe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because the check may happen or it may not

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but it depends on whether the flag allow_schema_mismatch is true or not. So, maybe, check all schemas match?

if not allow_schema_mismatch:
sigs = [shard.get_logical_type_signature() for shard in self.shards]
sig2count = Counter(sigs)
if len(sig2count) != 1:
Copy link
Contributor

@XiaohanZhangCMU XiaohanZhangCMU Feb 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are mixed logical types of the same kind (like int32 + int64) not allowed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not allowed at this time for MVP. Maybe in the future, we would go with the wider type for all shards? But that would be kind of magical and might interact with custom getitem work?

count2sigs = defaultdict(list)
for sig, count in sorted(sig2count.items()):
count2sigs[count].append(sig)
parts = []
for count, sigs in sorted(count2sigs.items()):
for sig in sorted(sigs):
part = f'{count} {sig}'
parts.append(part)
text = ', '.join(parts)
raise ValueError(f'Expected the columns of every shard to match, but got: {text}.')

# Wait for the pool workers (stream index download processes) to finish.
if pool is not None:
pool.join()
Expand Down Expand Up @@ -516,7 +533,7 @@ def __init__(
self.cache_limit = None

# Build the shard index (for partitioning and mapping samples to shards).
self.samples_per_shard = np.array([shard.samples for shard in self.shards], np.int64)
self.samples_per_shard = np.array([shard.num_samples for shard in self.shards], np.int64)
self.sample_offset_per_shard = self.samples_per_shard.cumsum() - self.samples_per_shard
self.spanner = Spanner(self.samples_per_shard)

Expand Down
17 changes: 16 additions & 1 deletion streaming/format/base/shard/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from streaming.array import Array
from streaming.format.base.file import ShardFile
from streaming.format.base.type import Type as LogicalType
from streaming.stream.dir_conf import StreamDirConf

__all__ = ['Shard']
Expand All @@ -33,11 +34,13 @@ def __init__(
conf: Optional[Any] = None,
stream: StreamDirConf,
num_samples: int,
logical_columns: Dict[str, LogicalType],
files: List[ShardFile],
) -> None:
self.conf = conf
self.stream = stream
self.num_samples = self.samples = num_samples
self.num_samples = num_samples
self.logical_columns = logical_columns
self.files = files

@classmethod
Expand Down Expand Up @@ -77,6 +80,18 @@ def set_stream(self, stream: StreamDirConf) -> None:
for file in self.files:
file.set_stream(stream)

def get_logical_type_signature(self) -> str:
"""Get a string encoding our logical column info.

Returns:
str: Logical type signature.
"""
parts = []
for name, logical_type in sorted(self.logical_columns.items()):
sig = logical_type.get_signature()
parts += f'{name}:{sig}',
return ','.join(parts)

def inventory_local(self, listing: Set[str]) -> Optional[int]:
"""Normalize what files/phases of files are present to a coherent state.

Expand Down
5 changes: 4 additions & 1 deletion streaming/format/base/shard/dual_row.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

"""Streaming shard abstract base classes."""

from typing import Any, Optional
from typing import Any, Dict, Optional

from streaming.format.base.file import ShardFile
from streaming.format.base.shard.row import RowShard
from streaming.format.base.type import Type as LogicalType
from streaming.stream.dir_conf import StreamDirConf

__all__ = ['DualRowShard']
Expand All @@ -31,13 +32,15 @@ def __init__(
conf: Optional[Any] = None,
stream: StreamDirConf,
num_samples: int,
logical_columns: Dict[str, LogicalType],
data_file: ShardFile,
meta_file: ShardFile,
) -> None:
super().__init__(
conf=conf,
stream=stream,
num_samples=num_samples,
logical_columns=logical_columns,
files=[data_file, meta_file],
)
self.data_file = data_file
Expand Down
5 changes: 4 additions & 1 deletion streaming/format/base/shard/mono_row.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

"""Streaming shard abstract base classes."""

from typing import Any, Optional
from typing import Any, Dict, Optional

from streaming.format.base.file import ShardFile
from streaming.format.base.shard.row import RowShard
from streaming.format.base.type import Type as LogicalType
from streaming.stream.dir_conf import StreamDirConf

__all__ = ['MonoRowShard']
Expand All @@ -30,12 +31,14 @@ def __init__(
conf: Optional[Any] = None,
stream: StreamDirConf,
num_samples: int,
logical_columns: Dict[str, LogicalType],
file: ShardFile,
) -> None:
super().__init__(
conf=conf,
stream=stream,
num_samples=num_samples,
logical_columns=logical_columns,
files=[file],
)
self.file = file
167 changes: 167 additions & 0 deletions streaming/format/base/type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""The Streaming logical type hierarchy.

This is a common language of types which the type systems of all Streaming shard formats are mapped
to. A field is stored as its shard format-specific physical type, and loaded and returned as its
logical type.
"""

from typing import Optional, Tuple

import numpy as np
from numpy.typing import DTypeLike

__all__ = [
'Type', 'Bytes', 'Str', 'Number', 'Decimal', 'Float', 'Float64', 'Float32', 'Float16', 'Int',
'Int64', 'Int32', 'Int16', 'Int8', 'UInt64', 'UInt32', 'UInt16', 'UInt8', 'Bool', 'NDArray',
'Image', 'JSON', 'Pickle'
]


class Type:
"""Logical type."""

def get_signature(self) -> str:
"""Get a string representation of this logical type.

Returns:
str: String representation.
"""
return self.__class__.__name__


class Bytes(Type):
"""Bytes logical type."""
pass


class Str(Type):
"""UTF-8 string logical type."""
pass


class Number(Type):
"""Number logical type."""
pass


class Decimal(Number):
"""Decimal logical type."""
pass


class Float(Number):
"""Native floating point logical type.

This logical type refers to your programming language's default floating point type. Presumably
the value will have been serialized at that precision or higher.

For example, in Python/CPython, the language has its own ``float`` type, which is internally
backed by a ``double`` in the implementation.
"""
pass


class Float64(Float):
"""Float64 logical type."""
pass


class Float32(Float64):
"""Float32 logical type."""
pass


class Float16(Float32):
"""Float16 logical type."""
pass


class Int(Number):
"""Arbitrary-precision integer logical type."""
pass


class Int64(Int):
"""``int64`` logical type."""
pass


class Int32(Int64):
"""``int32`` logical type."""
pass


class Int16(Int32):
"""``int16`` logical type."""
pass


class Int8(Int16):
"""``int8`` logical type."""
pass


class UInt64(Int):
"""``uint64`` logical type."""
pass


class UInt32(UInt64):
"""``uint32`` logical type."""
pass


class UInt16(UInt32):
"""``uint16`` logical type."""
pass


class UInt8(UInt16):
"""``uint8`` logical type."""
pass


class Bool(UInt8):
"""``bool`` logical type."""
pass


class NDArray(Type):
"""Numpy ndarray logical type.

Args:
shape (Tuple[int], optional): Optional shape requirement.
dtype (DTypeLike, optional): Optional dtype requirement.
"""

def __init__(
self,
shape: Optional[Tuple[int]] = None,
dtype: Optional[DTypeLike] = None,
) -> None:
self.shape = shape
self.dtype = np.dtype(dtype) if dtype else None

def get_signature(self) -> str:
logical_type = self.__class__.__name__
shape = ','.join(map(str, self.shape)) if self.shape else ''
dtype = self.dtype.name if self.dtype else ''
return ':'.join([logical_type, shape, dtype])


class Image(Type):
"""PIL Image logical type."""
pass


class JSON(Type):
"""JSON logical type."""
pass


class Pickle(Type):
"""Pickle logical type."""
pass
32 changes: 30 additions & 2 deletions streaming/format/jsonl/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from abc import ABC, abstractmethod
from typing import Any

__all__ = ['is_jsonl_encoded', 'is_jsonl_encoding']
from streaming.format.base.type import Float as LogicalFloat
from streaming.format.base.type import Int as LogicalInt
from streaming.format.base.type import Str as LogicalStr
from streaming.format.base.type import Type as LogicalType

__all__ = ['is_jsonl_encoded', 'is_jsonl_encoding', 'jsonl_encoding_to_logical_type']


class Encoding(ABC):
Expand Down Expand Up @@ -36,6 +41,8 @@ def _validate(data: Any, expected_type: Any) -> bool:
class Str(Encoding):
"""Store str."""

logical_type = LogicalStr
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering, where is this getting used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

streaming.format.mds.encoding.MDSEncoding.logical_type (start here)
-> streaming.format.base.type.Type (then go to the logical type class)
-> streaming.format.base.type.Type.get_signature (it has a stringify single column method)
-> streaming.format.base.shard.base.Shard.get_logical_type_signature (which is used by shard to stringify all columns for equality comparison)
-> streaming.dataset.StreamingDataset.__init__ (which is needed for allow_schema_mismatch impl)


@classmethod
def is_encoded(cls, obj: Any) -> bool:
return cls._validate(obj, str)
Expand All @@ -44,6 +51,8 @@ def is_encoded(cls, obj: Any) -> bool:
class Int(Encoding):
"""Store int."""

logical_type = LogicalInt

@classmethod
def is_encoded(cls, obj: Any) -> bool:
return cls._validate(obj, int)
Expand All @@ -52,12 +61,18 @@ def is_encoded(cls, obj: Any) -> bool:
class Float(Encoding):
"""Store float."""

logical_type = LogicalFloat

@classmethod
def is_encoded(cls, obj: Any) -> bool:
return cls._validate(obj, float)


_encodings = {'str': Str, 'int': Int, 'float': Float}
_encodings = {
'str': Str,
'int': Int,
'float': Float,
}


def is_jsonl_encoded(encoding: str, value: Any) -> bool:
Expand All @@ -84,3 +99,16 @@ def is_jsonl_encoding(encoding: str) -> bool:
bool: Whether encoding is supported.
"""
return encoding in _encodings


def jsonl_encoding_to_logical_type(encoding: str) -> LogicalType:
"""Get the logical type of the given encoding.

Args:
encoding (str): Encoding.

Returns:
LogicalType: Its logical type.
"""
cls = _encodings[encoding]
return cls.logical_type()
Loading
Loading