Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strip serialization from base #11

Merged
merged 1 commit into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions src/invrs_opt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import dataclasses
from typing import Any, Protocol

from totypes import json_utils

PyTree = Any


Expand Down Expand Up @@ -46,23 +44,3 @@ class Optimizer:
init: InitFn
params: ParamsFn
update: UpdateFn


# Additional custom types and prefixes used for serializing optimizer state.
CUSTOM_TYPES_AND_PREFIXES = ()


def serialize(tree: PyTree) -> str:
"""Serializes a pytree into a string."""
return json_utils.json_from_pytree(
tree,
extra_custom_types_and_prefixes=CUSTOM_TYPES_AND_PREFIXES,
)


def deserialize(serialized: str) -> PyTree:
"""Restores a pytree from a string."""
return json_utils.pytree_from_json(
serialized,
extra_custom_types_and_prefixes=CUSTOM_TYPES_AND_PREFIXES,
)
16 changes: 12 additions & 4 deletions tests/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import parameterized

import invrs_opt
from totypes import symmetry, types
from totypes import json_utils, symmetry, types

jax.config.update("jax_enable_x64", True)

Expand Down Expand Up @@ -153,14 +153,22 @@ def _lists_to_tuple(pytree, max_depth=10):
return pytree


def serialize(pytree) -> str:
return json_utils.json_from_pytree(pytree=pytree)


def deserialize(serialized):
return json_utils.pytree_from_json(serialized=serialized)


class BasicOptimizerTest(unittest.TestCase):
@parameterized.parameterized.expand(itertools.product(PARAMS, OPTIMIZERS))
def test_state_is_serializable(self, params, opt):
state = opt.init(params)
leaves, treedef = jax.tree_util.tree_flatten(state)

serialized_state = invrs_opt.base.serialize(state)
restored_state = invrs_opt.base.deserialize(serialized_state)
serialized_state = serialize(state)
restored_state = deserialize(serialized_state)
# Serialization/deserialization unavoidably converts tuples to lists.
# Convert back to tuples to facilitate comparison.
restored_state = _lists_to_tuple(restored_state)
Expand Down Expand Up @@ -211,7 +219,7 @@ def loss_fn(params):
expected_grad_list.append(grad)

def serdes(x):
return invrs_opt.base.deserialize(invrs_opt.base.serialize(x))
return deserialize(serialize(x))

# Optimize with serialization.
params_list = []
Expand Down