Skip to content

Commit

Permalink
fix: Replaces unwrap() with Err return.
Browse files Browse the repository at this point in the history
- The RwLock::write errors occur if there is a risk of deadlock.
- The timestamp SQL statement might fail if there is a database connection drop.
- The borrowing of the connection for calc_quota_usage_sync should never fail.
  • Loading branch information
tommie committed Jan 8, 2025
1 parent 581b620 commit 9911784
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 67 deletions.
16 changes: 8 additions & 8 deletions syncstorage-mysql/src/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub fn create(db: &MysqlDb, params: params::CreateBatch) -> DbResult<results::Cr
batch_uploads::user_id.eq(&user_id),
batch_uploads::collection_id.eq(&collection_id),
))
.execute(&mut *db.conn.write().unwrap())
.execute(&mut *db.conn.write()?)
.map_err(|e| -> DbError {
match e {
// The user tried to create two batches with the same timestamp
Expand Down Expand Up @@ -77,7 +77,7 @@ pub fn validate(db: &MysqlDb, params: params::ValidateBatch) -> DbResult<bool> {
.filter(batch_uploads::batch_id.eq(&batch_id))
.filter(batch_uploads::user_id.eq(&user_id))
.filter(batch_uploads::collection_id.eq(&collection_id))
.get_result::<i32>(&mut *db.conn.write().unwrap())
.get_result::<i32>(&mut *db.conn.write()?)
.optional()?;
Ok(exists.is_some())
}
Expand Down Expand Up @@ -127,11 +127,11 @@ pub fn delete(db: &MysqlDb, params: params::DeleteBatch) -> DbResult<()> {
.filter(batch_uploads::batch_id.eq(&batch_id))
.filter(batch_uploads::user_id.eq(&user_id))
.filter(batch_uploads::collection_id.eq(&collection_id))
.execute(&mut *db.conn.write().unwrap())?;
.execute(&mut *db.conn.write()?)?;
diesel::delete(batch_upload_items::table)
.filter(batch_upload_items::batch_id.eq(&batch_id))
.filter(batch_upload_items::user_id.eq(&user_id))
.execute(&mut *db.conn.write().unwrap())?;
.execute(&mut *db.conn.write()?)?;
Ok(())
}

Expand All @@ -151,7 +151,7 @@ pub fn commit(db: &MysqlDb, params: params::CommitBatch) -> DbResult<results::Co
.bind::<BigInt, _>(user_id)
.bind::<BigInt, _>(&db.timestamp().as_i64())
.bind::<BigInt, _>(&db.timestamp().as_i64())
.execute(&mut *db.conn.write().unwrap())?;
.execute(&mut *db.conn.write()?)?;

db.update_collection(user_id as u32, collection_id, None)?;

Expand Down Expand Up @@ -211,7 +211,7 @@ pub fn do_append(
)
.bind::<BigInt, _>(user_id.legacy_id as i64)
.bind::<BigInt, _>(batch_id)
.get_results::<ExistsResult>(&mut *db.conn.write().unwrap())?
.get_results::<ExistsResult>(&mut *db.conn.write()?)?
{
existing.insert(exist_idx(
user_id.legacy_id,
Expand All @@ -235,7 +235,7 @@ pub fn do_append(
payload_size,
ttl_offset: bso.ttl.map(|ttl| ttl as i32),
})
.execute(&mut *db.conn.write().unwrap())?;
.execute(&mut *db.conn.write()?)?;
} else {
diesel::insert_into(batch_upload_items::table)
.values((
Expand All @@ -247,7 +247,7 @@ pub fn do_append(
batch_upload_items::payload_size.eq(payload_size),
batch_upload_items::ttl_offset.eq(bso.ttl.map(|ttl| ttl as i32)),
))
.execute(&mut *db.conn.write().unwrap())?;
.execute(&mut *db.conn.write()?)?;
// make sure to include the key into our table check.
existing.insert(exist_idx);
}
Expand Down
6 changes: 6 additions & 0 deletions syncstorage-mysql/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,9 @@ from_error!(
DbError,
|error: std::boxed::Box<dyn std::error::Error>| DbError::internal_error(error.to_string())
);

impl<Guard> From<std::sync::PoisonError<Guard>> for DbError {
fn from(inner: std::sync::PoisonError<Guard>) -> DbError {
DbError::internal_error(inner.to_string())
}
}
68 changes: 33 additions & 35 deletions syncstorage-mysql/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ impl MysqlDb {
.filter(user_collections::user_id.eq(user_id))
.filter(user_collections::collection_id.eq(collection_id))
.lock_in_share_mode()
.first(&mut *self.conn.write().unwrap())
.first(&mut *self.conn.write()?)
.optional()?;
if let Some(modified) = modified {
let modified = SyncTimestamp::from_i64(modified)?;
Expand Down Expand Up @@ -222,7 +222,7 @@ impl MysqlDb {
.filter(user_collections::user_id.eq(user_id))
.filter(user_collections::collection_id.eq(collection_id))
.for_update()
.first(&mut *self.conn.write().unwrap())
.first(&mut *self.conn.write()?)
.optional()?;
let timestamp = if let Some((timestamp, modified)) = result {
let modified = SyncTimestamp::from_i64(modified)?;
Expand All @@ -238,8 +238,7 @@ impl MysqlDb {
now
} else {
let result = sql_query("SELECT UNIX_TIMESTAMP(UTC_TIMESTAMP(2))*1000 AS timestamp")
.get_result::<TimestampResult>(&mut *self.conn.write().unwrap())
.unwrap();
.get_result::<TimestampResult>(&mut *self.conn.write()?)?;
SyncTimestamp::from_i64(result.timestamp)?
};
self.set_timestamp(timestamp);
Expand All @@ -253,7 +252,7 @@ impl MysqlDb {

pub(super) fn begin(&self, for_write: bool) -> DbResult<()> {
<InternalConn as Connection>::TransactionManager::begin_transaction(
&mut *self.conn.write().unwrap(),
&mut *self.conn.write()?,
)?;
self.session.borrow_mut().in_transaction = true;
if for_write {
Expand All @@ -269,7 +268,7 @@ impl MysqlDb {
fn commit_sync(&self) -> DbResult<()> {
if self.session.borrow().in_transaction {
<InternalConn as Connection>::TransactionManager::commit_transaction(
&mut *self.conn.write().unwrap(),
&mut *self.conn.write()?,
)?;
}
Ok(())
Expand All @@ -278,7 +277,7 @@ impl MysqlDb {
fn rollback_sync(&self) -> DbResult<()> {
if self.session.borrow().in_transaction {
<InternalConn as Connection>::TransactionManager::rollback_transaction(
&mut *self.conn.write().unwrap(),
&mut *self.conn.write()?,
)?;
}
Ok(())
Expand All @@ -297,7 +296,7 @@ impl MysqlDb {
.bind::<BigInt, _>(user_id as i64)
.bind::<Integer, _>(TOMBSTONE)
.bind::<BigInt, _>(self.timestamp().as_i64())
.execute(&mut *self.conn.write().unwrap())?;
.execute(&mut *self.conn.write()?)?;
Ok(())
}

Expand All @@ -306,11 +305,11 @@ impl MysqlDb {
// Delete user data.
delete(bso::table)
.filter(bso::user_id.eq(user_id))
.execute(&mut *self.conn.write().unwrap())?;
.execute(&mut *self.conn.write()?)?;
// Delete user collections.
delete(user_collections::table)
.filter(user_collections::user_id.eq(user_id))
.execute(&mut *self.conn.write().unwrap())?;
.execute(&mut *self.conn.write()?)?;
Ok(())
}

Expand All @@ -323,11 +322,11 @@ impl MysqlDb {
let mut count = delete(bso::table)
.filter(bso::user_id.eq(user_id))
.filter(bso::collection_id.eq(&collection_id))
.execute(&mut *self.conn.write().unwrap())?;
.execute(&mut *self.conn.write()?)?;
count += delete(user_collections::table)
.filter(user_collections::user_id.eq(user_id))
.filter(user_collections::collection_id.eq(&collection_id))
.execute(&mut *self.conn.write().unwrap())?;
.execute(&mut *self.conn.write()?)?;
if count == 0 {
return Err(DbError::collection_not_found());
} else {
Expand All @@ -341,7 +340,7 @@ impl MysqlDb {
return Ok(id);
}

let id = self.conn.write().unwrap().transaction(|tx| {
let id = self.conn.write()?.transaction(|tx| {
diesel::insert_or_ignore_into(collections::table)
.values(collections::name.eq(name))
.execute(tx)?;
Expand Down Expand Up @@ -370,7 +369,7 @@ impl MysqlDb {
WHERE name = ?",
)
.bind::<Text, _>(name)
.get_result::<IdResult>(&mut *self.conn.write().unwrap())
.get_result::<IdResult>(&mut *self.conn.write()?)
.optional()?
.ok_or_else(DbError::collection_not_found)?
.id;
Expand All @@ -390,7 +389,7 @@ impl MysqlDb {
WHERE id = ?",
)
.bind::<Integer, _>(&id)
.get_result::<NameResult>(&mut *self.conn.write().unwrap())
.get_result::<NameResult>(&mut *self.conn.write()?)
.optional()?
.ok_or_else(DbError::collection_not_found)?
.name
Expand Down Expand Up @@ -428,7 +427,7 @@ impl MysqlDb {
}
}

self.conn.write().unwrap().transaction(|tx| {
self.conn.write()?.transaction(|tx| {
let payload = bso.payload.as_deref().unwrap_or_default();
let sortindex = bso.sortindex;
let ttl = bso.ttl.map_or(DEFAULT_BSO_TTL, |ttl| ttl);
Expand Down Expand Up @@ -552,7 +551,7 @@ impl MysqlDb {
// https://github.com/mozilla-services/server-syncstorage/blob/a0f8117/syncstorage/storage/sql/__init__.py#L404
query = query.offset(numeric_offset);
}
let mut bsos = query.load::<results::GetBso>(&mut *self.conn.write().unwrap())?;
let mut bsos = query.load::<results::GetBso>(&mut *self.conn.write()?)?;

// XXX: an additional get_collection_timestamp is done here in
// python to trigger potential CollectionNotFoundErrors
Expand Down Expand Up @@ -622,7 +621,7 @@ impl MysqlDb {
// https://github.com/mozilla-services/server-syncstorage/blob/a0f8117/syncstorage/storage/sql/__init__.py#L404
query = query.offset(numeric_offset);
}
let mut ids = query.load::<String>(&mut *self.conn.write().unwrap())?;
let mut ids = query.load::<String>(&mut *self.conn.write()?)?;

// XXX: an additional get_collection_timestamp is done here in
// python to trigger potential CollectionNotFoundErrors
Expand Down Expand Up @@ -657,7 +656,7 @@ impl MysqlDb {
.filter(bso::collection_id.eq(&collection_id))
.filter(bso::id.eq(&params.id))
.filter(bso::expiry.ge(self.timestamp().as_i64()))
.get_result::<results::GetBso>(&mut *self.conn.write().unwrap())
.get_result::<results::GetBso>(&mut *self.conn.write()?)
.optional()?)
}

Expand All @@ -669,7 +668,7 @@ impl MysqlDb {
.filter(bso::collection_id.eq(&collection_id))
.filter(bso::id.eq(params.id))
.filter(bso::expiry.gt(&self.timestamp().as_i64()))
.execute(&mut *self.conn.write().unwrap())?;
.execute(&mut *self.conn.write()?)?;
if affected_rows == 0 {
return Err(DbError::bso_not_found());
}
Expand All @@ -683,7 +682,7 @@ impl MysqlDb {
.filter(bso::user_id.eq(user_id))
.filter(bso::collection_id.eq(&collection_id))
.filter(bso::id.eq_any(params.ids))
.execute(&mut *self.conn.write().unwrap())?;
.execute(&mut *self.conn.write()?)?;
self.update_collection(user_id as u32, collection_id, None)
}

Expand Down Expand Up @@ -725,7 +724,7 @@ impl MysqlDb {
let modified = user_collections::table
.select(max(user_collections::modified))
.filter(user_collections::user_id.eq(user_id))
.first::<Option<i64>>(&mut *self.conn.write().unwrap())?
.first::<Option<i64>>(&mut *self.conn.write()?)?
.unwrap_or_default();
SyncTimestamp::from_i64(modified).map_err(Into::into)
}
Expand All @@ -748,7 +747,7 @@ impl MysqlDb {
.select(user_collections::modified)
.filter(user_collections::user_id.eq(user_id as i64))
.filter(user_collections::collection_id.eq(collection_id))
.first(&mut *self.conn.write().unwrap())
.first(&mut *self.conn.write()?)
.optional()?
.ok_or_else(DbError::collection_not_found)
}
Expand All @@ -761,7 +760,7 @@ impl MysqlDb {
.filter(bso::user_id.eq(user_id))
.filter(bso::collection_id.eq(&collection_id))
.filter(bso::id.eq(&params.id))
.first::<i64>(&mut *self.conn.write().unwrap())
.first::<i64>(&mut *self.conn.write()?)
.optional()?
.unwrap_or_default();
SyncTimestamp::from_i64(modified).map_err(Into::into)
Expand All @@ -782,7 +781,7 @@ impl MysqlDb {
))
.bind::<BigInt, _>(user_id.legacy_id as i64)
.bind::<Integer, _>(TOMBSTONE)
.load::<UserCollectionsResult>(&mut *self.conn.write().unwrap())?
.load::<UserCollectionsResult>(&mut *self.conn.write()?)?
.into_iter()
.map(|cr| {
SyncTimestamp::from_i64(cr.last_modified)
Expand All @@ -795,8 +794,7 @@ impl MysqlDb {

fn check_sync(&self) -> DbResult<results::Check> {
// has the database been up for more than 0 seconds?
let result =
sql_query("SHOW STATUS LIKE \"Uptime\"").execute(&mut *self.conn.write().unwrap())?;
let result = sql_query("SHOW STATUS LIKE \"Uptime\"").execute(&mut *self.conn.write()?)?;
Ok(result as u64 > 0)
}

Expand Down Expand Up @@ -830,7 +828,7 @@ impl MysqlDb {
let result = collections::table
.select((collections::id, collections::name))
.filter(collections::id.eq_any(uncached))
.load::<(i32, String)>(&mut *self.conn.write().unwrap())?;
.load::<(i32, String)>(&mut *self.conn.write()?)?;

for (id, name) in result {
names.insert(id, name.clone());
Expand All @@ -850,7 +848,7 @@ impl MysqlDb {
mut conn: Option<&mut InternalConn>,
) -> DbResult<SyncTimestamp> {
let quota = if self.quota.enabled {
self.calc_quota_usage_sync(user_id, collection_id, Some(&mut **conn.as_mut().unwrap()))?
self.calc_quota_usage_sync(user_id, collection_id, conn.as_deref_mut())?
} else {
results::GetQuotaUsage {
count: 0,
Expand Down Expand Up @@ -886,7 +884,7 @@ impl MysqlDb {
if let Some(conn) = conn {
q.execute(conn)?;
} else {
q.execute(&mut *self.conn.write().unwrap())?;
q.execute(&mut *self.conn.write()?)?;
}
Ok(self.timestamp())
}
Expand All @@ -901,7 +899,7 @@ impl MysqlDb {
.select(sql::<Nullable<BigInt>>("SUM(LENGTH(payload))"))
.filter(bso::user_id.eq(uid))
.filter(bso::expiry.gt(&self.timestamp().as_i64()))
.get_result::<Option<i64>>(&mut *self.conn.write().unwrap())?;
.get_result::<Option<i64>>(&mut *self.conn.write()?)?;
Ok(total_bytes.unwrap_or_default() as u64)
}

Expand All @@ -918,7 +916,7 @@ impl MysqlDb {
))
.filter(user_collections::user_id.eq(uid))
.filter(user_collections::collection_id.eq(params.collection_id))
.get_result(&mut *self.conn.write().unwrap())
.get_result(&mut *self.conn.write()?)
.optional()?
.unwrap_or_default();
Ok(results::GetQuotaUsage {
Expand All @@ -945,7 +943,7 @@ impl MysqlDb {
let (total_bytes, count): (i64, i32) = if let Some(conn) = conn {
q.get_result(conn)
} else {
q.get_result(&mut *self.conn.write().unwrap())
q.get_result(&mut *self.conn.write()?)
}
.optional()?
.unwrap_or_default();
Expand All @@ -964,7 +962,7 @@ impl MysqlDb {
.filter(bso::user_id.eq(user_id.legacy_id as i64))
.filter(bso::expiry.gt(&self.timestamp().as_i64()))
.group_by(bso::collection_id)
.load(&mut *self.conn.write().unwrap())?
.load(&mut *self.conn.write()?)?
.into_iter()
.collect();
self.map_collection_names(counts)
Expand All @@ -985,7 +983,7 @@ impl MysqlDb {
.filter(bso::user_id.eq(user_id.legacy_id as i64))
.filter(bso::expiry.gt(&self.timestamp().as_i64()))
.group_by(bso::collection_id)
.load(&mut *self.conn.write().unwrap())?
.load(&mut *self.conn.write()?)?
.into_iter()
.collect();
self.map_collection_names(counts)
Expand Down
2 changes: 1 addition & 1 deletion syncstorage-mysql/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ fn static_collection_id() -> DbResult<()> {
.filter(collections::name.ne(""))
.filter(collections::name.ne("xxx_col2")) // from server::test
.filter(collections::name.ne("col2")) // from older intergration tests
.load(&mut *db.inner.conn.write().unwrap())?
.load(&mut *db.inner.conn.write()?)?
.into_iter()
.collect();
assert_eq!(results.len(), cols.len(), "mismatched columns");
Expand Down
6 changes: 6 additions & 0 deletions tokenserver-db/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,9 @@ from_error!(
DbError,
|error: std::boxed::Box<dyn std::error::Error>| DbError::internal_error(error.to_string())
);

impl<Guard> From<std::sync::PoisonError<Guard>> for DbError {
fn from(inner: std::sync::PoisonError<Guard>) -> DbError {
DbError::internal_error(inner.to_string())
}
}
Loading

0 comments on commit 9911784

Please sign in to comment.