diff --git a/commands.go b/commands.go index 60f0fc1..d864b3d 100644 --- a/commands.go +++ b/commands.go @@ -194,6 +194,7 @@ var cmdDeletePortal = &commands.FullHandler{ func fnDeletePortal(ce *WrappedCommandEvent) { ce.Portal.delete() - ce.Portal.cleanup(false) + + ce.Bridge.cleanupRoom(ce.Portal.MainIntent(), ce.Portal.MXID, false, ce.Log) ce.Log.Infofln("Deleted portal") } diff --git a/database/portal.go b/database/portal.go index 4b96d1a..e06c019 100644 --- a/database/portal.go +++ b/database/portal.go @@ -70,6 +70,8 @@ type Portal struct { FirstEventID id.EventID NextBatchID id.BatchID FirstSlackID string + + InSpace bool } func (p *Portal) Scan(row dbutil.Scannable) *Portal { @@ -78,7 +80,7 @@ func (p *Portal) Scan(row dbutil.Scannable) *Portal { err := row.Scan(&p.Key.TeamID, &p.Key.ChannelID, &mxid, &p.Type, &dmUserID, &p.PlainName, &p.Name, &p.NameSet, &p.Topic, &p.TopicSet, &p.Avatar, &avatarURL, &p.AvatarSet, &firstEventID, - &p.Encrypted, &nextBatchID, &firstSlackID) + &p.Encrypted, &nextBatchID, &firstSlackID, &p.InSpace) if err != nil { if err != sql.ErrNoRows { @@ -110,13 +112,13 @@ func (p *Portal) Insert() { query := "INSERT INTO portal" + " (team_id, channel_id, mxid, type, dm_user_id, plain_name," + " name, name_set, topic, topic_set, avatar, avatar_url, avatar_set," + - " first_event_id, encrypted, next_batch_id, first_slack_id)" + - " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)" + " first_event_id, encrypted, next_batch_id, first_slack_id, in_space)" + + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18)" _, err := p.db.Exec(query, p.Key.TeamID, p.Key.ChannelID, p.mxidPtr(), p.Type, p.DMUserID, p.PlainName, p.Name, p.NameSet, p.Topic, p.TopicSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet, - p.FirstEventID.String(), p.Encrypted, p.NextBatchID.String(), p.FirstSlackID) + p.FirstEventID.String(), p.Encrypted, p.NextBatchID.String(), p.FirstSlackID, p.InSpace) if err != nil { p.log.Warnfln("Failed to insert %s: %v", p.Key, err) @@ -127,13 +129,13 @@ func (p *Portal) Update(txn dbutil.Transaction) { query := "UPDATE portal SET" + " mxid=$1, type=$2, dm_user_id=$3, plain_name=$4, name=$5, name_set=$6," + " topic=$7, topic_set=$8, avatar=$9, avatar_url=$10, avatar_set=$11," + - " first_event_id=$12, encrypted=$13, next_batch_id=$14, first_slack_id=$15" + - " WHERE team_id=$16 AND channel_id=$17" + " first_event_id=$12, encrypted=$13, next_batch_id=$14, first_slack_id=$15, in_space=$16" + + " WHERE team_id=$17 AND channel_id=$18" args := []interface{}{p.mxidPtr(), p.Type, p.DMUserID, p.PlainName, p.Name, p.NameSet, p.Topic, p.TopicSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet, p.FirstEventID.String(), p.Encrypted, p.NextBatchID.String(), p.FirstSlackID, - p.Key.TeamID, p.Key.ChannelID} + p.InSpace, p.Key.TeamID, p.Key.ChannelID} var err error if txn != nil { diff --git a/database/portalquery.go b/database/portalquery.go index 69dff4f..8aeca16 100644 --- a/database/portalquery.go +++ b/database/portalquery.go @@ -25,7 +25,7 @@ const ( portalSelect = "SELECT team_id, channel_id, mxid, type, " + " dm_user_id, plain_name, name, name_set, topic, topic_set," + " avatar, avatar_url, avatar_set, first_event_id," + - " encrypted, next_batch_id, first_slack_id FROM portal" + " encrypted, next_batch_id, first_slack_id, in_space FROM portal" ) type PortalQuery struct { diff --git a/database/teaminfo.go b/database/teaminfo.go index a172dd1..642f62c 100644 --- a/database/teaminfo.go +++ b/database/teaminfo.go @@ -38,7 +38,7 @@ func (tiq *TeamInfoQuery) New() *TeamInfo { } func (tiq *TeamInfoQuery) GetBySlackTeam(team string) *TeamInfo { - query := `SELECT team_id, team_domain, team_url, team_name, avatar, avatar_url FROM team_info WHERE team_id=$1` + query := `SELECT team_id, team_domain, team_url, team_name, avatar, avatar_url, space_room, name_set, avatar_set FROM team_info WHERE team_id=$1` row := tiq.db.QueryRow(query, team) if row == nil { @@ -48,6 +48,17 @@ func (tiq *TeamInfoQuery) GetBySlackTeam(team string) *TeamInfo { return tiq.New().Scan(row) } +func (tiq *TeamInfoQuery) GetByMXID(mxid id.RoomID) *TeamInfo { + query := `SELECT team_id, team_domain, team_url, team_name, avatar, avatar_url, space_room, name_set, avatar_set FROM team_info WHERE space_room=$1` + + row := tiq.db.QueryRow(query, mxid) + if row == nil { + return nil + } + + return tiq.New().Scan(row) +} + type TeamInfo struct { db *Database log log.Logger @@ -58,6 +69,9 @@ type TeamInfo struct { TeamName string Avatar string AvatarUrl id.ContentURI + SpaceRoom id.RoomID + NameSet bool + AvatarSet bool } func (ti *TeamInfo) Scan(row dbutil.Scannable) *TeamInfo { @@ -66,8 +80,9 @@ func (ti *TeamInfo) Scan(row dbutil.Scannable) *TeamInfo { var teamName sql.NullString var avatar sql.NullString var avatarUrl sql.NullString + var spaceRoom sql.NullString - err := row.Scan(&ti.TeamID, &teamDomain, &teamUrl, &teamName, &avatar, &avatarUrl) + err := row.Scan(&ti.TeamID, &teamDomain, &teamUrl, &teamName, &avatar, &avatarUrl, &spaceRoom, &ti.NameSet, &ti.AvatarSet) if err != nil { if err != sql.ErrNoRows { ti.log.Errorln("Database scan failed:", err) @@ -91,16 +106,26 @@ func (ti *TeamInfo) Scan(row dbutil.Scannable) *TeamInfo { if avatarUrl.Valid { ti.AvatarUrl, _ = id.ParseContentURI(avatarUrl.String) } + if spaceRoom.Valid { + ti.SpaceRoom = id.RoomID(spaceRoom.String) + } return ti } func (ti *TeamInfo) Upsert() { query := ` - INSERT INTO team_info (team_id, team_domain, team_url, team_name, avatar, avatar_url) - VALUES ($1, $2, $3, $4, $5, $6) + INSERT INTO team_info (team_id, team_domain, team_url, team_name, avatar, avatar_url, space_room, name_set, avatar_set) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) ON CONFLICT (team_id) DO UPDATE - SET team_domain=excluded.team_domain, team_url=excluded.team_url, team_name=excluded.team_name, avatar=excluded.avatar, avatar_url=excluded.avatar_url + SET team_domain=excluded.team_domain, + team_url=excluded.team_url, + team_name=excluded.team_name, + avatar=excluded.avatar, + avatar_url=excluded.avatar_url, + space_room=excluded.space_room, + name_set=excluded.name_set, + avatar_set=excluded.avatar_set ` teamDomain := sqlNullString(ti.TeamDomain) @@ -108,8 +133,9 @@ func (ti *TeamInfo) Upsert() { teamName := sqlNullString(ti.TeamName) avatar := sqlNullString(ti.Avatar) avatarUrl := sqlNullString(ti.AvatarUrl.String()) + spaceRoom := sqlNullString(ti.SpaceRoom.String()) - _, err := ti.db.Exec(query, ti.TeamID, teamDomain, teamUrl, teamName, avatar, avatarUrl) + _, err := ti.db.Exec(query, ti.TeamID, teamDomain, teamUrl, teamName, avatar, avatarUrl, spaceRoom, ti.NameSet, ti.AvatarSet) if err != nil { ti.log.Warnfln("Failed to upsert team %s: %v", ti.TeamID, err) diff --git a/database/upgrades/00-latest-revision.sql b/database/upgrades/00-latest-revision.sql index c7137c8..08779f5 100644 --- a/database/upgrades/00-latest-revision.sql +++ b/database/upgrades/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v15: Latest revision +-- v0 -> v16: Latest revision CREATE TABLE portal ( team_id TEXT, @@ -23,6 +23,8 @@ CREATE TABLE portal ( next_batch_id TEXT, first_slack_id TEXT, + in_space BOOLEAN DEFAULT false, + PRIMARY KEY (team_id, channel_id) ); @@ -54,28 +56,31 @@ CREATE TABLE puppet ( CREATE TABLE "user" ( mxid TEXT PRIMARY KEY, - management_room TEXT + management_room TEXT, + space_room TEXT ); CREATE TABLE "user_team" ( mxid TEXT NOT NULL, slack_email TEXT NOT NULL, - slack_id TEXT NOT NULL, + slack_id TEXT NOT NULL, team_name TEXT NOT NULL, - team_id TEXT NOT NULL, + team_id TEXT NOT NULL, - token TEXT, + token TEXT, cookie_token TEXT, + in_space BOOLEAN DEFAULT false, + PRIMARY KEY(mxid, slack_id, team_id) ); CREATE TABLE user_team_portal ( - matrix_user_id TEXT NOT NULL, - slack_user_id TEXT NOT NULL, - slack_team_id TEXT NOT NULL, + matrix_user_id TEXT NOT NULL, + slack_user_id TEXT NOT NULL, + slack_team_id TEXT NOT NULL, portal_channel_id TEXT NOT NULL, FOREIGN KEY(matrix_user_id, slack_user_id, slack_team_id) REFERENCES "user_team"(mxid, slack_id, team_id) ON DELETE CASCADE, FOREIGN KEY(slack_team_id, portal_channel_id) REFERENCES portal(team_id, channel_id) ON DELETE CASCADE @@ -85,9 +90,9 @@ CREATE TABLE message ( team_id TEXT NOT NULL, channel_id TEXT NOT NULL, - slack_message_id TEXT NOT NULL, - slack_thread_id TEXT, - matrix_message_id TEXT NOT NULL UNIQUE, + slack_message_id TEXT NOT NULL, + slack_thread_id TEXT, + matrix_message_id TEXT NOT NULL UNIQUE, author_id TEXT NOT NULL, @@ -118,30 +123,33 @@ CREATE TABLE attachment ( channel_id TEXT NOT NULL, slack_message_id TEXT NOT NULL, - slack_file_id TEXT NOT NULL, - matrix_event_id TEXT NOT NULL UNIQUE, - slack_thread_id TEXT, + slack_file_id TEXT NOT NULL, + matrix_event_id TEXT NOT NULL UNIQUE, + slack_thread_id TEXT, PRIMARY KEY(slack_message_id, slack_file_id, matrix_event_id), FOREIGN KEY(team_id, channel_id) REFERENCES portal(team_id, channel_id) ON DELETE CASCADE ); CREATE TABLE "team_info" ( - team_id TEXT NOT NULL UNIQUE, + team_id TEXT NOT NULL UNIQUE, team_domain TEXT, - team_url TEXT, - team_name TEXT, - avatar TEXT, - avatar_url TEXT + team_url TEXT, + team_name TEXT, + avatar TEXT, + avatar_url TEXT, + space_room TEXT, + name_set BOOLEAN DEFAULT false, + avatar_set BOOLEAN DEFAULT false ); CREATE TABLE backfill_state ( team_id TEXT, channel_id TEXT, backfill_complete BOOLEAN, - dispatched BOOLEAN, - message_count INTEGER, - immediate_complete BOOLEAN, + dispatched BOOLEAN, + message_count INTEGER, + immediate_complete BOOLEAN, PRIMARY KEY (team_id, channel_id), FOREIGN KEY (team_id, channel_id) REFERENCES portal (team_id, channel_id) ON DELETE CASCADE ); diff --git a/database/upgrades/16-add-spaces.sql b/database/upgrades/16-add-spaces.sql new file mode 100644 index 0000000..20757d4 --- /dev/null +++ b/database/upgrades/16-add-spaces.sql @@ -0,0 +1,8 @@ +-- v16: Add spaces + +ALTER TABLE "user" ADD space_room TEXT; +ALTER TABLE team_info ADD space_room TEXT; +ALTER TABLE team_info ADD name_set BOOLEAN DEFAULT false; +ALTER TABLE team_info ADD avatar_set BOOLEAN DEFAULT false; +ALTER TABLE portal ADD in_space BOOLEAN DEFAULT false; +ALTER TABLE user_team ADD in_space BOOLEAN DEFAULT false; diff --git a/database/user.go b/database/user.go index 1cbf2ff..ab2c6bf 100644 --- a/database/user.go +++ b/database/user.go @@ -32,6 +32,7 @@ type User struct { MXID id.UserID ManagementRoom id.RoomID + SpaceRoom id.RoomID TeamsLock sync.Mutex Teams map[string]*UserTeam @@ -47,7 +48,9 @@ func (user *User) loadTeams() { } func (u *User) Scan(row dbutil.Scannable) *User { - err := row.Scan(&u.MXID, &u.ManagementRoom) + var spaceRoom sql.NullString + + err := row.Scan(&u.MXID, &u.ManagementRoom, &spaceRoom) if err != nil { if err != sql.ErrNoRows { u.log.Errorln("Database scan failed:", err) @@ -56,6 +59,8 @@ func (u *User) Scan(row dbutil.Scannable) *User { return nil } + u.SpaceRoom = id.RoomID(spaceRoom.String) + u.loadTeams() return u @@ -79,9 +84,9 @@ func (u *User) SyncTeams() { } func (u *User) Insert() { - query := "INSERT INTO \"user\" (mxid, management_room) VALUES ($1, $2);" + query := "INSERT INTO \"user\" (mxid, management_room, space_room) VALUES ($1, $2, $3);" - _, err := u.db.Exec(query, u.MXID, u.ManagementRoom) + _, err := u.db.Exec(query, u.MXID, u.ManagementRoom, u.SpaceRoom) if err != nil { u.log.Warnfln("Failed to insert %s: %v", u.MXID, err) @@ -91,9 +96,9 @@ func (u *User) Insert() { } func (u *User) Update() { - query := "UPDATE \"user\" SET management_room=$1 WHERE mxid=$2;" + query := "UPDATE \"user\" SET management_room=$1 AND space_room=$2 WHERE mxid=$3;" - _, err := u.db.Exec(query, u.ManagementRoom, u.MXID) + _, err := u.db.Exec(query, u.ManagementRoom, u.SpaceRoom, u.MXID) if err != nil { u.log.Warnfln("Failed to update %q: %v", u.MXID, err) diff --git a/database/userquery.go b/database/userquery.go index edea870..1ece391 100644 --- a/database/userquery.go +++ b/database/userquery.go @@ -35,7 +35,7 @@ func (uq *UserQuery) New() *User { } func (uq *UserQuery) GetByMXID(userID id.UserID) *User { - query := `SELECT mxid, management_room FROM "user" WHERE mxid=$1` + query := `SELECT mxid, management_room, space_room FROM "user" WHERE mxid=$1` row := uq.db.QueryRow(query, userID) if row == nil { return nil @@ -45,7 +45,7 @@ func (uq *UserQuery) GetByMXID(userID id.UserID) *User { } func (uq *UserQuery) GetBySlackID(teamID, userID string) *User { - query := `SELECT u.mxid, u.management_room FROM "user" u` + + query := `SELECT u.mxid, u.management_room, u.space_room FROM "user" u` + ` INNER JOIN user_team ut ON u.mxid = ut.mxid` + ` WHERE ut.team_id=$1 AND ut.slack_id=$2` row := uq.db.QueryRow(query, teamID, userID) @@ -57,7 +57,7 @@ func (uq *UserQuery) GetBySlackID(teamID, userID string) *User { } func (uq *UserQuery) GetAll() []*User { - rows, err := uq.db.Query(`SELECT mxid, management_room FROM "user"`) + rows, err := uq.db.Query(`SELECT mxid, management_room, space_room FROM "user"`) if err != nil || rows == nil { return nil } diff --git a/database/userteam.go b/database/userteam.go index 244b5c7..5b57710 100644 --- a/database/userteam.go +++ b/database/userteam.go @@ -40,7 +40,7 @@ func (utq *UserTeamQuery) New() *UserTeam { } } -const userTeamSelect = "SELECT ut.mxid, ut.slack_email, ut.slack_id, ut.team_name, ut.team_id, ut.token, ut.cookie_token FROM user_team ut " +const userTeamSelect = "SELECT ut.mxid, ut.slack_email, ut.slack_id, ut.team_name, ut.team_id, ut.token, ut.cookie_token, ut.in_space FROM user_team ut " func (utq *UserTeamQuery) GetBySlackDomain(userID id.UserID, email, domain string) *UserTeam { query := userTeamSelect + "WHERE ut.mxid=$1 AND ut.slack_email=$2 AND ut.team_id=(SELECT team_id FROM team_info WHERE team_domain=$3)" @@ -129,6 +129,8 @@ type UserTeam struct { Token string CookieToken string + InSpace bool + Client *slack.Client RTM *slack.RTM } @@ -157,7 +159,7 @@ func (ut *UserTeam) Scan(row dbutil.Scannable) *UserTeam { var token sql.NullString var cookieToken sql.NullString - err := row.Scan(&ut.Key.MXID, &ut.SlackEmail, &ut.Key.SlackID, &ut.TeamName, &ut.Key.TeamID, &token, &cookieToken) + err := row.Scan(&ut.Key.MXID, &ut.SlackEmail, &ut.Key.SlackID, &ut.TeamName, &ut.Key.TeamID, &token, &cookieToken, &ut.InSpace) if err != nil { if err != sql.ErrNoRows { ut.log.Errorln("Database scan failed:", err) @@ -178,16 +180,16 @@ func (ut *UserTeam) Scan(row dbutil.Scannable) *UserTeam { func (ut *UserTeam) Upsert() { query := ` - INSERT INTO user_team (mxid, slack_email, slack_id, team_name, team_id, token, cookie_token) - VALUES ($1, $2, $3, $4, $5, $6, $7) + INSERT INTO user_team (mxid, slack_email, slack_id, team_name, team_id, token, cookie_token, in_space) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT (mxid, slack_id, team_id) DO UPDATE - SET slack_email=excluded.slack_email, team_name=excluded.team_name, token=excluded.token, cookie_token=excluded.cookie_token + SET slack_email=excluded.slack_email, team_name=excluded.team_name, token=excluded.token, cookie_token=excluded.cookie_token, in_space=excluded.in_space ` token := sqlNullString(ut.Token) cookieToken := sqlNullString(ut.CookieToken) - _, err := ut.db.Exec(query, ut.Key.MXID, ut.SlackEmail, ut.Key.SlackID, ut.TeamName, ut.Key.TeamID, token, cookieToken) + _, err := ut.db.Exec(query, ut.Key.MXID, ut.SlackEmail, ut.Key.SlackID, ut.TeamName, ut.Key.TeamID, token, cookieToken, ut.InSpace) if err != nil { ut.log.Warnfln("Failed to upsert %s/%s/%s: %v", ut.Key.MXID, ut.Key.SlackID, ut.Key.TeamID, err) diff --git a/main.go b/main.go index daf0ad2..c3c7e20 100644 --- a/main.go +++ b/main.go @@ -65,6 +65,10 @@ type SlackBridge struct { portalsByID map[database.PortalKey]*Portal portalsLock sync.Mutex + teamsByMXID map[id.RoomID]*Team + teamsByID map[string]*Team + teamsLock sync.Mutex + puppets map[string]*Puppet puppetsByCustomMXID map[id.UserID]*Puppet puppetsLock sync.Mutex @@ -156,6 +160,9 @@ func main() { portalsByMXID: make(map[id.RoomID]*Portal), portalsByID: make(map[database.PortalKey]*Portal), + teamsByMXID: make(map[id.RoomID]*Team), + teamsByID: make(map[string]*Team), + puppets: make(map[string]*Puppet), puppetsByCustomMXID: make(map[id.UserID]*Puppet), } diff --git a/portal.go b/portal.go index 6b2b959..e9b3977 100644 --- a/portal.go +++ b/portal.go @@ -1019,7 +1019,7 @@ func (portal *Portal) sendSlackRepeatTyping() { func (portal *Portal) HandleMatrixLeave(brSender bridge.User) { portal.log.Debugln("User left private chat portal, cleaning up and deleting...") portal.delete() - portal.cleanup(false) + portal.bridge.cleanupRoom(portal.MainIntent(), portal.MXID, false, portal.log) // TODO: figure out how to close a dm from the API. @@ -1059,39 +1059,23 @@ func (portal *Portal) cleanupIfEmpty() { if len(users) == 0 { portal.log.Infoln("Room seems to be empty, cleaning up...") - portal.cleanup(false) portal.delete() - } -} - -func (portal *Portal) cleanup(puppetsOnly bool) { - if portal.MXID == "" { - return - } - - if portal.bridge.SpecVersions.UnstableFeatures["com.beeper.room_yeeting"] { - intent := portal.MainIntent() - err := intent.BeeperDeleteRoom(portal.MXID) - if err == nil || errors.Is(err, mautrix.MNotFound) { - return - } - portal.log.Warnfln("Failed to delete %s using hungryserv yeet endpoint, falling back to normal behavior: %v", portal.MXID, err) - } - - if portal.IsPrivateChat() { - _, err := portal.MainIntent().LeaveRoom(portal.MXID) - if err != nil { - portal.log.Warnln("Failed to leave private chat portal with main intent:", err) + if portal.bridge.SpecVersions.UnstableFeatures["com.beeper.room_yeeting"] { + intent := portal.MainIntent() + err := intent.BeeperDeleteRoom(portal.MXID) + if err == nil || errors.Is(err, mautrix.MNotFound) { + return + } + portal.log.Warnfln("Failed to delete %s using hungryserv yeet endpoint, falling back to normal behavior: %v", portal.MXID, err) } - - return + portal.bridge.cleanupRoom(portal.MainIntent(), portal.MXID, false, portal.log) } +} - intent := portal.MainIntent() - members, err := intent.JoinedMembers(portal.MXID) +func (br *SlackBridge) cleanupRoom(intent *appservice.IntentAPI, mxid id.RoomID, puppetsOnly bool, log log.Logger) { + members, err := intent.JoinedMembers(mxid) if err != nil { - portal.log.Errorln("Failed to get portal members for cleanup:", err) - + log.Errorln("Failed to get portal members for cleanup:", err) return } @@ -1100,23 +1084,23 @@ func (portal *Portal) cleanup(puppetsOnly bool) { continue } - puppet := portal.bridge.GetPuppetByMXID(member) + puppet := br.GetPuppetByMXID(member) if puppet != nil { - _, err = puppet.DefaultIntent().LeaveRoom(portal.MXID) + _, err = puppet.DefaultIntent().LeaveRoom(mxid) if err != nil { - portal.log.Errorln("Error leaving as puppet while cleaning up portal:", err) + log.Errorln("Error leaving as puppet while cleaning up portal:", err) } } else if !puppetsOnly { - _, err = intent.KickUser(portal.MXID, &mautrix.ReqKickUser{UserID: member, Reason: "Deleting portal"}) + _, err = intent.KickUser(mxid, &mautrix.ReqKickUser{UserID: member, Reason: "Deleting portal"}) if err != nil { - portal.log.Errorln("Error kicking user while cleaning up portal:", err) + log.Errorln("Error kicking user while cleaning up portal:", err) } } } - _, err = intent.LeaveRoom(portal.MXID) + _, err = intent.LeaveRoom(mxid) if err != nil { - portal.log.Errorln("Error leaving with main intent while cleaning up portal:", err) + log.Errorln("Error leaving with main intent while cleaning up portal:", err) } } diff --git a/teamportal.go b/teamportal.go new file mode 100644 index 0000000..38027d2 --- /dev/null +++ b/teamportal.go @@ -0,0 +1,327 @@ +package main + +import ( + "errors" + "fmt" + "sync" + + "github.com/slack-go/slack" + + "go.mau.fi/mautrix-slack/database" + log "maunium.net/go/maulogger/v2" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type Team struct { + *database.TeamInfo + + bridge *SlackBridge + log log.Logger + + roomCreateLock sync.Mutex +} + +func (br *SlackBridge) loadTeam(dbTeam *database.TeamInfo, id string, createIfNotExist bool) *Team { + if dbTeam == nil { + if id == "" || !createIfNotExist { + return nil + } + + dbTeam = br.DB.TeamInfo.New() + dbTeam.TeamID = id + dbTeam.Upsert() + } + + team := br.NewTeam(dbTeam) + + br.teamsByID[team.TeamID] = team + if team.SpaceRoom != "" { + br.teamsByMXID[team.SpaceRoom] = team + } + + return team +} + +func (br *SlackBridge) GetTeamByMXID(mxid id.RoomID) *Team { + br.teamsLock.Lock() + defer br.teamsLock.Unlock() + + portal, ok := br.teamsByMXID[mxid] + if !ok { + return br.loadTeam(br.DB.TeamInfo.GetByMXID(mxid), "", false) + } + + return portal +} + +func (br *SlackBridge) GetTeamByID(id string, createIfNotExist bool) *Team { + br.teamsLock.Lock() + defer br.teamsLock.Unlock() + + team, ok := br.teamsByID[id] + if !ok { + return br.loadTeam(br.DB.TeamInfo.GetBySlackTeam(id), id, createIfNotExist) + } + + return team +} + +// func (br *SlackBridge) GetAllTeams() []*Team { +// return br.dbTeamsToTeams(br.DB.TeamInfo.GetAll()) +// } + +// func (br *SlackBridge) dbTeamsToTeams(dbTeams []*database.TeamInfo) []*Team { +// br.teamsLock.Lock() +// defer br.teamsLock.Unlock() + +// output := make([]*Team, len(dbTeams)) +// for index, dbTeam := range dbTeams { +// if dbTeam == nil { +// continue +// } + +// team, ok := br.teamsByID[dbTeam.TeamID] +// if !ok { +// team = br.loadTeam(dbTeam, "", false) +// } + +// output[index] = team +// } + +// return output +// } + +func (br *SlackBridge) NewTeam(dbTeam *database.TeamInfo) *Team { + team := &Team{ + TeamInfo: dbTeam, + bridge: br, + log: br.Log.Sub(fmt.Sprintf("Team/%s", dbTeam.TeamID)), + } + + return team +} + +func (team *Team) getBridgeInfo() (string, event.BridgeEventContent) { + bridgeInfo := event.BridgeEventContent{ + BridgeBot: team.bridge.Bot.UserID, + Creator: team.bridge.Bot.UserID, + Protocol: event.BridgeInfoSection{ + ID: "slackgo", + DisplayName: "Slack", + AvatarURL: team.bridge.Config.AppService.Bot.ParsedAvatar.CUString(), + ExternalURL: "https://slack.com/", + }, + Channel: event.BridgeInfoSection{ + ID: team.TeamID, + DisplayName: team.TeamName, + AvatarURL: team.AvatarUrl.CUString(), + }, + } + bridgeInfoStateKey := fmt.Sprintf("fi.mau.slack://slackgo/%s", team.TeamID) + return bridgeInfoStateKey, bridgeInfo +} + +func (team *Team) UpdateBridgeInfo() { + if len(team.SpaceRoom) == 0 { + team.log.Debugln("Not updating bridge info: no Matrix room created") + return + } + team.log.Debugln("Updating bridge info...") + stateKey, content := team.getBridgeInfo() + _, err := team.bridge.Bot.SendStateEvent(team.SpaceRoom, event.StateBridge, stateKey, content) + if err != nil { + team.log.Warnln("Failed to update m.bridge:", err) + } + // TODO remove this once https://github.com/matrix-org/matrix-doc/pull/2346 is in spec + _, err = team.bridge.Bot.SendStateEvent(team.SpaceRoom, event.StateHalfShotBridge, stateKey, content) + if err != nil { + team.log.Warnln("Failed to update uk.half-shot.bridge:", err) + } +} + +func (team *Team) CreateMatrixRoom(user *User, meta *slack.TeamInfo) error { + team.roomCreateLock.Lock() + defer team.roomCreateLock.Unlock() + if team.SpaceRoom != "" { + return nil + } + team.log.Infoln("Creating Matrix room for team") + team.UpdateInfo(user, meta) + + bridgeInfoStateKey, bridgeInfo := team.getBridgeInfo() + + initialState := []*event.Event{{ + Type: event.StateBridge, + Content: event.Content{Parsed: bridgeInfo}, + StateKey: &bridgeInfoStateKey, + }, { + // TODO remove this once https://github.com/matrix-org/matrix-doc/pull/2346 is in spec + Type: event.StateHalfShotBridge, + Content: event.Content{Parsed: bridgeInfo}, + StateKey: &bridgeInfoStateKey, + }} + + if !team.AvatarUrl.IsEmpty() { + initialState = append(initialState, &event.Event{ + Type: event.StateRoomAvatar, + Content: event.Content{Parsed: &event.RoomAvatarEventContent{ + URL: team.AvatarUrl, + }}, + }) + } + + creationContent := map[string]interface{}{ + "type": event.RoomTypeSpace, + } + if !team.bridge.Config.Bridge.FederateRooms { + creationContent["m.federate"] = false + } + + resp, err := team.bridge.Bot.CreateRoom(&mautrix.ReqCreateRoom{ + Visibility: "private", + Name: team.TeamName, + Preset: "private_chat", + InitialState: initialState, + CreationContent: creationContent, + }) + if err != nil { + team.log.Warnln("Failed to create room:", err) + return err + } + + team.SpaceRoom = resp.RoomID + team.NameSet = true + team.AvatarSet = !team.AvatarUrl.IsEmpty() + team.Upsert() + team.bridge.teamsLock.Lock() + team.bridge.teamsByMXID[team.SpaceRoom] = team + team.bridge.teamsLock.Unlock() + team.log.Infoln("Matrix room created:", team.SpaceRoom) + + user.ensureInvited(nil, team.SpaceRoom, false) + + return nil +} + +func (team *Team) UpdateInfo(source *User, meta *slack.TeamInfo) (changed bool) { + changed = team.UpdateName(meta) || changed + changed = team.UpdateAvatar(meta) || changed + if team.TeamDomain != meta.Domain { + team.TeamDomain = meta.Domain + changed = true + } + if team.TeamUrl != meta.URL { + team.TeamUrl = meta.URL + changed = true + } + if changed { + team.UpdateBridgeInfo() + team.Upsert() + } + return +} + +func (team *Team) UpdateName(meta *slack.TeamInfo) (changed bool) { + if team.TeamName != meta.Name { + team.log.Debugfln("Updating name %q -> %q", team.TeamName, meta.Name) + team.TeamName = meta.Name + changed = true + } + if team.SpaceRoom != "" { + _, err := team.bridge.Bot.SetRoomName(team.SpaceRoom, team.TeamName) + if err != nil { + team.log.Warnln("Failed to update room name: %s", err) + } else { + team.NameSet = true + } + } + return +} + +func (team *Team) UpdateAvatar(meta *slack.TeamInfo) (changed bool) { + if meta.Icon["image_default"] != nil && meta.Icon["image_default"] == true && team.Avatar != "" { + team.Avatar = "" + team.AvatarUrl = id.MustParseContentURI("") + changed = true + } else if meta.Icon["image_default"] != nil && meta.Icon["image_default"] == false && meta.Icon["image_230"] != nil && team.Avatar != meta.Icon["image_230"] { + avatar, err := uploadPlainFile(team.bridge.AS.BotIntent(), meta.Icon["image_230"].(string)) + if err != nil { + team.log.Warnfln("Error uploading new team avatar for team %s: %v", team.TeamID, err) + } else { + team.Avatar = meta.Icon["image_230"].(string) + team.AvatarUrl = avatar + changed = true + } + } + if team.SpaceRoom != "" { + _, err := team.bridge.Bot.SetRoomAvatar(team.SpaceRoom, team.AvatarUrl) + if err != nil { + team.log.Warnln("Failed to update room avatar:", err) + } else { + team.AvatarSet = true + } + } + return +} + +func (team *Team) cleanup() { + if team.SpaceRoom == "" { + return + } + intent := team.bridge.Bot + if team.bridge.SpecVersions.UnstableFeatures["com.beeper.room_yeeting"] { + err := intent.BeeperDeleteRoom(team.SpaceRoom) + if err == nil || errors.Is(err, mautrix.MNotFound) { + return + } + team.log.Warnfln("Failed to delete %s using hungryserv yeet endpoint, falling back to normal behavior: %v", team.SpaceRoom, err) + } + team.bridge.cleanupRoom(intent, team.SpaceRoom, false, team.log) +} + +func (team *Team) RemoveMXID() { + team.bridge.teamsLock.Lock() + defer team.bridge.teamsLock.Unlock() + if team.SpaceRoom == "" { + return + } + delete(team.bridge.teamsByMXID, team.SpaceRoom) + team.SpaceRoom = "" + team.AvatarSet = false + team.NameSet = false + team.Upsert() +} + +// func (team *Team) Delete() { +// team.TeamInfo.Delete() +// team.bridge.teamsLock.Lock() +// delete(team.bridge.teamsByID, team.TeamID) +// if team.SpaceRoom != "" { +// delete(team.bridge.teamsByMXID, team.SpaceRoom) +// } +// team.bridge.teamsLock.Unlock() + +// } + +func (team *Team) addPortalToTeam(portal *database.Portal, isInSpace bool) bool { + if len(team.SpaceRoom) == 0 { + team.log.Errorln("Tried to add portal to space that has no matrix ID") + return false + } + + if len(portal.MXID) > 0 && !isInSpace { + _, err := team.bridge.Bot.SendStateEvent(team.SpaceRoom, event.StateSpaceChild, portal.MXID.String(), &event.SpaceChildEventContent{ + Via: []string{team.bridge.AS.HomeserverDomain}, + }) + if err != nil { + team.log.Errorfln("Failed to add portal %s to team space", portal.MXID) + } else { + isInSpace = true + } + } + + return isInSpace +} diff --git a/user.go b/user.go index fd89077..66d6f84 100644 --- a/user.go +++ b/user.go @@ -56,6 +56,8 @@ type User struct { BridgeStates map[string]*bridge.BridgeStateQueue PermissionLevel bridgeconfig.PermissionLevel + + spaceCreateLock sync.Mutex } func (user *User) GetPermissionLevel() bridgeconfig.PermissionLevel { @@ -555,7 +557,7 @@ func (user *User) isChannelOrOpenIM(channel *slack.Channel) bool { } } -func (user *User) SyncPortals(userTeam *database.UserTeam, force bool) error { +func (user *User) SyncPortals(team *Team, userTeam *database.UserTeam, force bool) error { channelInfo := map[string]slack.Channel{} if !strings.HasPrefix(userTeam.Token, "xoxs") { @@ -614,6 +616,8 @@ func (user *User) SyncPortals(userTeam *database.UserTeam, force bool) error { } else { portal.CreateMatrixRoom(user, userTeam, &channel, true) } + portal.InSpace = team.addPortalToTeam(portal.Portal, portal.InSpace) + portal.Update(nil) // Delete already handled ones from the map delete(channelInfo, dbPortal.Key.ChannelID) } @@ -628,6 +632,8 @@ func (user *User) SyncPortals(userTeam *database.UserTeam, force bool) error { } else { portal.CreateMatrixRoom(user, userTeam, &channel, true) } + portal.InSpace = team.addPortalToTeam(portal.Portal, portal.InSpace) + portal.Update(nil) } return nil @@ -635,46 +641,47 @@ func (user *User) SyncPortals(userTeam *database.UserTeam, force bool) error { func (user *User) UpdateTeam(userTeam *database.UserTeam, force bool) error { user.log.Debugfln("Updating team info for team %s", userTeam.Key.TeamID) - currentTeamInfo := user.bridge.DB.TeamInfo.GetBySlackTeam(userTeam.Key.TeamID) - if currentTeamInfo == nil { - currentTeamInfo = user.bridge.DB.TeamInfo.New() - currentTeamInfo.TeamID = userTeam.Key.TeamID - } + team := user.bridge.GetTeamByID(userTeam.Key.TeamID, true) teamInfo, err := userTeam.Client.GetTeamInfo() if err != nil { - user.log.Errorfln("Error fetching info for team %s: %v", userTeam.Key.TeamID, err) - return err + user.log.Errorln("Failed to fetch team info ", userTeam.Key.TeamID) + } + + var changed bool + if team.SpaceRoom == "" { + err = team.CreateMatrixRoom(user, teamInfo) } - changed := false - if currentTeamInfo.TeamName != teamInfo.Name { - currentTeamInfo.TeamName = teamInfo.Name + if team.TeamName != teamInfo.Name { + team.TeamName = teamInfo.Name changed = true } - if currentTeamInfo.TeamDomain != teamInfo.Domain { - currentTeamInfo.TeamDomain = teamInfo.Domain + if team.TeamDomain != teamInfo.Domain { + team.TeamDomain = teamInfo.Domain changed = true } - if currentTeamInfo.TeamUrl != teamInfo.URL { - currentTeamInfo.TeamUrl = teamInfo.URL + if team.TeamUrl != teamInfo.URL { + team.TeamUrl = teamInfo.URL changed = true } - if teamInfo.Icon["image_default"] != nil && teamInfo.Icon["image_default"] == true && currentTeamInfo.Avatar != "" { - currentTeamInfo.Avatar = "" - currentTeamInfo.AvatarUrl = id.MustParseContentURI("") + if teamInfo.Icon["image_default"] != nil && teamInfo.Icon["image_default"] == true && team.Avatar != "" { + team.Avatar = "" + team.AvatarUrl = id.MustParseContentURI("") changed = true - } else if teamInfo.Icon["image_default"] != nil && teamInfo.Icon["image_default"] == false && teamInfo.Icon["image_230"] != nil && currentTeamInfo.Avatar != teamInfo.Icon["image_230"] { + } else if teamInfo.Icon["image_default"] != nil && teamInfo.Icon["image_default"] == false && teamInfo.Icon["image_230"] != nil && team.Avatar != teamInfo.Icon["image_230"] { avatar, err := uploadPlainFile(user.bridge.AS.BotIntent(), teamInfo.Icon["image_230"].(string)) if err != nil { user.log.Warnfln("Error uploading new team avatar for team %s: %v", userTeam.Key.TeamID, err) } else { - currentTeamInfo.Avatar = teamInfo.Icon["image_230"].(string) - currentTeamInfo.AvatarUrl = avatar + team.Avatar = teamInfo.Icon["image_230"].(string) + team.AvatarUrl = avatar changed = true } + } else { + changed = team.UpdateInfo(user, teamInfo) } - currentTeamInfo.Upsert() + team.Upsert() emojis, err := userTeam.Client.GetEmoji() if err != nil { @@ -695,7 +702,12 @@ func (user *User) UpdateTeam(userTeam *database.UserTeam, force bool) error { for _, puppet := range puppets { puppet.UpdateInfo(userTeam, false, nil) } - return user.SyncPortals(userTeam, changed || force) + + inSpace := user.addTeamToSpace(team.TeamInfo, userTeam.InSpace) + userTeam.InSpace = inSpace + userTeam.Upsert() + + return user.SyncPortals(team, userTeam, changed || force) } func (user *User) Connect() error { @@ -909,3 +921,91 @@ func (user *User) updateChatMute(portal *Portal, muted bool) { user.log.Warnfln("Failed to update push rule for %s through double puppet: %v", portal.MXID, err) } } + +func (user *User) getSpaceRoom(ptr *id.RoomID, name, topic string, parent id.RoomID) id.RoomID { + if len(*ptr) > 0 { + return *ptr + } + user.spaceCreateLock.Lock() + defer user.spaceCreateLock.Unlock() + if len(*ptr) > 0 { + return *ptr + } + + initialState := []*event.Event{{ + Type: event.StateRoomAvatar, + Content: event.Content{ + Parsed: &event.RoomAvatarEventContent{ + URL: user.bridge.Config.AppService.Bot.ParsedAvatar, + }, + }, + }} + + if parent != "" { + parentIDStr := parent.String() + initialState = append(initialState, &event.Event{ + Type: event.StateSpaceParent, + StateKey: &parentIDStr, + Content: event.Content{ + Parsed: &event.SpaceParentEventContent{ + Canonical: true, + Via: []string{user.bridge.AS.HomeserverDomain}, + }, + }, + }) + } + + resp, err := user.bridge.Bot.CreateRoom(&mautrix.ReqCreateRoom{ + Visibility: "private", + Name: name, + Topic: topic, + InitialState: initialState, + CreationContent: map[string]interface{}{ + "type": event.RoomTypeSpace, + }, + PowerLevelOverride: &event.PowerLevelsEventContent{ + Users: map[id.UserID]int{ + user.bridge.Bot.UserID: 9001, + user.MXID: 50, + }, + }, + }) + + if err != nil { + user.log.Error("Failed to auto-create space room") + } else { + *ptr = resp.RoomID + user.Update() + user.ensureInvited(nil, *ptr, false) + + if parent != "" { + _, err = user.bridge.Bot.SendStateEvent(parent, event.StateSpaceChild, resp.RoomID.String(), &event.SpaceChildEventContent{ + Via: []string{user.bridge.AS.HomeserverDomain}, + Order: " 0000", + }) + if err != nil { + user.log.Error("Failed to add created space room to parent space") + } + } + } + return *ptr +} + +func (user *User) GetSpaceRoom() id.RoomID { + return user.getSpaceRoom(&user.SpaceRoom, "Slack", "Your Slack bridged chats", "") +} + +func (user *User) addTeamToSpace(teamInfo *database.TeamInfo, isInSpace bool) bool { + if len(teamInfo.SpaceRoom) > 0 && !isInSpace { + _, err := user.bridge.Bot.SendStateEvent(user.GetSpaceRoom(), event.StateSpaceChild, teamInfo.SpaceRoom.String(), &event.SpaceChildEventContent{ + Via: []string{user.bridge.AS.HomeserverDomain}, + }) + if err != nil { + user.log.Errorfln("Failed to add team space %s to user space", teamInfo.SpaceRoom) + } else { + isInSpace = true + } + } + + return isInSpace +}