From 77fac74c5cf13df42b9253fb091d09636f61146b Mon Sep 17 00:00:00 2001 From: derekwin Date: Tue, 21 Feb 2023 14:30:39 +0800 Subject: [PATCH] add sqlInjectSafe; add respUsersFavoriteCountMap --- README.md | 2 ++ biz/handler/core/feed_server.go | 2 -- kitex_server/interaction.go | 44 ++++++++++++++++++++++++++++++--- kitex_server/userservice.go | 20 ++++++++++++--- kitex_server/videoservice.go | 27 +++++++++++--------- tools/safe/sqlsafe.go | 17 +++++++++++++ tools/safe/sqlsafe_test.go | 16 ++++++++++++ 7 files changed, 108 insertions(+), 20 deletions(-) create mode 100644 tools/safe/sqlsafe.go create mode 100644 tools/safe/sqlsafe_test.go diff --git a/README.md b/README.md index 86e649a..8aa956d 100755 --- a/README.md +++ b/README.md @@ -29,6 +29,8 @@ go get github.com/ClubWeGo/usermicro@latest go get github.com/ClubWeGo/relationmicro@latest +go get github.com/ClubWeGo/favoritemicro@latest + # 说明 diff --git a/biz/handler/core/feed_server.go b/biz/handler/core/feed_server.go index d573193..cabf58c 100755 --- a/biz/handler/core/feed_server.go +++ b/biz/handler/core/feed_server.go @@ -4,7 +4,6 @@ package core import ( "context" - "log" "time" core "github.com/ClubWeGo/douyin/biz/model/core" @@ -53,7 +52,6 @@ func FeedMethod(ctx context.Context, c *app.RequestContext) { // 缓存未命中,去后端调api resultList, nextTime, err := kitex_server.GetFeed(latestTime, currentUserId, 30) if err != nil { - log.Println(err) resp.StatusCode = 1 resp.StatusMsg = &msgFailed c.JSON(consts.StatusOK, resp) diff --git a/kitex_server/interaction.go b/kitex_server/interaction.go index 89ab71b..663b7aa 100644 --- a/kitex_server/interaction.go +++ b/kitex_server/interaction.go @@ -2,6 +2,7 @@ package kitex_server import ( "context" + "sync" "github.com/ClubWeGo/douyin/biz/model/interaction" "github.com/ClubWeGo/douyin/tools/errno" @@ -81,7 +82,44 @@ func CountUserFavorite(ctx context.Context, uid int64) (int64, int64, error) { return res.FavoriteCount, res.FavoritedCount, nil } -// TODO : 传入userId切片,批量查询user对应的favorite, total_favorited -func GetFavoriteCountByUserIdSet(idSet []int64) (favoriteSet, favoritedSet []int64, err error) { - return []int64{}, []int64{}, nil +// 传入userId切片,批量查询user对应的favorite, total_favorited +// map[int64][]int64 [FavoriteCount FavoritedCount] +func GetUsersFavoriteCountMap(idSet []int64, respUsersFavoriteCountMap chan map[int64][]int64, wg *sync.WaitGroup, errChan chan error) { + defer wg.Done() + + res, err := FavoriteClient.UsersFavoriteCountMethod(context.Background(), &favorite.UsersFavoriteCountReq{ + UserIdList: idSet, + }) + if err != nil { + respUsersFavoriteCountMap <- map[int64][]int64{} + errChan <- err + return + } + respUsersFavoriteCountMap <- res.FavoriteCountMap + errChan <- nil +} + +// 传入videoId切片,批量查询video对应的favorite, favorited +// map[int64]int64 FavoriteCount +func GetVideosFavoriteCountMap(idSet []int64, respVideosFavoriteCountMap chan map[int64]int64, wg *sync.WaitGroup, errChan chan error) { + defer wg.Done() + + res, err := FavoriteClient.VideosFavoriteCountMethod(context.Background(), &favorite.VideosFavoriteCountReq{ + VideoIdList: idSet, + }) + if err != nil { + respVideosFavoriteCountMap <- map[int64]int64{} + errChan <- err + return + } + respVideosFavoriteCountMap <- res.FavoriteCountMap + errChan <- nil +} + +// 传入videoId切片和当前用户id,批量查询喜欢情况 +func GetIsFavoriteMap() (idSet []int64, currentUser int64, respIsFavoriteMap chan map[int64]bool, wg *sync.WaitGroup, errChan chan error) { + defer wg.Done() + + // res, err := FavoriteClient.FavoriteRelationMethod(context.Background(), &favorite.FavoriteRelationReq{}) + return } diff --git a/kitex_server/userservice.go b/kitex_server/userservice.go index a2b2fbe..c7aa569 100755 --- a/kitex_server/userservice.go +++ b/kitex_server/userservice.go @@ -100,7 +100,13 @@ func GetUserLatestMap(idSet []int64, currentUser int64, respUserMap chan map[int wgUser.Add(1) go GetRelationMap(idSet, currentUser, respRelationMap, wgUser, respRelationMapError) - // TODO : TotalFavourited, FavoriteCount,传入查询的userId切片,查对应这两个字段的切片,(结果需要携带UserId):从favorite服务 + // 批量查询TotalFavourited, FavoriteCount,传入查询的userId切片 + respUsersFavoriteCountMap := make(chan map[int64][]int64, 1) // [FavoriteCount FavoritedCount] + defer close(respRelationMap) + respUsersFavoriteCountMapError := make(chan error, 1) + defer close(respUsersFavoriteCountMapError) + wgUser.Add(1) + go GetUsersFavoriteCountMap(idSet, respUsersFavoriteCountMap, wgUser, respUsersFavoriteCountMapError) // 等待数据 wgUser.Wait() @@ -124,6 +130,12 @@ func GetUserLatestMap(idSet []int64, currentUser int64, respUserMap chan map[int if err != nil { errSlice = append(errSlice, err) } + + FavoriteCountMap := <-respUsersFavoriteCountMap + err = <-respUsersFavoriteCountMapError + if err != nil { + errSlice = append(errSlice, err) + } // TODO: 其他协程的错误处理 errChan <- errSlice // 错误切片 @@ -139,9 +151,9 @@ func GetUserLatestMap(idSet []int64, currentUser int64, respUserMap chan map[int Avatar: user.Avatar, BackgroundImage: user.BackgroundImage, Signature: user.Signature, - TotalFavourited: "", // TODO: 从获取的数据中拿 - WorkCount: VideoCountMap[id].Count, // 最新的count数据 - FavoriteCount: 0, // TODO: 从获取的数据中拿 + TotalFavourited: strconv.FormatInt(FavoriteCountMap[id][1], 10), // TODO: 从获取的数据中拿 + WorkCount: VideoCountMap[id].Count, // 最新的count数据 + FavoriteCount: FavoriteCountMap[id][0], // TODO: 从获取的数据中拿 } } diff --git a/kitex_server/videoservice.go b/kitex_server/videoservice.go index 839d370..e7cd75d 100755 --- a/kitex_server/videoservice.go +++ b/kitex_server/videoservice.go @@ -52,22 +52,27 @@ func GetVideoLatestMap(idSet []int64, currentUser int64, respVideoMap chan map[i wgVideo := &sync.WaitGroup{} // 本函数子协程的wg // 批量查询视频的 被喜欢数 ,传入视频id的切片,返回对应的FavoriteCount的切片(需携带对应视频id) 从Favorite服务 + respVideosFavoriteCountMap := make(chan map[int64]int64, 1) + defer close(respVideosFavoriteCountMap) + respVideosFavoriteCountMapError := make(chan error, 1) + defer close(respVideosFavoriteCountMapError) + wgVideo.Add(1) + go GetVideosFavoriteCountMap(idSet, respVideosFavoriteCountMap, wgVideo, respVideosFavoriteCountMapError) // 批量查询视频的评论数,传入视频id的切片,返回对应的评论数(需携带对应视频id),从comment服务 // 批量查询 is_favorite, 传入目标视频id切片和currentUser查is_favorite的切片(结果需要携带视频id,douyin里后续需要转成map):从favorite; + GetIsFavoriteMap() // 等待数据 wgVideo.Wait() - // // 处理协程错误 - var errSlice = []error{} // 防止外部设置的chan缓存不够造成阻塞,要求外部设置长度为1的error切片类型 - // err := <-respAuthorMapError - // if err != nil { - // errSlice = append(errSlice, err) - // } - - // // TODO: 其他协程的错误处理 + var errSlice = []error{} + VideosFavoriteCountMap := <-respVideosFavoriteCountMap + err := <-respVideosFavoriteCountMapError + if err != nil { + errSlice = append(errSlice, err) + } errChan <- errSlice // 记录错误的切片,至少应该返回一个空切片,否则chan会阻塞 @@ -75,9 +80,9 @@ func GetVideoLatestMap(idSet []int64, currentUser int64, respVideoMap chan map[i videoLatestMap := make(map[int64]core.Video, len(idSet)) // 视频切片的id是没有重复的 for _, id := range idSet { videoLatestMap[id] = core.Video{ // 视频id对应的Video存储查到的关键字段 - FavoriteCount: 0, // TODO:从拿到的MAP数据更新 - CommentCount: 0, // TODO:从拿到的MAP数据更新 - IsFavorite: false, // TODO:从拿到的MAP数据更新 + FavoriteCount: VideosFavoriteCountMap[id], // TODO:从拿到的MAP数据更新 + CommentCount: 0, // TODO:从拿到的MAP数据更新 + IsFavorite: false, // TODO:从拿到的MAP数据更新 } } respVideoMap <- videoLatestMap // 返回数据 diff --git a/tools/safe/sqlsafe.go b/tools/safe/sqlsafe.go new file mode 100644 index 0000000..b9715ad --- /dev/null +++ b/tools/safe/sqlsafe.go @@ -0,0 +1,17 @@ +package safe + +import ( + "errors" + "regexp" +) + +// ref: https://blog.csdn.net/qq_40127376/article/details/108516561 +var sqlInjectReg = regexp.MustCompile(`(.*\=.*\-\-.*)|(.*(\+|\-).*)|(.*\w+(%|\$|#|&)\w+.*)|(.*\|\|.*)|(.*\s+(and|or)\s+.*)|(.*\b(select|update|union|and|or|delete|insert|trancate|char|into|substr|ascii|declare|exec|count|master|into|drop|execute)\b.*)`) + +func SqlInjectCheck(input string) error { + reg := sqlInjectReg.FindAllString(input, 1) // 匹配一个就行 + if reg != nil { + return errors.New("输入存在非法字段") + } + return nil +} diff --git a/tools/safe/sqlsafe_test.go b/tools/safe/sqlsafe_test.go new file mode 100644 index 0000000..6426cb3 --- /dev/null +++ b/tools/safe/sqlsafe_test.go @@ -0,0 +1,16 @@ +package safe + +import ( + "log" + "testing" +) + +func TestSqlInjectCheck(t *testing.T) { + str1 := "select 1" + err := SqlInjectCheck(str1) + if err != nil { + log.Println(err) + return + } + log.Println("no") +}