From b39a57c704744759edf33af4f09f19b4d0530026 Mon Sep 17 00:00:00 2001 From: rustam810 Date: Tue, 13 Feb 2024 12:53:25 +0100 Subject: [PATCH 1/2] Add RedisClusterTrackerStore and RedisClusterLockStore --- rasa/core/lock_store.py | 34 +++++++++++++++++++++++++++ rasa/core/tracker_store.py | 48 ++++++++++++++++++++++++++++++++++++++ version | 2 +- 3 files changed, 83 insertions(+), 1 deletion(-) diff --git a/rasa/core/lock_store.py b/rasa/core/lock_store.py index ab3c539b3853..72f3d11276bb 100644 --- a/rasa/core/lock_store.py +++ b/rasa/core/lock_store.py @@ -266,6 +266,38 @@ def save_lock(self, lock: TicketLock) -> None: self.red.set(self.key_prefix + lock.conversation_id, lock.dumps()) +class RedisClusterLockStore(RedisLockStore): + """Redis store for ticket locks.""" + + def __init__( + self, + host: Text = "localhost", + port: int = 6379, + db: int = 1, + password: Optional[Text] = None, + use_ssl: bool = False, + key_prefix: Optional[Text] = None, + socket_timeout: float = DEFAULT_SOCKET_TIMEOUT_IN_SECONDS, + ) -> None: + """Create a lock store which uses Redis Cluster for persistence. + """ + import redis + + self.red = redis.cluster.RedisCluster( + host=host, + port=int(port), + password=password, + ssl=use_ssl, + socket_timeout=socket_timeout, + ) + + self.key_prefix = DEFAULT_REDIS_LOCK_STORE_KEY_PREFIX + if key_prefix: + logger.debug(f"Setting non-default redis key prefix: '{key_prefix}'.") + self._set_key_prefix(key_prefix) + + super(RedisLockStore, self).__init__() + class InMemoryLockStore(LockStore): """In-memory store for ticket locks.""" @@ -304,6 +336,8 @@ def _create_from_endpoint_config( lock_store: LockStore = InMemoryLockStore() elif endpoint_config.type == "redis": lock_store = RedisLockStore(host=endpoint_config.url, **endpoint_config.kwargs) + elif endpoint_config.type == "redis_cluster": + lock_store = RedisClusterLockStore(host=endpoint_config.url, **endpoint_config.kwargs) else: lock_store = _load_from_module_name_in_endpoint_config(endpoint_config) diff --git a/rasa/core/tracker_store.py b/rasa/core/tracker_store.py index 72559d8a3712..cefe645bfbbc 100644 --- a/rasa/core/tracker_store.py +++ b/rasa/core/tracker_store.py @@ -615,6 +615,47 @@ def _merge_trackers( return merged +class RedisClusterTrackerStore(RedisTrackerStore): + """Stores conversation history in Redis.""" + + def __init__( + self, + domain: Domain, + host: Text = "localhost", + port: int = 6379, + db: int = 0, + password: Optional[Text] = None, + event_broker: Optional[EventBroker] = None, + record_exp: Optional[float] = None, + key_prefix: Optional[Text] = None, + use_ssl: bool = False, + ssl_keyfile: Optional[Text] = None, + ssl_certfile: Optional[Text] = None, + ssl_ca_certs: Optional[Text] = None, + **kwargs: Dict[Text, Any], + ) -> None: + """Initializes the tracker store.""" + import redis + + self.red = redis.cluster.RedisCluster( + host=host, + port=port, + password=password, + ssl=use_ssl, + ssl_keyfile=ssl_keyfile, + ssl_certfile=ssl_certfile, + ssl_ca_certs=ssl_ca_certs, + decode_responses=True, + ) + self.record_exp = record_exp + + self.key_prefix = DEFAULT_REDIS_TRACKER_STORE_KEY_PREFIX + if key_prefix: + logger.debug(f"Setting non-default redis key prefix: '{key_prefix}'.") + self._set_key_prefix(key_prefix) + + super(RedisTrackerStore, self).__init__(domain, event_broker, **kwargs) + class DynamoTrackerStore(TrackerStore, SerializedTrackerAsDict): """Stores conversation history in DynamoDB.""" @@ -1504,6 +1545,13 @@ def _create_from_endpoint_config( event_broker=event_broker, **endpoint_config.kwargs, ) + elif endpoint_config.type.lower() == "redis_cluster": + tracker_store = RedisClusterTrackerStore( + domain=domain, + host=endpoint_config.url, + event_broker=event_broker, + **endpoint_config.kwargs, + ) elif endpoint_config.type.lower() == "mongod": tracker_store = MongoTrackerStore( domain=domain, diff --git a/version b/version index b08f45eeb180..836c6cf5acbc 100644 --- a/version +++ b/version @@ -1 +1 @@ -3.6.5-e8-0.2.1 \ No newline at end of file +3.6.5-e8-0.3.0 \ No newline at end of file From a99a85920f2b8339e808a7c689a14c6dc21414b3 Mon Sep 17 00:00:00 2001 From: Anton Grechkin Date: Sun, 31 Mar 2024 23:44:27 +0400 Subject: [PATCH 2/2] Allow using custom training_tracker_provider --- rasa/engine/recipes/default_recipe.py | 19 ++++++++++++++----- version | 2 +- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/rasa/engine/recipes/default_recipe.py b/rasa/engine/recipes/default_recipe.py index c0355c10a064..b553bed28992 100644 --- a/rasa/engine/recipes/default_recipe.py +++ b/rasa/engine/recipes/default_recipe.py @@ -150,10 +150,10 @@ def decorator(registered_class: Type[GraphComponent]) -> Type[GraphComponent]: else: unique_types = set(component_types) - cls._registered_components[ - registered_class.__name__ - ] = cls.RegisteredComponent( - registered_class, unique_types, is_trainable, model_from + cls._registered_components[registered_class.__name__] = ( + cls.RegisteredComponent( + registered_class, unique_types, is_trainable, model_from + ) ) return registered_class @@ -581,12 +581,21 @@ def _add_core_train_nodes( config={"exclusion_percentage": cli_parameters.get("exclusion_percentage")}, is_input=True, ) + + training_tracker_provider_name = train_config.get("training_tracker_provider") + if training_tracker_provider_name is not None: + training_tracker_provider_cls = self._from_registry( + training_tracker_provider_name + ).clazz + else: + training_tracker_provider_cls = TrainingTrackerProvider + train_nodes["training_tracker_provider"] = SchemaNode( needs={ "story_graph": "story_graph_provider", "domain": "domain_for_core_training_provider", }, - uses=TrainingTrackerProvider, + uses=training_tracker_provider_cls, constructor_name="create", fn="provide", config={ diff --git a/version b/version index 836c6cf5acbc..9fd21eac004b 100644 --- a/version +++ b/version @@ -1 +1 @@ -3.6.5-e8-0.3.0 \ No newline at end of file +3.6.5-e8-0.4.0 \ No newline at end of file