diff --git a/ee/tabby-db/src/lib.rs b/ee/tabby-db/src/lib.rs index 248d38886baf..31aaaf69d88b 100644 --- a/ee/tabby-db/src/lib.rs +++ b/ee/tabby-db/src/lib.rs @@ -14,9 +14,9 @@ pub use repositories::RepositoryDAO; pub use server_setting::ServerSettingDAO; use sqlx::{query, query_scalar, sqlite::SqliteQueryResult, Pool, Sqlite, SqlitePool}; pub use threads::{ - ThreadDAO, ThreadMessageAttachmentClientCode, ThreadMessageAttachmentCode, - ThreadMessageAttachmentDoc, ThreadMessageAttachmentIssueDoc, ThreadMessageAttachmentPullDoc, - ThreadMessageAttachmentWebDoc, ThreadMessageDAO, + ThreadDAO, ThreadMessageAttachmentAuthor, ThreadMessageAttachmentClientCode, + ThreadMessageAttachmentCode, ThreadMessageAttachmentDoc, ThreadMessageAttachmentIssueDoc, + ThreadMessageAttachmentPullDoc, ThreadMessageAttachmentWebDoc, ThreadMessageDAO, }; use tokio::sync::Mutex; use user_completions::UserCompletionDailyStatsDAO; diff --git a/ee/tabby-db/src/threads.rs b/ee/tabby-db/src/threads.rs index 1e3909c6bc3e..a68709402efc 100644 --- a/ee/tabby-db/src/threads.rs +++ b/ee/tabby-db/src/threads.rs @@ -50,6 +50,7 @@ pub struct ThreadMessageAttachmentWebDoc { pub struct ThreadMessageAttachmentIssueDoc { pub title: String, pub link: String, + pub author_user_id: Option, pub body: String, pub closed: bool, } @@ -58,11 +59,19 @@ pub struct ThreadMessageAttachmentIssueDoc { pub struct ThreadMessageAttachmentPullDoc { pub title: String, pub link: String, + pub author_user_id: Option, pub body: String, pub diff: String, pub merged: bool, } +#[derive(Serialize, Deserialize)] +pub struct ThreadMessageAttachmentAuthor { + pub id: String, + pub name: String, + pub email: String, +} + #[derive(Serialize, Deserialize)] pub struct ThreadMessageAttachmentCode { pub git_url: String, diff --git a/ee/tabby-schema/graphql/schema.graphql b/ee/tabby-schema/graphql/schema.graphql index 97c59b64c9d0..67c1b82377d1 100644 --- a/ee/tabby-schema/graphql/schema.graphql +++ b/ee/tabby-schema/graphql/schema.graphql @@ -512,6 +512,7 @@ type MessageAttachmentCodeScores { type MessageAttachmentIssueDoc { title: String! link: String! + author: User body: String! closed: Boolean! } @@ -519,6 +520,7 @@ type MessageAttachmentIssueDoc { type MessageAttachmentPullDoc { title: String! link: String! + author: User body: String! patch: String! merged: Boolean! diff --git a/ee/tabby-schema/src/dao.rs b/ee/tabby-schema/src/dao.rs index 8e8ba645d82e..7961ca48dc24 100644 --- a/ee/tabby-schema/src/dao.rs +++ b/ee/tabby-schema/src/dao.rs @@ -5,11 +5,12 @@ use tabby_db::{ EmailSettingDAO, IntegrationDAO, InvitationDAO, JobRunDAO, OAuthCredentialDAO, ServerSettingDAO, ThreadDAO, ThreadMessageAttachmentClientCode, ThreadMessageAttachmentCode, ThreadMessageAttachmentDoc, ThreadMessageAttachmentIssueDoc, ThreadMessageAttachmentPullDoc, - ThreadMessageAttachmentWebDoc, ThreadMessageDAO, UserEventDAO, + ThreadMessageAttachmentWebDoc, UserEventDAO, }; use crate::{ integration::{Integration, IntegrationKind, IntegrationStatus}, + interface::UserValue, repository::RepositoryKind, schema::{ auth::{self, OAuthCredential, OAuthProvider}, @@ -22,7 +23,7 @@ use crate::{ user_event::{EventKind, UserEvent}, CoreError, }, - thread::{self, MessageAttachment}, + thread::{self}, }; impl From for auth::Invitation { @@ -228,33 +229,36 @@ impl From<&thread::MessageAttachmentCodeInput> for ThreadMessageAttachmentClient } } -impl From for thread::MessageAttachmentDoc { - fn from(value: ThreadMessageAttachmentDoc) -> Self { - match value { - ThreadMessageAttachmentDoc::Web(val) => { - thread::MessageAttachmentDoc::Web(thread::MessageAttachmentWebDoc { - title: val.title, - link: val.link, - content: val.content, - }) - } - ThreadMessageAttachmentDoc::Issue(val) => { - thread::MessageAttachmentDoc::Issue(thread::MessageAttachmentIssueDoc { - title: val.title, - link: val.link, - body: val.body, - closed: val.closed, - }) - } - ThreadMessageAttachmentDoc::Pull(val) => { - thread::MessageAttachmentDoc::Pull(thread::MessageAttachmentPullDoc { - title: val.title, - link: val.link, - body: val.body, - patch: val.diff, - merged: val.merged, - }) - } +pub fn from_thread_message_attachment_document( + doc: ThreadMessageAttachmentDoc, + author: Option, +) -> thread::MessageAttachmentDoc { + match doc { + ThreadMessageAttachmentDoc::Web(web) => { + thread::MessageAttachmentDoc::Web(thread::MessageAttachmentWebDoc { + title: web.title, + link: web.link, + content: web.content, + }) + } + ThreadMessageAttachmentDoc::Issue(issue) => { + thread::MessageAttachmentDoc::Issue(thread::MessageAttachmentIssueDoc { + title: issue.title, + link: issue.link, + author, + body: issue.body, + closed: issue.closed, + }) + } + ThreadMessageAttachmentDoc::Pull(pull) => { + thread::MessageAttachmentDoc::Pull(thread::MessageAttachmentPullDoc { + title: pull.title, + link: pull.link, + author, + body: pull.body, + patch: pull.diff, + merged: pull.merged, + }) } } } @@ -273,6 +277,9 @@ impl From<&thread::MessageAttachmentDoc> for ThreadMessageAttachmentDoc { ThreadMessageAttachmentDoc::Issue(ThreadMessageAttachmentIssueDoc { title: val.title.clone(), link: val.link.clone(), + author_user_id: val.author.as_ref().map(|x| match x { + UserValue::UserSecured(user) => user.id.to_string(), + }), body: val.body.clone(), closed: val.closed, }) @@ -281,6 +288,9 @@ impl From<&thread::MessageAttachmentDoc> for ThreadMessageAttachmentDoc { ThreadMessageAttachmentDoc::Pull(ThreadMessageAttachmentPullDoc { title: val.title.clone(), link: val.link.clone(), + author_user_id: val.author.as_ref().map(|x| match x { + UserValue::UserSecured(user) => user.id.to_string(), + }), body: val.body.clone(), diff: val.patch.clone(), merged: val.merged, @@ -301,37 +311,6 @@ impl From for thread::Thread { } } -impl TryFrom for thread::Message { - type Error = anyhow::Error; - fn try_from(value: ThreadMessageDAO) -> Result { - let code = value.code_attachments; - let client_code = value.client_code_attachments; - let doc = value.doc_attachments; - - let attachment = MessageAttachment { - code: code - .map(|x| x.0.into_iter().map(|i| i.into()).collect()) - .unwrap_or_default(), - client_code: client_code - .map(|x| x.0.into_iter().map(|i| i.into()).collect()) - .unwrap_or_default(), - doc: doc - .map(|x| x.0.into_iter().map(|i| i.into()).collect()) - .unwrap_or_default(), - }; - - Ok(Self { - id: value.id.as_id(), - thread_id: value.thread_id.as_id(), - role: thread::Role::from_enum_str(&value.role)?, - content: value.content, - attachment, - created_at: value.created_at, - updated_at: value.updated_at, - }) - } -} - lazy_static! { static ref HASHER: HashIds = HashIds::builder() .with_salt("tabby-id-serializer") diff --git a/ee/tabby-schema/src/schema/auth.rs b/ee/tabby-schema/src/schema/auth.rs index c05ed1b6d27b..404da5d72a8c 100644 --- a/ee/tabby-schema/src/schema/auth.rs +++ b/ee/tabby-schema/src/schema/auth.rs @@ -174,7 +174,7 @@ impl JWTPayload { } } -#[derive(Debug, GraphQLObject)] +#[derive(Debug, GraphQLObject, Clone)] #[graphql(context = Context, impl = [UserValue])] pub struct UserSecured { // === implements User === diff --git a/ee/tabby-schema/src/schema/thread/types.rs b/ee/tabby-schema/src/schema/thread/types.rs index 3b0cd7588a2f..7ac4a4c9ebe4 100644 --- a/ee/tabby-schema/src/schema/thread/types.rs +++ b/ee/tabby-schema/src/schema/thread/types.rs @@ -3,11 +3,11 @@ use juniper::{GraphQLEnum, GraphQLInputObject, GraphQLObject, GraphQLUnion, ID}; use serde::Serialize; use tabby_common::api::{ code::{CodeSearchDocument, CodeSearchHit, CodeSearchScores}, - structured_doc::{DocSearchDocument, DocSearchHit}, + structured_doc::DocSearchDocument, }; use validator::Validate; -use crate::{juniper::relay::NodeType, Context}; +use crate::{interface::UserValue, juniper::relay::NodeType, Context}; #[derive(GraphQLEnum, Serialize, Clone, PartialEq, Eq)] pub enum Role { @@ -55,6 +55,7 @@ pub struct UpdateMessageInput { } #[derive(GraphQLObject, Clone, Default)] +#[graphql(context = Context)] pub struct MessageAttachment { pub code: Vec, pub client_code: Vec, @@ -122,6 +123,7 @@ impl From for MessageCodeSearchHit { } #[derive(GraphQLUnion, Clone)] +#[graphql(context = Context)] pub enum MessageAttachmentDoc { Web(MessageAttachmentWebDoc), Issue(MessageAttachmentIssueDoc), @@ -136,24 +138,28 @@ pub struct MessageAttachmentWebDoc { } #[derive(GraphQLObject, Clone)] +#[graphql(context = Context)] pub struct MessageAttachmentIssueDoc { pub title: String, pub link: String, + pub author: Option, pub body: String, pub closed: bool, } #[derive(GraphQLObject, Clone)] +#[graphql(context = Context)] pub struct MessageAttachmentPullDoc { pub title: String, pub link: String, + pub author: Option, pub body: String, pub patch: String, pub merged: bool, } -impl From for MessageAttachmentDoc { - fn from(doc: DocSearchDocument) -> Self { +impl MessageAttachmentDoc { + pub fn from_doc_search_document(doc: DocSearchDocument, author: Option) -> Self { match doc { DocSearchDocument::Web(web) => MessageAttachmentDoc::Web(MessageAttachmentWebDoc { title: web.title, @@ -164,6 +170,7 @@ impl From for MessageAttachmentDoc { MessageAttachmentDoc::Issue(MessageAttachmentIssueDoc { title: issue.title, link: issue.link, + author, body: issue.body, closed: issue.closed, }) @@ -171,6 +178,7 @@ impl From for MessageAttachmentDoc { DocSearchDocument::Pull(pull) => MessageAttachmentDoc::Pull(MessageAttachmentPullDoc { title: pull.title, link: pull.link, + author, body: pull.body, patch: pull.diff, merged: pull.merged, @@ -180,20 +188,12 @@ impl From for MessageAttachmentDoc { } #[derive(GraphQLObject)] +#[graphql(context = Context)] pub struct MessageDocSearchHit { pub doc: MessageAttachmentDoc, pub score: f64, } -impl From for MessageDocSearchHit { - fn from(hit: DocSearchHit) -> Self { - Self { - doc: hit.doc.into(), - score: hit.score as f64, - } - } -} - #[derive(GraphQLObject)] #[graphql(context = Context)] pub struct Thread { @@ -245,6 +245,7 @@ pub struct ThreadAssistantMessageAttachmentsCode { } #[derive(GraphQLObject)] +#[graphql(context = Context)] pub struct ThreadAssistantMessageAttachmentsDoc { pub hits: Vec, } @@ -263,6 +264,7 @@ pub struct ThreadAssistantMessageCompleted { /// /// Apart from `thread_message_content_delta`, all other items will only appear once in the stream. #[derive(GraphQLUnion)] +#[graphql(context = Context)] pub enum ThreadRunItem { ThreadCreated(ThreadCreated), ThreadRelevantQuestions(ThreadRelevantQuestions), diff --git a/ee/tabby-webserver/src/service/access_policy.rs b/ee/tabby-webserver/src/service/access_policy.rs index ed5711a5d4f9..185c9b719ce3 100644 --- a/ee/tabby-webserver/src/service/access_policy.rs +++ b/ee/tabby-webserver/src/service/access_policy.rs @@ -57,3 +57,18 @@ impl AccessPolicyService for AccessPolicyServiceImpl { pub fn create(db: DbConn, context: Arc) -> impl AccessPolicyService { AccessPolicyServiceImpl { db, context } } + +#[cfg(test)] +pub mod testutils { + use tabby_schema::policy::AccessPolicy; + + use super::*; + + pub async fn make_policy() -> AccessPolicy { + AccessPolicy::new( + DbConn::new_in_memory().await.unwrap(), + &ID::from("nihao".to_string()), + false, + ) + } +} diff --git a/ee/tabby-webserver/src/service/answer.rs b/ee/tabby-webserver/src/service/answer.rs index 55ec60c225f8..7b9458a4aad3 100644 --- a/ee/tabby-webserver/src/service/answer.rs +++ b/ee/tabby-webserver/src/service/answer.rs @@ -22,18 +22,19 @@ use tabby_common::{ CodeSearch, CodeSearchError, CodeSearchHit, CodeSearchParams, CodeSearchQuery, CodeSearchScores, }, - structured_doc::{DocSearch, DocSearchError, DocSearchHit}, + structured_doc::{DocSearch, DocSearchDocument, DocSearchError, DocSearchHit}, }, config::AnswerConfig, }; use tabby_inference::ChatCompletionStream; use tabby_schema::{ + auth::AuthenticationService, context::{ContextInfoHelper, ContextService}, policy::AccessPolicy, repository::{Repository, RepositoryService}, thread::{ self, CodeQueryInput, CodeSearchParamsOverrideInput, DocQueryInput, MessageAttachment, - MessageAttachmentDoc, ThreadAssistantMessageAttachmentsCode, + MessageAttachmentDoc, MessageDocSearchHit, ThreadAssistantMessageAttachmentsCode, ThreadAssistantMessageAttachmentsDoc, ThreadAssistantMessageContentDelta, ThreadRelevantQuestions, ThreadRunItem, ThreadRunOptionsInput, }, @@ -44,6 +45,7 @@ use crate::bail; pub struct AnswerService { config: AnswerConfig, + auth: Arc, chat: Arc, code: Arc, doc: Arc, @@ -55,6 +57,7 @@ pub struct AnswerService { impl AnswerService { fn new( config: &AnswerConfig, + auth: Arc, chat: Arc, code: Arc, doc: Arc, @@ -64,6 +67,7 @@ impl AnswerService { ) -> Self { Self { config: config.clone(), + auth, chat, code, doc, @@ -122,14 +126,24 @@ impl AnswerService { if let Some(doc_query) = options.doc_query.as_ref() { let hits = self.collect_relevant_docs(&context_info_helper, doc_query) .await; - attachment.doc = hits.iter() - .map(|x| x.doc.clone().into()) - .collect::>(); + attachment.doc = futures::future::join_all(hits.iter().map(|x| async { + Self::new_message_attachment_doc(self.auth.clone(), x.doc.clone()).await + })).await; debug!("doc content: {:?}: {:?}", doc_query.content, attachment.doc.len()); if !attachment.doc.is_empty() { - let hits = hits.into_iter().map(|x| x.into()).collect::>(); + let hits = futures::future::join_all(hits.into_iter().map(|x| { + let score = x.score; + let doc = x.doc.clone(); + let auth = self.auth.clone(); + async move { + MessageDocSearchHit { + score: score as f64, + doc: Self::new_message_attachment_doc(auth, doc).await, + } + } + })).await; yield Ok(ThreadRunItem::ThreadAssistantMessageAttachmentsDoc( ThreadAssistantMessageAttachmentsDoc { hits } )); @@ -201,6 +215,23 @@ impl AnswerService { Ok(Box::pin(s)) } + async fn new_message_attachment_doc( + auth: Arc, + doc: DocSearchDocument, + ) -> MessageAttachmentDoc { + let email = match &doc { + DocSearchDocument::Issue(issue) => issue.author_email.as_deref(), + DocSearchDocument::Pull(pull) => pull.author_email.as_deref(), + _ => None, + }; + let user = if let Some(email) = email { + auth.get_user_by_email(email).await.ok().map(|x| x.into()) + } else { + None + }; + MessageAttachmentDoc::from_doc_search_document(doc, user) + } + async fn collect_relevant_code( &self, helper: &ContextInfoHelper, @@ -377,6 +408,7 @@ fn trim_bullet(s: &str) -> String { pub fn create( config: &AnswerConfig, + auth: Arc, chat: Arc, code: Arc, doc: Arc, @@ -384,7 +416,7 @@ pub fn create( serper: Option>, repository: Arc, ) -> AnswerService { - AnswerService::new(config, chat, code, doc, context, serper, repository) + AnswerService::new(config, auth, chat, code, doc, context, serper, repository) } fn convert_messages_to_chat_completion_request( @@ -639,13 +671,16 @@ mod tests { AsID, }; - use crate::answer::{ - merge_code_snippets, - testutils::{ - make_policy, make_repository_service, FakeChatCompletionStream, FakeCodeSearch, - FakeCodeSearchFail, FakeCodeSearchFailNotReady, FakeContextService, FakeDocSearch, + use crate::{ + answer::{ + merge_code_snippets, + testutils::{ + make_repository_service, FakeChatCompletionStream, FakeCodeSearch, + FakeCodeSearchFail, FakeCodeSearchFailNotReady, FakeContextService, FakeDocSearch, + }, + trim_bullet, AnswerService, }, - trim_bullet, AnswerService, + service::{access_policy::testutils::make_policy, auth}, }; const TEST_SOURCE_ID: &str = "source-1"; @@ -822,6 +857,7 @@ mod tests { #[tokio::test] async fn test_collect_relevant_code() { + let auth = Arc::new(auth::testutils::FakeAuthService::new(vec![])); let chat: Arc = Arc::new(FakeChatCompletionStream { return_error: false, }); @@ -836,6 +872,7 @@ mod tests { let mut service = AnswerService::new( &config, + auth.clone(), chat.clone(), code.clone(), doc.clone(), @@ -899,6 +936,7 @@ mod tests { service = AnswerService::new( &config, + auth.clone(), chat.clone(), code_fail_not_ready.clone(), doc.clone(), @@ -912,6 +950,7 @@ mod tests { service = AnswerService::new( &config, + auth.clone(), chat.clone(), code_fail.clone(), doc.clone(), @@ -923,6 +962,7 @@ mod tests { #[tokio::test] async fn test_generate_relevant_questions_v2() { + let auth = Arc::new(auth::testutils::FakeAuthService::new(vec![])); let chat: Arc = Arc::new(FakeChatCompletionStream { return_error: false, }); @@ -936,6 +976,7 @@ mod tests { let service = AnswerService::new( &config, + auth.clone(), chat.clone(), code.clone(), doc.clone(), @@ -983,6 +1024,7 @@ mod tests { #[tokio::test] async fn test_generate_relevant_questions_v2_error() { + let auth = Arc::new(auth::testutils::FakeAuthService::new(vec![])); let chat: Arc = Arc::new(FakeChatCompletionStream { return_error: true }); let code: Arc = Arc::new(FakeCodeSearch); @@ -995,6 +1037,7 @@ mod tests { let service = AnswerService::new( &config, + auth.clone(), chat.clone(), code.clone(), doc.clone(), @@ -1036,6 +1079,7 @@ mod tests { #[tokio::test] async fn test_collect_relevant_docs() { + let auth = Arc::new(auth::testutils::FakeAuthService::new(vec![])); let chat: Arc = Arc::new(FakeChatCompletionStream { return_error: false, }); @@ -1049,6 +1093,7 @@ mod tests { let service = AnswerService::new( &config, + auth.clone(), chat.clone(), code.clone(), doc.clone(), @@ -1105,6 +1150,7 @@ mod tests { use futures::StreamExt; use tabby_schema::{policy::AccessPolicy, thread::ThreadRunOptionsInput}; + let auth = Arc::new(auth::testutils::FakeAuthService::new(vec![])); let chat: Arc = Arc::new(FakeChatCompletionStream { return_error: false, }); @@ -1121,7 +1167,7 @@ mod tests { let db = DbConn::new_in_memory().await.unwrap(); let repo = make_repository_service(db).await.unwrap(); let service = Arc::new(AnswerService::new( - &config, chat, code, doc, context, serper, repo, + &config, auth, chat, code, doc, context, serper, repo, )); let db = DbConn::new_in_memory().await.unwrap(); diff --git a/ee/tabby-webserver/src/service/auth.rs b/ee/tabby-webserver/src/service/auth.rs index 8df4f01f4f15..e8d13c17dacf 100644 --- a/ee/tabby-webserver/src/service/auth.rs +++ b/ee/tabby-webserver/src/service/auth.rs @@ -32,6 +32,9 @@ use crate::{ oauth::{self, OAuthClient}, }; +#[cfg(test)] +pub mod testutils; + #[derive(Clone)] struct ImpersonateUserCredential { id: i64, diff --git a/ee/tabby-webserver/src/service/auth/testutils.rs b/ee/tabby-webserver/src/service/auth/testutils.rs new file mode 100644 index 000000000000..eca153cee36a --- /dev/null +++ b/ee/tabby-webserver/src/service/auth/testutils.rs @@ -0,0 +1,209 @@ +use async_trait::async_trait; +use chrono::{Duration, Utc}; +use juniper::ID; +use tabby_schema::{ + auth::{ + AuthenticationService, Invitation, JWTPayload, OAuthCredential, OAuthError, OAuthProvider, + OAuthResponse, RefreshTokenResponse, RegisterResponse, RequestInvitationInput, + TokenAuthResponse, UpdateOAuthCredentialInput, UserSecured, + }, + Result, +}; +use tokio::task::JoinHandle; + +pub struct FakeAuthService { + users: Vec, +} + +impl FakeAuthService { + pub fn new(users: Vec) -> Self { + FakeAuthService { users } + } +} + +#[async_trait] +impl AuthenticationService for FakeAuthService { + async fn register( + &self, + _email: String, + _password: String, + _invitation_code: Option, + _name: Option, + ) -> Result { + Ok(RegisterResponse::new( + "access_token".to_string(), + "refresh_token".to_string(), + )) + } + + async fn allow_self_signup(&self) -> Result { + Ok(true) + } + + async fn generate_reset_password_url(&self, id: &ID) -> Result { + Ok(format!("https://example.com/reset-password/{}", id)) + } + + async fn request_password_reset_email(&self, _email: String) -> Result>> { + Ok(None) + } + + async fn password_reset(&self, _code: &str, _password: &str) -> Result<()> { + Ok(()) + } + + async fn update_user_password( + &self, + _id: &ID, + _old_password: Option<&str>, + _new_password: &str, + ) -> Result<()> { + Ok(()) + } + + async fn update_user_avatar(&self, _id: &ID, _avatar: Option>) -> Result<()> { + Ok(()) + } + + async fn get_user_avatar(&self, _id: &ID) -> Result>> { + Ok(None) + } + + async fn update_user_name(&self, _id: &ID, _name: String) -> Result<()> { + Ok(()) + } + + async fn token_auth(&self, _email: String, _password: String) -> Result { + Ok(TokenAuthResponse::new( + "access_token".to_string(), + "refresh_token".to_string(), + )) + } + + async fn refresh_token(&self, _token: String) -> Result { + Ok(RefreshTokenResponse::new( + "access_token".to_string(), + "new_refresh_token".to_string(), + Utc::now() + Duration::days(30), + )) + } + + async fn verify_access_token(&self, _access_token: &str) -> Result { + Ok(JWTPayload::new( + ID::new("user_id"), + Utc::now().timestamp(), + Utc::now().timestamp() + Duration::days(30).num_seconds(), + false, + )) + } + + async fn verify_auth_token(&self, _token: &str) -> Result { + Ok(ID::new("user_id")) + } + + async fn is_admin_initialized(&self) -> Result { + Ok(true) + } + + async fn update_user_role(&self, _id: &ID, _is_admin: bool) -> Result<()> { + Ok(()) + } + + async fn get_user_by_email(&self, email: &str) -> Result { + self.users + .iter() + .find(|user| user.email == email) + .cloned() + .ok_or_else(|| anyhow::anyhow!("User not found")) + .map_err(Into::into) + } + + async fn get_user(&self, id: &ID) -> Result { + self.users + .iter() + .find(|user| user.id == *id) + .cloned() + .ok_or_else(|| anyhow::anyhow!("User not found")) + .map_err(Into::into) + } + + async fn create_invitation(&self, email: String) -> Result { + let invitation = Invitation { + id: ID::new("1"), + email: email.clone(), + code: "invitation_code".to_string(), + created_at: Utc::now(), + }; + Ok(invitation) + } + + async fn request_invitation_email(&self, input: RequestInvitationInput) -> Result { + self.create_invitation(input.email).await + } + + async fn delete_invitation(&self, id: &ID) -> Result { + Ok(id.clone()) + } + + async fn reset_user_auth_token(&self, _id: &ID) -> Result<()> { + Ok(()) + } + + async fn logout_all_sessions(&self, _id: &ID) -> Result<()> { + Ok(()) + } + + async fn list_users( + &self, + _after: Option, + _before: Option, + _first: Option, + _last: Option, + ) -> Result> { + Ok(self.users.clone()) + } + + async fn list_invitations( + &self, + _after: Option, + _before: Option, + _first: Option, + _last: Option, + ) -> Result> { + Ok(vec![]) + } + + async fn oauth( + &self, + _code: String, + _provider: OAuthProvider, + ) -> std::result::Result { + Ok(OAuthResponse { + access_token: "access_token".to_string(), + refresh_token: "refresh_token".to_string(), + }) + } + + async fn read_oauth_credential( + &self, + _provider: OAuthProvider, + ) -> Result> { + Ok(None) + } + + async fn oauth_callback_url(&self, _provider: OAuthProvider) -> Result { + Ok("https://example.com/oauth/callback/".to_string()) + } + + async fn update_oauth_credential(&self, _input: UpdateOAuthCredentialInput) -> Result<()> { + Ok(()) + } + + async fn delete_oauth_credential(&self, _provider: OAuthProvider) -> Result<()> { + Ok(()) + } + + async fn update_user_active(&self, _id: &ID, _active: bool) -> Result<()> { + Ok(()) + } +} diff --git a/ee/tabby-webserver/src/service/mod.rs b/ee/tabby-webserver/src/service/mod.rs index 387c637ad37f..66a90352f5e4 100644 --- a/ee/tabby-webserver/src/service/mod.rs +++ b/ee/tabby-webserver/src/service/mod.rs @@ -22,14 +22,18 @@ use std::sync::Arc; use answer::AnswerService; use anyhow::Context; use async_trait::async_trait; +pub use auth::create as new_auth_service; use axum::{ body::Body, http::{HeaderName, HeaderValue, Request, StatusCode}, middleware::Next, response::IntoResponse, }; +pub use email::new_email_service; use hyper::{HeaderMap, Uri}; use juniper::ID; +pub use license::new_license_service; +pub use setting::create as new_setting_service; use tabby_common::{ api::{code::CodeSearch, event::EventLogger}, constants::USER_HEADER_FIELD_NAME, @@ -58,10 +62,9 @@ use tabby_schema::{ AsID, AsRowid, CoreError, Result, ServiceLocator, }; -use self::{ - analytic::new_analytic_service, email::new_email_service, license::new_license_service, -}; +use self::analytic::new_analytic_service; use crate::rate_limit::UserRateLimiter; + struct ServerContext { db_conn: DbConn, mail: Arc, @@ -91,6 +94,7 @@ struct ServerContext { impl ServerContext { pub async fn new( logger: Arc, + auth: Arc, chat: Option>, completion: Option>, code: Arc, @@ -100,22 +104,18 @@ impl ServerContext { answer: Option>, context: Arc, web_documents: Arc, + mail: Arc, + license: Arc, + setting: Arc, db_conn: DbConn, embedding: Arc, ) -> Self { - let mail = Arc::new( - new_email_service(db_conn.clone()) - .await - .expect("failed to initialize mail service"), - ); - let license = Arc::new( - new_license_service(db_conn.clone()) - .await - .expect("failed to initialize license service"), - ); let user_event = Arc::new(user_event::create(db_conn.clone())); - let setting = Arc::new(setting::create(db_conn.clone())); - let thread = Arc::new(thread::create(db_conn.clone(), answer.clone())); + let thread = Arc::new(thread::create( + db_conn.clone(), + answer.clone(), + Some(auth.clone()), + )); let user_group = Arc::new(user_group::create(db_conn.clone())); let access_policy = Arc::new(access_policy::create(db_conn.clone(), context.clone())); @@ -132,16 +132,11 @@ impl ServerContext { .await; Self { - mail: mail.clone(), + mail, embedding, chat, completion, - auth: Arc::new(auth::create( - db_conn.clone(), - mail, - license.clone(), - setting.clone(), - )), + auth, web_documents, thread, context, @@ -354,6 +349,7 @@ impl ServiceLocator for ArcServerContext { pub async fn create_service_locator( logger: Arc, + auth: Arc, chat: Option>, completion: Option>, code: Arc, @@ -363,12 +359,16 @@ pub async fn create_service_locator( answer: Option>, context: Arc, web_documents: Arc, + mail: Arc, + license: Arc, + setting: Arc, db: DbConn, embedding: Arc, ) -> Arc { Arc::new(ArcServerContext::new( ServerContext::new( logger, + auth, chat, completion, code, @@ -378,6 +378,9 @@ pub async fn create_service_locator( answer, context, web_documents, + mail, + license, + setting, db, embedding, ) diff --git a/ee/tabby-webserver/src/service/thread.rs b/ee/tabby-webserver/src/service/thread.rs index e7ebcbb6ccdf..607709ca9da4 100644 --- a/ee/tabby-webserver/src/service/thread.rs +++ b/ee/tabby-webserver/src/service/thread.rs @@ -3,13 +3,15 @@ use std::sync::Arc; use async_trait::async_trait; use futures::StreamExt; use juniper::ID; -use tabby_db::{DbConn, ThreadMessageDAO}; +use tabby_db::{DbConn, ThreadMessageAttachmentDoc, ThreadMessageDAO}; use tabby_schema::{ - bail, + auth::AuthenticationService, + bail, from_thread_message_attachment_document, policy::AccessPolicy, thread::{ - self, CreateMessageInput, CreateThreadInput, MessageAttachmentInput, ThreadRunItem, - ThreadRunOptionsInput, ThreadRunStream, ThreadService, UpdateMessageInput, + self, CreateMessageInput, CreateThreadInput, MessageAttachment, MessageAttachmentDoc, + MessageAttachmentInput, ThreadRunItem, ThreadRunOptionsInput, ThreadRunStream, + ThreadService, UpdateMessageInput, }, AsID, AsRowid, DbEnum, Result, }; @@ -18,6 +20,7 @@ use super::{answer::AnswerService, graphql_pagination_to_filter}; struct ThreadServiceImpl { db: DbConn, + auth: Option>, answer: Option>, } @@ -27,7 +30,77 @@ impl ThreadServiceImpl { .db .list_thread_messages(thread_id.as_rowid()?, None, None, false) .await?; - to_vec_messages(messages) + self.to_vec_messages(messages).await + } + + async fn to_vec_messages( + &self, + messages: Vec, + ) -> Result> { + let mut output = vec![]; + output.reserve(messages.len()); + + for message in messages { + let code = message.code_attachments; + let client_code = message.client_code_attachments; + let doc = message.doc_attachments; + + let attachment = MessageAttachment { + code: code + .map(|x| x.0.into_iter().map(|i| i.into()).collect()) + .unwrap_or_default(), + client_code: client_code + .map(|x| x.0.into_iter().map(|i| i.into()).collect()) + .unwrap_or_default(), + doc: if let Some(docs) = doc { + self.to_message_attachment_docs(docs.0).await + } else { + vec![] + }, + }; + + output.push(thread::Message { + id: message.id.as_id(), + thread_id: message.thread_id.as_id(), + role: thread::Role::from_enum_str(&message.role)?, + content: message.content, + attachment, + created_at: message.created_at, + updated_at: message.updated_at, + }); + } + + Ok(output) + } + + async fn to_message_attachment_docs( + &self, + thread_docs: Vec, + ) -> Vec { + let mut output = vec![]; + output.reserve(thread_docs.len()); + for thread_doc in thread_docs { + let id = match &thread_doc { + ThreadMessageAttachmentDoc::Issue(issue) => issue.author_user_id.as_deref(), + ThreadMessageAttachmentDoc::Pull(pull) => pull.author_user_id.as_deref(), + _ => None, + }; + let user = if let Some(auth) = self.auth.as_ref() { + if let Some(id) = id { + auth.get_user(&juniper::ID::from(id.to_owned())) + .await + .ok() + .map(|x| x.into()) + } else { + None + } + } else { + None + }; + + output.push(from_thread_message_attachment_document(thread_doc, user)); + } + output } } @@ -243,7 +316,7 @@ impl ThreadService for ThreadServiceImpl { .list_thread_messages(thread_id, limit, skip_id, backwards) .await?; - to_vec_messages(messages) + self.to_vec_messages(messages).await } async fn delete_thread_message_pair( @@ -268,20 +341,12 @@ impl ThreadService for ThreadServiceImpl { } } -fn to_vec_messages(messages: Vec) -> Result> { - let mut output = vec![]; - output.reserve(messages.len()); - - for x in messages { - let message: thread::Message = x.try_into()?; - output.push(message); - } - - Ok(output) -} - -pub fn create(db: DbConn, answer: Option>) -> impl ThreadService { - ThreadServiceImpl { db, answer } +pub fn create( + db: DbConn, + answer: Option>, + auth: Option>, +) -> impl ThreadService { + ThreadServiceImpl { db, answer, auth } } #[cfg(test)] @@ -302,16 +367,19 @@ mod tests { use thread::MessageAttachmentCodeInput; use super::*; - use crate::answer::testutils::{ - make_repository_service, FakeChatCompletionStream, FakeCodeSearch, FakeContextService, - FakeDocSearch, + use crate::{ + answer::testutils::{ + make_repository_service, FakeChatCompletionStream, FakeCodeSearch, FakeContextService, + FakeDocSearch, + }, + service::auth, }; #[tokio::test] async fn test_create_thread() { let db = DbConn::new_in_memory().await.unwrap(); let user_id = create_user(&db).await.as_id(); - let service = create(db, None); + let service = create(db, None, None); let input = CreateThreadInput { user_message: CreateMessageInput { @@ -327,7 +395,7 @@ mod tests { async fn test_append_messages() { let db = DbConn::new_in_memory().await.unwrap(); let user_id = create_user(&db).await.as_id(); - let service = create(db, None); + let service = create(db, None, None); let thread_id = service .create( @@ -373,7 +441,7 @@ mod tests { async fn test_delete_thread_message_pair() { let db = DbConn::new_in_memory().await.unwrap(); let user_id = create_user(&db).await.as_id(); - let service = create(db.clone(), None); + let service = create(db.clone(), None, None); let thread_id = service .create( @@ -462,7 +530,7 @@ mod tests { async fn test_get_thread() { let db = DbConn::new_in_memory().await.unwrap(); let user_id = create_user(&db).await.as_id(); - let service = create(db, None); + let service = create(db, None, None); let input = CreateThreadInput { user_message: CreateMessageInput { @@ -486,7 +554,7 @@ mod tests { async fn test_delete_thread() { let db = DbConn::new_in_memory().await.unwrap(); let user_id = create_user(&db).await.as_id(); - let service = create(db.clone(), None); + let service = create(db.clone(), None, None); let input = CreateThreadInput { user_message: CreateMessageInput { @@ -513,7 +581,7 @@ mod tests { async fn test_set_persisted() { let db = DbConn::new_in_memory().await.unwrap(); let user_id = create_user(&db).await.as_id(); - let service = create(db.clone(), None); + let service = create(db.clone(), None, None); let input = CreateThreadInput { user_message: CreateMessageInput { @@ -548,6 +616,7 @@ mod tests { async fn test_create_run() { let db = DbConn::new_in_memory().await.unwrap(); let user_id = create_user(&db).await.as_id(); + let auth = Arc::new(auth::testutils::FakeAuthService::new(vec![])); let chat: Arc = Arc::new(FakeChatCompletionStream { return_error: false, }); @@ -559,6 +628,7 @@ mod tests { let repo = make_repository_service(db.clone()).await.unwrap(); let answer_service = Arc::new(crate::answer::create( &config, + auth.clone(), chat.clone(), code.clone(), doc.clone(), @@ -566,7 +636,7 @@ mod tests { serper, repo, )); - let service = create(db.clone(), Some(answer_service)); + let service = create(db.clone(), Some(answer_service), None); let input = CreateThreadInput { user_message: CreateMessageInput { @@ -591,7 +661,7 @@ mod tests { async fn test_list_threads() { let db = DbConn::new_in_memory().await.unwrap(); let user_id = create_user(&db).await.as_id(); - let service = create(db, None); + let service = create(db, None, None); for i in 0..3 { let input = CreateThreadInput { diff --git a/ee/tabby-webserver/src/webserver.rs b/ee/tabby-webserver/src/webserver.rs index a7e360955908..3e65b57e42bc 100644 --- a/ee/tabby-webserver/src/webserver.rs +++ b/ee/tabby-webserver/src/webserver.rs @@ -18,7 +18,8 @@ use crate::{ path::db_file, routes, service::{ - create_service_locator, event_logger::create_event_logger, integration, job, repository, + create_service_locator, event_logger::create_event_logger, integration, job, + new_auth_service, new_email_service, new_license_service, new_setting_service, repository, web_documents, }, }; @@ -88,9 +89,28 @@ impl Webserver { serper.is_some(), )); + let mail = Arc::new( + new_email_service(db.clone()) + .await + .expect("failed to initialize mail service"), + ); + let license = Arc::new( + new_license_service(db.clone()) + .await + .expect("failed to initialize license service"), + ); + let setting = Arc::new(new_setting_service(db.clone())); + let auth = Arc::new(new_auth_service( + db.clone(), + mail.clone(), + license.clone(), + setting.clone(), + )); + let answer = chat.as_ref().map(|chat| { Arc::new(crate::service::answer::create( &config.answer, + auth.clone(), chat.clone(), code.clone(), docsearch.clone(), @@ -102,6 +122,7 @@ impl Webserver { let ctx = create_service_locator( self.logger(), + auth, chat.clone(), completion.clone(), code.clone(), @@ -111,6 +132,9 @@ impl Webserver { answer.clone(), context.clone(), web_documents.clone(), + mail, + license, + setting, self.db.clone(), self.embedding.clone(), )