Skip to content

Commit

Permalink
Allow setting an explicit upstream account name (#3600)
Browse files Browse the repository at this point in the history
  • Loading branch information
sandhose authored Nov 29, 2024
1 parent 2c01b43 commit 2e3b8bd
Show file tree
Hide file tree
Showing 22 changed files with 302 additions and 42 deletions.
3 changes: 2 additions & 1 deletion crates/cli/src/commands/manage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -955,9 +955,10 @@ impl UserCreationRequest<'_> {
}

for (provider, subject) in upstream_provider_mappings {
// Note that we don't pass a human_account_name here, as we don't ask for it
let link = repo
.upstream_oauth_link()
.add(rng, clock, provider, subject)
.add(rng, clock, provider, subject, None)
.await?;

repo.upstream_oauth_link()
Expand Down
3 changes: 3 additions & 0 deletions crates/cli/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ fn map_claims_imports(
mas_data_model::UpsreamOAuthProviderSetEmailVerification::Import
}
},
account_name: mas_data_model::UpstreamOAuthProviderSubjectPreference {
template: config.account_name.template.clone(),
},
}
}

Expand Down
24 changes: 24 additions & 0 deletions crates/config/src/sections/upstream_oauth2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,23 @@ impl EmailImportPreference {
}
}

/// What should be done for the account name attribute
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
pub struct AccountNameImportPreference {
/// The Jinja2 template to use for the account name. This name is only used
/// for display purposes.
///
/// If not provided, it will be ignored.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub template: Option<String>,
}

impl AccountNameImportPreference {
const fn is_default(&self) -> bool {
self.template.is_none()
}
}

/// How claims should be imported
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
pub struct ClaimsImports {
Expand All @@ -307,6 +324,13 @@ pub struct ClaimsImports {
/// `email_verified` claims
#[serde(default, skip_serializing_if = "EmailImportPreference::is_default")]
pub email: EmailImportPreference,

/// Set a human-readable name for the upstream account for display purposes
#[serde(
default,
skip_serializing_if = "AccountNameImportPreference::is_default"
)]
pub account_name: AccountNameImportPreference,
}

