diff --git a/.golangci.yaml b/.golangci.yaml index 3d9984e7..8bb78de2 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -65,6 +65,8 @@ linters-settings: disabled: true - name: early-return disabled: true + - name: use-any + disabled: true - name: exported arguments: - "disableStutteringCheck" diff --git a/go.mod b/go.mod index 85983e4a..9573d50f 100644 --- a/go.mod +++ b/go.mod @@ -74,6 +74,7 @@ require ( github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1 github.com/h2non/filetype v1.1.3 github.com/improbable-eng/grpc-web v0.15.0 + github.com/lib/pq v1.10.9 github.com/mssola/useragent v1.0.0 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 diff --git a/go.sum b/go.sum index 2105dccd..60cb6237 100644 --- a/go.sum +++ b/go.sum @@ -292,6 +292,8 @@ github.com/labstack/echo/v4 v4.11.2/go.mod h1:UcGuQ8V6ZNRmSweBIJkPvGfwCMIlFmiqrP github.com/labstack/gommon v0.4.0 h1:y7cvthEAEbU0yHOf4axH8ZG2NH8knB9iNSoTO8dyIk8= github.com/labstack/gommon v0.4.0/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM= github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ= diff --git a/store/db/db.go b/store/db/db.go index 4d6eac78..d65fc2f4 100644 --- a/store/db/db.go +++ b/store/db/db.go @@ -5,6 +5,7 @@ import ( "github.com/yourselfhosted/slash/server/profile" "github.com/yourselfhosted/slash/store" + "github.com/yourselfhosted/slash/store/db/postgres" "github.com/yourselfhosted/slash/store/db/sqlite" ) @@ -16,6 +17,8 @@ func NewDBDriver(profile *profile.Profile) (store.Driver, error) { switch profile.Driver { case "sqlite": driver, err = sqlite.NewDB(profile) + case "postgres": + driver, err = postgres.NewDB(profile) default: return nil, errors.New("unknown db driver") } diff --git a/store/db/postgres/activity.go b/store/db/postgres/activity.go new file mode 100644 index 00000000..a16c2b0b --- /dev/null +++ b/store/db/postgres/activity.go @@ -0,0 +1,88 @@ +package postgres + +import ( + "context" + "fmt" + "strings" + + "github.com/yourselfhosted/slash/store" +) + +func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store.Activity, error) { + stmt := ` + INSERT INTO activity ( + creator_id, + type, + level, + payload + ) + VALUES ($1, $2, $3, $4) + RETURNING id, created_ts + ` + if err := d.db.QueryRowContext(ctx, stmt, + create.CreatorID, + create.Type.String(), + create.Level.String(), + create.Payload, + ).Scan( + &create.ID, + &create.CreatedTs, + ); err != nil { + return nil, err + } + + activity := create + return activity, nil +} + +func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) { + where, args := []string{"1 = 1"}, []any{} + if find.Type != "" { + where, args = append(where, "type = $"+fmt.Sprint(len(args)+1)), append(args, find.Type.String()) + } + if find.Level != "" { + where, args = append(where, "level = $"+fmt.Sprint(len(args)+1)), append(args, find.Level.String()) + } + if find.Where != nil { + where = append(where, find.Where...) + } + + query := ` + SELECT + id, + creator_id, + created_ts, + type, + level, + payload + FROM activity + WHERE ` + strings.Join(where, " AND ") + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := []*store.Activity{} + for rows.Next() { + activity := &store.Activity{} + if err := rows.Scan( + &activity.ID, + &activity.CreatorID, + &activity.CreatedTs, + &activity.Type, + &activity.Level, + &activity.Payload, + ); err != nil { + return nil, err + } + + list = append(list, activity) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} diff --git a/store/db/postgres/collection.go b/store/db/postgres/collection.go new file mode 100644 index 00000000..665e17ec --- /dev/null +++ b/store/db/postgres/collection.go @@ -0,0 +1,194 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/pkg/errors" + + "github.com/yourselfhosted/slash/internal/util" + storepb "github.com/yourselfhosted/slash/proto/gen/store" + "github.com/yourselfhosted/slash/store" +) + +func (d *DB) CreateCollection(ctx context.Context, create *storepb.Collection) (*storepb.Collection, error) { + set := []string{"creator_id", "name", "title", "description", "shortcut_ids", "visibility"} + args := []any{create.CreatorId, create.Name, create.Title, create.Description, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(create.ShortcutIds)), ","), "[]"), create.Visibility.String()} + placeholder := []string{"$1", "$2", "$3", "$4", "$5", "$6"} + + stmt := ` + INSERT INTO collection ( + ` + strings.Join(set, ", ") + ` + ) + VALUES (` + strings.Join(placeholder, ",") + `) + RETURNING id, created_ts, updated_ts + ` + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( + &create.Id, + &create.CreatedTs, + &create.UpdatedTs, + ); err != nil { + return nil, err + } + collection := create + return collection, nil +} + +func (d *DB) UpdateCollection(ctx context.Context, update *store.UpdateCollection) (*storepb.Collection, error) { + set, args := []string{}, []any{} + if update.Name != nil { + set, args = append(set, "name = $1"), append(args, *update.Name) + } + if update.Title != nil { + set, args = append(set, "title = $2"), append(args, *update.Title) + } + if update.Description != nil { + set, args = append(set, "description = $3"), append(args, *update.Description) + } + if update.ShortcutIDs != nil { + set, args = append(set, "shortcut_ids = $4"), append(args, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(update.ShortcutIDs)), ","), "[]")) + } + if update.Visibility != nil { + set, args = append(set, "visibility = $5"), append(args, update.Visibility.String()) + } + if len(set) == 0 { + return nil, errors.New("no update specified") + } + args = append(args, update.ID) + + stmt := ` + UPDATE collection + SET + ` + strings.Join(set, ", ") + ` + WHERE + id = $6 + RETURNING id, creator_id, created_ts, updated_ts, name, title, description, shortcut_ids, visibility + ` + collection := &storepb.Collection{} + var shortcutIDs, visibility string + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( + &collection.Id, + &collection.CreatorId, + &collection.CreatedTs, + &collection.UpdatedTs, + &collection.Name, + &collection.Title, + &collection.Description, + &shortcutIDs, + &visibility, + ); err != nil { + return nil, err + } + + collection.ShortcutIds = []int32{} + if shortcutIDs != "" { + for _, idStr := range strings.Split(shortcutIDs, ",") { + shortcutID, err := util.ConvertStringToInt32(idStr) + if err != nil { + return nil, errors.Wrap(err, "failed to convert shortcut id") + } + collection.ShortcutIds = append(collection.ShortcutIds, shortcutID) + } + } + collection.Visibility = convertVisibilityStringToStorepb(visibility) + return collection, nil +} + +func (d *DB) ListCollections(ctx context.Context, find *store.FindCollection) ([]*storepb.Collection, error) { + where, args := []string{"1 = 1"}, []any{} + if v := find.ID; v != nil { + where, args = append(where, "id = $1"), append(args, *v) + } + if v := find.CreatorID; v != nil { + where, args = append(where, "creator_id = $2"), append(args, *v) + } + if v := find.Name; v != nil { + where, args = append(where, "name = $3"), append(args, *v) + } + if v := find.VisibilityList; len(v) != 0 { + list := []string{} + for i, visibility := range v { + list = append(list, fmt.Sprintf("$%d", len(args)+i+1)) + args = append(args, visibility) + } + where = append(where, fmt.Sprintf("visibility IN (%s)", strings.Join(list, ","))) + } + + rows, err := d.db.QueryContext(ctx, ` + SELECT + id, + creator_id, + created_ts, + updated_ts, + name, + title, + description, + shortcut_ids, + visibility + FROM collection + WHERE `+strings.Join(where, " AND ")+` + ORDER BY created_ts DESC`, + args..., + ) + if err != nil { + return nil, err + } + defer rows.Close() + + list := make([]*storepb.Collection, 0) + for rows.Next() { + collection := &storepb.Collection{} + var shortcutIDs, visibility string + if err := rows.Scan( + &collection.Id, + &collection.CreatorId, + &collection.CreatedTs, + &collection.UpdatedTs, + &collection.Name, + &collection.Title, + &collection.Description, + &shortcutIDs, + &visibility, + ); err != nil { + return nil, err + } + + collection.ShortcutIds = []int32{} + if shortcutIDs != "" { + for _, idStr := range strings.Split(shortcutIDs, ",") { + shortcutID, err := util.ConvertStringToInt32(idStr) + if err != nil { + return nil, errors.Wrap(err, "failed to convert shortcut id") + } + collection.ShortcutIds = append(collection.ShortcutIds, shortcutID) + } + } + collection.Visibility = storepb.Visibility(storepb.Visibility_value[visibility]) + list = append(list, collection) + } + + if err := rows.Err(); err != nil { + return nil, err + } + return list, nil +} + +func (d *DB) DeleteCollection(ctx context.Context, delete *store.DeleteCollection) error { + if _, err := d.db.ExecContext(ctx, `DELETE FROM collection WHERE id = $1`, delete.ID); err != nil { + return err + } + + return nil +} + +func vacuumCollection(ctx context.Context, tx *sql.Tx) error { + stmt := `DELETE FROM collection WHERE creator_id NOT IN (SELECT id FROM user)` + _, err := tx.ExecContext(ctx, stmt) + if err != nil { + return err + } + + return nil +} diff --git a/store/db/postgres/memo.go b/store/db/postgres/memo.go new file mode 100644 index 00000000..aa0cf377 --- /dev/null +++ b/store/db/postgres/memo.go @@ -0,0 +1,208 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/pkg/errors" + + storepb "github.com/yourselfhosted/slash/proto/gen/store" + "github.com/yourselfhosted/slash/store" +) + +func (d *DB) CreateMemo(ctx context.Context, create *storepb.Memo) (*storepb.Memo, error) { + set := []string{"creator_id", "name", "title", "content", "visibility", "tag"} + args := []any{create.CreatorId, create.Name, create.Title, create.Content, create.Visibility.String(), strings.Join(create.Tags, " ")} + + stmt := ` + INSERT INTO memo ( + ` + strings.Join(set, ", ") + ` + ) + VALUES (` + placeholders(len(args)) + `) + RETURNING id, created_ts, updated_ts, row_status + ` + + var rowStatus string + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( + &create.Id, + &create.CreatedTs, + &create.UpdatedTs, + &rowStatus, + ); err != nil { + return nil, err + } + create.RowStatus = store.ConvertRowStatusStringToStorepb(rowStatus) + memo := create + return memo, nil +} + +func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) (*storepb.Memo, error) { + set, args := []string{}, []any{} + if update.RowStatus != nil { + set = append(set, fmt.Sprintf("row_status = $%d", len(set)+1)) + args = append(args, update.RowStatus.String()) + } + if update.Name != nil { + set = append(set, fmt.Sprintf("name = $%d", len(set)+1)) + args = append(args, *update.Name) + } + if update.Title != nil { + set = append(set, fmt.Sprintf("title = $%d", len(set)+1)) + args = append(args, *update.Title) + } + if update.Content != nil { + set = append(set, fmt.Sprintf("content = $%d", len(set)+1)) + args = append(args, *update.Content) + } + if update.Visibility != nil { + set = append(set, fmt.Sprintf("visibility = $%d", len(set)+1)) + args = append(args, update.Visibility.String()) + } + if update.Tag != nil { + set = append(set, fmt.Sprintf("tag = $%d", len(set)+1)) + args = append(args, *update.Tag) + } + if len(set) == 0 { + return nil, errors.New("no update specified") + } + args = append(args, update.ID) + + stmt := ` + UPDATE memo + SET + ` + strings.Join(set, ", ") + ` + WHERE + id = $` + fmt.Sprint(len(set)+1) + ` + RETURNING id, creator_id, created_ts, updated_ts, row_status, name, title, content, visibility, tag + ` + + memo := &storepb.Memo{} + var rowStatus, visibility, tags string + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( + &memo.Id, + &memo.CreatorId, + &memo.CreatedTs, + &memo.UpdatedTs, + &rowStatus, + &memo.Name, + &memo.Title, + &memo.Content, + &visibility, + &tags, + ); err != nil { + return nil, err + } + memo.RowStatus = store.ConvertRowStatusStringToStorepb(rowStatus) + memo.Visibility = convertVisibilityStringToStorepb(visibility) + memo.Tags = filterTags(strings.Split(tags, " ")) + return memo, nil +} + +func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*storepb.Memo, error) { + where, args := []string{"1 = 1"}, []any{} + if v := find.ID; v != nil { + where, args = append(where, "id = $1"), append(args, *v) + } + if v := find.CreatorID; v != nil { + where, args = append(where, "creator_id = $2"), append(args, *v) + } + if v := find.RowStatus; v != nil { + where, args = append(where, "row_status = $3"), append(args, *v) + } + if v := find.Name; v != nil { + where, args = append(where, "name = $4"), append(args, *v) + } + if v := find.VisibilityList; len(v) != 0 { + list := []string{} + for i, visibility := range v { + list = append(list, fmt.Sprintf("$%d", len(args)+i+1)) + args = append(args, visibility) + } + where = append(where, fmt.Sprintf("visibility IN (%s)", strings.Join(list, ","))) + } + if v := find.Tag; v != nil { + where, args = append(where, "tag LIKE $"+fmt.Sprint(len(args)+1)), append(args, "%"+*v+"%") + } + + rows, err := d.db.QueryContext(ctx, ` + SELECT + id, + creator_id, + created_ts, + updated_ts, + row_status, + name, + title, + content, + visibility, + tag + FROM memo + WHERE `+strings.Join(where, " AND ")+` + ORDER BY created_ts DESC`, + args..., + ) + if err != nil { + return nil, err + } + defer rows.Close() + + list := make([]*storepb.Memo, 0) + for rows.Next() { + memo := &storepb.Memo{} + var rowStatus, visibility, tags string + if err := rows.Scan( + &memo.Id, + &memo.CreatorId, + &memo.CreatedTs, + &memo.UpdatedTs, + &rowStatus, + &memo.Name, + &memo.Title, + &memo.Content, + &visibility, + &tags, + ); err != nil { + return nil, err + } + memo.RowStatus = store.ConvertRowStatusStringToStorepb(rowStatus) + memo.Visibility = storepb.Visibility(storepb.Visibility_value[visibility]) + memo.Tags = filterTags(strings.Split(tags, " ")) + list = append(list, memo) + } + + if err := rows.Err(); err != nil { + return nil, err + } + return list, nil +} + +func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error { + if _, err := d.db.ExecContext(ctx, `DELETE FROM memo WHERE id = $1`, delete.ID); err != nil { + return err + } + return nil +} + +func vacuumMemo(ctx context.Context, tx *sql.Tx) error { + stmt := `DELETE FROM memo WHERE creator_id NOT IN (SELECT id FROM user)` + _, err := tx.ExecContext(ctx, stmt) + if err != nil { + return err + } + + return nil +} + +func placeholders(n int) string { + placeholder := "" + for i := 0; i < n; i++ { + if i == 0 { + placeholder = fmt.Sprintf("$%d", i+1) + } else { + placeholder = fmt.Sprintf("%s, $%d", placeholder, i+1) + } + } + return placeholder +} diff --git a/store/db/postgres/migration/dev/LATEST__SCHEMA.sql b/store/db/postgres/migration/dev/LATEST__SCHEMA.sql new file mode 100644 index 00000000..b8a257a9 --- /dev/null +++ b/store/db/postgres/migration/dev/LATEST__SCHEMA.sql @@ -0,0 +1,92 @@ +-- migration_history +CREATE TABLE migration_history ( + version TEXT NOT NULL PRIMARY KEY, + created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()) +); + +-- workspace_setting +CREATE TABLE workspace_setting ( + key TEXT NOT NULL UNIQUE, + value TEXT NOT NULL +); + +-- user +CREATE TABLE user ( + id SERIAL PRIMARY KEY, + created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), + updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), + row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', + email TEXT NOT NULL UNIQUE, + nickname TEXT NOT NULL, + password_hash TEXT NOT NULL, + role TEXT NOT NULL CHECK (role IN ('ADMIN', 'USER')) DEFAULT 'USER' +); + +CREATE INDEX idx_user_email ON user(email); + +-- user_setting +CREATE TABLE user_setting ( + user_id INTEGER REFERENCES user(id) NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + PRIMARY KEY (user_id, key) +); + +-- shortcut +CREATE TABLE shortcut ( + id SERIAL PRIMARY KEY, + creator_id INTEGER REFERENCES user(id) NOT NULL, + created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), + updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), + row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', + name TEXT NOT NULL UNIQUE, + link TEXT NOT NULL, + title TEXT NOT NULL DEFAULT '', + description TEXT NOT NULL DEFAULT '', + visibility TEXT NOT NULL CHECK (visibility IN ('PRIVATE', 'WORKSPACE', 'PUBLIC')) DEFAULT 'PRIVATE', + tag TEXT NOT NULL DEFAULT '', + og_metadata TEXT NOT NULL DEFAULT '{}' +); + +CREATE INDEX idx_shortcut_name ON shortcut(name); + +-- activity +CREATE TABLE activity ( + id SERIAL PRIMARY KEY, + creator_id INTEGER REFERENCES user(id) NOT NULL, + created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), + type TEXT NOT NULL DEFAULT '', + level TEXT NOT NULL CHECK (level IN ('INFO', 'WARN', 'ERROR')) DEFAULT 'INFO', + payload TEXT NOT NULL DEFAULT '{}' +); + +-- collection +CREATE TABLE collection ( + id SERIAL PRIMARY KEY, + creator_id INTEGER REFERENCES user(id) NOT NULL, + created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), + updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), + name TEXT NOT NULL UNIQUE, + title TEXT NOT NULL DEFAULT '', + description TEXT NOT NULL DEFAULT '', + shortcut_ids INTEGER ARRAY NOT NULL, + visibility TEXT NOT NULL CHECK (visibility IN ('PRIVATE', 'WORKSPACE', 'PUBLIC')) DEFAULT 'PRIVATE' +); + +CREATE INDEX idx_collection_name ON collection(name); + +-- memo +CREATE TABLE memo ( + id SERIAL PRIMARY KEY, + creator_id INTEGER REFERENCES user(id) NOT NULL, + created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), + updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), + row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', + name TEXT NOT NULL UNIQUE, + title TEXT NOT NULL DEFAULT '', + content TEXT NOT NULL DEFAULT '', + visibility TEXT NOT NULL CHECK (visibility IN ('PRIVATE', 'WORKSPACE', 'PUBLIC')) DEFAULT 'PRIVATE', + tag TEXT NOT NULL DEFAULT '' +); + +CREATE INDEX idx_memo_name ON memo(name); diff --git a/store/db/postgres/migration/prod/LATEST__SCHEMA.sql b/store/db/postgres/migration/prod/LATEST__SCHEMA.sql new file mode 100644 index 00000000..b8a257a9 --- /dev/null +++ b/store/db/postgres/migration/prod/LATEST__SCHEMA.sql @@ -0,0 +1,92 @@ +-- migration_history +CREATE TABLE migration_history ( + version TEXT NOT NULL PRIMARY KEY, + created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()) +); + +-- workspace_setting +CREATE TABLE workspace_setting ( + key TEXT NOT NULL UNIQUE, + value TEXT NOT NULL +); + +-- user +CREATE TABLE user ( + id SERIAL PRIMARY KEY, + created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), + updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), + row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', + email TEXT NOT NULL UNIQUE, + nickname TEXT NOT NULL, + password_hash TEXT NOT NULL, + role TEXT NOT NULL CHECK (role IN ('ADMIN', 'USER')) DEFAULT 'USER' +); + +CREATE INDEX idx_user_email ON user(email); + +-- user_setting +CREATE TABLE user_setting ( + user_id INTEGER REFERENCES user(id) NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + PRIMARY KEY (user_id, key) +); + +-- shortcut +CREATE TABLE shortcut ( + id SERIAL PRIMARY KEY, + creator_id INTEGER REFERENCES user(id) NOT NULL, + created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), + updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), + row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', + name TEXT NOT NULL UNIQUE, + link TEXT NOT NULL, + title TEXT NOT NULL DEFAULT '', + description TEXT NOT NULL DEFAULT '', + visibility TEXT NOT NULL CHECK (visibility IN ('PRIVATE', 'WORKSPACE', 'PUBLIC')) DEFAULT 'PRIVATE', + tag TEXT NOT NULL DEFAULT '', + og_metadata TEXT NOT NULL DEFAULT '{}' +); + +CREATE INDEX idx_shortcut_name ON shortcut(name); + +-- activity +CREATE TABLE activity ( + id SERIAL PRIMARY KEY, + creator_id INTEGER REFERENCES user(id) NOT NULL, + created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), + type TEXT NOT NULL DEFAULT '', + level TEXT NOT NULL CHECK (level IN ('INFO', 'WARN', 'ERROR')) DEFAULT 'INFO', + payload TEXT NOT NULL DEFAULT '{}' +); + +-- collection +CREATE TABLE collection ( + id SERIAL PRIMARY KEY, + creator_id INTEGER REFERENCES user(id) NOT NULL, + created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), + updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), + name TEXT NOT NULL UNIQUE, + title TEXT NOT NULL DEFAULT '', + description TEXT NOT NULL DEFAULT '', + shortcut_ids INTEGER ARRAY NOT NULL, + visibility TEXT NOT NULL CHECK (visibility IN ('PRIVATE', 'WORKSPACE', 'PUBLIC')) DEFAULT 'PRIVATE' +); + +CREATE INDEX idx_collection_name ON collection(name); + +-- memo +CREATE TABLE memo ( + id SERIAL PRIMARY KEY, + creator_id INTEGER REFERENCES user(id) NOT NULL, + created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), + updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), + row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', + name TEXT NOT NULL UNIQUE, + title TEXT NOT NULL DEFAULT '', + content TEXT NOT NULL DEFAULT '', + visibility TEXT NOT NULL CHECK (visibility IN ('PRIVATE', 'WORKSPACE', 'PUBLIC')) DEFAULT 'PRIVATE', + tag TEXT NOT NULL DEFAULT '' +); + +CREATE INDEX idx_memo_name ON memo(name); diff --git a/store/db/postgres/migration_history.go b/store/db/postgres/migration_history.go new file mode 100644 index 00000000..0d7c6e49 --- /dev/null +++ b/store/db/postgres/migration_history.go @@ -0,0 +1,57 @@ +package postgres + +import ( + "context" + + "github.com/yourselfhosted/slash/store" +) + +func (d *DB) UpsertMigrationHistory(ctx context.Context, upsert *store.UpsertMigrationHistory) (*store.MigrationHistory, error) { + stmt := ` + INSERT INTO migration_history ( + version + ) + VALUES ($1) + ON CONFLICT(version) DO UPDATE + SET + version=EXCLUDED.version + RETURNING version, created_ts + ` + var migrationHistory store.MigrationHistory + if err := d.db.QueryRowContext(ctx, stmt, upsert.Version).Scan( + &migrationHistory.Version, + &migrationHistory.CreatedTs, + ); err != nil { + return nil, err + } + + return &migrationHistory, nil +} + +func (d *DB) ListMigrationHistories(ctx context.Context, _ *store.FindMigrationHistory) ([]*store.MigrationHistory, error) { + query := "SELECT version, created_ts FROM migration_history ORDER BY created_ts DESC" + rows, err := d.db.QueryContext(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() + + list := make([]*store.MigrationHistory, 0) + for rows.Next() { + var migrationHistory store.MigrationHistory + if err := rows.Scan( + &migrationHistory.Version, + &migrationHistory.CreatedTs, + ); err != nil { + return nil, err + } + + list = append(list, &migrationHistory) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} diff --git a/store/db/postgres/migrator.go b/store/db/postgres/migrator.go new file mode 100644 index 00000000..724e1b1b --- /dev/null +++ b/store/db/postgres/migrator.go @@ -0,0 +1,233 @@ +package postgres + +import ( + "context" + "embed" + "fmt" + "io/fs" + "os" + "regexp" + "sort" + "time" + + "github.com/pkg/errors" + + "github.com/yourselfhosted/slash/server/version" + "github.com/yourselfhosted/slash/store" +) + +//go:embed migration +var migrationFS embed.FS + +//go:embed seed +var seedFS embed.FS + +// Migrate applies the latest schema to the database. +func (d *DB) Migrate(ctx context.Context) error { + currentVersion := version.GetCurrentVersion(d.profile.Mode) + if d.profile.Mode == "prod" { + _, err := os.Stat(d.profile.DSN) + if err != nil { + // If db file not exists, we should create a new one with latest schema. + if errors.Is(err, os.ErrNotExist) { + if err := d.applyLatestSchema(ctx); err != nil { + return errors.Wrap(err, "failed to apply latest schema") + } + // Upsert the newest version to migration_history. + if _, err := d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{ + Version: currentVersion, + }); err != nil { + return errors.Wrap(err, "failed to upsert migration history") + } + } else { + return errors.Wrap(err, "failed to get db file stat") + } + } else { + // If db file exists, we should check if we need to migrate the database. + migrationHistoryList, err := d.ListMigrationHistories(ctx, &store.FindMigrationHistory{}) + if err != nil { + return errors.Wrap(err, "failed to find migration history") + } + // If no migration history, we should apply the latest version migration and upsert the migration history. + if len(migrationHistoryList) == 0 { + minorVersion := version.GetMinorVersion(currentVersion) + if err := d.applyMigrationForMinorVersion(ctx, minorVersion); err != nil { + return errors.Wrapf(err, "failed to apply version %s migration", minorVersion) + } + _, err := d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{ + Version: currentVersion, + }) + if err != nil { + return errors.Wrap(err, "failed to upsert migration history") + } + return nil + } + + migrationHistoryVersionList := []string{} + for _, migrationHistory := range migrationHistoryList { + migrationHistoryVersionList = append(migrationHistoryVersionList, migrationHistory.Version) + } + sort.Sort(version.SortVersion(migrationHistoryVersionList)) + latestMigrationHistoryVersion := migrationHistoryVersionList[len(migrationHistoryVersionList)-1] + + if version.IsVersionGreaterThan(version.GetSchemaVersion(currentVersion), latestMigrationHistoryVersion) { + minorVersionList := getMinorVersionList() + // backup the raw database file before migration + rawBytes, err := os.ReadFile(d.profile.DSN) + if err != nil { + return errors.Wrap(err, "failed to read raw database file") + } + backupDBFilePath := fmt.Sprintf("%s/memos_%s_%d_backup.db", d.profile.Data, d.profile.Version, time.Now().Unix()) + if err := os.WriteFile(backupDBFilePath, rawBytes, 0644); err != nil { + return errors.Wrap(err, "failed to write raw database file") + } + println("succeed to copy a backup database file") + println("start migrate") + for _, minorVersion := range minorVersionList { + normalizedVersion := minorVersion + ".0" + if version.IsVersionGreaterThan(normalizedVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) { + println("applying migration for", normalizedVersion) + if err := d.applyMigrationForMinorVersion(ctx, minorVersion); err != nil { + return errors.Wrap(err, "failed to apply minor version migration") + } + } + } + println("end migrate") + + // remove the created backup db file after migrate succeed + if err := os.Remove(backupDBFilePath); err != nil { + println(fmt.Sprintf("Failed to remove temp database file, err %v", err)) + } + } + } + } else { + // In non-prod mode, we should always migrate the database. + if _, err := os.Stat(d.profile.DSN); errors.Is(err, os.ErrNotExist) { + if err := d.applyLatestSchema(ctx); err != nil { + return errors.Wrap(err, "failed to apply latest schema") + } + // In demo mode, we should seed the database. + if d.profile.Mode == "demo" { + if err := d.seed(ctx); err != nil { + return errors.Wrap(err, "failed to seed") + } + } + } + } + + return nil +} + +const ( + latestSchemaFileName = "LATEST__SCHEMA.sql" +) + +func (d *DB) applyLatestSchema(ctx context.Context) error { + schemaMode := "dev" + if d.profile.Mode == "prod" { + schemaMode = "prod" + } + latestSchemaPath := fmt.Sprintf("migration/%s/%s", schemaMode, latestSchemaFileName) + buf, err := migrationFS.ReadFile(latestSchemaPath) + if err != nil { + return errors.Wrapf(err, "failed to read latest schema %q", latestSchemaPath) + } + stmt := string(buf) + if err := d.execute(ctx, stmt); err != nil { + return errors.Wrapf(err, "migrate error: %s", stmt) + } + return nil +} + +func (d *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion string) error { + filenames, err := fs.Glob(migrationFS, fmt.Sprintf("%s/%s/*.sql", "migration/prod", minorVersion)) + if err != nil { + return errors.Wrap(err, "failed to read ddl files") + } + + sort.Strings(filenames) + migrationStmt := "" + + // Loop over all migration files and execute them in order. + for _, filename := range filenames { + buf, err := migrationFS.ReadFile(filename) + if err != nil { + return errors.Wrapf(err, "failed to read minor version migration file, filename=%s", filename) + } + stmt := string(buf) + migrationStmt += stmt + if err := d.execute(ctx, stmt); err != nil { + return errors.Wrapf(err, "migrate error: %s", stmt) + } + } + + // Upsert the newest version to migration_history. + version := minorVersion + ".0" + if _, err = d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{ + Version: version, + }); err != nil { + return errors.Wrapf(err, "failed to upsert migration history with version: %s", version) + } + + return nil +} + +func (d *DB) seed(ctx context.Context) error { + filenames, err := fs.Glob(seedFS, fmt.Sprintf("%s/*.sql", "seed")) + if err != nil { + return errors.Wrap(err, "failed to read seed files") + } + + sort.Strings(filenames) + + // Loop over all seed files and execute them in order. + for _, filename := range filenames { + buf, err := seedFS.ReadFile(filename) + if err != nil { + return errors.Wrapf(err, "failed to read seed file, filename=%s", filename) + } + stmt := string(buf) + if err := d.execute(ctx, stmt); err != nil { + return errors.Wrapf(err, "seed error: %s", stmt) + } + } + return nil +} + +// execute runs a single SQL statement within a transaction. +func (d *DB) execute(ctx context.Context, stmt string) error { + tx, err := d.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + if _, err := tx.ExecContext(ctx, stmt); err != nil { + return errors.Wrap(err, "failed to execute statement") + } + + return tx.Commit() +} + +// minorDirRegexp is a regular expression for minor version directory. +var minorDirRegexp = regexp.MustCompile(`^migration/prod/[0-9]+\.[0-9]+$`) + +func getMinorVersionList() []string { + minorVersionList := []string{} + + if err := fs.WalkDir(migrationFS, "migration", func(path string, file fs.DirEntry, err error) error { + if err != nil { + return err + } + if file.IsDir() && minorDirRegexp.MatchString(path) { + minorVersionList = append(minorVersionList, file.Name()) + } + + return nil + }); err != nil { + panic(err) + } + + sort.Sort(version.SortVersion(minorVersionList)) + return minorVersionList +} diff --git a/store/db/postgres/postgres.go b/store/db/postgres/postgres.go new file mode 100644 index 00000000..6bcc319b --- /dev/null +++ b/store/db/postgres/postgres.go @@ -0,0 +1,45 @@ +package postgres + +import ( + "database/sql" + "log" + + // Import the PostgreSQL driver. + _ "github.com/lib/pq" + "github.com/pkg/errors" + + "github.com/yourselfhosted/slash/server/profile" + "github.com/yourselfhosted/slash/store" +) + +type DB struct { + db *sql.DB + profile *profile.Profile +} + +func NewDB(profile *profile.Profile) (store.Driver, error) { + if profile == nil { + return nil, errors.New("profile is nil") + } + + // Open the PostgreSQL connection + db, err := sql.Open("postgres", profile.DSN) + if err != nil { + log.Printf("Failed to open database: %s", err) + return nil, errors.Wrapf(err, "failed to open database: %s", profile.DSN) + } + + var driver store.Driver = &DB{ + db: db, + profile: profile, + } + return driver, nil +} + +func (d *DB) GetDB() *sql.DB { + return d.db +} + +func (d *DB) Close() error { + return d.db.Close() +} diff --git a/store/db/postgres/seed/10000__reset.sql b/store/db/postgres/seed/10000__reset.sql new file mode 100644 index 00000000..79be1d8a --- /dev/null +++ b/store/db/postgres/seed/10000__reset.sql @@ -0,0 +1,9 @@ +DELETE FROM activity; + +DELETE FROM shortcut; + +DELETE FROM user_setting; + +DELETE FROM user; + +DELETE FROM workspace_setting; diff --git a/store/db/postgres/shortcut.go b/store/db/postgres/shortcut.go new file mode 100644 index 00000000..821898df --- /dev/null +++ b/store/db/postgres/shortcut.go @@ -0,0 +1,228 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/pkg/errors" + "google.golang.org/protobuf/encoding/protojson" + + storepb "github.com/yourselfhosted/slash/proto/gen/store" + "github.com/yourselfhosted/slash/store" +) + +func (d *DB) CreateShortcut(ctx context.Context, create *storepb.Shortcut) (*storepb.Shortcut, error) { + set := []string{"creator_id", "name", "link", "title", "description", "visibility", "tag"} + args := []any{create.CreatorId, create.Name, create.Link, create.Title, create.Description, create.Visibility.String(), strings.Join(create.Tags, " ")} + if create.OgMetadata != nil { + set = append(set, "og_metadata") + openGraphMetadataBytes, err := protojson.Marshal(create.OgMetadata) + if err != nil { + return nil, err + } + args = append(args, string(openGraphMetadataBytes)) + } + + stmt := fmt.Sprintf(` + INSERT INTO shortcut (%s) + VALUES (%s) + RETURNING id, created_ts, updated_ts, row_status + `, strings.Join(set, ","), placeholders(len(args))) + + var rowStatus string + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( + &create.Id, + &create.CreatedTs, + &create.UpdatedTs, + &rowStatus, + ); err != nil { + return nil, err + } + create.RowStatus = store.ConvertRowStatusStringToStorepb(rowStatus) + shortcut := create + return shortcut, nil +} + +func (d *DB) UpdateShortcut(ctx context.Context, update *store.UpdateShortcut) (*storepb.Shortcut, error) { + set, args := []string{}, []any{} + if update.RowStatus != nil { + set, args = append(set, fmt.Sprintf("row_status = $%d", len(args)+1)), append(args, update.RowStatus.String()) + } + if update.Name != nil { + set, args = append(set, fmt.Sprintf("name = $%d", len(args)+1)), append(args, *update.Name) + } + if update.Link != nil { + set, args = append(set, fmt.Sprintf("link = $%d", len(args)+1)), append(args, *update.Link) + } + if update.Title != nil { + set, args = append(set, fmt.Sprintf("title = $%d", len(args)+1)), append(args, *update.Title) + } + if update.Description != nil { + set, args = append(set, fmt.Sprintf("description = $%d", len(args)+1)), append(args, *update.Description) + } + if update.Visibility != nil { + set, args = append(set, fmt.Sprintf("visibility = $%d", len(args)+1)), append(args, update.Visibility.String()) + } + if update.Tag != nil { + set, args = append(set, fmt.Sprintf("tag = $%d", len(args)+1)), append(args, *update.Tag) + } + if update.OpenGraphMetadata != nil { + openGraphMetadataBytes, err := protojson.Marshal(update.OpenGraphMetadata) + if err != nil { + return nil, errors.Wrap(err, "Failed to marshal activity payload") + } + set, args = append(set, fmt.Sprintf("og_metadata = $%d", len(args)+1)), append(args, string(openGraphMetadataBytes)) + } + if len(set) == 0 { + return nil, errors.New("no update specified") + } + args = append(args, update.ID) + + stmt := fmt.Sprintf(` + UPDATE shortcut + SET %s + WHERE id = $%d + RETURNING id, creator_id, created_ts, updated_ts, row_status, name, link, title, description, visibility, tag, og_metadata + `, strings.Join(set, ","), len(args)) + + shortcut := &storepb.Shortcut{} + var rowStatus, visibility, tags, openGraphMetadataString string + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( + &shortcut.Id, + &shortcut.CreatorId, + &shortcut.CreatedTs, + &shortcut.UpdatedTs, + &rowStatus, + &shortcut.Name, + &shortcut.Link, + &shortcut.Title, + &shortcut.Description, + &visibility, + &tags, + &openGraphMetadataString, + ); err != nil { + return nil, err + } + shortcut.RowStatus = store.ConvertRowStatusStringToStorepb(rowStatus) + shortcut.Visibility = convertVisibilityStringToStorepb(visibility) + shortcut.Tags = filterTags(strings.Split(tags, " ")) + var ogMetadata storepb.OpenGraphMetadata + if err := protojson.Unmarshal([]byte(openGraphMetadataString), &ogMetadata); err != nil { + return nil, err + } + shortcut.OgMetadata = &ogMetadata + return shortcut, nil +} + +func (d *DB) ListShortcuts(ctx context.Context, find *store.FindShortcut) ([]*storepb.Shortcut, error) { + where, args := []string{"1 = 1"}, []any{} + if v := find.ID; v != nil { + where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v) + } + if v := find.CreatorID; v != nil { + where, args = append(where, fmt.Sprintf("creator_id = $%d", len(args)+1)), append(args, *v) + } + if v := find.RowStatus; v != nil { + where, args = append(where, fmt.Sprintf("row_status = $%d", len(args)+1)), append(args, *v) + } + if v := find.Name; v != nil { + where, args = append(where, fmt.Sprintf("name = $%d", len(args)+1)), append(args, *v) + } + if v := find.VisibilityList; len(v) != 0 { + list := []string{} + for _, visibility := range v { + list = append(list, fmt.Sprintf("$%d", len(args)+1)) + args = append(args, visibility) + } + where = append(where, fmt.Sprintf("visibility IN (%s)", strings.Join(list, ","))) + } + if v := find.Tag; v != nil { + where, args = append(where, fmt.Sprintf("tag LIKE $%d", len(args)+1)), append(args, "%"+*v+"%") + } + + rows, err := d.db.QueryContext(ctx, fmt.Sprintf(` + SELECT + id, + creator_id, + created_ts, + updated_ts, + row_status, + name, + link, + title, + description, + visibility, + tag, + og_metadata + FROM shortcut + WHERE %s + ORDER BY created_ts DESC + `, strings.Join(where, " AND ")), args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := make([]*storepb.Shortcut, 0) + for rows.Next() { + shortcut := &storepb.Shortcut{} + var rowStatus, visibility, tags, openGraphMetadataString string + if err := rows.Scan( + &shortcut.Id, + &shortcut.CreatorId, + &shortcut.CreatedTs, + &shortcut.UpdatedTs, + &rowStatus, + &shortcut.Name, + &shortcut.Link, + &shortcut.Title, + &shortcut.Description, + &visibility, + &tags, + &openGraphMetadataString, + ); err != nil { + return nil, err + } + shortcut.RowStatus = store.ConvertRowStatusStringToStorepb(rowStatus) + shortcut.Visibility = storepb.Visibility(storepb.Visibility_value[visibility]) + shortcut.Tags = filterTags(strings.Split(tags, " ")) + var ogMetadata storepb.OpenGraphMetadata + if err := protojson.Unmarshal([]byte(openGraphMetadataString), &ogMetadata); err != nil { + return nil, err + } + shortcut.OgMetadata = &ogMetadata + list = append(list, shortcut) + } + + if err := rows.Err(); err != nil { + return nil, err + } + return list, nil +} + +func (d *DB) DeleteShortcut(ctx context.Context, delete *store.DeleteShortcut) error { + _, err := d.db.ExecContext(ctx, "DELETE FROM shortcut WHERE id = $1", delete.ID) + return err +} + +func vacuumShortcut(ctx context.Context, tx *sql.Tx) error { + stmt := `DELETE FROM shortcut WHERE creator_id NOT IN (SELECT id FROM "user")` + _, err := tx.ExecContext(ctx, stmt) + return err +} + +func filterTags(tags []string) []string { + result := []string{} + for _, tag := range tags { + if tag != "" { + result = append(result, tag) + } + } + return result +} + +func convertVisibilityStringToStorepb(visibility string) storepb.Visibility { + return storepb.Visibility(storepb.Visibility_value[visibility]) +} diff --git a/store/db/postgres/user.go b/store/db/postgres/user.go new file mode 100644 index 00000000..641e687a --- /dev/null +++ b/store/db/postgres/user.go @@ -0,0 +1,182 @@ +package postgres + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/yourselfhosted/slash/store" +) + +func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) { + stmt := ` + INSERT INTO "user" ( + email, + nickname, + password_hash, + role + ) + VALUES ($1, $2, $3, $4) + RETURNING id, created_ts, updated_ts, row_status + ` + if err := d.db.QueryRowContext(ctx, stmt, + create.Email, + create.Nickname, + create.PasswordHash, + create.Role, + ).Scan( + &create.ID, + &create.CreatedTs, + &create.UpdatedTs, + &create.RowStatus, + ); err != nil { + return nil, err + } + + user := create + return user, nil +} + +func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) { + set, args := []string{}, []any{} + if v := update.RowStatus; v != nil { + set, args = append(set, "row_status = $"+placeholder(len(args)+1)), append(args, *v) + } + if v := update.Email; v != nil { + set, args = append(set, "email = $"+placeholder(len(args)+1)), append(args, *v) + } + if v := update.Nickname; v != nil { + set, args = append(set, "nickname = $"+placeholder(len(args)+1)), append(args, *v) + } + if v := update.PasswordHash; v != nil { + set, args = append(set, "password_hash = $"+placeholder(len(args)+1)), append(args, *v) + } + if v := update.Role; v != nil { + set, args = append(set, "role = $"+placeholder(len(args)+1)), append(args, *v) + } + + if len(set) == 0 { + return nil, errors.New("no fields to update") + } + + stmt := ` + UPDATE "user" + SET ` + strings.Join(set, ", ") + ` + WHERE id = $` + placeholder(len(args)+1) + ` + RETURNING id, created_ts, updated_ts, row_status, email, nickname, password_hash, role + ` + args = append(args, update.ID) + user := &store.User{} + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( + &user.ID, + &user.CreatedTs, + &user.UpdatedTs, + &user.RowStatus, + &user.Email, + &user.Nickname, + &user.PasswordHash, + &user.Role, + ); err != nil { + return nil, err + } + + return user, nil +} + +func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) { + where, args := []string{"1 = 1"}, []any{} + + if v := find.ID; v != nil { + where, args = append(where, "id = $"+placeholder(len(args)+1)), append(args, *v) + } + if v := find.RowStatus; v != nil { + where, args = append(where, "row_status = $"+placeholder(len(args)+1)), append(args, v.String()) + } + if v := find.Email; v != nil { + where, args = append(where, "email = $"+placeholder(len(args)+1)), append(args, *v) + } + if v := find.Nickname; v != nil { + where, args = append(where, "nickname = $"+placeholder(len(args)+1)), append(args, *v) + } + if v := find.Role; v != nil { + where, args = append(where, "role = $"+placeholder(len(args)+1)), append(args, *v) + } + + query := ` + SELECT + id, + created_ts, + updated_ts, + row_status, + email, + nickname, + password_hash, + role + FROM "user" + WHERE ` + strings.Join(where, " AND ") + ` + ORDER BY updated_ts DESC, created_ts DESC + ` + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := make([]*store.User, 0) + for rows.Next() { + user := &store.User{} + if err := rows.Scan( + &user.ID, + &user.CreatedTs, + &user.UpdatedTs, + &user.RowStatus, + &user.Email, + &user.Nickname, + &user.PasswordHash, + &user.Role, + ); err != nil { + return nil, err + } + list = append(list, user) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} + +func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error { + tx, err := d.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + + if _, err := tx.ExecContext(ctx, ` + DELETE FROM "user" WHERE id = $1 + `, delete.ID); err != nil { + return err + } + + if err := vacuumUserSetting(ctx, tx); err != nil { + return err + } + if err := vacuumShortcut(ctx, tx); err != nil { + return err + } + if err := vacuumMemo(ctx, tx); err != nil { + return err + } + if err := vacuumCollection(ctx, tx); err != nil { + return err + } + + return tx.Commit() +} + +func placeholder(n int) string { + return "$" + fmt.Sprint(n) +} diff --git a/store/db/postgres/user_setting.go b/store/db/postgres/user_setting.go new file mode 100644 index 00000000..a797da18 --- /dev/null +++ b/store/db/postgres/user_setting.go @@ -0,0 +1,122 @@ +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + + "google.golang.org/protobuf/encoding/protojson" + + storepb "github.com/yourselfhosted/slash/proto/gen/store" + "github.com/yourselfhosted/slash/store" +) + +func (d *DB) UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) { + stmt := ` + INSERT INTO user_setting ( + user_id, key, value + ) + VALUES ($1, $2, $3) + ON CONFLICT(user_id, key) DO UPDATE + SET value = EXCLUDED.value + RETURNING user_id, key, value + ` + + var valueString string + if upsert.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS { + valueBytes, err := protojson.Marshal(upsert.GetAccessTokens()) + if err != nil { + return nil, err + } + valueString = string(valueBytes) + } else if upsert.Key == storepb.UserSettingKey_USER_SETTING_LOCALE { + valueString = upsert.GetLocale().String() + } else if upsert.Key == storepb.UserSettingKey_USER_SETTING_COLOR_THEME { + valueString = upsert.GetColorTheme().String() + } else { + return nil, errors.New("invalid user setting key") + } + + if _, err := d.db.ExecContext(ctx, stmt, upsert.UserId, upsert.Key.String(), valueString); err != nil { + return nil, err + } + + userSettingMessage := upsert + return userSettingMessage, nil +} + +func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ([]*storepb.UserSetting, error) { + where, args := []string{"1 = 1"}, []any{} + + if v := find.Key; v != storepb.UserSettingKey_USER_SETTING_KEY_UNSPECIFIED { + where, args = append(where, fmt.Sprintf("key = $%d", len(args)+1)), append(args, v.String()) + } + if v := find.UserID; v != nil { + where, args = append(where, fmt.Sprintf("user_id = $%d", len(args)+1)), append(args, *find.UserID) + } + + query := ` + SELECT + user_id, + key, + value + FROM user_setting + WHERE ` + strings.Join(where, " AND ") + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + userSettingList := make([]*storepb.UserSetting, 0) + for rows.Next() { + userSetting := &storepb.UserSetting{} + var keyString, valueString string + if err := rows.Scan( + &userSetting.UserId, + &keyString, + &valueString, + ); err != nil { + return nil, err + } + userSetting.Key = storepb.UserSettingKey(storepb.UserSettingKey_value[keyString]) + if userSetting.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS { + accessTokensUserSetting := &storepb.AccessTokensUserSetting{} + if err := protojson.Unmarshal([]byte(valueString), accessTokensUserSetting); err != nil { + return nil, err + } + userSetting.Value = &storepb.UserSetting_AccessTokens{ + AccessTokens: accessTokensUserSetting, + } + } else if userSetting.Key == storepb.UserSettingKey_USER_SETTING_LOCALE { + userSetting.Value = &storepb.UserSetting_Locale{ + Locale: storepb.LocaleUserSetting(storepb.LocaleUserSetting_value[valueString]), + } + } else if userSetting.Key == storepb.UserSettingKey_USER_SETTING_COLOR_THEME { + userSetting.Value = &storepb.UserSetting_ColorTheme{ + ColorTheme: storepb.ColorThemeUserSetting(storepb.ColorThemeUserSetting_value[valueString]), + } + } else { + return nil, errors.New("invalid user setting key") + } + userSettingList = append(userSettingList, userSetting) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return userSettingList, nil +} + +func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error { + stmt := `DELETE FROM user_setting WHERE user_id NOT IN (SELECT id FROM "user")` + _, err := tx.ExecContext(ctx, stmt) + if err != nil { + return err + } + + return nil +} diff --git a/store/db/postgres/workspace_setting.go b/store/db/postgres/workspace_setting.go new file mode 100644 index 00000000..e6514425 --- /dev/null +++ b/store/db/postgres/workspace_setting.go @@ -0,0 +1,116 @@ +package postgres + +import ( + "context" + "errors" + "strconv" + "strings" + + "google.golang.org/protobuf/encoding/protojson" + + storepb "github.com/yourselfhosted/slash/proto/gen/store" + "github.com/yourselfhosted/slash/store" +) + +func (d *DB) UpsertWorkspaceSetting(ctx context.Context, upsert *storepb.WorkspaceSetting) (*storepb.WorkspaceSetting, error) { + stmt := ` + INSERT INTO workspace_setting ( + key, + value + ) + VALUES ($1, $2) + ON CONFLICT(key) DO UPDATE + SET value = EXCLUDED.value + ` + var valueString string + if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_LICENSE_KEY { + valueString = upsert.GetLicenseKey() + } else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_SECRET_SESSION { + valueString = upsert.GetSecretSession() + } else if upsert.Key == storepb.WorkspaceSettingKey_WORKSAPCE_SETTING_ENABLE_SIGNUP { + valueString = strconv.FormatBool(upsert.GetEnableSignup()) + } else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_STYLE { + valueString = upsert.GetCustomStyle() + } else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_SCRIPT { + valueString = upsert.GetCustomScript() + } else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_AUTO_BACKUP { + valueBytes, err := protojson.Marshal(upsert.GetAutoBackup()) + if err != nil { + return nil, err + } + valueString = string(valueBytes) + } else { + return nil, errors.New("invalid workspace setting key") + } + + if _, err := d.db.ExecContext(ctx, stmt, upsert.Key.String(), valueString); err != nil { + return nil, err + } + + workspaceSetting := upsert + return workspaceSetting, nil +} + +func (d *DB) ListWorkspaceSettings(ctx context.Context, find *store.FindWorkspaceSetting) ([]*storepb.WorkspaceSetting, error) { + where, args := []string{"1 = 1"}, []interface{}{} + + if find.Key != storepb.WorkspaceSettingKey_WORKSPACE_SETTING_KEY_UNSPECIFIED { + where, args = append(where, "key = $"+placeholder(len(args)+1)), append(args, find.Key.String()) + } + + query := ` + SELECT + key, + value + FROM workspace_setting + WHERE ` + strings.Join(where, " AND ") + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + + defer rows.Close() + + list := []*storepb.WorkspaceSetting{} + for rows.Next() { + workspaceSetting := &storepb.WorkspaceSetting{} + var keyString, valueString string + if err := rows.Scan( + &keyString, + &valueString, + ); err != nil { + return nil, err + } + workspaceSetting.Key = storepb.WorkspaceSettingKey(storepb.WorkspaceSettingKey_value[keyString]) + if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_LICENSE_KEY { + workspaceSetting.Value = &storepb.WorkspaceSetting_LicenseKey{LicenseKey: valueString} + } else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_SECRET_SESSION { + workspaceSetting.Value = &storepb.WorkspaceSetting_SecretSession{SecretSession: valueString} + } else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSAPCE_SETTING_ENABLE_SIGNUP { + enableSignup, err := strconv.ParseBool(valueString) + if err != nil { + return nil, err + } + workspaceSetting.Value = &storepb.WorkspaceSetting_EnableSignup{EnableSignup: enableSignup} + } else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_STYLE { + workspaceSetting.Value = &storepb.WorkspaceSetting_CustomStyle{CustomStyle: valueString} + } else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_SCRIPT { + workspaceSetting.Value = &storepb.WorkspaceSetting_CustomScript{CustomScript: valueString} + } else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_AUTO_BACKUP { + autoBackupSetting := &storepb.AutoBackupWorkspaceSetting{} + if err := protojson.Unmarshal([]byte(valueString), autoBackupSetting); err != nil { + return nil, err + } + workspaceSetting.Value = &storepb.WorkspaceSetting_AutoBackup{AutoBackup: autoBackupSetting} + } else { + continue + } + list = append(list, workspaceSetting) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +}