Skip to content

Commit

Permalink
Task expiration and deletion
Browse files Browse the repository at this point in the history
  • Loading branch information
inahga committed Apr 16, 2024
1 parent 72404dc commit b8bf913
Show file tree
Hide file tree
Showing 14 changed files with 652 additions and 263 deletions.
25 changes: 24 additions & 1 deletion src/api_mocks/aggregator_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{
clients::aggregator_client::api_types::{
AggregatorApiConfig, AggregatorVdaf, AuthenticationToken, HpkeAeadId, HpkeConfig,
HpkeKdfId, HpkeKemId, HpkePublicKey, JanusDuration, QueryType, Role, TaskCreate, TaskId,
TaskIds, TaskResponse, TaskUploadMetrics,
TaskIds, TaskPatch, TaskResponse, TaskUploadMetrics,
},
entity::aggregator::{Feature, Features},
};
Expand Down Expand Up @@ -41,6 +41,7 @@ pub fn mock() -> impl Handler {
)
.post("/tasks", api(post_task))
.get("/tasks/:task_id", api(get_task))
.patch("/tasks/:task_id", api(patch_task))
.get("/task_ids", api(task_ids))
.delete("/tasks/:task_id", Status::Ok)
.get(
Expand Down Expand Up @@ -109,6 +110,28 @@ async fn post_task(
(State(task_create.clone()), Json(task_response(task_create)))
}

async fn patch_task(conn: &mut Conn, Json(patch): Json<TaskPatch>) -> Json<TaskResponse> {
let task_id = conn.param("task_id").unwrap();
Json(TaskResponse {
task_id: task_id.parse().unwrap(),
peer_aggregator_endpoint: "https://_".parse().unwrap(),
query_type: QueryType::TimeInterval,
vdaf: AggregatorVdaf::Prio3Count,
role: Role::Leader,
vdaf_verify_key: random_chars(10),
max_batch_query_count: 100,
task_expiration: patch.task_expiration,
report_expiry_age: None,
min_batch_size: 1000,
time_precision: JanusDuration::from_seconds(60),
tolerable_clock_skew: JanusDuration::from_seconds(60),
collector_hpke_config: random_hpke_config(),
aggregator_auth_token: Some(AuthenticationToken::new(random_chars(32))),
collector_auth_token: Some(AuthenticationToken::new(random_chars(32))),
aggregator_hpke_configs: repeat_with(random_hpke_config).take(5).collect(),
})
}

pub fn task_response(task_create: TaskCreate) -> TaskResponse {
let task_id = TaskId::try_from(
Sha256::digest(URL_SAFE_NO_PAD.decode(task_create.vdaf_verify_key).unwrap()).as_slice(),
Expand Down
39 changes: 32 additions & 7 deletions src/clients/aggregator_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ use crate::{
entity::{task::ProvisionableTask, Aggregator},
handler::Error,
};
use janus_messages::Time as JanusTime;
use serde::{de::DeserializeOwned, Serialize};
use trillium_client::{Client, KnownHeaderName};
use url::Url;
pub mod api_types;
pub use api_types::{AggregatorApiConfig, TaskCreate, TaskIds, TaskResponse, TaskUploadMetrics};

use self::api_types::TaskPatch;

const CONTENT_TYPE: &str = "application/vnd.janus.aggregator+json;version=0.1";

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -83,15 +86,26 @@ impl AggregatorClient {
self.get(&format!("tasks/{task_id}/metrics/uploads")).await
}

pub async fn delete_task(&self, task_id: &str) -> Result<(), ClientError> {
self.delete(&format!("tasks/{task_id}")).await
}

pub async fn create_task(&self, task: &ProvisionableTask) -> Result<TaskResponse, Error> {
let task_create = TaskCreate::build(&self.aggregator, task)?;
self.post("tasks", &task_create).await.map_err(Into::into)
}

pub async fn update_task_expiration(
&self,
task_id: &str,
expiration: Option<JanusTime>,
) -> Result<TaskResponse, Error> {
self.patch(
&format!("tasks/{task_id}"),
&TaskPatch {
task_expiration: expiration,
},
)
.await
.map_err(Into::into)
}

// private below here

async fn get<T: DeserializeOwned>(&self, path: &str) -> Result<T, ClientError> {
Expand Down Expand Up @@ -120,8 +134,19 @@ impl AggregatorClient {
.map_err(ClientError::from)
}

async fn delete(&self, path: &str) -> Result<(), ClientError> {
let _ = self.client.delete(path).success_or_client_error().await?;
Ok(())
async fn patch<T: DeserializeOwned>(
&self,
path: &str,
body: &impl Serialize,
) -> Result<T, ClientError> {
self.client
.patch(path)
.with_json_body(body)?
.with_request_header(KnownHeaderName::ContentType, CONTENT_TYPE)
.success_or_client_error()
.await?
.response_json()
.await
.map_err(ClientError::from)
}
}
5 changes: 5 additions & 0 deletions src/clients/aggregator_client/api_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,11 @@ impl TaskCreate {
}
}

#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
pub struct TaskPatch {
pub task_expiration: Option<JanusTime>,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TaskResponse {
pub task_id: TaskId,
Expand Down
192 changes: 21 additions & 171 deletions src/entity/task.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
use crate::{
clients::aggregator_client::api_types::{TaskResponse, TaskUploadMetrics},
clients::aggregator_client::api_types::TaskResponse,
entity::{
account, membership, AccountColumn, Accounts, Aggregator, AggregatorColumn, Aggregators,
CollectorCredentialColumn, CollectorCredentials,
Aggregator, AggregatorColumn, Aggregators, CollectorCredentialColumn, CollectorCredentials,
},
};
use sea_orm::{
ActiveModelBehavior, ActiveModelTrait, ActiveValue, ConnectionTrait, DbErr, DeriveEntityModel,
DerivePrimaryKey, DeriveRelation, EntityTrait, EnumIter, IntoActiveModel, PrimaryKeyTrait,
Related, RelationDef, RelationTrait,
};
use serde::{Deserialize, Serialize};
use serde::Deserialize;
use std::fmt::Debug;
use time::{Duration, OffsetDateTime};
use uuid::Uuid;
use validator::{Validate, ValidationError};
Expand All @@ -22,172 +17,27 @@ pub use new_task::NewTask;
mod update_task;
pub use update_task::UpdateTask;
mod provisionable_task;
pub use provisionable_task::{ProvisionableTask, TaskProvisioningError};
pub use provisionable_task::ProvisionableTask;
pub mod model;
pub use model::*;

pub const DEFAULT_EXPIRATION_DURATION: Duration = Duration::days(365);

use super::json::Json;

#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)]
#[sea_orm(table_name = "task")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub id: String,
pub account_id: Uuid,
pub name: String,
pub vdaf: Json<Vdaf>,
pub min_batch_size: i64,
pub max_batch_size: Option<i64>,
#[serde(with = "time::serde::rfc3339")]
pub created_at: OffsetDateTime,
#[serde(with = "time::serde::rfc3339")]
pub updated_at: OffsetDateTime,
pub time_precision_seconds: i32,

/// Deprecated metrics field. Never populated, only reads zero.
#[sea_orm(ignore)]
#[serde(default)]
pub report_count: i32,
/// Deprecated metrics field. Never populated, only reads zero.
#[sea_orm(ignore)]
#[serde(default)]
pub aggregate_collection_count: i32,

#[serde(default, with = "time::serde::rfc3339::option")]
pub expiration: Option<OffsetDateTime>,
pub leader_aggregator_id: Uuid,
pub helper_aggregator_id: Uuid,
pub collector_credential_id: Uuid,

pub report_counter_interval_collected: i64,
pub report_counter_decode_failure: i64,
pub report_counter_decrypt_failure: i64,
pub report_counter_expired: i64,
pub report_counter_outdated_key: i64,
pub report_counter_success: i64,
pub report_counter_too_early: i64,
pub report_counter_task_expired: i64,
}

impl Model {
pub async fn update_task_upload_metrics(
self,
metrics: TaskUploadMetrics,
db: impl ConnectionTrait,
) -> Result<Self, DbErr> {
let mut task = self.into_active_model();
task.report_counter_interval_collected =
ActiveValue::Set(metrics.interval_collected.try_into().unwrap_or(i64::MAX));
task.report_counter_decode_failure =
ActiveValue::Set(metrics.report_decode_failure.try_into().unwrap_or(i64::MAX));
task.report_counter_decrypt_failure = ActiveValue::Set(
metrics
.report_decrypt_failure
.try_into()
.unwrap_or(i64::MAX),
);
task.report_counter_expired =
ActiveValue::Set(metrics.report_expired.try_into().unwrap_or(i64::MAX));
task.report_counter_outdated_key =
ActiveValue::Set(metrics.report_outdated_key.try_into().unwrap_or(i64::MAX));
task.report_counter_success =
ActiveValue::Set(metrics.report_success.try_into().unwrap_or(i64::MAX));
task.report_counter_too_early =
ActiveValue::Set(metrics.report_too_early.try_into().unwrap_or(i64::MAX));
task.report_counter_task_expired =
ActiveValue::Set(metrics.task_expired.try_into().unwrap_or(i64::MAX));
task.updated_at = ActiveValue::Set(OffsetDateTime::now_utc());
task.update(&db).await
}

pub async fn leader_aggregator(
&self,
db: &impl ConnectionTrait,
) -> Result<super::Aggregator, DbErr> {
super::Aggregators::find_by_id(self.leader_aggregator_id)
.one(db)
.await
.transpose()
.ok_or(DbErr::Custom("expected leader aggregator".into()))?
}

pub async fn helper_aggregator(&self, db: &impl ConnectionTrait) -> Result<Aggregator, DbErr> {
Aggregators::find_by_id(self.helper_aggregator_id)
.one(db)
.await
.transpose()
.ok_or(DbErr::Custom("expected helper aggregator".into()))?
}

pub async fn aggregators(&self, db: &impl ConnectionTrait) -> Result<[Aggregator; 2], DbErr> {
futures_lite::future::try_zip(self.leader_aggregator(db), self.helper_aggregator(db))
.await
.map(|(leader, helper)| [leader, helper])
}

pub async fn first_party_aggregator(
&self,
db: &impl ConnectionTrait,
) -> Result<Option<Aggregator>, DbErr> {
Ok(self
.aggregators(db)
.await?
.into_iter()
.find(|agg| agg.is_first_party))
}
}

#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "Accounts",
from = "Column::AccountId",
to = "AccountColumn::Id"
)]
Account,

#[sea_orm(
belongs_to = "Aggregators",
from = "Column::HelperAggregatorId",
to = "AggregatorColumn::Id"
)]
HelperAggregator,

