Skip to content

Commit

Permalink
feat(torii-grpc): composite query (#2113)
Browse files Browse the repository at this point in the history
* feat(torii-grpc): composite query

* chore
  • Loading branch information
Larkooo authored Jun 28, 2024
1 parent 1fcaa7a commit 93ccf66
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 65 deletions.
261 changes: 197 additions & 64 deletions crates/torii/grpc/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,28 +297,11 @@ impl DojoWorld {
table: &str,
model_relation_table: &str,
entity_relation_column: &str,
keys_clause: proto::types::KeysClause,
keys_clause: &proto::types::KeysClause,
limit: Option<u32>,
offset: Option<u32>,
) -> Result<(Vec<proto::types::Entity>, u32), Error> {
let keys = keys_clause
.keys
.iter()
.map(|bytes| {
if bytes.is_empty() {
return Ok("0x[0-9a-fA-F]+".to_string());
}
Ok(FieldElement::from_byte_slice_be(bytes)
.map(|felt| format!("{felt:#x}"))
.map_err(ParseError::FromByteSliceError)?)
})
.collect::<Result<Vec<_>, Error>>()?;
let mut keys_pattern = format!("^{}", keys.join("/"));

if keys_clause.pattern_matching == proto::types::PatternMatching::VariableLen as i32 {
keys_pattern += "(/0x[0-9a-fA-F]+)*";
}
keys_pattern += "/$";
let keys_pattern = build_keys_pattern(keys_clause)?;

// total count of rows that matches keys_pattern without limit and offset
let count_query = format!(
Expand Down Expand Up @@ -431,28 +414,11 @@ impl DojoWorld {

pub(crate) async fn events_by_keys(
&self,
keys_clause: proto::types::KeysClause,
keys_clause: &proto::types::KeysClause,
limit: Option<u32>,
offset: Option<u32>,
) -> Result<Vec<proto::types::Event>, Error> {
let keys = keys_clause
.keys
.iter()
.map(|bytes| {
if bytes.is_empty() {
return Ok("0x[0-9a-fA-F]+".to_string());
}
Ok(FieldElement::from_byte_slice_be(bytes)
.map(|felt| format!("{felt:#x}"))
.map_err(ParseError::FromByteSliceError)?)
})
.collect::<Result<Vec<_>, Error>>()?;
let mut keys_pattern = format!("^{}", keys.join("/"));

if keys_clause.pattern_matching == proto::types::PatternMatching::VariableLen as i32 {
keys_pattern += "(/0x[0-9a-fA-F]+)*";
}
keys_pattern += "/$";
let keys_pattern = build_keys_pattern(keys_clause)?;

let events_query = r#"
SELECT keys, data, transaction_hash
Expand Down Expand Up @@ -491,19 +457,7 @@ impl DojoWorld {
.value_type
.ok_or(QueryError::MissingParam("value_type".into()))?;

let comparison_value = match value_type {
proto::types::value::ValueType::StringValue(string) => string,
proto::types::value::ValueType::IntValue(int) => int.to_string(),
proto::types::value::ValueType::UintValue(uint) => uint.to_string(),
proto::types::value::ValueType::BoolValue(bool) => {
if bool {
"1".to_string()
} else {
"0".to_string()
}
}
_ => return Err(QueryError::UnsupportedQuery.into()),
};
let comparison_value = value_to_string(&value_type)?;

let models_query = format!(
r#"
Expand Down Expand Up @@ -558,15 +512,158 @@ impl DojoWorld {

async fn query_by_composite(
&self,
_table: &str,
_model_relation_table: &str,
_entity_relation_column: &str,
_composite: proto::types::CompositeClause,
_limit: Option<u32>,
_offset: Option<u32>,
table: &str,
model_relation_table: &str,
entity_relation_column: &str,
composite: proto::types::CompositeClause,
limit: Option<u32>,
offset: Option<u32>,
) -> Result<(Vec<proto::types::Entity>, u32), Error> {
// TODO: Implement
Err(QueryError::UnsupportedQuery.into())
// different types of clauses
let mut where_clauses = Vec::new();
let mut model_clauses: HashMap<String, Vec<(String, ComparisonOperator, String)>> =
HashMap::new();
let mut having_clauses = Vec::new();

// bind valeus for prepared statement
let mut bind_values = Vec::new();

for clause in composite.clauses {
match clause.clause_type.unwrap() {
ClauseType::HashedKeys(hashed_keys) => {
let ids = hashed_keys
.hashed_keys
.iter()
.map(|id| {
Ok(FieldElement::from_byte_slice_be(id)
.map(|id| format!("{table}.id = '{id:#x}'"))
.map_err(ParseError::FromByteSliceError)?)
})
.collect::<Result<Vec<_>, Error>>()?;
where_clauses.push(format!("({})", ids.join(" OR ")));
}
ClauseType::Keys(keys) => {
let keys_pattern = build_keys_pattern(&keys)?;
where_clauses.push(format!("{table}.keys REGEXP '{keys_pattern}'"));
}
ClauseType::Member(member) => {
let comparison_operator =
ComparisonOperator::from_repr(member.operator as usize)
.expect("invalid comparison operator");
let value = member.value.unwrap().value_type.unwrap();
let comparison_value = value_to_string(&value)?;

let column_name = format!("external_{}", member.member);

model_clauses.entry(member.model.clone()).or_default().push((
column_name,
comparison_operator,
comparison_value,
));

let model_id =
get_selector_from_name(&member.model).map_err(ParseError::NonAsciiName)?;
having_clauses.push(format!("INSTR(model_ids, '{:#x}') > 0", model_id));
}
_ => return Err(QueryError::UnsupportedQuery.into()),
}
}

let mut join_clauses = Vec::new();
for (model, clauses) in model_clauses {
let model_conditions = clauses
.into_iter()
.map(|(column, op, value)| {
bind_values.push(value);
format!("{}.{} {} ?", model, column, op)
})
.collect::<Vec<_>>()
.join(" AND ");

join_clauses.push(format!(
"JOIN {} ON {}.id = {}.entity_id AND ({})",
model, table, model, model_conditions
));
}

let join_clause = join_clauses.join(" ");
let where_clause = if !where_clauses.is_empty() {
format!("WHERE {}", where_clauses.join(" AND "))
} else {
String::new()
};
let having_clause = if !having_clauses.is_empty() {
format!("HAVING {}", having_clauses.join(" AND "))
} else {
String::new()
};

let count_query = format!(
r#"
SELECT COUNT(DISTINCT {table}.id)
FROM {table}
{join_clause}
{where_clause}
"#
);

let mut count_query = sqlx::query_scalar::<_, u32>(&count_query);
for value in &bind_values {
count_query = count_query.bind(value);
}

let total_count = count_query.fetch_one(&self.pool).await?;

if total_count == 0 {
return Ok((Vec::new(), 0));
}

let query = format!(
r#"
SELECT {table}.id, group_concat({model_relation_table}.model_id) as model_ids
FROM {table}
JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id
{join_clause}
{where_clause}
GROUP BY {table}.id
{having_clause}
ORDER BY {table}.event_id DESC
LIMIT ? OFFSET ?
"#
);

let mut db_query = sqlx::query_as(&query);
for value in bind_values {
db_query = db_query.bind(value);
}
db_query = db_query.bind(limit.unwrap_or(u32::MAX)).bind(offset.unwrap_or(0));

let db_entities: Vec<(String, String)> = db_query.fetch_all(&self.pool).await?;

let mut entities = Vec::with_capacity(db_entities.len());
for (entity_id, models_str) in &db_entities {
let model_ids: Vec<&str> = models_str.split(',').collect();
let schemas = self.model_cache.schemas(model_ids).await?;

let (entity_query, arrays_queries) = build_sql_query(
&schemas,
table,
entity_relation_column,
Some(&format!("{table}.id = ?")),
Some(&format!("{table}.id = ?")),
)?;

let row = sqlx::query(&entity_query).bind(entity_id).fetch_one(&self.pool).await?;
let mut arrays_rows = HashMap::new();
for (name, query) in arrays_queries {
let rows = sqlx::query(&query).bind(entity_id).fetch_all(&self.pool).await?;
arrays_rows.insert(name, rows);
}

entities.push(map_row_to_entity(&row, &arrays_rows, &schemas)?);
}

Ok((entities, total_count))
}

pub async fn model_metadata(&self, model: &str) -> Result<proto::types::ModelMetadata, Error> {
Expand Down Expand Up @@ -669,7 +766,7 @@ impl DojoWorld {
ENTITIES_TABLE,
ENTITIES_MODEL_RELATION_TABLE,
ENTITIES_ENTITY_RELATION_COLUMN,
keys,
&keys,
Some(query.limit),
Some(query.offset),
)
Expand Down Expand Up @@ -746,7 +843,7 @@ impl DojoWorld {
EVENT_MESSAGES_TABLE,
EVENT_MESSAGES_MODEL_RELATION_TABLE,
EVENT_MESSAGES_ENTITY_RELATION_COLUMN,
keys,
&keys,
Some(query.limit),
Some(query.offset),
)
Expand Down Expand Up @@ -783,9 +880,9 @@ impl DojoWorld {

async fn retrieve_events(
&self,
query: proto::types::EventQuery,
query: &proto::types::EventQuery,
) -> Result<proto::world::RetrieveEventsResponse, Error> {
let events = match query.keys {
let events = match &query.keys {
None => self.events_all(query.limit, query.offset).await?,
Some(keys) => self.events_by_keys(keys, Some(query.limit), Some(query.offset)).await?,
};
Expand Down Expand Up @@ -838,6 +935,42 @@ fn map_row_to_entity(
Ok(proto::types::Entity { hashed_keys: hashed_keys.to_bytes_be().to_vec(), models })
}

// this builds a sql safe regex pattern to match against for keys
fn build_keys_pattern(clause: &proto::types::KeysClause) -> Result<String, Error> {
let keys = clause
.keys
.iter()
.map(|bytes| {
if bytes.is_empty() {
return Ok("0x[0-9a-fA-F]+".to_string());
}
Ok(FieldElement::from_byte_slice_be(bytes)
.map(|felt| format!("{felt:#x}"))
.map_err(ParseError::FromByteSliceError)?)
})
.collect::<Result<Vec<_>, Error>>()?;
let mut keys_pattern = format!("^{}", keys.join("/"));

if clause.pattern_matching == proto::types::PatternMatching::VariableLen as i32 {
keys_pattern += "(/0x[0-9a-fA-F]+)*";
}
keys_pattern += "/$";

Ok(keys_pattern)
}

fn value_to_string(value: &proto::types::value::ValueType) -> Result<String, Error> {
match value {
proto::types::value::ValueType::StringValue(string) => Ok(string.clone()),
proto::types::value::ValueType::IntValue(int) => Ok(int.to_string()),
proto::types::value::ValueType::UintValue(uint) => Ok(uint.to_string()),
proto::types::value::ValueType::BoolValue(bool) => {
Ok(if *bool { "1".to_string() } else { "0".to_string() })
}
_ => Err(QueryError::UnsupportedQuery.into()),
}
}

type ServiceResult<T> = Result<Response<T>, Status>;
type SubscribeModelsResponseStream =
Pin<Box<dyn Stream<Item = Result<SubscribeModelsResponse, Status>> + Send>>;
Expand Down Expand Up @@ -943,7 +1076,7 @@ impl proto::world::world_server::World for DojoWorld {
.ok_or_else(|| Status::invalid_argument("Missing query argument"))?;

let events =
self.retrieve_events(query).await.map_err(|e| Status::internal(e.to_string()))?;
self.retrieve_events(&query).await.map_err(|e| Status::internal(e.to_string()))?;

Ok(Response::new(events))
}
Expand Down
2 changes: 1 addition & 1 deletion crates/torii/grpc/src/server/tests/entities_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async fn test_entities_queries() {
"entities",
"entity_model",
"entity_id",
KeysClause {
&KeysClause {
keys: vec![account.address().to_bytes_be().to_vec()],
pattern_matching: 0,
models: vec![],
Expand Down

0 comments on commit 93ccf66

Please sign in to comment.