From d308eb690530df95e8fb10dc7fb109d62181fc1e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 25 Nov 2023 21:55:13 -0800 Subject: [PATCH] Adds unique_id property to SparseCoreStackedTableTrackable PiperOrigin-RevId: 585365755 --- tensorflow/python/tpu/tpu_embedding_v3_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorflow/python/tpu/tpu_embedding_v3_utils.py b/tensorflow/python/tpu/tpu_embedding_v3_utils.py index 08ed796c54b492..276731051be54f 100644 --- a/tensorflow/python/tpu/tpu_embedding_v3_utils.py +++ b/tensorflow/python/tpu/tpu_embedding_v3_utils.py @@ -170,6 +170,12 @@ def __init__(self, stacked_layouts, table_to_config): shape=variable_shape, dtype=dtypes.float32, ) + # TODO(b/312743130): This is a workaround. During checkpoint restoration + # optimizer expects the trackable to provide a `_unique_id` or equivalent. + # Remove this when the bug is fixed. + @property + def _unique_id(self): + return self.vars[self._stacked_layouts[0].table_name]._unique_id def _serialize_to_tensors(self) -> Any: return {