impl ClaimsImports {
Expand Down
1 change: 1 addition & 0 deletions crates/data-model/src/upstream_oauth2/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ pub struct UpstreamOAuthLink {
pub provider_id: Ulid,
pub user_id: Option<Ulid>,
pub subject: String,
pub human_account_name: Option<String>,
pub created_at: DateTime<Utc>,
}
4 changes: 4 additions & 0 deletions crates/data-model/src/upstream_oauth2/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,14 @@ pub struct ClaimsImports {
#[serde(default)]
pub email: ImportPreference,

#[serde(default)]
pub account_name: SubjectPreference,

#[serde(default)]
pub verify_email: SetEmailVerification,
}

// XXX: this should have another name
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct SubjectPreference {
#[serde(default)]
Expand Down
22 changes: 20 additions & 2 deletions crates/handlers/src/upstream_oauth2/callback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ pub(crate) async fn handler(
.as_deref()
.unwrap_or("{{ user.sub }}");
let subject = env
.render_str(template, context)
.render_str(template, context.clone())
.map_err(RouteError::ExtractSubject)?;

if subject.is_empty() {
Expand All @@ -375,8 +375,26 @@ pub(crate) async fn handler(
let link = if let Some(link) = maybe_link {
link
} else {
// Try to render the human account name if we have one,
// but just log if it fails
let human_account_name = provider
.claims_imports
.account_name
.template
.as_deref()
.and_then(|template| match env.render_str(template, context) {
Ok(name) => Some(name),
Err(e) => {
tracing::warn!(
error = &e as &dyn std::error::Error,
"Failed to render account name"
);
None
}
});

repo.upstream_oauth_link()
.add(&mut rng, &clock, &provider, subject)
.add(&mut rng, &clock, &provider, subject, human_account_name)
.await?
};

Expand Down
12 changes: 9 additions & 3 deletions crates/handlers/src/upstream_oauth2/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ pub(crate) async fn get(
.await?
.ok_or(RouteError::ProviderNotFound)?;

let ctx = UpstreamRegister::default();
let ctx = UpstreamRegister::new(link.clone(), provider.clone());

let env = environment();

Expand Down Expand Up @@ -596,7 +596,7 @@ pub(crate) async fn post(
.map_or(false, |v| v == "true");

// Create a template context in case we need to re-render because of an error
let ctx = UpstreamRegister::default();
let ctx = UpstreamRegister::new(link.clone(), provider.clone());

let display_name = if provider
.claims_imports
Expand Down Expand Up @@ -954,7 +954,13 @@ mod tests {

let link = repo
.upstream_oauth_link()
.add(&mut rng, &state.clock, &provider, "subject".to_owned())
.add(
&mut rng,
&state.clock,
&provider,
"subject".to_owned(),
None,
)
.await
.unwrap();

Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
-- Copyright 2024 New Vector Ltd.
--
-- SPDX-License-Identifier: AGPL-3.0-only
-- Please see LICENSE in the repository root for full details.

-- Add the human_account_name column to the upstream_oauth_links table to store
-- a human-readable name for the upstream account
ALTER TABLE "upstream_oauth_links"
ADD COLUMN "human_account_name" TEXT;
1 change: 1 addition & 0 deletions crates/storage-pg/src/iden.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,5 +122,6 @@ pub enum UpstreamOAuthLinks {
UpstreamOAuthProviderId,
UserId,
Subject,
HumanAccountName,
CreatedAt,
}
18 changes: 17 additions & 1 deletion crates/storage-pg/src/upstream_oauth2/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ struct LinkLookup {
upstream_oauth_provider_id: Uuid,
user_id: Option<Uuid>,
subject: String,
human_account_name: Option<String>,
created_at: DateTime<Utc>,
}

Expand All @@ -57,6 +58,7 @@ impl From<LinkLookup> for UpstreamOAuthLink {
provider_id: Ulid::from(value.upstream_oauth_provider_id),
user_id: value.user_id.map(Ulid::from),
subject: value.subject,
human_account_name: value.human_account_name,
created_at: value.created_at,
}
}
Expand Down Expand Up @@ -124,6 +126,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
upstream_oauth_provider_id,
user_id,
subject,
human_account_name,
created_at
FROM upstream_oauth_links
WHERE upstream_oauth_link_id = $1
Expand Down Expand Up @@ -163,6 +166,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
upstream_oauth_provider_id,
user_id,
subject,
human_account_name,
created_at
FROM upstream_oauth_links
WHERE upstream_oauth_provider_id = $1
Expand All @@ -186,6 +190,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
db.query.text,
upstream_oauth_link.id,
upstream_oauth_link.subject = subject,
upstream_oauth_link.human_account_name = human_account_name,
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
Expand All @@ -198,6 +203,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
clock: &dyn Clock,
upstream_oauth_provider: &UpstreamOAuthProvider,
subject: String,
human_account_name: Option<String>,
) -> Result<UpstreamOAuthLink, Self::Error> {
let created_at = clock.now();
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
Expand All @@ -210,12 +216,14 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
upstream_oauth_provider_id,
user_id,
subject,
human_account_name,
created_at
) VALUES ($1, $2, NULL, $3, $4)
) VALUES ($1, $2, NULL, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(upstream_oauth_provider.id),
&subject,
human_account_name.as_deref(),
created_at,
)
.traced()
Expand All @@ -227,6 +235,7 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
provider_id: upstream_oauth_provider.id,
user_id: None,
subject,
human_account_name,
created_at,
})
}
Expand Down Expand Up @@ -300,6 +309,13 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::Subject)),
LinkLookupIden::Subject,
)
.expr_as(
Expr::col((
UpstreamOAuthLinks::Table,
UpstreamOAuthLinks::HumanAccountName,
)),
LinkLookupIden::HumanAccountName,
)
.expr_as(
Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::CreatedAt)),
LinkLookupIden::CreatedAt,
Expand Down
2 changes: 1 addition & 1 deletion crates/storage-pg/src/upstream_oauth2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ mod tests {
// Create a link
let link = repo
.upstream_oauth_link()
.add(&mut rng, &clock, &provider, "a-subject".to_owned())
.add(&mut rng, &clock, &provider, "a-subject".to_owned(), None)
.await
.unwrap();

Expand Down
3 changes: 3 additions & 0 deletions crates/storage/src/upstream_oauth2/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync {
/// * `upsream_oauth_provider`: The upstream OAuth provider for which to
/// create the link
/// * `subject`: The subject of the upstream OAuth link to create
/// * `human_account_name`: A human-readable name for the upstream account
///
/// # Errors
///
Expand All @@ -138,6 +139,7 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync {
clock: &dyn Clock,
upstream_oauth_provider: &UpstreamOAuthProvider,
subject: String,
human_account_name: Option<String>,
) -> Result<UpstreamOAuthLink, Self::Error>;

/// Associate an upstream OAuth link to a user
Expand Down Expand Up @@ -201,6 +203,7 @@ repository_impl!(UpstreamOAuthLinkRepository:
clock: &dyn Clock,
upstream_oauth_provider: &UpstreamOAuthProvider,
subject: String,
human_account_name: Option<String>,
) -> Result<UpstreamOAuthLink, Self::Error>;

async fn associate_to_user(
Expand Down
Loading

0 comments on commit 2e3b8bd

Please sign in to comment.