From b3a3075e591c7bdd49162a9c4f7b13d4a857e2b9 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 1 Oct 2024 14:30:39 -0700 Subject: [PATCH] Delete jax.lib.xla_client.execute_with_python_values. Nothing under jax.lib.xla_client is public, so there's no deprecation period required. PiperOrigin-RevId: 681166972 --- .../python/core/environments/xla_backend/runtime.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tensorflow_federated/python/core/environments/xla_backend/runtime.py b/tensorflow_federated/python/core/environments/xla_backend/runtime.py index adcd21a880..559fbe7e3f 100644 --- a/tensorflow_federated/python/core/environments/xla_backend/runtime.py +++ b/tensorflow_federated/python/core/environments/xla_backend/runtime.py @@ -15,6 +15,7 @@ from jax.lib import xla_client from jax.lib import xla_extension +import jax.numpy as jnp import numpy as np from tensorflow_federated.proto.v0 import computation_pb2 as pb @@ -166,12 +167,13 @@ def __call__(self, *args, **kwargs): flat_py_args = structure.flatten(positional_arg) reordered_flat_py_args = [ - flat_py_args[idx] for idx in self._inverted_parameter_tensor_indexes + jnp.asarray(flat_py_args[idx]) + for idx in self._inverted_parameter_tensor_indexes ] - unordered_result = xla_client.execute_with_python_values( - self._executable, reordered_flat_py_args, self._backend - ) + unordered_result = [ + np.asarray(x) for x in self._executable.execute(reordered_flat_py_args) + ] py_typecheck.check_type(unordered_result, list) result = [unordered_result[idx] for idx in self._result_tensor_indexes] result_type = self.type_signature.result