Skip to content

Commit

Permalink
Make ID generation thread-safe (#1472)
Browse files Browse the repository at this point in the history
Resolves #1473 

While there is currently a lock for ID generation, it is not thread-safe
as the start value of ID values can be modified by different threads.
e.g. thread A starts generating IDs, thread B resets the ID generation
start value, thread A repeats IDs.

This PR:
* Adds a test to verify consistent ID generation for SQL with multiple
threads.
* Changes the ID generator to use thread-local state.
* Removes patching of state and replaces it with an explicit stack.
* Assorted organizational / naming changes.
  • Loading branch information
plypaul authored Oct 21, 2024
1 parent c050718 commit c0589ef
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 36 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20241021-120748.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Make ID generation thread-safe
time: 2024-10-21T12:07:48.313324-07:00
custom:
Author: plypaul
Issue: "1473"
94 changes: 64 additions & 30 deletions metricflow-semantics/metricflow_semantics/dag/sequential_id.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from __future__ import annotations

import dataclasses
import logging
import threading
from contextlib import ExitStack, contextmanager
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Dict, Generator
from unittest.mock import patch

from typing_extensions import override

from metricflow_semantics.dag.id_prefix import IdPrefix

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class SequentialId:
Expand All @@ -27,46 +30,77 @@ def __repr__(self) -> str:
return self.str_value


@dataclass
class _IdGenerationState:
"""A thread-local class that keeps track of the next IDs to return for a given prefix."""

default_start_value: int
prefix_to_next_value: Dict[IdPrefix, int] = dataclasses.field(default_factory=dict)


class _IdGenerationStateStack(threading.local):
"""A thread-local stack that keeps track of the state of ID generation.
The stack allows the use of context managers to enter sections where ID generation starts at a configured value.
When entering the section, a new `_IdGenerationState` is pushed on to the stack, and the state at the top of the
stack is used to generate IDs. When exiting a section, the state is popped off so that ID generation resumes from
the previous values.
This stack is thread-local so that ID generation is consistent for a given thread.
"""

def __init__(self, initial_default_start_value: int = 0) -> None: # noqa: D107
self._state_stack = [_IdGenerationState(initial_default_start_value)]

def push_state(self, id_generation_state: _IdGenerationState) -> None:
self._state_stack.append(id_generation_state)

def pop_state(self) -> None:
state_stack_size = len(self._state_stack)
if state_stack_size <= 1:
logger.error(
f"Attempted to pop the stack when {state_stack_size=}. Since sequential ID generation may not "
f"be absolutely critical for resolving queries, logging this as an error but it should be "
f"investigated.",
stack_info=True,
)
return

self._state_stack.pop(-1)

@property
def current_state(self) -> _IdGenerationState:
return self._state_stack[-1]


class SequentialIdGenerator:
"""Generates sequential ID values based on a prefix."""

_default_start_value = 0
_state_lock = threading.Lock()
_prefix_to_next_value: Dict[IdPrefix, int] = {}
_THREAD_LOCAL_ID_GENERATION_STATE_STACK = _IdGenerationStateStack()

@classmethod
def create_next_id(cls, id_prefix: IdPrefix) -> SequentialId: # noqa: D102
with cls._state_lock:
if id_prefix not in cls._prefix_to_next_value:
cls._prefix_to_next_value[id_prefix] = cls._default_start_value
index = cls._prefix_to_next_value[id_prefix]
cls._prefix_to_next_value[id_prefix] = index + 1

return SequentialId(id_prefix, index)
id_generation_state = cls._THREAD_LOCAL_ID_GENERATION_STATE_STACK.current_state
if id_prefix not in id_generation_state.prefix_to_next_value:
id_generation_state.prefix_to_next_value[id_prefix] = id_generation_state.default_start_value
index = id_generation_state.prefix_to_next_value[id_prefix]
id_generation_state.prefix_to_next_value[id_prefix] = index + 1
return SequentialId(id_prefix, index)

@classmethod
def reset(cls, default_start_value: int = 0) -> None:
"""Resets the numbering of the generated IDs so that it starts at the given value."""
with cls._state_lock:
cls._prefix_to_next_value = {}
cls._default_start_value = default_start_value
id_generation_state = cls._THREAD_LOCAL_ID_GENERATION_STATE_STACK.current_state
id_generation_state.prefix_to_next_value = {}
id_generation_state.default_start_value = default_start_value

@classmethod
@contextmanager
def patch_id_generators_helper(cls, start_value: int) -> Generator[None, None, None]:
"""Replace ID generators in IdGeneratorRegistry with one that has the given start value.
def id_number_space(cls, start_value: int) -> Generator[None, None, None]:
"""Open a context where ID generation starts with the given start value.
TODO: This method will be modified in a later PR.
On exit, resume ID numbering from prior to entering the context.
"""
# Create patch context managers for all ID generators in the registry with introspection magic.
patch_context_managers = [
patch.object(SequentialIdGenerator, "_prefix_to_next_value", {}),
patch.object(SequentialIdGenerator, "_default_start_value", start_value),
]

# Enter the patch context for the patches above.
with ExitStack() as stack:
for patch_context_manager in patch_context_managers:
stack.enter_context(patch_context_manager) # type: ignore
# This will un-patch when done with the test.
yield None
SequentialIdGenerator._THREAD_LOCAL_ID_GENERATION_STATE_STACK.push_state(_IdGenerationState(start_value))
yield None
SequentialIdGenerator._THREAD_LOCAL_ID_GENERATION_STATE_STACK.pop_state()
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@


@pytest.fixture(autouse=True, scope="function")
def patch_id_generators() -> Generator[None, None, None]:
"""Patch ID generators with a new one to get repeatability in plan outputs before every test.
def setup_id_generators() -> Generator[None, None, None]:
"""Setup ID generation to start numbering at a specific value to get repeatability in generated IDs.
Plan outputs contain IDs, so if the IDs are not consistent from run to run, there will be diffs in the actual vs.
expected outputs during a test.
Fixtures may generate IDs, so this needs to be done before every test.
"""
with SequentialIdGenerator.patch_id_generators_helper(start_value=IdNumberSpace.for_test_start().start_value):
with SequentialIdGenerator.id_number_space(start_value=IdNumberSpace.for_test_start().start_value):
yield None


Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# These imports are required to properly set up pytest fixtures.
from __future__ import annotations

from metricflow_semantics.test_helpers.id_helpers import patch_id_generators # noqa: F401
from metricflow_semantics.test_helpers.id_helpers import setup_id_generators # noqa: F401

from tests_metricflow_semantics.fixtures.manifest_fixtures import * # noqa: F401, F403
from tests_metricflow_semantics.fixtures.setup_fixtures import * # noqa: F401, F403
2 changes: 1 addition & 1 deletion tests_metricflow/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# These imports are required to properly set up pytest fixtures.
from __future__ import annotations

from metricflow_semantics.test_helpers.id_helpers import patch_id_generators # noqa: F401
from metricflow_semantics.test_helpers.id_helpers import setup_id_generators # noqa: F401

from tests_metricflow.fixtures.cli_fixtures import * # noqa: F401, F403
from tests_metricflow.fixtures.dataflow_fixtures import * # noqa: F401, F403
Expand Down
Empty file.
31 changes: 31 additions & 0 deletions tests_metricflow/engine/test_explain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

from concurrent.futures import Future, ThreadPoolExecutor
from typing import Mapping, Sequence

from metricflow.engine.metricflow_engine import MetricFlowEngine, MetricFlowExplainResult, MetricFlowQueryRequest
from tests_metricflow.fixtures.manifest_fixtures import MetricFlowEngineTestFixture, SemanticManifestSetup


def _explain_one_query(mf_engine: MetricFlowEngine) -> str:
explain_result: MetricFlowExplainResult = mf_engine.explain(
MetricFlowQueryRequest.create_with_random_request_id(saved_query_name="p0_booking")
)
return explain_result.rendered_sql.sql_query


def test_concurrent_explain_consistency(
mf_engine_test_fixture_mapping: Mapping[SemanticManifestSetup, MetricFlowEngineTestFixture]
) -> None:
"""Tests that concurrent requests for the same query generate the same SQL.
Prior to consistency fixes for ID generation, this test would fail due to issues with sequentially numbered aliases.
"""
mf_engine = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].metricflow_engine

request_count = 4
with ThreadPoolExecutor(max_workers=2) as executor:
futures: Sequence[Future] = [executor.submit(_explain_one_query, mf_engine) for _ in range(request_count)]
results = [future.result() for future in futures]
for result in results:
assert result == results[0], "Expected only one unique result / results to be the same"
2 changes: 1 addition & 1 deletion tests_metricflow/fixtures/manifest_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def mf_engine_test_fixture_mapping(
"""Returns a mapping for all semantic manifests used in testing to the associated test fixture."""
fixture_mapping: Dict[SemanticManifestSetup, MetricFlowEngineTestFixture] = {}
for semantic_manifest_setup in SemanticManifestSetup:
with SequentialIdGenerator.patch_id_generators_helper(semantic_manifest_setup.id_number_space.start_value):
with SequentialIdGenerator.id_number_space(semantic_manifest_setup.id_number_space.start_value):
fixture_mapping[semantic_manifest_setup] = MetricFlowEngineTestFixture.from_parameters(
sql_client, load_semantic_manifest(semantic_manifest_setup.yaml_file_dir, template_mapping)
)
Expand Down

0 comments on commit c0589ef

Please sign in to comment.