#[sea_orm(
belongs_to = "Aggregators",
from = "Column::LeaderAggregatorId",
to = "AggregatorColumn::Id"
)]
LeaderAggregator,

#[sea_orm(
belongs_to = "CollectorCredentials",
from = "Column::CollectorCredentialId",
to = "CollectorCredentialColumn::Id"
)]
CollectorCredential,
#[derive(thiserror::Error, Debug, Clone, Copy)]
pub enum TaskProvisioningError {
#[error("discrepancy in {0}")]
Discrepancy(&'static str),
}

impl Related<account::Entity> for Entity {
fn to() -> RelationDef {
Relation::Account.def()
pub(crate) fn assert_same<T: Eq + Debug>(
ours: T,
theirs: T,
property: &'static str,
) -> Result<(), TaskProvisioningError> {
if ours == theirs {
Ok(())
} else {
log::error!("{property} discrepancy. ours: {ours:?}, theirs: {theirs:?}");
Err(TaskProvisioningError::Discrepancy(property))
}
}

impl Related<membership::Entity> for Entity {
fn to() -> RelationDef {
account::Relation::Memberships.def()
}

fn via() -> Option<RelationDef> {
Some(account::Relation::Tasks.def().rev())
}
}

impl Related<CollectorCredentials> for Entity {
fn to() -> RelationDef {
Relation::CollectorCredential.def()
}
}

impl ActiveModelBehavior for ActiveModel {}
Loading

0 comments on commit b8bf913

Please sign in to comment.