Skip to content

Commit

Permalink
Merge pull request #7 from epoch8/3.6.5-e8
Browse files Browse the repository at this point in the history
3.6.5 e8
  • Loading branch information
rustam810 authored May 8, 2024
2 parents a748166 + 4a876c0 commit 55173fa
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 6 deletions.
34 changes: 34 additions & 0 deletions rasa/core/lock_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,38 @@ def save_lock(self, lock: TicketLock) -> None:
self.red.set(self.key_prefix + lock.conversation_id, lock.dumps())


class RedisClusterLockStore(RedisLockStore):
"""Redis store for ticket locks."""

def __init__(
self,
host: Text = "localhost",
port: int = 6379,
db: int = 1,
password: Optional[Text] = None,
use_ssl: bool = False,
key_prefix: Optional[Text] = None,
socket_timeout: float = DEFAULT_SOCKET_TIMEOUT_IN_SECONDS,
) -> None:
"""Create a lock store which uses Redis Cluster for persistence.
"""
import redis

self.red = redis.cluster.RedisCluster(
host=host,
port=int(port),
password=password,
ssl=use_ssl,
socket_timeout=socket_timeout,
)

self.key_prefix = DEFAULT_REDIS_LOCK_STORE_KEY_PREFIX
if key_prefix:
logger.debug(f"Setting non-default redis key prefix: '{key_prefix}'.")
self._set_key_prefix(key_prefix)

super(RedisLockStore, self).__init__()

class InMemoryLockStore(LockStore):
"""In-memory store for ticket locks."""

Expand Down Expand Up @@ -304,6 +336,8 @@ def _create_from_endpoint_config(
lock_store: LockStore = InMemoryLockStore()
elif endpoint_config.type == "redis":
lock_store = RedisLockStore(host=endpoint_config.url, **endpoint_config.kwargs)
elif endpoint_config.type == "redis_cluster":
lock_store = RedisClusterLockStore(host=endpoint_config.url, **endpoint_config.kwargs)
else:
lock_store = _load_from_module_name_in_endpoint_config(endpoint_config)

Expand Down
48 changes: 48 additions & 0 deletions rasa/core/tracker_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,47 @@ def _merge_trackers(
return merged


class RedisClusterTrackerStore(RedisTrackerStore):
"""Stores conversation history in Redis."""

def __init__(
self,
domain: Domain,
host: Text = "localhost",
port: int = 6379,
db: int = 0,
password: Optional[Text] = None,
event_broker: Optional[EventBroker] = None,
record_exp: Optional[float] = None,
key_prefix: Optional[Text] = None,
use_ssl: bool = False,
ssl_keyfile: Optional[Text] = None,
ssl_certfile: Optional[Text] = None,
ssl_ca_certs: Optional[Text] = None,
**kwargs: Dict[Text, Any],
) -> None:
"""Initializes the tracker store."""
import redis

self.red = redis.cluster.RedisCluster(
host=host,
port=port,
password=password,
ssl=use_ssl,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
ssl_ca_certs=ssl_ca_certs,
decode_responses=True,
)
self.record_exp = record_exp

self.key_prefix = DEFAULT_REDIS_TRACKER_STORE_KEY_PREFIX
if key_prefix:
logger.debug(f"Setting non-default redis key prefix: '{key_prefix}'.")
self._set_key_prefix(key_prefix)

super(RedisTrackerStore, self).__init__(domain, event_broker, **kwargs)

class DynamoTrackerStore(TrackerStore, SerializedTrackerAsDict):
"""Stores conversation history in DynamoDB."""

Expand Down Expand Up @@ -1504,6 +1545,13 @@ def _create_from_endpoint_config(
event_broker=event_broker,
**endpoint_config.kwargs,
)
elif endpoint_config.type.lower() == "redis_cluster":
tracker_store = RedisClusterTrackerStore(
domain=domain,
host=endpoint_config.url,
event_broker=event_broker,
**endpoint_config.kwargs,
)
elif endpoint_config.type.lower() == "mongod":
tracker_store = MongoTrackerStore(
domain=domain,
Expand Down
19 changes: 14 additions & 5 deletions rasa/engine/recipes/default_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,10 @@ def decorator(registered_class: Type[GraphComponent]) -> Type[GraphComponent]:
else:
unique_types = set(component_types)

cls._registered_components[
registered_class.__name__
] = cls.RegisteredComponent(
registered_class, unique_types, is_trainable, model_from
cls._registered_components[registered_class.__name__] = (
cls.RegisteredComponent(
registered_class, unique_types, is_trainable, model_from
)
)
return registered_class

Expand Down Expand Up @@ -581,12 +581,21 @@ def _add_core_train_nodes(
config={"exclusion_percentage": cli_parameters.get("exclusion_percentage")},
is_input=True,
)

training_tracker_provider_name = train_config.get("training_tracker_provider")
if training_tracker_provider_name is not None:
training_tracker_provider_cls = self._from_registry(
training_tracker_provider_name
).clazz
else:
training_tracker_provider_cls = TrainingTrackerProvider

train_nodes["training_tracker_provider"] = SchemaNode(
needs={
"story_graph": "story_graph_provider",
"domain": "domain_for_core_training_provider",
},
uses=TrainingTrackerProvider,
uses=training_tracker_provider_cls,
constructor_name="create",
fn="provide",
config={
Expand Down
2 changes: 1 addition & 1 deletion version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.6.5-e8-0.2.1
3.6.5-e8-0.4.0

0 comments on commit 55173fa

Please sign in to comment.