diff --git a/src/sentry/utils/registry.py b/src/sentry/utils/registry.py index 0ff110e268e054..3eff6865e82051 100644 --- a/src/sentry/utils/registry.py +++ b/src/sentry/utils/registry.py @@ -15,9 +15,15 @@ class NoRegistrationExistsError(ValueError): class Registry(Generic[T]): - def __init__(self): + """ + A simple generic registry that allows for registering and retrieving items by key. Reverse lookup by value is enabled by default. + If you have duplicate values, you may want to disable reverse lookup. + """ + + def __init__(self, enable_reverse_lookup=True): self.registrations: dict[str, T] = {} self.reverse_lookup: dict[T, str] = {} + self.enable_reverse_lookup = enable_reverse_lookup def register(self, key: str): def inner(item: T) -> T: @@ -26,13 +32,14 @@ def inner(item: T) -> T: f"A registration already exists for {key}: {self.registrations[key]}" ) - if item in self.reverse_lookup: - raise AlreadyRegisteredError( - f"A registration already exists for {item}: {self.reverse_lookup[item]}" - ) + if self.enable_reverse_lookup: + if item in self.reverse_lookup: + raise AlreadyRegisteredError( + f"A registration already exists for {item}: {self.reverse_lookup[item]}" + ) + self.reverse_lookup[item] = key self.registrations[key] = item - self.reverse_lookup[item] = key return item @@ -44,6 +51,8 @@ def get(self, key: str) -> T: return self.registrations[key] def get_key(self, item: T) -> str: + if not self.enable_reverse_lookup: + raise NotImplementedError("Reverse lookup is not enabled") if item not in self.reverse_lookup: raise NoRegistrationExistsError(f"No registration exists for {item}") return self.reverse_lookup[item] diff --git a/tests/sentry/utils/test_registry.py b/tests/sentry/utils/test_registry.py index 2f3415c288fd04..cbb886a7884ca8 100644 --- a/tests/sentry/utils/test_registry.py +++ b/tests/sentry/utils/test_registry.py @@ -33,3 +33,23 @@ def unregistered_func(): test_registry.register("something else")(unregistered_func) assert test_registry.get("something else") == unregistered_func + + def test_allow_duplicate_values(self): + test_registry = Registry[str](enable_reverse_lookup=False) + + @test_registry.register("something") + @test_registry.register("something 2") + def registered_func(): + pass + + assert test_registry.get("something") == registered_func + assert test_registry.get("something 2") == registered_func + + with pytest.raises(NoRegistrationExistsError): + test_registry.get("something else") + + with pytest.raises(NotImplementedError): + test_registry.get_key(registered_func) + + test_registry.register("something else")(registered_func) + assert test_registry.get("something else") == registered_func