Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optim(rng): Use binarized formats for de/serialization #92

Merged
merged 5 commits into from
May 4, 2024

Conversation

eddiebergman
Copy link
Contributor

@eddiebergman eddiebergman commented May 4, 2024

This PR does three things:

  • Importantly save and load the RNG state of python/numpy/torch as binary files. This saves time as stated in [Optim] Consider pickle for optimizer state file in run (with option to toggle) #64.
    • We use .npy for python and numpy seed keys which works well across versions and we ensure no pickle injections.
    • We use .pt for torch/CUDA rng and also disable arbitrary pickle injections.
  • Removes the need for optimizers to know about random state, the neps.runtime manages it for them
  • Drops the shared state polling time for the lock from 1 second to 0.1 second. I imagine this was previously high due to issues like the 16 ==15 issue in Metahyper sampling extra config #42 but I do not know of a concrete reason to have it as high as 1 second.

Impact

With the time.sleep(2) in the neps_examples/basic_usage/hyperparameters.py removed, this change resulted in the time going from 9.3 seconds to 3.9 seconds on my machine. Half of the program duration was spent just serializing and dersializing random state

I'm hoping this also halves the time taken to run the tests, meaning we could just run them all locally instead of having to deal with marked tests.


This is the test file which previously there was no test that serialization actually worked as intended:

@pytest.mark.parametrize(
    "make_ints", (
        lambda: [random.randint(0, 100) for _ in range(10)],
        lambda: list(np.random.randint(0, 100, (10,))),
        lambda: list(torch.randint(0, 100, (10,))),
    )
)
def test_randomstate_consistent(tmp_path: Path, make_ints: Callable[[], list[int]]) -> None:
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)

    seed_dir = tmp_path / "seed_dir"

    seed_state = SeedState.get()
    integers_1 = make_ints()

    seed_state.set_as_global_state()
    integers_2 = make_ints()

    assert integers_1 == integers_2

    SeedState.get().dump(seed_dir)
    integers_3 = make_ints()

    assert integers_3 != integers_2, "Ensure we have actually changed random state"

    SeedState.load(seed_dir).set_as_global_state()
    integers_4 = make_ints()

    assert integers_3 == integers_4

@eddiebergman eddiebergman merged commit 5bdeced into master May 4, 2024
12 checks passed
@eddiebergman eddiebergman deleted the optim-serialize-randomstate branch May 4, 2024 22:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Optim] Consider pickle for optimizer state file in run (with option to toggle)
1 participant