Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update orbax handler to use bulk read APIs. #33

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 92 additions & 5 deletions pathwaysutils/persistence/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
"""Helper functions for persistence."""

import base64
import concurrent.futures
import datetime
import json
from typing import Sequence, Union
from typing import Sequence, Tuple, Union

import jax
from jax import core
Expand Down Expand Up @@ -82,6 +83,7 @@ def get_write_request(
name: str,
jax_array: jax.Array,
timeout: datetime.timedelta,
return_dict: bool = False,
) -> str:
"""Returns a string representation of the plugin program which writes the given jax_array to the given location."""
sharding = jax_array.sharding
Expand All @@ -91,7 +93,7 @@ def get_write_request(
timeout.total_seconds(), 1
)
timeout_nanoseconds = timeout_fractional_seconds * 1e9
return json.dumps({
d = {
"persistenceWriteRequest": {
"b64_location": string_to_base64(location_path),
"b64_name": string_to_base64(name),
Expand All @@ -112,7 +114,29 @@ def get_write_request(
"nanos": int(timeout_nanoseconds),
},
}
})
}

if return_dict:
return d
return json.dumps(d)


def get_bulk_write_request(
location_path: str,
names: Sequence[str],
jax_arrays: Sequence[jax.Array],
timeout: datetime.timedelta,
) -> str:
"""Returns a string representation of a bulk write request, writes multiple arrays with one call."""
write_requests = [
get_write_request(location_path, name, jax_array, timeout, True)[
"persistenceWriteRequest"
]
for name, jax_array in zip(names, jax_arrays)
]
return json.dumps(
{"bulk_persistence_write_request": {"write_requests": write_requests}}
)


