diff --git a/src/meta/src/controller/catalog.rs b/src/meta/src/controller/catalog.rs index 2b992bab70f55..05b80ff69b1e0 100644 --- a/src/meta/src/controller/catalog.rs +++ b/src/meta/src/controller/catalog.rs @@ -16,7 +16,7 @@ use std::iter; use itertools::Itertools; use risingwave_common::catalog::{DEFAULT_SCHEMA_NAME, SYSTEM_SCHEMAS}; -use risingwave_pb::catalog::{PbDatabase, PbSchema}; +use risingwave_pb::catalog::{PbDatabase, PbFunction, PbSchema}; use risingwave_pb::meta::subscribe_response::{ Info as NotificationInfo, Operation as NotificationOperation, }; @@ -29,7 +29,7 @@ use tokio::sync::RwLock; use crate::controller::utils::construct_obj_dependency_query; use crate::controller::ObjectModel; -use crate::manager::{DatabaseId, MetaSrvEnv, NotificationVersion, UserId}; +use crate::manager::{DatabaseId, FunctionId, MetaSrvEnv, NotificationVersion, UserId}; use crate::model_v2::object::ObjectType; use crate::model_v2::prelude::*; use crate::model_v2::{ @@ -297,6 +297,55 @@ impl CatalogController { .await; Ok(version) } + + pub async fn create_function( + &self, + mut pb_function: PbFunction, + ) -> MetaResult { + let inner = self.inner.write().await; + let txn = inner.db.begin().await?; + let owner_id = pb_function.owner; + + let function_obj = Self::create_object(&txn, ObjectType::Function, owner_id).await?; + pb_function.id = function_obj.oid as _; + let function: function::ActiveModel = pb_function.clone().into(); + function.insert(&txn).await?; + txn.commit().await?; + + let version = self + .notify_frontend( + NotificationOperation::Add, + NotificationInfo::Function(pb_function), + ) + .await; + + Ok(version) + } + + pub async fn drop_function(&self, function_id: FunctionId) -> MetaResult { + let inner = self.inner.write().await; + // todo: wrap the error and return the list if used by others. + let res = Object::delete(object::ActiveModel { + oid: ActiveValue::Set(function_id as _), + ..Default::default() + }) + .exec(&inner.db) + .await?; + if res.rows_affected == 0 { + return Err(MetaError::catalog_id_not_found("function", function_id)); + } + + let version = self + .notify_frontend( + NotificationOperation::Delete, + NotificationInfo::Function(PbFunction { + id: function_id, + ..Default::default() + }), + ) + .await; + Ok(version) + } } #[cfg(test)] @@ -304,23 +353,70 @@ mod tests { use risingwave_common::catalog::DEFAULT_SUPER_USER_ID; use super::*; + use crate::manager::SchemaId; + + const TEST_DATABASE_ID: DatabaseId = 1; + const TEST_SCHEMA_ID: SchemaId = 2; + const TEST_OWNER_ID: UserId = 1; #[tokio::test] #[cfg(not(madsim))] - async fn test_create_database() { - let mgr = CatalogController::new(MetaSrvEnv::for_test().await).unwrap(); + async fn test_create_database() -> MetaResult<()> { + let mgr = CatalogController::new(MetaSrvEnv::for_test().await)?; let db = PbDatabase { name: "test".to_string(), owner: DEFAULT_SUPER_USER_ID, ..Default::default() }; - mgr.create_database(db).await.unwrap(); + mgr.create_database(db).await?; + let db = Database::find() .filter(database::Column::Name.eq("test")) .one(&mgr.inner.read().await.db) - .await - .unwrap() + .await? .unwrap(); - mgr.drop_database(db.database_id as _).await.unwrap(); + mgr.drop_database(db.database_id as _).await?; + Ok(()) + } + + #[tokio::test] + #[cfg(not(madsim))] + async fn test_create_function() -> MetaResult<()> { + let mgr = CatalogController::new(MetaSrvEnv::for_test().await)?; + let return_type = risingwave_pb::data::DataType { + type_name: risingwave_pb::data::data_type::TypeName::Int32 as _, + ..Default::default() + }; + mgr.create_function(PbFunction { + schema_id: TEST_SCHEMA_ID, + database_id: TEST_DATABASE_ID, + name: "test_function".to_string(), + owner: TEST_OWNER_ID, + arg_types: vec![], + return_type: Some(return_type.clone()), + language: "python".to_string(), + kind: Some(risingwave_pb::catalog::function::Kind::Scalar( + Default::default(), + )), + ..Default::default() + }) + .await?; + + let function = Function::find() + .filter( + function::Column::DatabaseId + .eq(TEST_DATABASE_ID) + .and(function::Column::SchemaId.eq(TEST_SCHEMA_ID)) + .add(function::Column::Name.eq("test_function")), + ) + .one(&mgr.inner.read().await.db) + .await? + .unwrap(); + assert_eq!(function.return_type.0, return_type); + assert_eq!(function.language, "python"); + + mgr.drop_function(function.function_id as _).await?; + + Ok(()) } } diff --git a/src/meta/src/model_v2/function.rs b/src/meta/src/model_v2/function.rs index e60f9ea5a2bde..0b1b220abc448 100644 --- a/src/meta/src/model_v2/function.rs +++ b/src/meta/src/model_v2/function.rs @@ -12,7 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. +use risingwave_pb::catalog::function::Kind; +use risingwave_pb::catalog::PbFunction; use sea_orm::entity::prelude::*; +use sea_orm::ActiveValue; + +use crate::model_v2::{DataType, DataTypeArray}; + +#[derive(Clone, Debug, PartialEq, Eq, EnumIter, DeriveActiveEnum)] +#[sea_orm(rs_type = "String", db_type = "String(None)")] +pub enum FunctionKind { + #[sea_orm(string_value = "Scalar")] + Scalar, + #[sea_orm(string_value = "Table")] + Table, + #[sea_orm(string_value = "Aggregate")] + Aggregate, +} #[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq)] #[sea_orm(table_name = "function")] @@ -22,12 +38,12 @@ pub struct Model { pub name: String, pub schema_id: i32, pub database_id: i32, - pub arg_types: Option, - pub return_type: Option, - pub language: Option, - pub link: Option, - pub identifier: Option, - pub kind: Option, + pub arg_types: DataTypeArray, + pub return_type: DataType, + pub language: String, + pub link: String, + pub identifier: String, + pub kind: FunctionKind, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] @@ -77,3 +93,30 @@ impl Related for Entity { } impl ActiveModelBehavior for ActiveModel {} + +impl From for FunctionKind { + fn from(kind: Kind) -> Self { + match kind { + Kind::Scalar(_) => Self::Scalar, + Kind::Table(_) => Self::Table, + Kind::Aggregate(_) => Self::Aggregate, + } + } +} + +impl From for ActiveModel { + fn from(function: PbFunction) -> Self { + Self { + function_id: ActiveValue::Set(function.id as _), + name: ActiveValue::Set(function.name), + schema_id: ActiveValue::Set(function.schema_id as _), + database_id: ActiveValue::Set(function.database_id as _), + arg_types: ActiveValue::Set(DataTypeArray(function.arg_types)), + return_type: ActiveValue::Set(DataType(function.return_type.unwrap())), + language: ActiveValue::Set(function.language), + link: ActiveValue::Set(function.link), + identifier: ActiveValue::Set(function.identifier), + kind: ActiveValue::Set(function.kind.unwrap().into()), + } + } +} diff --git a/src/meta/src/model_v2/migration/src/m20230908_072257_init.rs b/src/meta/src/model_v2/migration/src/m20230908_072257_init.rs index e5a07e8c87574..010e3ffff77c4 100644 --- a/src/meta/src/model_v2/migration/src/m20230908_072257_init.rs +++ b/src/meta/src/model_v2/migration/src/m20230908_072257_init.rs @@ -164,6 +164,7 @@ impl MigrationTrait for Migration { .name("FK_object_owner_id") .from(Object::Table, Object::OwnerId) .to(User::Table, User::UserId) + .on_delete(ForeignKeyAction::Cascade) .to_owned(), ) .to_owned(), @@ -671,12 +672,12 @@ impl MigrationTrait for Migration { .col(ColumnDef::new(Function::Name).string().not_null()) .col(ColumnDef::new(Function::SchemaId).integer().not_null()) .col(ColumnDef::new(Function::DatabaseId).integer().not_null()) - .col(ColumnDef::new(Function::ArgTypes).json()) - .col(ColumnDef::new(Function::ReturnType).string()) - .col(ColumnDef::new(Function::Language).string()) - .col(ColumnDef::new(Function::Link).string()) - .col(ColumnDef::new(Function::Identifier).string()) - .col(ColumnDef::new(Function::Kind).json()) + .col(ColumnDef::new(Function::ArgTypes).json().not_null()) + .col(ColumnDef::new(Function::ReturnType).json().not_null()) + .col(ColumnDef::new(Function::Language).string().not_null()) + .col(ColumnDef::new(Function::Link).string().not_null()) + .col(ColumnDef::new(Function::Identifier).string().not_null()) + .col(ColumnDef::new(Function::Kind).string().not_null()) .foreign_key( &mut ForeignKey::create() .name("FK_function_database_id") diff --git a/src/meta/src/model_v2/mod.rs b/src/meta/src/model_v2/mod.rs index 8f496b56d1291..56bf61dc3ae18 100644 --- a/src/meta/src/model_v2/mod.rs +++ b/src/meta/src/model_v2/mod.rs @@ -41,3 +41,9 @@ pub mod worker_property; #[derive(Clone, Debug, PartialEq, FromJsonQueryResult, Eq, Serialize, Deserialize, Default)] pub struct I32Array(pub Vec); + +#[derive(Clone, Debug, PartialEq, FromJsonQueryResult, Eq, Serialize, Deserialize, Default)] +pub struct DataType(pub risingwave_pb::data::DataType); + +#[derive(Clone, Debug, PartialEq, FromJsonQueryResult, Eq, Serialize, Deserialize, Default)] +pub struct DataTypeArray(pub Vec); diff --git a/src/meta/src/model_v2/worker.rs b/src/meta/src/model_v2/worker.rs index 4002b55323a36..2a747f44a6c0f 100644 --- a/src/meta/src/model_v2/worker.rs +++ b/src/meta/src/model_v2/worker.rs @@ -13,6 +13,7 @@ // limitations under the License. use sea_orm::entity::prelude::*; + #[derive(Clone, Debug, PartialEq, Eq, EnumIter, DeriveActiveEnum)] #[sea_orm(rs_type = "String", db_type = "String(None)")] pub enum WorkerType {