Skip to content

Commit

Permalink
Migrate tff.federated_computation to federated_language.federated_com…
Browse files Browse the repository at this point in the history
…putation.

...in .py code locations.

This API has moved to the new federated_language project.

PiperOrigin-RevId: 707583133
  • Loading branch information
Chloé Kiddon authored and copybara-github committed Jan 3, 2025
1 parent bf3843c commit 342ce81
Show file tree
Hide file tree
Showing 18 changed files with 88 additions and 71 deletions.
6 changes: 3 additions & 3 deletions examples/program/computations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
METRICS_TOTAL_SUM = 'total_sum'


@tff.federated_computation()
@federated_language.federated_computation()
def initialize():
"""Returns the initial state."""
return tff.federated_value(0, federated_language.SERVER)
Expand All @@ -44,7 +44,7 @@ def _sum_integers(x: int, y: int) -> int:
return x + y


@tff.federated_computation(
@federated_language.federated_computation(
federated_language.FederatedType(np.int32, federated_language.SERVER),
federated_language.FederatedType(
federated_language.SequenceType(np.int32), federated_language.CLIENTS
Expand Down Expand Up @@ -78,7 +78,7 @@ def train(server_state: int, client_data: tf.data.Dataset):
return updated_state, metrics


@tff.federated_computation(
@federated_language.federated_computation(
federated_language.FederatedType(np.int32, federated_language.SERVER),
federated_language.FederatedType(
federated_language.SequenceType(np.int32), federated_language.CLIENTS
Expand Down
6 changes: 3 additions & 3 deletions examples/simple_fedavg/simple_fedavg_tff.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def initialize_client_data(server_message):
model = model_fn()
return init_client_ouput(model, server_message)

@tff.federated_computation(
@federated_language.federated_computation(
server_message_type, federated_language.SequenceType(batch_type)
)
def client_update_weights_fn(server_message, batches):
Expand Down Expand Up @@ -174,7 +174,7 @@ def server_message_fn(server_state):
tf_dataset_type, federated_language.CLIENTS
)

@tff.federated_computation(
@federated_language.federated_computation(
federated_server_state_type, federated_dataset_type
)
def run_one_round(server_state, federated_dataset):
Expand Down Expand Up @@ -221,7 +221,7 @@ def run_one_round(server_state, federated_dataset):
)
return server_state, aggregated_outputs

@tff.federated_computation
@federated_language.federated_computation
def server_init_tff():
"""Orchestration logic for server model initialization."""
return tff.federated_eval(server_init_tf, federated_language.SERVER)
Expand Down
4 changes: 2 additions & 2 deletions examples/stateful_clients/stateful_fedavg_tff.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def client_update_fn(tf_dataset, client_state, server_message):
client_state_type, federated_language.CLIENTS
)

@tff.federated_computation(
@federated_language.federated_computation(
federated_server_state_type,
federated_dataset_type,
federated_client_state_type,
Expand Down Expand Up @@ -178,7 +178,7 @@ def run_one_round(server_state, federated_dataset, client_states):

return server_state, round_loss_metric, client_outputs.client_state

@tff.federated_computation
@federated_language.federated_computation
def server_init_tff():
"""Orchestration logic for server model initialization."""
return tff.federated_value(server_init_tf(), federated_language.SERVER)
Expand Down
6 changes: 4 additions & 2 deletions tensorflow_federated/python/aggregators/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,7 +1254,8 @@ def _build_test_sum_fn_py_bounds(value_type, lower_bound, upper_bound):
upper_bound: A Python numeric constant or a numpy array.
Returns:
A `tff.federated_computation` with type signature `(value_type@CLIENTS ->
A `federated_language.federated_computation` with type signature
`(value_type@CLIENTS ->
value_type@SERVER)`.
"""

Expand Down Expand Up @@ -1286,7 +1287,8 @@ def _build_test_sum_fn_tff_bounds(
upper_bound_type: A `federated_language.Type` of upper_bound to be used.
Returns:
A `tff.federated_computation` with type signature `((value_type@CLIENTS,
A `federated_language.federated_computation` with type signature
`((value_type@CLIENTS,
lower_bound_type@SERVER, upper_bound_type@SERVER) -> value_type@SERVER)`.
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
`MapReduceForm` maps to a single federated round.
```python
@tff.federated_computation
@federated_language.federated_computation
def round_comp(server_state, client_data):
# The server prepares an input to be broadcast to all clients that controls
Expand Down Expand Up @@ -208,7 +208,7 @@ def round_comp(server_state, client_data):
single federated round.
```python
@tff.federated_computation
@federated_language.federated_computation
def round_comp(server_state, client_data):
# The server prepares an input to be broadcast to all clients and generates
# a temporary state that may be used by later parts of the computation.
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_federated/python/core/backends/mapreduce/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def _check_tensorflow_computation(label, comp):


def _check_lambda_computation(label, comp):
"""Validates a Lambda computation."""
py_typecheck.check_type(
comp, federated_language.framework.ConcreteComputation, label
)
Expand Down Expand Up @@ -131,7 +132,7 @@ class BroadcastForm:
```
server_data_type = self.compute_server_context.type_signature.parameter
client_data_type = self.client_processing.type_signature.parameter[1]
@tff.federated_computation(server_data_type, client_data_type)
@federated_language.federated_computation(server_data_type, client_data_type)
def _(server_data, client_data):
# Select out the bit of server context to send to the clients.
context_at_server = tff.federated_map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class MergeableCompForm:
without changing the results of the computation, a computation of the form:
```
@tff.federated_computation(...)
@federated_language.federated_computation(...)
def single_aggregation(arg):
result_at_clients = work(arg)
agg_result = tff.federated_aggregate(
Expand All @@ -80,19 +80,19 @@ def single_aggregation(arg):
can be represented as the `MergeableCompForm` triplet:
```
@tff.federated_computation(federated_language.AbstractType('T'))
@federated_language.federated_computation(federated_language.AbstractType('T'))
def up_to_merge(arg):
result_at_clients = work(arg)
agg_result = tff.federated_aggregate(
result_at_clients, accumulate_zero, accumulate, merge, identity_report)
return agg_result
@tff.federated_computation([up_to_merge.type_signature.result.member,
@federated_language.federated_computation([up_to_merge.type_signature.result.member,
up_to_merge.type_signature.result.member])
def merge(arg):
return merge(arg[0], arg[1])
@tff.federated_computation(
@federated_language.federated_computation(
federated_language.AbstractType('T'),
federated_language.FederatedType(merge.type_signature.result,
federated_language.SERVER),
Expand Down
9 changes: 6 additions & 3 deletions tensorflow_federated/python/learning/metrics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def sum_then_finalize(
the polymorphic method is invoked on.
Note: invoking this computation outside of a federated context (a method
decorated with `tff.federated_computation`) will require first wrapping it in
decorated with `federated_language.federated_computation`) will require first
wrapping it in
a concrete, non-polymorphic `federated_language.Computation` with appropriate
federated
types.
Expand Down Expand Up @@ -148,7 +149,8 @@ def secure_sum_then_finalize(
the polymorphic method is invoked on.
Note: invoking this computation outside of a federated context (a method
decorated with `tff.federated_computation`) will require first wrapping it in
decorated with `federated_language.federated_computation`) will require first
wrapping it in
a concrete, non-polymorphic `federated_language.Computation` with appropriate
federated
types.
Expand Down Expand Up @@ -319,7 +321,8 @@ def finalize_then_sample(
the polymorphic method is invoked on.
Note: invoking this computation outside of a federated context (a method
decoratedc with `tff.federated_computation`) will require first wrapping it in
decoratedc with `federated_language.federated_computation`) will require first
wrapping it in
a concrete, non-polymorphic `federated_language.Computation` with appropriate
federated
types.
Expand Down
5 changes: 4 additions & 1 deletion tensorflow_federated/python/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ py_test(
name = "ast_generation_test",
size = "small",
srcs = ["ast_generation_test.py"],
deps = ["//tensorflow_federated"],
deps = [
"//tensorflow_federated",
"@federated_language//federated_language",
],
)

py_test(
Expand Down
5 changes: 3 additions & 2 deletions tensorflow_federated/python/tests/ast_generation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from absl.testing import absltest
import federated_language
import tensorflow as tf
import tensorflow_federated as tff

Expand All @@ -36,7 +37,7 @@ def test_flattens_to_tf_computation(self):
def five():
return 5

@tff.federated_computation
@federated_language.federated_computation
def federated_five():
return five()

Expand All @@ -56,7 +57,7 @@ def test_only_one_random_only_generates_a_single_call_to_random(self):
def rand():
return tf.random.normal([])

@tff.federated_computation
@federated_language.federated_computation
def same_rand_tuple():
single_random_number = rand()
return (single_random_number, single_random_number)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ async def test_install_and_execute_in_context(self):
factory = tff.framework.local_cpp_executor_factory()
context = tff.framework.AsyncExecutionContext(factory)

@tff.federated_computation(np.int32)
@federated_language.federated_computation(np.int32)
def identity(x):
return x

Expand All @@ -42,7 +42,7 @@ async def test_install_and_execute_computations_with_different_cardinalities(
factory = tff.framework.local_cpp_executor_factory()
context = tff.framework.AsyncExecutionContext(factory)

@tff.federated_computation(
@federated_language.federated_computation(
federated_language.FederatedType(np.int32, federated_language.CLIENTS)
)
def repackage_arg(x):
Expand All @@ -60,7 +60,7 @@ async def test_runs_cardinality_free(self):
factory, cardinality_inference_fn=(lambda x, y: {})
)

@tff.federated_computation(np.int32)
@federated_language.federated_computation(np.int32)
def identity(x):
return x

Expand All @@ -85,7 +85,7 @@ def _cardinality_fn(x, y):
np.int32, federated_language.CLIENTS
)

@tff.federated_computation(arg_type)
@federated_language.federated_computation(arg_type)
def identity(x):
return x

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def dataset_reduce_fn_wrapper(ds):
initial_val = tf.Variable(np.int64(1.0))
return dataset_reduce_fn(ds, initial_val)

@tff.federated_computation(
@federated_language.federated_computation(
tff.at_clients(federated_language.SequenceType(np.int64))
)
def parallel_client_run(client_datasets):
Expand All @@ -59,7 +59,7 @@ def dataset_reduce_fn_wrapper(ds):
initial_val = tf.Variable(np.int64(1.0))
return dataset_reduce_fn(ds, initial_val)

@tff.federated_computation(
@federated_language.federated_computation(
tff.at_clients(federated_language.SequenceType(np.int64))
)
def parallel_client_run(client_datasets):
Expand Down
Loading

0 comments on commit 342ce81

Please sign in to comment.