diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b0de811 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.pyc +dist +build +*.egg-info diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..b16bd94 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,33 @@ +# How to contribute + +We'd love to accept your patches and contributions to this project. + +## Before you begin + +### Sign our Contributor License Agreement + +Contributions to this project must be accompanied by a +[Contributor License Agreement](https://cla.developers.google.com/about) (CLA). +You (or your employer) retain the copyright to your contribution; this simply +gives us permission to use and redistribute your contributions as part of the +project. + +If you or your current employer have already signed the Google CLA (even if it +was for a different project), you probably don't need to do it again. + +Visit to see your current agreements or to +sign a new one. + +### Review our community guidelines + +This project follows +[Google's Open Source Community Guidelines](https://opensource.google/conduct/). + +## Contribution process + +### Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..3784424 --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +Package of Pathways-on-Cloud utilities + +For customers to utilize Pathways-on-Cloud, there are several in changes that need to be made to the user job. We want to encapsulate these changes to a single Python package for two primary reasons. First, most of the changes are temporary patches that will not be needed long-term. Second, several of the changes follow anti-patterns and we want to confine them to a single repository. diff --git a/pathwaysutils/__init__.py b/pathwaysutils/__init__.py new file mode 100644 index 0000000..f22638a --- /dev/null +++ b/pathwaysutils/__init__.py @@ -0,0 +1,42 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Package of Pathways-on-Cloud utilities.""" + +import datetime +from absl import logging +import jax +from pathwaysutils import cloud_logging +from pathwaysutils import profiling +from pathwaysutils import proxy_backend +from pathwaysutils.persistence import pathways_orbax_handler + + +# This is a brittle implementation since the platforms value is not necessarily +# which backend is ultimately selected +def _is_pathways_used(): + return jax.config.jax_platforms and "proxy" in jax.config.jax_platforms + + +if _is_pathways_used(): + logging.warning("pathwaysutils: Detected Pathways-on-Cloud backend. Applying changes.") + proxy_backend.register_backend_factory() + profiling.monkey_patch_jax() + # pathways_orbax_handler.register_pathways_handlers( + # datetime.timedelta(minutes=10) + # ) + cloud_logging.setup() +else: + logging.warning( + "pathwaysutils: Did not detect Pathways-on-Cloud backend. No changes applied." + ) diff --git a/pathwaysutils/cloud_logging.py b/pathwaysutils/cloud_logging.py new file mode 100644 index 0000000..8064aa9 --- /dev/null +++ b/pathwaysutils/cloud_logging.py @@ -0,0 +1,21 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Package for setting up Cloud Logging for Python.""" + +import google.cloud.logging + + +def setup(): + client = google.cloud.logging.Client() + client.setup_logging() diff --git a/pathwaysutils/persistence/__init__.py b/pathwaysutils/persistence/__init__.py new file mode 100644 index 0000000..10c9675 --- /dev/null +++ b/pathwaysutils/persistence/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/pathwaysutils/persistence/helper.py b/pathwaysutils/persistence/helper.py new file mode 100644 index 0000000..5c2d252 --- /dev/null +++ b/pathwaysutils/persistence/helper.py @@ -0,0 +1,178 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helper functions for persistence.""" + +import base64 +import datetime +import json +from typing import Sequence, Union + +import jax +from jax import core +from jax._src.lib import xla_client as xc +import numpy as np +from pathwaysutils import plugin_executable + + +def base64_utf8_stringify(bs: bytes) -> str: + """Converts bytes to a base64-encoded utf-8 string. + + Args: + bs: The bytes to convert. + + Returns: + The base64-encoded utf-8 string. + """ + return base64.b64encode(bs).decode("utf-8") + + +def string_to_base64(text: str) -> str: + """Encodes a string to base64 format. + + Args: + text: The string to encode. + + Returns: + The base64-encoded string. + """ + return base64_utf8_stringify(text.encode("utf-8")) + + +def get_hlo_sharding_string( + sharding: jax.sharding.XLACompatibleSharding, + num_dimensions: int, +) -> str: + """Serializes the sharding to an hlo-sharding, encodes it to base64 and returns the base-64 as an utf-8 string.""" + return base64_utf8_stringify( + # pylint:disable=protected-access + sharding._to_xla_hlo_sharding(num_dimensions) # pytype: disable=attribute-error + # pylint:enable=protected-access + .to_proto().SerializeToString() + ) + + +def get_shape_string( + dtype: np.dtype, + shape: Sequence[int], +) -> str: + """Serializes the shape, encodes it to base64 and returns the base-64 as an utf-8 string.""" + return base64_utf8_stringify( + xc.Shape.array_shape( + xc.PrimitiveType(xc.dtype_to_etype(dtype)), + shape, + ) + .with_major_to_minor_layout_if_absent() + .to_serialized_proto() + ) + + +def get_write_request( + location_path: str, + name: str, + jax_array: jax.Array, + timeout: int, +) -> str: + """Returns a string representation of the plugin program which writes the given jax_array to the given location.""" + sharding = jax_array.sharding + assert isinstance(sharding, jax.sharding.XLACompatibleSharding), sharding + return json.dumps({ + "persistenceWriteRequest": { + "b64_location": string_to_base64(location_path), + "b64_name": string_to_base64(name), + "b64_hlo_sharding_string": get_hlo_sharding_string( + jax_array.sharding, len(jax_array.shape) + ), + "shape": jax_array.shape, + "devices": { + "device_ids": [ + # pylint:disable=protected-access + device.id + for device in sharding._device_assignment + # pylint:enable=protected-access + ], + }, + "timeout": {"seconds": timeout}, + } + }) + + +def get_read_request( + location_path: str, + name: str, + dtype: np.dtype, + shape: Sequence[int], + sharding: jax.sharding.XLACompatibleSharding, + devices: Sequence[jax.Device], + timeout_seconds: int, +) -> str: + """Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding.""" + if not isinstance(devices, np.ndarray): + devices = np.array(devices) + return json.dumps({ + "persistenceReadRequest": { + "b64_location": string_to_base64(location_path), + "b64_shape_proto_string": get_shape_string(dtype, shape), + "b64_name": string_to_base64(name), + "b64_hlo_sharding_string": get_hlo_sharding_string( + sharding, len(shape) + ), + "devices": { + "device_ids": [device.id for device in devices.flatten()] + }, + "timeout": {"seconds": timeout_seconds}, + } + }) + + +def write_one_array( + location: str, + name: str, + value: jax.Array, + timeout: datetime.timedelta, +): + """Creates the write array plugin program string, compiles it to an executable, calls it and returns an awaitable future.""" + write_request = get_write_request( + location, name, value, timeout.total_seconds() + ) + write_executable = plugin_executable.PluginExecutable(write_request) + _, write_future = write_executable.call([value]) + return write_future + + +def read_one_array( + location: str, + name: str, + dtype: np.dtype, + shape: Sequence[int], + shardings: jax.sharding.XLACompatibleSharding, + devices: Union[Sequence[jax.Device], np.ndarray], + timeout: datetime.timedelta, +): + """Creates the read array plugin program string, compiles it to an executable, calls it and returns the result.""" + read_request = get_read_request( + location, + name, + dtype, + shape, + shardings, + devices, + timeout.total_seconds(), + ) + read_executable = plugin_executable.PluginExecutable(read_request) + out_aval = core.ShapedArray(shape, dtype) + read_array, read_future = read_executable.call( + out_shardings=[shardings], out_avals=[out_aval] + ) + read_future.result() + return read_array[0] diff --git a/pathwaysutils/persistence/pathways_orbax_handler.py b/pathwaysutils/persistence/pathways_orbax_handler.py new file mode 100644 index 0000000..7fa3f50 --- /dev/null +++ b/pathwaysutils/persistence/pathways_orbax_handler.py @@ -0,0 +1,192 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TypeHandlers supporting Pathways backend.""" + +import collections +import datetime +import functools +import typing +from typing import Optional, Sequence + +from absl import logging +import jax +from orbax.checkpoint import future +from orbax.checkpoint import type_handlers +from pathwaysutils.persistence import helper + +ParamInfo = type_handlers.ParamInfo +SaveArgs = type_handlers.SaveArgs +RestoreArgs = type_handlers.RestoreArgs +ArrayRestoreArgs = type_handlers.ArrayRestoreArgs + + +def extract_parent_dir_and_name( + infos: Sequence[ParamInfo], +) -> tuple[Sequence[str], Sequence[str]]: + """Extracts names and locations from ParamInfos.""" + parent_dirs = [str(info.parent_dir) for info in infos] + names = [str(info.name) for info in infos] + return parent_dirs, names + + +class CloudPathwaysArrayHandler(type_handlers.ArrayHandler): + """A TypeHandler for array types when using Pathways.""" + + def __init__( + self, + read_timeout: Optional[datetime.timedelta] = None, + use_ocdbt: bool = False, + ): + """Constructor. + + Args: + read_timeout: Duration indicating the timeout for reading arrays + use_ocdbt: allows using Tensorstore OCDBT driver. + """ + self._read_timeout = read_timeout + + if use_ocdbt: + raise ValueError('OCDBT not supported for Pathways.') + super().__init__() + + async def serialize( + self, + values: Sequence[jax.Array], + infos: Sequence[ParamInfo], + args: Optional[Sequence[SaveArgs]] = None, + ) -> Sequence[future.Future]: + """Uses Pathways Persistence API to serialize a jax array.""" + type_handlers.check_input_arguments(values, infos, args) + + if any([arg.dtype is not None for arg in args]): + raise ValueError('Casting during save not supported for Pathways.') + + locations, names = extract_parent_dir_and_name(infos) + f = functools.partial( + helper.write_one_array, timeout=self._read_timeout + ) + return list(map(f, locations, names, values)) + + async def deserialize( + self, + infos: Sequence[ParamInfo], + args: Optional[Sequence[RestoreArgs]] = None, + ) -> Sequence[jax.Array]: + """Uses Pathways Persistence API to deserialize a jax array.""" + if args is None: + raise ValueError('Must provide ArrayRestoreArgs to restore as jax.Array.') + type_handlers.check_input_arguments(infos, args) + + global_meshes = [] + mesh_axes = [] + global_shapes = [] + dtypes = [] + shardings = [] + + should_open_metadata = False + for arg in args: + if not isinstance(arg, ArrayRestoreArgs): + raise ValueError( + 'To restore jax.Array, provide ArrayRestoreArgs; found' + f' {type(arg).__name__}' + ) + arg = typing.cast(ArrayRestoreArgs, arg) + if arg.sharding is None and (arg.mesh is None or arg.mesh_axes is None): + raise ValueError( + 'Sharding of jax.Array cannot be None. Provide `mesh`' + ' and `mesh_axes` OR `sharding`.' + ) + if arg.sharding is None: + global_meshes.append(arg.mesh) + mesh_axes.append(arg.mesh_axes) + shardings.append( + jax.sharding.NamedSharding(mesh=arg.mesh, spec=arg.mesh_axes) + ) + else: + if not isinstance(arg.sharding, jax.sharding.NamedSharding): + raise ValueError('Pathways only supports jax.sharding.NamedSharding.') + sharding = typing.cast(jax.sharding.NamedSharding, arg.sharding) + global_meshes.append(sharding.mesh) + mesh_axes.append(sharding.spec) + shardings.append(sharding) + if arg.global_shape is None or arg.dtype is None: + logging.warning( + 'Shape or dtype not provided for restoration. Provide these' + ' properties for improved performance.' + ) + should_open_metadata = True + global_shapes.append(arg.global_shape) + dtypes.append(arg.dtype) + + if should_open_metadata: + metadatas = await self.metadata(infos) + global_shapes = [ + m.shape if s is None else s for m, s in zip(metadatas, global_shapes) + ] + dtypes = [m.dtype if d is None else d for m, d in zip(metadatas, dtypes)] + + # Group inputs by global_mesh so that we can perform batched Array + # construction for each global_mesh. + inputs_by_global_mesh = collections.defaultdict(list) + for i, global_mesh in enumerate(global_meshes): + inputs_by_global_mesh[global_mesh].append(i) + + results = [None] * len(infos) + + for global_mesh, idxs in inputs_by_global_mesh.items(): + grouped_infos = [infos[idx] for idx in idxs] + grouped_global_shapes = [global_shapes[idx] for idx in idxs] + grouped_dtypes = [dtypes[idx] for idx in idxs] + grouped_shardings = [shardings[idx] for idx in idxs] + locations, names = extract_parent_dir_and_name(grouped_infos) + f = functools.partial( + helper.read_one_array, + devices=global_mesh.devices, + timeout=self._read_timeout, + ) + grouped_arrays = [ + f( + location=location, + name=name, + dtype=dtype, + shape=shape, + shardings=sharding, + ) + for location, name, dtype, shape, sharding in zip( + locations, + names, + grouped_dtypes, + grouped_global_shapes, + grouped_shardings, + ) + ] + for idx, arr in zip(idxs, grouped_arrays): + results[idx] = arr + return results # pytype: disable=bad-return-type + + +def register_pathways_handlers( + read_timeout: Optional[datetime.timedelta] = None, +): + """Function that must be called before saving or restoring with Pathways.""" + logging.warning( + 'Registering CloudPathwaysArrayHandler (Pathways Persistence API).' + ) + type_handlers.register_type_handler( + jax.Array, + CloudPathwaysArrayHandler( + read_timeout=read_timeout, + ), + override=True, + ) diff --git a/pathwaysutils/plugin_executable.py b/pathwaysutils/plugin_executable.py new file mode 100644 index 0000000..1f494c5 --- /dev/null +++ b/pathwaysutils/plugin_executable.py @@ -0,0 +1,65 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PluginExecutable is a class for executing plugin programs.""" + +import concurrent.futures +import threading +from typing import List, Sequence, Tuple + +import jax +from jax._src.interpreters import pxla +from jaxlib import xla_client + + +class PluginExecutable: + """Class for running compiled IFRT program over the IFRT Proxy.""" + + def __init__(self, prog_str: str): + ifrt_client = jax.local_devices()[0].client + program = xla_client.ifrt_programs.make_plugin_program(prog_str) + options = xla_client.ifrt_programs.make_plugin_compile_options() + self.compiled = ifrt_client.compile_ifrt_program(program, options) + + def call( + self, + in_arr: Sequence[List[jax.Array]] = (), + out_shardings: Sequence[jax.sharding.XLACompatibleSharding] = (), + out_avals: Sequence[jax.core.ShapedArray] = (), + out_committed: bool = True, + ) -> Tuple[Sequence[jax.Array], concurrent.futures.Future[None]]: + """Runs the compiled IFRT program and returns the result and a future.""" + results_with_token = self.compiled.execute_sharded(in_arr, with_tokens=True) + + out_arr = results_with_token.consume_with_handlers( + pxla.global_avals_to_results_handler( + out_avals, out_shardings, out_committed + ).handlers + ) + + out_fut = concurrent.futures.Future() + + def call_on_done(): + try: + results_with_token.consume_token().block_until_ready() + except Exception as e: # pylint: disable=broad-exception-caught + out_fut.set_exception(e) + return + out_fut.set_result(None) + + t = threading.Thread( + target=call_on_done, name="plugin_executable_call_on_done" + ) + t.start() + + return (out_arr, out_fut) diff --git a/pathwaysutils/profiling.py b/pathwaysutils/profiling.py new file mode 100644 index 0000000..7fd2a16 --- /dev/null +++ b/pathwaysutils/profiling.py @@ -0,0 +1,183 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Profiling utilites.""" + +import dataclasses +import threading +import time + +from absl import logging +from fastapi import FastAPI +import jax +from jax import numpy as jnp +from pathwaysutils import plugin_executable +import uvicorn + +logging.set_verbosity(logging.INFO) + + +class _ProfileState: + def __init__(self): + self.executable = None + self.lock = threading.Lock() + + def reset(self): + self.executable = None + + +_profile_state = _ProfileState() +_original_start_trace = jax.profiler.start_trace +_original_stop_trace = jax.profiler.stop_trace + + +def toy_computation(): + """A toy computation to run before the first profile.""" + x = jax.jit(lambda x: x + 1)(jnp.array(1)) + x.block_until_ready() + + +def start_trace(gcs_bucket: str): + """Starts a profiler trace.""" + with _profile_state.lock: + if start_trace._first_profile_start: # pylint: disable=protected-access, attribute-error + start_trace._first_profile_start = False # pylint: disable=protected-access + toy_computation() + + if _profile_state.executable is not None: + raise ValueError( + "start_trace called while a trace is already being taken!" + ) + _profile_state.executable = plugin_executable.PluginExecutable( + f"{{profileRequest: {{traceLocation: '{gcs_bucket}'}}}}" + ) + try: + _profile_state.executable.call()[1].result() + except: + _profile_state.reset() + raise + + _original_start_trace(gcs_bucket) + + +start_trace._first_profile_start = True # pylint: disable=protected-access + + +def stop_trace(): + """Stops the currently-running profiler trace.""" + with _profile_state.lock: + if _profile_state.executable is None: + raise ValueError("stop_trace called before a trace is being taken!") + try: + _profile_state.executable.call()[1].result() + except: + _profile_state.reset() + raise + _profile_state.reset() + + _original_stop_trace() + + +_profiler_thread = None + + +def start_server(port: int): + """Starts the profiling server on port `port`. + + The signature is slightly different from `jax.profiler.start_server` + because no handle to the server is returned because there is no + `xla_client.profiler.ProfilerServer` to return. + + Args: + port : The port to start the server on. + """ + def server_loop(port: int): + logging.info("Starting JAX profiler server on port %s", port) + app = FastAPI() + + @dataclasses.dataclass + class ProfilingConfig: + duration_ms: int + repository_path: str + + @app.post("/profiling") + async def profiling(pc: ProfilingConfig): + logging.info("Capturing profiling data for %s ms", pc.duration_ms) + logging.info("Writing profiling data to %s", pc.repository_path) + jax.profiler.start_trace(pc.repository_path) + time.sleep(pc.duration_ms / 1e3) + jax.profiler.stop_trace() + return {"response": "profiling completed"} + + uvicorn.run(app, host="0.0.0.0", port=port, log_level="info") + + global _profiler_thread + if _profiler_thread is not None: + raise ValueError("Only one profiler server can be active at a time.") + + _profiler_thread = threading.Thread(target=server_loop, args=(port,)) + _profiler_thread.start() + + +def stop_server(): + """Raises an error if there is not an active profiler server but otherwise does nothing. + + Pathways profiling servers are not stoppable at this time. + """ + if _profiler_thread is None: + raise ValueError("No active profiler server.") + + +def monkey_patch_jax(): + """Monkey patches JAX with Pathways versions of functions. + + The signatures in patched functions should match the original. + + Patched functions are: + - `jax.profiler.start_trace` + https://jax.readthedocs.io/en/latest/_autosummary/jax.profiler.start_trace.html + - `jax.profiler.stop_trace` + https://jax.readthedocs.io/en/latest/_autosummary/jax.profiler.stop_trace.html + - `jax.profiler.start_server` + https://jax.readthedocs.io/en/latest/_autosummary/jax.profiler.start_server.html + - `jax.profiler.stop_server` + """ + + def start_trace_patch( + log_dir, + create_perfetto_link: bool = False, # pylint: disable=unused-argument + create_perfetto_trace: bool = False, # pylint: disable=unused-argument + ) -> None: + logging.info("jax.profile.start_trace patched with pathways' start_trace") + return start_trace(log_dir) + + jax.profiler.start_trace = start_trace_patch + + def stop_trace_patch() -> None: + logging.info("jax.profile.stop_trace patched with pathways' stop_trace") + return stop_trace() + + jax.profiler.stop_trace = stop_trace_patch + + def start_server_patch(port: int): + logging.info("jax.profile.start_server patched with pathways' start_server") + return start_server(port) + + jax.profiler.start_server = start_server_patch + + def stop_server_patch(): + logging.info("jax.profile.stop_server patched with pathways' stop_server") + return stop_server() + + jax.profiler.stop_server = stop_server_patch + diff --git a/pathwaysutils/proxy_backend.py b/pathwaysutils/proxy_backend.py new file mode 100644 index 0000000..1cac7f9 --- /dev/null +++ b/pathwaysutils/proxy_backend.py @@ -0,0 +1,29 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Register the IFRT Proxy as a backend for JAX.""" + +import jax +from jax._src import xla_bridge +from jaxlib.xla_extension import ifrt_proxy + + +def register_backend_factory(): + xla_bridge.register_backend_factory( + "proxy", + lambda: ifrt_proxy.get_client( + jax.config.read("jax_backend_target"), + ifrt_proxy.ClientConnectionOptions(), + ), + priority=-1, + ) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a40a00c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +jax[cpu]>=0.4.26 +absl-py +orbax +uvicorn +fastapi +google-cloud-logging diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..9b5844b --- /dev/null +++ b/setup.py @@ -0,0 +1,32 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Setup file for pathwaysutils.""" + +import setuptools + +setuptools.setup( + name='pathwaysutils', + version='0.0.5', + description='Pathways-on-Cloud utilities', + packages=setuptools.find_packages(), # Automatically find packages + # Add any dependencies your package needs here: + install_requires=[ + 'google-cloud-logging', + 'jax[cpu]>=0.4.26', + 'absl-py', + 'orbax', + 'uvicorn', + 'fastapi', + ], +)