diff --git a/crates/torii/graphql/src/object/component.rs b/crates/torii/graphql/src/object/component.rs index 0a7bcf4e14..7cb038d9b6 100644 --- a/crates/torii/graphql/src/object/component.rs +++ b/crates/torii/graphql/src/object/component.rs @@ -1,5 +1,5 @@ use async_graphql::dynamic::{ - Field, FieldFuture, FieldValue, InputValue, SubscriptionField, SubscriptionFieldFuture, TypeRef, + Field, FieldFuture, InputValue, SubscriptionField, SubscriptionFieldFuture, TypeRef, }; use async_graphql::{Name, Value}; use indexmap::IndexMap; @@ -97,18 +97,32 @@ impl ObjectTrait for ComponentObject { fn subscriptions(&self) -> Option> { let name = format!("{}Registered", self.name()); - Some(vec![SubscriptionField::new(name, TypeRef::named_nn(self.type_name()), |_| { - { - SubscriptionFieldFuture::new(async { - Result::Ok(SimpleBroker::::subscribe().map( - |component: Component| { - Result::Ok(FieldValue::owned_any(ComponentObject::value_mapping( - component, - ))) - }, - )) - }) - } - })]) + Some(vec![ + SubscriptionField::new(name, TypeRef::named_nn(self.type_name()), |ctx| { + { + SubscriptionFieldFuture::new(async move { + let id = match ctx.args.get("id") { + Some(id) => Some(id.string()?.to_string()), + None => None, + }; + // if id is None, then subscribe to all components + // if id is Some, then subscribe to only the component with that id + Ok(SimpleBroker::::subscribe().filter_map( + move |component: Component| { + if id.is_none() || id == Some(component.id.clone()) { + Some(Ok(Value::Object(ComponentObject::value_mapping( + component, + )))) + } else { + // id != component.id, so don't send anything, still listening + None + } + }, + )) + }) + } + }) + .argument(InputValue::new("id", TypeRef::named(TypeRef::ID))), + ]) } } diff --git a/crates/torii/graphql/src/object/entity.rs b/crates/torii/graphql/src/object/entity.rs index 86d688e220..2890c5df34 100644 --- a/crates/torii/graphql/src/object/entity.rs +++ b/crates/torii/graphql/src/object/entity.rs @@ -153,13 +153,27 @@ impl ObjectTrait for EntityObject { fn subscriptions(&self) -> Option> { let name = format!("{}Updated", self.name()); - Some(vec![SubscriptionField::new(name, TypeRef::named_nn(self.type_name()), |_| { - SubscriptionFieldFuture::new(async { - Ok(SimpleBroker::::subscribe().map(|entity: Entity| { - Ok(FieldValue::owned_any(EntityObject::value_mapping(entity))) - })) + Some(vec![ + SubscriptionField::new(name, TypeRef::named_nn(self.type_name()), |ctx| { + SubscriptionFieldFuture::new(async move { + let id = match ctx.args.get("id") { + Some(id) => Some(id.string()?.to_string()), + None => None, + }; + // if id is None, then subscribe to all entities + // if id is Some, then subscribe to only the entity with that id + Ok(SimpleBroker::::subscribe().filter_map(move |entity: Entity| { + if id.is_none() || id == Some(entity.id.clone()) { + Some(Ok(Value::Object(EntityObject::value_mapping(entity)))) + } else { + // id != entity.id , then don't send anything, still listening + None + } + })) + }) }) - })]) + .argument(InputValue::new("id", TypeRef::named(TypeRef::ID))), + ]) } } diff --git a/crates/torii/graphql/src/tests/subscription_test.rs b/crates/torii/graphql/src/tests/subscription_test.rs index e04a9c73e5..90d5313ec4 100644 --- a/crates/torii/graphql/src/tests/subscription_test.rs +++ b/crates/torii/graphql/src/tests/subscription_test.rs @@ -56,6 +56,50 @@ mod tests { rx.recv().await.unwrap(); } + #[sqlx::test(migrations = "../migrations")] + async fn test_entity_subscription_with_id(pool: SqlitePool) { + // Sleep in order to run this test in a single thread + tokio::time::sleep(Duration::from_secs(1)).await; + let state = init(&pool).await; + // 0. Preprocess expected entity value + let key = vec![FieldElement::ONE]; + let entity_id = format!("{:#x}", poseidon_hash_many(&key)); + let keys_str = key.iter().map(|k| format!("{:#x}", k)).collect::>().join(","); + let expected_value: async_graphql::Value = value!({ + "entityUpdated": { "id": entity_id.clone(), "keys":vec![keys_str.clone()], "componentNames": "Moves" } + }); + let (tx, mut rx) = mpsc::channel(10); + + tokio::spawn(async move { + // 1. Open process and sleep.Go to execute subscription + tokio::time::sleep(Duration::from_secs(1)).await; + + // Set entity with one moves component + // remaining: 10, last_direction: 0 + let moves_values = vec![FieldElement::from_hex_be("0xa").unwrap(), FieldElement::ZERO]; + state.set_entity("Moves".to_string(), key, moves_values).await.unwrap(); + // 3. fn publish() is called from state.set_entity() + + tx.send(()).await.unwrap(); + }); + + // 2. The subscription is executed and it is listeing, waiting for publish() to be executed + let response_value = run_graphql_subscription( + &pool, + r#" + subscription { + entityUpdated(id: "0x579e8877c7755365d5ec1ec7d3a94a457eff5d1f40482bbe9729c064cdead2") { + id, keys, componentNames + } + }"#, + ) + .await; + // 4. The subcription has received the message from publish() + // 5. Compare values + assert_eq!(expected_value, response_value); + rx.recv().await.unwrap(); + } + #[sqlx::test(migrations = "../migrations")] async fn test_component_subscription(pool: SqlitePool) { // Sleep in order to run this test at the end in a single thread @@ -104,4 +148,53 @@ mod tests { assert_eq!(expected_value, response_value); rx.recv().await.unwrap(); } + + #[sqlx::test(migrations = "../migrations")] + async fn test_component_subscription_with_id(pool: SqlitePool) { + // Sleep in order to run this test at the end in a single thread + tokio::time::sleep(Duration::from_secs(2)).await; + + let state = Sql::new(pool.clone(), FieldElement::ZERO).await.unwrap(); + // 0. Preprocess component value + let name = "Test".to_string(); + let component_id = name.to_lowercase(); + let class_hash = FieldElement::TWO; + let hex_class_hash = format!("{:#x}", class_hash); + let expected_value: async_graphql::Value = value!({ + "componentRegistered": { "id": component_id.clone(), "name":name, "classHash": hex_class_hash } + }); + let (tx, mut rx) = mpsc::channel(7); + + tokio::spawn(async move { + // 1. Open process and sleep.Go to execute subscription + tokio::time::sleep(Duration::from_secs(1)).await; + + let component = Component { + name, + members: vec![Member { name: "test".into(), ty: "u32".into(), key: false }], + class_hash, + ..Default::default() + }; + state.register_component(component).await.unwrap(); + // 3. fn publish() is called from state.set_entity() + + tx.send(()).await.unwrap(); + }); + + // 2. The subscription is executed and it is listeing, waiting for publish() to be executed + let response_value = run_graphql_subscription( + &pool, + r#" + subscription { + componentRegistered(id: "test") { + id, name, classHash + } + }"#, + ) + .await; + // 4. The subcription has received the message from publish() + // 5. Compare values + assert_eq!(expected_value, response_value); + rx.recv().await.unwrap(); + } }