def get_read_request(
Expand All @@ -123,6 +147,7 @@ def get_read_request(
sharding: jax.sharding.Sharding,
devices: Sequence[jax.Device],
timeout: datetime.timedelta,
return_dict: bool = False,
) -> str:
"""Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding."""
if not isinstance(devices, np.ndarray):
Expand All @@ -132,7 +157,7 @@ def get_read_request(
timeout.total_seconds(), 1
)
timeout_nanoseconds = timeout_fractional_seconds * 1e9
return json.dumps({
d = {
"persistenceReadRequest": {
"b64_location": string_to_base64(location_path),
"b64_shape_proto_string": get_shape_string(dtype, shape),
Expand All @@ -148,7 +173,32 @@ def get_read_request(
"nanos": int(timeout_nanoseconds),
},
}
})
}

if return_dict:
return d
return json.dumps(d)


def get_bulk_read_request(
location_path: str,
names: Sequence[str],
dtypes: np.dtype,
shapes: Sequence[Sequence[int]],
shardings: Sequence[jax.sharding.Sharding],
devices: Sequence[jax.Device],
timeout: datetime.timedelta,
) -> str:
"""Returns a string representation of a bulk read request, reads multiple arrays with one call."""
read_requests = [
get_read_request(
location_path, name, dtype, shape, sharding, devices, timeout, True
)["persistenceReadRequest"]
for name, dtype, shape, sharding in zip(names, dtypes, shapes, shardings)
]
return json.dumps(
{"bulk_persistence_read_request": {"read_requests": read_requests}}
)


def write_one_array(
Expand All @@ -164,6 +214,19 @@ def write_one_array(
return write_future


def write_arrays(
location: str,
names: Sequence[str],
values: Sequence[jax.Array],
timeout: datetime.timedelta,
):
"""Creates the write array plugin program string, compiles it to an executable, calls it and returns an awaitable future."""
bulk_write_request = get_bulk_write_request(location, names, values, timeout)
bulk_write_executable = plugin_executable.PluginExecutable(bulk_write_request)
_, bulk_write_future = bulk_write_executable.call(values)
return bulk_write_future


def read_one_array(
location: str,
name: str,
Expand All @@ -190,3 +253,27 @@ def read_one_array(
)
read_future.result()
return read_array[0]


def read_arrays(
location: str,
names: Sequence[str],
dtypes: Sequence[np.dtype],
shapes: Sequence[int],
shardings: Sequence[jax.sharding.Sharding],
devices: Union[Sequence[jax.Device], np.ndarray],
timeout: datetime.timedelta,
) -> Tuple[Sequence[jax.Array], concurrent.futures.Future[None]]:
"""Creates the read array plugin program string, compiles it to an executable, calls it and returns the result."""

bulk_read_request = get_bulk_read_request(
location, names, dtypes, shapes, shardings, devices, timeout
)
bulk_read_executable = plugin_executable.PluginExecutable(bulk_read_request)
out_avals = [
core.ShapedArray(shape, dtype) for shape, dtype in zip(shapes, dtypes)
]
arrays, read_future = bulk_read_executable.call(
out_shardings=shardings, out_avals=out_avals
)
return (arrays, read_future)
54 changes: 21 additions & 33 deletions pathwaysutils/persistence/pathways_orbax_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
self._read_timeout = read_timeout

if use_ocdbt:
raise ValueError('OCDBT not supported for Pathways.')
raise ValueError("OCDBT not supported for Pathways.")
super().__init__()

async def serialize(
Expand All @@ -73,12 +73,10 @@ async def serialize(
type_handlers.check_input_arguments(values, infos, args)

if any([arg.dtype is not None for arg in args]):
raise ValueError('Casting during save not supported for Pathways.')
raise ValueError("Casting during save not supported for Pathways.")

locations, names = extract_parent_dir_and_name(infos)
f = functools.partial(
helper.write_one_array, timeout=self._read_timeout
)
f = functools.partial(helper.write_one_array, timeout=self._read_timeout)
return list(map(f, locations, names, values))

async def deserialize(
Expand All @@ -88,7 +86,7 @@ async def deserialize(
) -> Sequence[jax.Array]:
"""Uses Pathways Persistence API to deserialize a jax array."""
if args is None:
raise ValueError('Must provide ArrayRestoreArgs to restore as jax.Array.')
raise ValueError("Must provide ArrayRestoreArgs to restore as jax.Array.")
type_handlers.check_input_arguments(infos, args)

global_meshes = []
Expand All @@ -101,14 +99,14 @@ async def deserialize(
for arg in args:
if not isinstance(arg, ArrayRestoreArgs):
raise ValueError(
'To restore jax.Array, provide ArrayRestoreArgs; found'
f' {type(arg).__name__}'
"To restore jax.Array, provide ArrayRestoreArgs; found"
f" {type(arg).__name__}"
)
arg = typing.cast(ArrayRestoreArgs, arg)
if arg.sharding is None and (arg.mesh is None or arg.mesh_axes is None):
raise ValueError(
'Sharding of jax.Array cannot be None. Provide `mesh`'
' and `mesh_axes` OR `sharding`.'
"Sharding of jax.Array cannot be None. Provide `mesh`"
" and `mesh_axes` OR `sharding`."
)
if arg.sharding is None:
global_meshes.append(arg.mesh)
Expand All @@ -118,15 +116,15 @@ async def deserialize(
)
else:
if not isinstance(arg.sharding, jax.sharding.NamedSharding):
raise ValueError('Pathways only supports jax.sharding.NamedSharding.')
raise ValueError("Pathways only supports jax.sharding.NamedSharding.")
sharding = typing.cast(jax.sharding.NamedSharding, arg.sharding)
global_meshes.append(sharding.mesh)
mesh_axes.append(sharding.spec)
shardings.append(sharding)
if arg.global_shape is None or arg.dtype is None:
logger.warning(
'Shape or dtype not provided for restoration. Provide these'
' properties for improved performance.'
"Shape or dtype not provided for restoration. Provide these"
" properties for improved performance."
)
should_open_metadata = True
global_shapes.append(arg.global_shape)
Expand All @@ -153,27 +151,17 @@ async def deserialize(
grouped_dtypes = [dtypes[idx] for idx in idxs]
grouped_shardings = [shardings[idx] for idx in idxs]
locations, names = extract_parent_dir_and_name(grouped_infos)
f = functools.partial(
helper.read_one_array,
devices=global_mesh.devices,
grouped_arrays, read_future = helper.read_arrays(
locations[0],
names,
grouped_dtypes,
grouped_global_shapes,
grouped_shardings,
global_mesh.devices,
timeout=self._read_timeout,
)
grouped_arrays = [
f(
location=location,
name=name,
dtype=dtype,
shape=shape,
shardings=sharding,
)
for location, name, dtype, shape, sharding in zip(
locations,
names,
grouped_dtypes,
grouped_global_shapes,
grouped_shardings,
)
]
# each persistence call is awaited serially.
read_future.result()
for idx, arr in zip(idxs, grouped_arrays):
results[idx] = arr
return results # pytype: disable=bad-return-type
Expand All @@ -184,7 +172,7 @@ def register_pathways_handlers(
):
"""Function that must be called before saving or restoring with Pathways."""
logger.debug(
'Registering CloudPathwaysArrayHandler (Pathways Persistence API).'
"Registering CloudPathwaysArrayHandler (Pathways Persistence API)."
)
type_handlers.register_type_handler(
jax.Array,
Expand Down
Loading