Skip to content

Commit

Permalink
Migrate tff.federated_computation to federated_language.federated_com…
Browse files Browse the repository at this point in the history
…putation.

...in .py code locations.

PiperOrigin-RevId: 707583133
  • Loading branch information
Chloé Kiddon authored and copybara-github committed Dec 23, 2024
1 parent 9337518 commit e05d869
Show file tree
Hide file tree
Showing 22 changed files with 160 additions and 78 deletions.
2 changes: 2 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ and this project adheres to
* Moved the tests of compatibility from `DPQuantileAggregator::MergeWith` to
`DPQuantileAggregator::IsCompatible`.
* Updated `MeasuredProcessOutput` to be a `NamedTuple`.
* Migrate tff.federated_computation to
federated_language.federated_computation in Python locations.

### Removed

Expand Down
5 changes: 4 additions & 1 deletion examples/program/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ licenses(["notice"])
py_library(
name = "computations",
srcs = ["computations.py"],
deps = ["//tensorflow_federated"],
deps = [
"//tensorflow_federated",
"@federated_language//federated_language",
],
)

py_library(
Expand Down
7 changes: 4 additions & 3 deletions examples/program/computations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@

import collections

import federated_language
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

METRICS_TOTAL_SUM = 'total_sum'


@tff.federated_computation()
@federated_language.federated_computation()
def initialize():
"""Returns the initial state."""
return tff.federated_value(0, tff.SERVER)
Expand All @@ -43,7 +44,7 @@ def _sum_integers(x: int, y: int) -> int:
return x + y


@tff.federated_computation(
@federated_language.federated_computation(
tff.FederatedType(np.int32, tff.SERVER),
tff.FederatedType(tff.SequenceType(np.int32), tff.CLIENTS),
)
Expand Down Expand Up @@ -75,7 +76,7 @@ def train(server_state: int, client_data: tf.data.Dataset):
return updated_state, metrics


@tff.federated_computation(
@federated_language.federated_computation(
tff.FederatedType(np.int32, tff.SERVER),
tff.FederatedType(tff.SequenceType(np.int32), tff.CLIENTS),
)
Expand Down
1 change: 1 addition & 0 deletions examples/simple_fedavg/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ py_library(
deps = [
":simple_fedavg_tf",
"//tensorflow_federated",
"@federated_language//federated_language",
],
)

Expand Down
7 changes: 4 additions & 3 deletions examples/simple_fedavg/simple_fedavg_tff.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
https://arxiv.org/abs/1602.05629
"""

import federated_language
import tensorflow as tf
import tensorflow_federated as tff

Expand Down Expand Up @@ -95,7 +96,7 @@ def initialize_client_data(server_message):
model = model_fn()
return init_client_ouput(model, server_message)

@tff.federated_computation(
@federated_language.federated_computation(
server_message_type, tff.SequenceType(batch_type)
)
def client_update_weights_fn(server_message, batches):
Expand Down Expand Up @@ -169,7 +170,7 @@ def server_message_fn(server_state):
federated_server_state_type = tff.FederatedType(server_state_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

@tff.federated_computation(
@federated_language.federated_computation(
federated_server_state_type, federated_dataset_type
)
def run_one_round(server_state, federated_dataset):
Expand Down Expand Up @@ -215,7 +216,7 @@ def run_one_round(server_state, federated_dataset):
)
return server_state, aggregated_outputs

@tff.federated_computation
@federated_language.federated_computation
def server_init_tff():
"""Orchestration logic for server model initialization."""
return tff.federated_eval(server_init_tf, tff.SERVER)
Expand Down
1 change: 1 addition & 0 deletions examples/stateful_clients/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ py_library(
deps = [
":stateful_fedavg_tf",
"//tensorflow_federated",
"@federated_language//federated_language",
],
)

Expand Down
5 changes: 3 additions & 2 deletions examples/stateful_clients/stateful_fedavg_tff.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
https://arxiv.org/abs/1602.05629
"""

import federated_language
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
Expand Down Expand Up @@ -132,7 +133,7 @@ def client_update_fn(tf_dataset, client_state, server_message):
client_state_type, tff.CLIENTS
)

@tff.federated_computation(
@federated_language.federated_computation(
federated_server_state_type,
federated_dataset_type,
federated_client_state_type,
Expand Down Expand Up @@ -173,7 +174,7 @@ def run_one_round(server_state, federated_dataset, client_states):

return server_state, round_loss_metric, client_outputs.client_state

@tff.federated_computation
@federated_language.federated_computation
def server_init_tff():
"""Orchestration logic for server model initialization."""
return tff.federated_value(server_init_tf(), tff.SERVER)
Expand Down
6 changes: 4 additions & 2 deletions tensorflow_federated/python/aggregators/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,7 +1254,8 @@ def _build_test_sum_fn_py_bounds(value_type, lower_bound, upper_bound):
upper_bound: A Python numeric constant or a numpy array.
Returns:
A `tff.federated_computation` with type signature `(value_type@CLIENTS ->
A `federated_language.federated_computation` with type signature
`(value_type@CLIENTS ->
value_type@SERVER)`.
"""

Expand Down Expand Up @@ -1285,7 +1286,8 @@ def _build_test_sum_fn_tff_bounds(
upper_bound_type: A `tff.Type` of upper_bound to be used.
Returns:
A `tff.federated_computation` with type signature `((value_type@CLIENTS,
A `federated_language.federated_computation` with type signature
`((value_type@CLIENTS,
lower_bound_type@SERVER, upper_bound_type@SERVER) -> value_type@SERVER)`.
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
`MapReduceForm` maps to a single federated round.
```python
@tff.federated_computation
@federated_language.federated_computation
def round_comp(server_state, client_data):
# The server prepares an input to be broadcast to all clients that controls
Expand Down Expand Up @@ -208,7 +208,7 @@ def round_comp(server_state, client_data):
single federated round.
```python
@tff.federated_computation
@federated_language.federated_computation
def round_comp(server_state, client_data):
# The server prepares an input to be broadcast to all clients and generates
# a temporary state that may be used by later parts of the computation.
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_federated/python/core/backends/mapreduce/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def _check_tensorflow_computation(label, comp):


def _check_lambda_computation(label, comp):
"""Validates a Lambda computation."""
py_typecheck.check_type(
comp, federated_language.framework.ConcreteComputation, label
)
Expand Down Expand Up @@ -131,7 +132,7 @@ class BroadcastForm:
```
server_data_type = self.compute_server_context.type_signature.parameter
client_data_type = self.client_processing.type_signature.parameter[1]
@tff.federated_computation(server_data_type, client_data_type)
@federated_language.federated_computation(server_data_type, client_data_type)
def _(server_data, client_data):
# Select out the bit of server context to send to the clients.
context_at_server = tff.federated_map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class MergeableCompForm:
without changing the results of the computation, a computation of the form:
```
@tff.federated_computation(...)
@federated_language.federated_computation(...)
def single_aggregation(arg):
result_at_clients = work(arg)
agg_result = tff.federated_aggregate(
Expand All @@ -78,19 +78,19 @@ def single_aggregation(arg):
can be represented as the `MergeableCompForm` triplet:
```
@tff.federated_computation(tff.AbstractType('T'))
@federated_language.federated_computation(tff.AbstractType('T'))
def up_to_merge(arg):
result_at_clients = work(arg)
agg_result = tff.federated_aggregate(
result_at_clients, accumulate_zero, accumulate, merge, identity_report)
return agg_result
@tff.federated_computation([up_to_merge.type_signature.result.member,
@federated_language.federated_computation([up_to_merge.type_signature.result.member,
up_to_merge.type_signature.result.member])
def merge(arg):
return merge(arg[0], arg[1])
@tff.federated_computation(
@federated_language.federated_computation(
tff.AbstractType('T'),
tff.FederatedType(merge.type_signature.result, tff.SERVER),
)
Expand Down
9 changes: 6 additions & 3 deletions tensorflow_federated/python/learning/metrics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def sum_then_finalize(
the polymorphic method is invoked on.
Note: invoking this computation outside of a federated context (a method
decorated with `tff.federated_computation`) will require first wrapping it in
decorated with `federated_language.federated_computation`) will require first
wrapping it in
a concrete, non-polymorphic `tff.Computation` with appropriate federated
types.
Expand Down Expand Up @@ -147,7 +148,8 @@ def secure_sum_then_finalize(
the polymorphic method is invoked on.
Note: invoking this computation outside of a federated context (a method
decorated with `tff.federated_computation`) will require first wrapping it in
decorated with `federated_language.federated_computation`) will require first
wrapping it in
a concrete, non-polymorphic `tff.Computation` with appropriate federated
types.
Expand Down Expand Up @@ -317,7 +319,8 @@ def finalize_then_sample(
the polymorphic method is invoked on.
Note: invoking this computation outside of a federated context (a method
decoratedc with `tff.federated_computation`) will require first wrapping it in
decoratedc with `federated_language.federated_computation`) will require first
wrapping it in
a concrete, non-polymorphic `tff.Computation` with appropriate federated
types.
Expand Down
36 changes: 29 additions & 7 deletions tensorflow_federated/python/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@ py_test(
name = "ast_generation_test",
size = "small",
srcs = ["ast_generation_test.py"],
deps = ["//tensorflow_federated"],
deps = [
"//tensorflow_federated",
"@federated_language//federated_language",
],
)

py_test(
name = "async_execution_context_integration_test",
srcs = ["async_execution_context_integration_test.py"],
deps = ["//tensorflow_federated"],
deps = [
"//tensorflow_federated",
"@federated_language//federated_language",
],
)

py_test(
Expand All @@ -34,13 +40,17 @@ py_test(
":temperature_sensor_example",
":test_contexts",
"//tensorflow_federated",
"@federated_language//federated_language",
],
)

py_cpu_gpu_test(
name = "backend_accelerators_test",
srcs = ["backend_accelerators_test.py"],
deps = ["//tensorflow_federated"],
deps = [
"//tensorflow_federated",
"@federated_language//federated_language",
],
)

py_test(
Expand All @@ -55,7 +65,10 @@ py_test(
size = "small",
timeout = "long",
srcs = ["mergeable_comp_execution_context_integration_test.py"],
deps = ["//tensorflow_federated"],
deps = [
"//tensorflow_federated",
"@federated_language//federated_language",
],
)

py_test(
Expand All @@ -73,20 +86,29 @@ py_test(
timeout = "moderate",
srcs = ["remote_runtime_stream_structs_test.py"],
tags = ["requires-mem:20g"],
deps = ["//tensorflow_federated"],
deps = [
"//tensorflow_federated",
"@federated_language//federated_language",
],
)

py_test(
name = "sync_local_cpp_execution_context_test",
size = "small",
srcs = ["sync_local_cpp_execution_context_test.py"],
deps = ["//tensorflow_federated"],
deps = [
"//tensorflow_federated",
"@federated_language//federated_language",
],
)

py_library(
name = "temperature_sensor_example",
srcs = ["temperature_sensor_example.py"],
deps = ["//tensorflow_federated"],
deps = [
"//tensorflow_federated",
"@federated_language//federated_language",
],
)

py_library(
Expand Down
5 changes: 3 additions & 2 deletions tensorflow_federated/python/tests/ast_generation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from absl.testing import absltest
import federated_language
import tensorflow as tf
import tensorflow_federated as tff

Expand All @@ -36,7 +37,7 @@ def test_flattens_to_tf_computation(self):
def five():
return 5

@tff.federated_computation
@federated_language.federated_computation
def federated_five():
return five()

Expand All @@ -56,7 +57,7 @@ def test_only_one_random_only_generates_a_single_call_to_random(self):
def rand():
return tf.random.normal([])

@tff.federated_computation
@federated_language.federated_computation
def same_rand_tuple():
single_random_number = rand()
return (single_random_number, single_random_number)
Expand Down
Loading

0 comments on commit e05d869

Please sign in to comment.