diff --git a/internal/crawlerdetect/baidu_strategy.go b/internal/crawlerdetect/baidu_strategy.go new file mode 100644 index 0000000..08a5399 --- /dev/null +++ b/internal/crawlerdetect/baidu_strategy.go @@ -0,0 +1,25 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package crawlerdetect + +type BaiduStrategy struct { + *UniversalStrategy +} + +func NewBaiduStrategy() *BaiduStrategy { + return &BaiduStrategy{ + UniversalStrategy: NewUniversalStrategy([]string{"baidu.com", "baidu.jp"}), + } +} diff --git a/internal/crawlerdetect/baidu_strategy_test.go b/internal/crawlerdetect/baidu_strategy_test.go new file mode 100644 index 0000000..e7f7d9a --- /dev/null +++ b/internal/crawlerdetect/baidu_strategy_test.go @@ -0,0 +1,67 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package crawlerdetect + +import ( + "errors" + "log" + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBaiduStrategy(t *testing.T) { + s := NewBaiduStrategy() + require.NotNil(t, s) + testCases := []struct { + name string + ip string + matched bool + errFunc require.ErrorAssertionFunc + }{ + { + name: "无效 ip", + ip: "256.0.0.0", + matched: false, + errFunc: func(t require.TestingT, err error, i ...interface{}) { + var dnsError *net.DNSError + if !errors.As(err, &dnsError) { + log.Fatal(err) + } + }, + }, + { + name: "非百度 ip", + ip: "166.249.90.77", + matched: false, + }, + { + name: "百度 ip", + ip: "111.206.198.69", + matched: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + m, err := s.CheckCrawler(tc.ip) + if err != nil { + tc.errFunc(t, err) + } + require.Equal(t, tc.matched, m) + }) + } +} diff --git a/internal/crawlerdetect/bing_strategy.go b/internal/crawlerdetect/bing_strategy.go new file mode 100644 index 0000000..be083ef --- /dev/null +++ b/internal/crawlerdetect/bing_strategy.go @@ -0,0 +1,25 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package crawlerdetect + +type BingStrategy struct { + *UniversalStrategy +} + +func NewBingStrategy() *BingStrategy { + return &BingStrategy{ + UniversalStrategy: NewUniversalStrategy([]string{"search.msn.com"}), + } +} diff --git a/internal/crawlerdetect/bing_strategy_test.go b/internal/crawlerdetect/bing_strategy_test.go new file mode 100644 index 0000000..5b8476a --- /dev/null +++ b/internal/crawlerdetect/bing_strategy_test.go @@ -0,0 +1,67 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package crawlerdetect + +import ( + "errors" + "log" + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBingStrategy(t *testing.T) { + s := NewBingStrategy() + require.NotNil(t, s) + testCases := []struct { + name string + ip string + matched bool + errFunc require.ErrorAssertionFunc + }{ + { + name: "无效 ip", + ip: "256.0.0.0", + matched: false, + errFunc: func(t require.TestingT, err error, i ...interface{}) { + var dnsError *net.DNSError + if !errors.As(err, &dnsError) { + log.Fatal(err) + } + }, + }, + { + name: "非必应 ip", + ip: "166.249.90.77", + matched: false, + }, + { + name: "必应 ip", + ip: "157.55.39.1", + matched: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + m, err := s.CheckCrawler(tc.ip) + if err != nil { + tc.errFunc(t, err) + } + require.Equal(t, tc.matched, m) + }) + } +} diff --git a/internal/crawlerdetect/crawler_detector.go b/internal/crawlerdetect/crawler_detector.go new file mode 100644 index 0000000..0a9b323 --- /dev/null +++ b/internal/crawlerdetect/crawler_detector.go @@ -0,0 +1,93 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package crawlerdetect + +import ( + "net" + "slices" + "strings" +) + +const ( + Baidu = "baidu" + Bing = "bing" + Google = "google" + Sogou = "sogou" +) + +var strategyMap = map[string]Strategy{ + Baidu: NewBaiduStrategy(), + Bing: NewBingStrategy(), + Google: NewGoogleStrategy(), + Sogou: NewSoGouStrategy(), +} + +type Strategy interface { + CheckCrawler(ip string) (bool, error) +} + +type UniversalStrategy struct { + Hosts []string +} + +func NewUniversalStrategy(hosts []string) *UniversalStrategy { + return &UniversalStrategy{ + Hosts: hosts, + } +} + +func (s *UniversalStrategy) CheckCrawler(ip string) (bool, error) { + names, err := net.LookupAddr(ip) + if err != nil { + return false, err + } + if len(names) == 0 { + return false, nil + } + + name, matched := s.matchHost(names) + if !matched { + return false, nil + } + + ips, err := net.LookupIP(name) + if err != nil { + return false, err + } + if slices.ContainsFunc(ips, func(netIp net.IP) bool { + return netIp.String() == ip + }) { + return true, nil + } + + return false, nil +} + +func (s *UniversalStrategy) matchHost(names []string) (string, bool) { + var matchedName string + return matchedName, slices.ContainsFunc(s.Hosts, func(host string) bool { + return slices.ContainsFunc(names, func(name string) bool { + if strings.Contains(name, host) { + matchedName = name + return true + } + return false + }) + }) +} + +func NewCrawlerDetector(crawler string) Strategy { + return strategyMap[crawler] +} diff --git a/internal/crawlerdetect/google_strategy.go b/internal/crawlerdetect/google_strategy.go new file mode 100644 index 0000000..e8cecd0 --- /dev/null +++ b/internal/crawlerdetect/google_strategy.go @@ -0,0 +1,25 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package crawlerdetect + +type GoogleStrategy struct { + *UniversalStrategy +} + +func NewGoogleStrategy() *GoogleStrategy { + return &GoogleStrategy{ + UniversalStrategy: NewUniversalStrategy([]string{"googlebot.com", "google.com", "googleusercontent.com"}), + } +} diff --git a/internal/crawlerdetect/google_strategy_test.go b/internal/crawlerdetect/google_strategy_test.go new file mode 100644 index 0000000..8556f2d --- /dev/null +++ b/internal/crawlerdetect/google_strategy_test.go @@ -0,0 +1,67 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package crawlerdetect + +import ( + "errors" + "log" + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGoogleStrategy(t *testing.T) { + s := NewGoogleStrategy() + require.NotNil(t, s) + testCases := []struct { + name string + ip string + matched bool + errFunc require.ErrorAssertionFunc + }{ + { + name: "无效 ip", + ip: "256.0.0.0", + matched: false, + errFunc: func(t require.TestingT, err error, i ...interface{}) { + var dnsError *net.DNSError + if !errors.As(err, &dnsError) { + log.Fatal(err) + } + }, + }, + { + name: "非谷歌 ip", + ip: "166.249.90.77", + matched: false, + }, + { + name: "谷歌 ip", + ip: "66.249.90.77", + matched: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + m, err := s.CheckCrawler(tc.ip) + if err != nil { + tc.errFunc(t, err) + } + require.Equal(t, tc.matched, m) + }) + } +} diff --git a/internal/crawlerdetect/sogou_strategy.go b/internal/crawlerdetect/sogou_strategy.go new file mode 100644 index 0000000..202d093 --- /dev/null +++ b/internal/crawlerdetect/sogou_strategy.go @@ -0,0 +1,50 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package crawlerdetect + +import ( + "net" + "slices" + "strings" +) + +type SoGouStrategy struct { + Hosts []string +} + +func NewSoGouStrategy() *SoGouStrategy { + return &SoGouStrategy{ + Hosts: []string{"sogou.com"}, + } +} + +func (s *SoGouStrategy) CheckCrawler(ip string) (bool, error) { + names, err := net.LookupAddr(ip) + if err != nil { + return false, err + } + if len(names) == 0 { + return false, nil + } + return s.matchHost(names), nil +} + +func (s *SoGouStrategy) matchHost(names []string) bool { + return slices.ContainsFunc(s.Hosts, func(host string) bool { + return slices.ContainsFunc(names, func(name string) bool { + return strings.Contains(name, host) + }) + }) +} diff --git a/internal/crawlerdetect/sogou_strategy_test.go b/internal/crawlerdetect/sogou_strategy_test.go new file mode 100644 index 0000000..16d9ed5 --- /dev/null +++ b/internal/crawlerdetect/sogou_strategy_test.go @@ -0,0 +1,67 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package crawlerdetect + +import ( + "errors" + "log" + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSoGouStrategy(t *testing.T) { + s := NewSoGouStrategy() + require.NotNil(t, s) + testCases := []struct { + name string + ip string + matched bool + errFunc require.ErrorAssertionFunc + }{ + { + name: "无效 ip", + ip: "256.0.0.0", + matched: false, + errFunc: func(t require.TestingT, err error, i ...interface{}) { + var dnsError *net.DNSError + if !errors.As(err, &dnsError) { + log.Fatal(err) + } + }, + }, + { + name: "非搜狗 ip", + ip: "166.249.90.77", + matched: false, + }, + { + name: "搜狗 ip", + ip: "123.126.113.110", + matched: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + m, err := s.CheckCrawler(tc.ip) + if err != nil { + tc.errFunc(t, err) + } + require.Equal(t, tc.matched, m) + }) + } +} diff --git a/internal/ratelimit/mocks/ratelimit.mock.go b/internal/ratelimit/mocks/ratelimit.mock.go index 01883ae..b7b3b43 100644 --- a/internal/ratelimit/mocks/ratelimit.mock.go +++ b/internal/ratelimit/mocks/ratelimit.mock.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: types.go +// Source: crawler_detector.go // Package limitmocks is a generated GoMock package. package limitmocks diff --git a/middlewares/crawlerdetect/builder.go b/middlewares/crawlerdetect/builder.go new file mode 100644 index 0000000..e02c5fa --- /dev/null +++ b/middlewares/crawlerdetect/builder.go @@ -0,0 +1,118 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package crawlerdetect + +import ( + "log/slog" + "net/http" + "strings" + + "github.com/ecodeclub/ginx/internal/crawlerdetect" + "github.com/gin-gonic/gin" +) + +const Baidu = crawlerdetect.Baidu +const Bing = crawlerdetect.Bing +const Google = crawlerdetect.Google +const Sogou = crawlerdetect.Sogou + +type Builder struct { + crawlersMap map[string]string +} + +func NewBuilder() *Builder { + return &Builder{ + // 常用的 User-Agent 映射 + crawlersMap: map[string]string{ + "Baiduspider": Baidu, + "Baiduspider-render": Baidu, + + "bingbot": Bing, + "adidxbot": Bing, + "MicrosoftPreview": Bing, + + "Googlebot": Google, + "Googlebot-Image": Google, + "Googlebot-News": Google, + "Googlebot-Video": Google, + "Storebot-Google": Google, + "Google-InspectionTool": Google, + "GoogleOther": Google, + "Google-Extended": Google, + + "Sogou web spider": Sogou, + }, + } +} + +// AddUserAgent 添加 user-agent 映射 +// 例如: +// +// map[string][]string{ +// crawlerdetect.Baidu: []string{"NewBaiduUserAgent"}, +// crawlerdetect.Bing: []string{"NewBingUserAgent"}, +// } +func (b *Builder) AddUserAgent(userAgents map[string][]string) *Builder { + for crawler, values := range userAgents { + for _, userAgent := range values { + b.crawlersMap[userAgent] = crawler + } + } + return b +} + +func (b *Builder) RemoveUserAgent(userAgents ...string) *Builder { + for _, userAgent := range userAgents { + delete(b.crawlersMap, userAgent) + } + return b +} + +func (b *Builder) Build() gin.HandlerFunc { + return func(ctx *gin.Context) { + userAgent := ctx.GetHeader("User-Agent") + ip := ctx.ClientIP() + if ip == "" { + slog.ErrorContext(ctx, "crawlerdetect", "error", "ip is empty.") + ctx.AbortWithStatus(http.StatusForbidden) + return + } + crawlerDetector := b.getCrawlerDetector(userAgent) + if crawlerDetector == nil { + ctx.AbortWithStatus(http.StatusForbidden) + return + } + pass, err := crawlerDetector.CheckCrawler(ip) + if err != nil { + slog.ErrorContext(ctx, "crawlerdetect", "error", err.Error()) + ctx.AbortWithStatus(http.StatusInternalServerError) + return + } + if !pass { + ctx.AbortWithStatus(http.StatusForbidden) + return + } + ctx.Next() + } +} + +func (b *Builder) getCrawlerDetector(userAgent string) crawlerdetect.Strategy { + for key, value := range b.crawlersMap { + if strings.Contains(userAgent, key) { + return crawlerdetect.NewCrawlerDetector(value) + } + } + return nil +} diff --git a/middlewares/crawlerdetect/builder_test.go b/middlewares/crawlerdetect/builder_test.go new file mode 100644 index 0000000..51b5fcd --- /dev/null +++ b/middlewares/crawlerdetect/builder_test.go @@ -0,0 +1,274 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package crawlerdetect + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func Test_Builder(t *testing.T) { + testCases := []struct { + name string + + reqBuilder func(t *testing.T) *http.Request + + wantCode int + }{ + { + name: "空 ip", + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/test", nil) + if err != nil { + t.Fatal(err) + } + return req + }, + wantCode: 403, + }, + { + name: "无效 ip", + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/test", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("X-Forwarded-For", "256.0.0.0") + req.Header.Set("User-Agent", "Mozilla/5.0 (compatible; Baiduspider/2.0; +http://www.baidu.com/search/spider.html)") + return req + }, + wantCode: 500, + }, + { + name: "用户", + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/test", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("X-Forwarded-For", "155.206.198.69") + req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/89.0.4389.82 Safari/537.36") + return req + }, + wantCode: 403, + }, + { + name: "百度 - Baiduspider", + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/test", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("User-Agent", "Mozilla/5.0 (compatible; Baiduspider/2.0; +http://www.baidu.com/search/spider.html)") + req.Header.Set("X-Forwarded-For", "111.206.198.69") + return req + }, + wantCode: 200, + }, + { + name: "百度 - Baiduspider-render", + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/test", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("User-Agent", "Mozilla/5.0 (iPhone;CPU iPhone OS 9_1 like Mac OS X) AppleWebKit/601.1.46 (KHTML, like Gecko)Version/9.0 Mobile/13B143 Safari/601.1 (compatible; Baiduspider-render/2.0;Smartapp; +http://www.baidu.com/search/spider.html)") + req.Header.Set("X-Forwarded-For", "111.206.198.69") + return req + }, + wantCode: 200, + }, + { + name: "必应 - bingbot", + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/test", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("User-Agent", "Mozilla/5.0 AppleWebKit/537.36 (KHTML, like Gecko; compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm) Chrome/") + req.Header.Set("X-Forwarded-For", "157.55.39.1") + return req + }, + wantCode: 200, + }, + { + name: "必应 - adidxbot", + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/test", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("User-Agent", "Mozilla/5.0 (Windows Phone 8.1; ARM; Trident/7.0; Touch; rv:11.0; IEMobile/11.0; NOKIA; Lumia 530) like Gecko (compatible; adidxbot/2.0; +http://www.bing.com/bingbot.htm)") + req.Header.Set("X-Forwarded-For", "157.55.39.1") + return req + }, + wantCode: 200, + }, + { + name: "必应 - MicrosoftPreview", + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/test", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("User-Agent", "Mozilla/5.0 AppleWebKit/537.36 (KHTML, like Gecko; compatible; MicrosoftPreview/2.0; +https://aka.ms/MicrosoftPreview) Chrome/W.X.Y.Z Safari/537.36") + req.Header.Set("X-Forwarded-For", "157.55.39.1") + return req + }, + wantCode: 200, + }, + { + name: "谷歌 - Googlebot", + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/test", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("User-Agent", "Mozilla/5.0 (Linux; Android 6.0.1; Nexus 5X Build/MMB29P) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/W.X.Y.Z Mobile Safari/537.36 (compatible; Googlebot/2.1; +http://www.google.com/bot.html)") + req.Header.Set("X-Forwarded-For", "66.249.66.1") + return req + }, + wantCode: 200, + }, + { + name: "谷歌 - Googlebot-Image", + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/test", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("User-Agent", "Mozilla/5.0 (Linux; Android 6.0.1; Nexus 5X Build/MMB29P) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/W.X.Y.Z Mobile Safari/537.36 (compatible; Googlebot-Image/1.0; +http://www.google.com/bot.html)") + req.Header.Set("X-Forwarded-For", "35.247.243.240") + return req + }, + wantCode: 200, + }, + { + name: "谷歌 - Googlebot-News", + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/test", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("User-Agent", "Mozilla/5.0 (Linux; Android 6.0.1; Nexus 5X Build/MMB29P) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/W.X.Y.Z Mobile Safari/537.36 (compatible; Googlebot-News/1.0; +http://www.google.com/bot.html)") + req.Header.Set("X-Forwarded-For", "66.249.90.77") + return req + }, + wantCode: 200, + }, + { + name: "谷歌 - Storebot-Google", + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/test", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64; Storebot-Google/1.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/W.X.Y.Z Safari/537.36") + req.Header.Set("X-Forwarded-For", "66.249.66.1") + return req + }, + wantCode: 200, + }, + { + name: "谷歌 - Google-InspectionTool", + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/test", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("User-Agent", "Mozilla/5.0 (Linux; Android 6.0.1; Nexus 5X Build/MMB29P) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/W.X.Y.Z Mobile Safari/537.36 (compatible; Google-InspectionTool/1.0;)") + req.Header.Set("X-Forwarded-For", "66.249.66.1") + return req + }, + wantCode: 200, + }, + { + name: "谷歌 - GoogleOther", + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/test", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64; GoogleOther/1.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/W.X.Y.Z Safari/537.36") + req.Header.Set("X-Forwarded-For", "66.249.66.1") + return req + }, + wantCode: 200, + }, + { + name: "谷歌 - Google-Extended", + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/test", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64; Google-Extended/1.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/W.X.Y.Z Safari/537.36") + req.Header.Set("X-Forwarded-For", "66.249.66.1") + return req + }, + wantCode: 200, + }, + { + name: "搜狗 - Sogou web spider", + reqBuilder: func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "/test", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64; Google-Extended/1.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/W.X.Y.Z Safari/537.36") + req.Header.Set("X-Forwarded-For", "66.249.66.1") + return req + }, + wantCode: 200, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + server := gin.Default() + server.TrustedPlatform = "X-Forwarded-For" + server.Use(NewBuilder().Build()) + server.GET("/test", func(ctx *gin.Context) { + ctx.JSON(200, nil) + }) + + recorder := httptest.NewRecorder() + req := tc.reqBuilder(t) + + server.ServeHTTP(recorder, req) + + require.Equal(t, tc.wantCode, recorder.Code) + }) + } +} + +func TestBuilder_AddUserAgent(t *testing.T) { + b := NewBuilder().AddUserAgent(map[string][]string{ + Baidu: {"test-new-baidu-user-agent"}, + }) + v, exist := b.crawlersMap["test-new-baidu-user-agent"] + require.Equal(t, Baidu, v) + require.True(t, exist) +} + +func TestBuilder_RemoveUserAgent(t *testing.T) { + b := NewBuilder().RemoveUserAgent("Baiduspider") + v, exist := b.crawlersMap["Baiduspider"] + require.Equal(t, "", v) + require.False(t, exist) +} diff --git a/session/provider.mock_test.go b/session/provider.mock_test.go index 634c44e..6dc4734 100644 --- a/session/provider.mock_test.go +++ b/session/provider.mock_test.go @@ -13,11 +13,11 @@ // limitations under the License. // Code generated by MockGen. DO NOT EDIT. -// Source: session/types.go +// Source: session/crawler_detector.go // // Generated by this command: // -// mockgen -copyright_file=.license_header -source=session/types.go -package=session -destination=session/provider.mock_test.go Provider +// mockgen -copyright_file=.license_header -source=session/crawler_detector.go -package=session -destination=session/provider.mock_test.go Provider // // Package session is a generated GoMock package. package session