Skip to content

Commit

Permalink
Add gRPC entity query clauses (#1149)
Browse files Browse the repository at this point in the history
  • Loading branch information
tarrencev authored Nov 3, 2023
1 parent a3140d8 commit c82eb06
Show file tree
Hide file tree
Showing 11 changed files with 294 additions and 94 deletions.
55 changes: 52 additions & 3 deletions crates/dojo-types/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,62 @@ impl Member {
}
}

/// Represents a model of an entity
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct EntityModel {
#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)]
pub struct EntityQuery {
pub model: String,
pub clause: Clause,
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)]
pub enum Clause {
Keys(KeysClause),
Attribute(AttributeClause),
Composite(CompositeClause),
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)]
pub struct KeysClause {
pub keys: Vec<FieldElement>,
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)]
pub struct AttributeClause {
pub attribute: String,
pub operator: ComparisonOperator,
pub value: Value,
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)]
pub struct CompositeClause {
pub operator: LogicalOperator,
pub clauses: Vec<Clause>,
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)]
pub enum LogicalOperator {
And,
Or,
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)]
pub enum ComparisonOperator {
Eq,
Neq,
Gt,
Gte,
Lt,
Lte,
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)]
pub enum Value {
String(String),
Int(i64),
UInt(u64),
Bool(bool),
Bytes(Vec<u8>),
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMetadata {
pub schema: Ty,
Expand Down
2 changes: 2 additions & 0 deletions crates/torii/client/src/client/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ pub enum Error {
GrpcClient(#[from] torii_grpc::client::Error),
#[error(transparent)]
Model(#[from] ModelError),
#[error("Unsupported query")]
UnsupportedQuery,
}

#[derive(Debug, thiserror::Error)]
Expand Down
45 changes: 31 additions & 14 deletions crates/torii/client/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::collections::HashSet;
use std::sync::Arc;

use dojo_types::packing::unpack;
use dojo_types::schema::{EntityModel, Ty};
use dojo_types::schema::{Clause, EntityQuery, Ty};
use dojo_types::WorldMetadata;
use dojo_world::contracts::WorldContractReader;
use parking_lot::{RwLock, RwLockReadGuard};
Expand Down Expand Up @@ -46,7 +46,7 @@ impl Client {
torii_url: String,
rpc_url: String,
world: FieldElement,
entities: Option<Vec<EntityModel>>,
queries: Option<Vec<EntityQuery>>,
) -> Result<Self, Error> {
let mut grpc_client = torii_grpc::client::WorldClient::new(torii_url, world).await?;

Expand All @@ -61,13 +61,18 @@ impl Client {
let provider = JsonRpcClient::new(HttpTransport::new(rpc_url));
let world_reader = WorldContractReader::new(world, provider);

if let Some(entities_to_sync) = entities {
subbed_entities.add_entities(entities_to_sync)?;
if let Some(queries) = queries {
subbed_entities.add_entities(queries)?;

// TODO: change this to querying the gRPC url instead
let subbed_entities = subbed_entities.entities.read().clone();
for EntityModel { model, keys } in subbed_entities {
for EntityQuery { model, clause } in subbed_entities {
let model_reader = world_reader.model(&model).await?;
let keys = if let Clause::Keys(clause) = clause {
clause.keys
} else {
return Err(Error::UnsupportedQuery);
};
let values = model_reader.entity_storage(&keys).await?;

client_storage.set_entity_storage(
Expand All @@ -93,7 +98,7 @@ impl Client {
self.metadata.read()
}

pub fn subscribed_entities(&self) -> RwLockReadGuard<'_, HashSet<EntityModel>> {
pub fn subscribed_entities(&self) -> RwLockReadGuard<'_, HashSet<EntityQuery>> {
self.subscribed_entities.entities.read()
}

Expand All @@ -104,21 +109,27 @@ impl Client {
///
/// If the requested entity is not among the synced entities, it will attempt to fetch it from
/// the RPC.
pub async fn entity(&self, entity: &EntityModel) -> Result<Option<Ty>, Error> {
pub async fn entity(&self, entity: &EntityQuery) -> Result<Option<Ty>, Error> {
let Some(mut schema) = self.metadata.read().model(&entity.model).map(|m| m.schema.clone())
else {
return Ok(None);
};

let keys = if let Clause::Keys(clause) = entity.clone().clause {
clause.keys
} else {
return Err(Error::UnsupportedQuery);
};

if !self.subscribed_entities.is_synced(entity) {
let model = self.world_reader.model(&entity.model).await?;
return Ok(Some(model.entity(&entity.keys).await?));
return Ok(Some(model.entity(&keys).await?));
}

let Ok(Some(raw_values)) = self.storage.get_entity_storage(
cairo_short_string_to_felt(&entity.model)
.map_err(ParseError::CairoShortStringToFelt)?,
&entity.keys,
&keys,
) else {
return Ok(Some(schema));
};
Expand All @@ -131,7 +142,7 @@ impl Client {
.expect("qed; layout should exist");

let unpacked = unpack(raw_values, layout).unwrap();
let mut keys_and_unpacked = [entity.keys.to_vec(), unpacked].concat();
let mut keys_and_unpacked = [keys.to_vec(), unpacked].concat();

schema.deserialize(&mut keys_and_unpacked).unwrap();

Expand All @@ -158,9 +169,15 @@ impl Client {
/// Adds entities to the list of entities to be synced.
///
/// NOTE: This will establish a new subscription stream with the server.
pub async fn add_entities_to_sync(&self, entities: Vec<EntityModel>) -> Result<(), Error> {
pub async fn add_entities_to_sync(&self, entities: Vec<EntityQuery>) -> Result<(), Error> {
for entity in &entities {
self.initiate_entity(&entity.model, entity.keys.clone()).await?;
let keys = if let Clause::Keys(clause) = entity.clone().clause {
clause.keys
} else {
return Err(Error::UnsupportedQuery);
};

self.initiate_entity(&entity.model, keys.clone()).await?;
}

self.subscribed_entities.add_entities(entities)?;
Expand All @@ -179,7 +196,7 @@ impl Client {
/// Removes entities from the list of entities to be synced.
///
/// NOTE: This will establish a new subscription stream with the server.
pub async fn remove_entities_to_sync(&self, entities: Vec<EntityModel>) -> Result<(), Error> {
pub async fn remove_entities_to_sync(&self, entities: Vec<EntityQuery>) -> Result<(), Error> {
self.subscribed_entities.remove_entities(entities)?;

let updated_entities =
Expand All @@ -199,7 +216,7 @@ impl Client {

async fn initiate_subscription(
&self,
entities: Vec<EntityModel>,
entities: Vec<EntityQuery>,
) -> Result<EntityUpdateStreaming, Error> {
let mut grpc_client = self.inner.write().await;
let stream = grpc_client.subscribe_entities(entities).await?;
Expand Down
33 changes: 16 additions & 17 deletions crates/torii/client/src/client/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ mod tests {
use std::collections::HashMap;
use std::sync::Arc;

use dojo_types::schema::Ty;
use dojo_types::schema::{KeysClause, Ty};
use dojo_types::WorldMetadata;
use parking_lot::RwLock;
use starknet::core::utils::cairo_short_string_to_felt;
Expand Down Expand Up @@ -201,14 +201,15 @@ mod tests {
#[test]
fn err_if_set_values_too_many() {
let storage = create_dummy_storage();
let entity = dojo_types::schema::EntityModel {
let keys = vec![felt!("0x12345")];
let entity = dojo_types::schema::EntityQuery {
model: "Position".into(),
keys: vec![felt!("0x12345")],
clause: dojo_types::schema::Clause::Keys(KeysClause { keys: keys.clone() }),
};

let values = vec![felt!("1"), felt!("2"), felt!("3"), felt!("4"), felt!("5")];
let model = cairo_short_string_to_felt(&entity.model).unwrap();
let result = storage.set_entity_storage(model, entity.keys, values);
let result = storage.set_entity_storage(model, keys, values);

assert!(storage.storage.read().is_empty());
matches!(
Expand All @@ -220,14 +221,15 @@ mod tests {
#[test]
fn err_if_set_values_too_few() {
let storage = create_dummy_storage();
let entity = dojo_types::schema::EntityModel {
let keys = vec![felt!("0x12345")];
let entity = dojo_types::schema::EntityQuery {
model: "Position".into(),
keys: vec![felt!("0x12345")],
clause: dojo_types::schema::Clause::Keys(KeysClause { keys: keys.clone() }),
};

let values = vec![felt!("1"), felt!("2")];
let model = cairo_short_string_to_felt(&entity.model).unwrap();
let result = storage.set_entity_storage(model, entity.keys, values);
let result = storage.set_entity_storage(model, keys, values);

assert!(storage.storage.read().is_empty());
matches!(
Expand All @@ -239,9 +241,10 @@ mod tests {
#[test]
fn set_and_get_entity_value() {
let storage = create_dummy_storage();
let entity = dojo_types::schema::EntityModel {
let keys = vec![felt!("0x12345")];
let entity = dojo_types::schema::EntityQuery {
model: "Position".into(),
keys: vec![felt!("0x12345")],
clause: dojo_types::schema::Clause::Keys(KeysClause { keys: keys.clone() }),
};

assert!(storage.storage.read().is_empty(), "storage must be empty initially");
Expand All @@ -250,31 +253,27 @@ mod tests {

let expected_storage_addresses = compute_all_storage_addresses(
cairo_short_string_to_felt(&model.name).unwrap(),
&entity.keys,
&keys,
model.packed_size,
);

let expected_values = vec![felt!("1"), felt!("2"), felt!("3"), felt!("4")];
let model_name_in_felt = cairo_short_string_to_felt(&entity.model).unwrap();

storage
.set_entity_storage(model_name_in_felt, entity.keys.clone(), expected_values.clone())
.set_entity_storage(model_name_in_felt, keys.clone(), expected_values.clone())
.expect("set storage values");

let actual_values = storage
.get_entity_storage(model_name_in_felt, &entity.keys)
.get_entity_storage(model_name_in_felt, &keys)
.expect("model exist")
.expect("values are set");

let actual_storage_addresses =
storage.storage.read().clone().into_keys().collect::<Vec<_>>();

assert!(
storage
.model_index
.read()
.get(&model_name_in_felt)
.is_some_and(|e| e.contains(&entity.keys)),
storage.model_index.read().get(&model_name_in_felt).is_some_and(|e| e.contains(&keys)),
"entity keys must be indexed"
);
assert!(actual_values == expected_values);
Expand Down
Loading

0 comments on commit c82eb06

Please sign in to comment.