From dbb9adcd225c79b1f84a9ac0ee3fdd363d4214d2 Mon Sep 17 00:00:00 2001 From: Sebastien Blot Date: Thu, 7 Nov 2024 11:03:26 +0100 Subject: [PATCH] Centralized allowlists support --- .github/workflows/go-tests-windows.yml | 2 +- .github/workflows/go-tests.yml | 2 +- .golangci.yml | 38 +- Makefile | 61 +- README.md | 2 +- azure-pipelines.yml | 2 +- cmd/crowdsec-cli/cliallowlists/allowlists.go | 519 +++++ cmd/crowdsec-cli/clibouncer/add.go | 2 +- cmd/crowdsec-cli/clibouncer/bouncers.go | 2 + cmd/crowdsec-cli/clibouncer/delete.go | 62 +- cmd/crowdsec-cli/clibouncer/inspect.go | 1 + cmd/crowdsec-cli/clicapi/capi.go | 21 + .../clinotifications/notifications.go | 2 +- cmd/crowdsec-cli/clipapi/papi.go | 2 +- cmd/crowdsec-cli/dashboard.go | 4 +- cmd/crowdsec-cli/main.go | 5 + cmd/crowdsec-cli/setup.go | 1 + cmd/crowdsec/appsec.go | 2 +- cmd/crowdsec/crowdsec.go | 12 +- cmd/crowdsec/lapiclient.go | 4 +- cmd/crowdsec/main.go | 4 +- cmd/notification-email/main.go | 2 +- go.mod | 2 +- pkg/acquisition/acquisition.go | 24 +- .../configuration/configuration.go | 10 +- pkg/acquisition/http.go | 12 + pkg/acquisition/modules/appsec/appsec.go | 147 +- .../modules/appsec/appsec_hooks_test.go | 4 +- .../modules/appsec/appsec_rules_test.go | 119 +- .../modules/appsec/appsec_runner.go | 50 +- .../modules/appsec/appsec_runner_test.go | 153 ++ pkg/acquisition/modules/appsec/appsec_test.go | 24 +- .../modules/appsec/bodyprocessors/raw.go | 7 +- pkg/acquisition/modules/appsec/utils.go | 168 +- .../modules/cloudwatch/cloudwatch.go | 5 +- pkg/acquisition/modules/docker/docker.go | 13 +- pkg/acquisition/modules/file/file.go | 10 +- pkg/acquisition/modules/http/http.go | 414 ++++ pkg/acquisition/modules/http/http_test.go | 784 +++++++ pkg/acquisition/modules/http/testdata/ca.crt | 23 + .../modules/http/testdata/client.crt | 24 + .../modules/http/testdata/client.key | 27 + .../modules/http/testdata/server.crt | 23 + .../modules/http/testdata/server.key | 27 + .../modules/journalctl/journalctl.go | 9 +- pkg/acquisition/modules/kafka/kafka.go | 9 +- pkg/acquisition/modules/kinesis/kinesis.go | 8 +- .../modules/kubernetesaudit/k8s_audit.go | 32 +- .../loki/internal/lokiclient/loki_client.go | 9 +- pkg/acquisition/modules/loki/loki.go | 61 +- pkg/acquisition/modules/loki/loki_test.go | 23 +- pkg/acquisition/modules/s3/s3.go | 8 +- .../syslog/internal/parser/rfc3164/parse.go | 1 - .../syslog/internal/parser/rfc5424/parse.go | 3 - .../internal/parser/rfc5424/parse_test.go | 58 +- .../syslog/internal/server/syslogserver.go | 1 - pkg/acquisition/modules/syslog/syslog.go | 8 +- .../wineventlog/wineventlog_windows.go | 14 +- pkg/alertcontext/alertcontext.go | 152 +- pkg/alertcontext/alertcontext_test.go | 161 ++ pkg/apiclient/allowlists_service.go | 91 + pkg/apiclient/auth_jwt.go | 3 - pkg/apiclient/client.go | 69 + pkg/apiclient/decisions_service.go | 13 + pkg/apiclient/decisions_service_test.go | 27 +- pkg/apiserver/alerts_test.go | 41 +- pkg/apiserver/allowlists_test.go | 119 + pkg/apiserver/api_key_test.go | 58 +- pkg/apiserver/apic.go | 164 +- pkg/apiserver/apic_metrics.go | 8 +- pkg/apiserver/apic_test.go | 32 +- pkg/apiserver/apiserver.go | 28 +- pkg/apiserver/apiserver_test.go | 20 +- pkg/apiserver/controllers/controller.go | 5 + pkg/apiserver/controllers/v1/alerts.go | 21 +- pkg/apiserver/controllers/v1/allowlist.go | 126 + pkg/apiserver/decisions_test.go | 16 +- pkg/apiserver/jwt_test.go | 10 +- pkg/apiserver/machines_test.go | 6 +- pkg/apiserver/middlewares/v1/api_key.go | 79 +- pkg/apiserver/papi.go | 6 +- pkg/apiserver/papi_cmd.go | 97 +- pkg/appsec/appsec.go | 73 +- pkg/appsec/appsec_rule/appsec_rule.go | 1 - pkg/appsec/appsec_rule/modsec_rule_test.go | 2 - pkg/appsec/coraza_logger.go | 2 +- pkg/appsec/request_test.go | 3 - pkg/cache/cache_test.go | 13 +- pkg/csconfig/api.go | 24 +- pkg/csconfig/api_test.go | 5 + pkg/csplugin/broker.go | 4 +- pkg/cticlient/types.go | 2 - pkg/cwversion/component/component.go | 29 +- pkg/database/allowlists.go | 255 +++ pkg/database/allowlists_test.go | 99 + pkg/database/bouncers.go | 19 +- pkg/database/ent/allowlist.go | 189 ++ pkg/database/ent/allowlist/allowlist.go | 133 ++ pkg/database/ent/allowlist/where.go | 429 ++++ pkg/database/ent/allowlist_create.go | 321 +++ pkg/database/ent/allowlist_delete.go | 88 + pkg/database/ent/allowlist_query.go | 636 ++++++ pkg/database/ent/allowlist_update.go | 421 ++++ pkg/database/ent/allowlistitem.go | 231 ++ .../ent/allowlistitem/allowlistitem.go | 165 ++ pkg/database/ent/allowlistitem/where.go | 664 ++++++ pkg/database/ent/allowlistitem_create.go | 398 ++++ pkg/database/ent/allowlistitem_delete.go | 88 + pkg/database/ent/allowlistitem_query.go | 636 ++++++ pkg/database/ent/allowlistitem_update.go | 463 ++++ pkg/database/ent/bouncer.go | 13 +- pkg/database/ent/bouncer/bouncer.go | 10 + pkg/database/ent/bouncer/where.go | 15 + pkg/database/ent/bouncer_create.go | 25 + pkg/database/ent/client.go | 374 ++- pkg/database/ent/ent.go | 22 +- pkg/database/ent/hook/hook.go | 24 + pkg/database/ent/migrate/schema.go | 91 + pkg/database/ent/mutation.go | 2022 ++++++++++++++++- pkg/database/ent/predicate/predicate.go | 6 + pkg/database/ent/runtime.go | 30 + pkg/database/ent/schema/allowlist.go | 44 + pkg/database/ent/schema/allowlist_item.go | 51 + pkg/database/ent/schema/bouncer.go | 2 + pkg/database/ent/tx.go | 6 + pkg/database/flush.go | 22 + pkg/exprhelpers/crowdsec_cti.go | 20 +- pkg/exprhelpers/geoip.go | 3 - pkg/fflag/crowdsec.go | 14 +- pkg/leakybucket/blackhole.go | 2 - pkg/leakybucket/bucket.go | 1 - pkg/leakybucket/buckets.go | 1 - pkg/leakybucket/conditional.go | 6 +- pkg/leakybucket/manager_load_test.go | 59 +- pkg/leakybucket/manager_run.go | 13 +- pkg/leakybucket/overflows.go | 27 +- pkg/leakybucket/processor.go | 3 +- pkg/leakybucket/reset_filter.go | 10 +- pkg/leakybucket/uniq.go | 6 +- pkg/longpollclient/client.go | 21 +- pkg/models/allowlist_item.go | 100 + pkg/models/check_allowlist_response.go | 50 + pkg/models/get_allowlist_response.go | 174 ++ pkg/models/get_allowlists_response.go | 78 + pkg/models/localapi_swagger.yaml | 173 ++ pkg/modelscapi/allowlist_link.go | 166 ++ pkg/modelscapi/centralapi_swagger.yaml | 48 + .../get_decisions_stream_response_links.go | 62 + pkg/parser/enrich.go | 6 +- pkg/parser/enrich_geoip.go | 3 - pkg/parser/node.go | 14 +- pkg/parser/parsing_test.go | 4 +- pkg/parser/runtime.go | 49 +- pkg/parser/unix_parser.go | 2 +- pkg/types/appsec_event.go | 9 +- pkg/types/constants.go | 34 +- pkg/types/event.go | 22 +- pkg/types/event_test.go | 2 - pkg/types/getfstype.go | 1 - pkg/types/ip.go | 3 +- pkg/types/ip_test.go | 5 +- pkg/types/utils.go | 8 +- rpm/SPECS/crowdsec.spec | 2 +- test/ansible/vagrant/fedora-40/Vagrantfile | 2 +- test/ansible/vagrant/fedora-41/Vagrantfile | 13 + test/ansible/vagrant/fedora-41/skip | 9 + .../vagrant/opensuse-leap-15/Vagrantfile | 10 + test/ansible/vagrant/opensuse-leap-15/skip | 9 + test/bats/10_bouncers.bats | 55 +- test/lib/init/crowdsec-daemon | 6 +- 170 files changed, 13064 insertions(+), 809 deletions(-) create mode 100644 cmd/crowdsec-cli/cliallowlists/allowlists.go create mode 100644 pkg/acquisition/http.go create mode 100644 pkg/acquisition/modules/appsec/appsec_runner_test.go create mode 100644 pkg/acquisition/modules/http/http.go create mode 100644 pkg/acquisition/modules/http/http_test.go create mode 100644 pkg/acquisition/modules/http/testdata/ca.crt create mode 100644 pkg/acquisition/modules/http/testdata/client.crt create mode 100644 pkg/acquisition/modules/http/testdata/client.key create mode 100644 pkg/acquisition/modules/http/testdata/server.crt create mode 100644 pkg/acquisition/modules/http/testdata/server.key create mode 100644 pkg/apiclient/allowlists_service.go create mode 100644 pkg/apiserver/allowlists_test.go create mode 100644 pkg/apiserver/controllers/v1/allowlist.go create mode 100644 pkg/database/allowlists.go create mode 100644 pkg/database/allowlists_test.go create mode 100644 pkg/database/ent/allowlist.go create mode 100644 pkg/database/ent/allowlist/allowlist.go create mode 100644 pkg/database/ent/allowlist/where.go create mode 100644 pkg/database/ent/allowlist_create.go create mode 100644 pkg/database/ent/allowlist_delete.go create mode 100644 pkg/database/ent/allowlist_query.go create mode 100644 pkg/database/ent/allowlist_update.go create mode 100644 pkg/database/ent/allowlistitem.go create mode 100644 pkg/database/ent/allowlistitem/allowlistitem.go create mode 100644 pkg/database/ent/allowlistitem/where.go create mode 100644 pkg/database/ent/allowlistitem_create.go create mode 100644 pkg/database/ent/allowlistitem_delete.go create mode 100644 pkg/database/ent/allowlistitem_query.go create mode 100644 pkg/database/ent/allowlistitem_update.go create mode 100644 pkg/database/ent/schema/allowlist.go create mode 100644 pkg/database/ent/schema/allowlist_item.go create mode 100644 pkg/models/allowlist_item.go create mode 100644 pkg/models/check_allowlist_response.go create mode 100644 pkg/models/get_allowlist_response.go create mode 100644 pkg/models/get_allowlists_response.go create mode 100644 pkg/modelscapi/allowlist_link.go create mode 100644 test/ansible/vagrant/fedora-41/Vagrantfile create mode 100644 test/ansible/vagrant/fedora-41/skip create mode 100644 test/ansible/vagrant/opensuse-leap-15/Vagrantfile create mode 100644 test/ansible/vagrant/opensuse-leap-15/skip diff --git a/.github/workflows/go-tests-windows.yml b/.github/workflows/go-tests-windows.yml index 2966b999a4a..3276dbb1bfd 100644 --- a/.github/workflows/go-tests-windows.yml +++ b/.github/workflows/go-tests-windows.yml @@ -61,6 +61,6 @@ jobs: - name: golangci-lint uses: golangci/golangci-lint-action@v6 with: - version: v1.61 + version: v1.62 args: --issues-exit-code=1 --timeout 10m only-new-issues: false diff --git a/.github/workflows/go-tests.yml b/.github/workflows/go-tests.yml index 3f4aa67e139..3638696b4f6 100644 --- a/.github/workflows/go-tests.yml +++ b/.github/workflows/go-tests.yml @@ -190,6 +190,6 @@ jobs: - name: golangci-lint uses: golangci/golangci-lint-action@v6 with: - version: v1.61 + version: v1.62 args: --issues-exit-code=1 --timeout 10m only-new-issues: false diff --git a/.golangci.yml b/.golangci.yml index 271e3a57d34..7217c6da2b1 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -183,7 +183,6 @@ linters-settings: - ifElseChain - importShadow - hugeParam - - rangeValCopy - commentedOutCode - commentedOutImport - unnamedResult @@ -211,9 +210,7 @@ linters: # # DEPRECATED by golangi-lint # - - execinquery - exportloopref - - gomnd # # Redundant @@ -348,10 +345,6 @@ issues: - errorlint text: "type switch on error will fail on wrapped errors. Use errors.As to check for specific errors" - - linters: - - errorlint - text: "comparing with .* will fail on wrapped errors. Use errors.Is to check for a specific error" - - linters: - nosprintfhostport text: "host:port in url should be constructed with net.JoinHostPort and not directly with fmt.Sprintf" @@ -460,3 +453,34 @@ issues: - revive path: "cmd/crowdsec/win_service.go" text: "deep-exit: .*" + + - linters: + - recvcheck + path: "pkg/csplugin/hclog_adapter.go" + text: 'the methods of "HCLogAdapter" use pointer receiver and non-pointer receiver.' + + # encoding to json/yaml requires value receivers + - linters: + - recvcheck + path: "pkg/cwhub/item.go" + text: 'the methods of "Item" use pointer receiver and non-pointer receiver.' + + - linters: + - gocritic + path: "cmd/crowdsec-cli" + text: "rangeValCopy: .*" + + - linters: + - gocritic + path: "pkg/(cticlient|hubtest)" + text: "rangeValCopy: .*" + + - linters: + - gocritic + path: "(.+)_test.go" + text: "rangeValCopy: .*" + + - linters: + - gocritic + path: "pkg/(appsec|acquisition|dumps|alertcontext|leakybucket|exprhelpers)" + text: "rangeValCopy: .*" diff --git a/Makefile b/Makefile index 29a84d5b066..f8ae66e1cb6 100644 --- a/Makefile +++ b/Makefile @@ -22,7 +22,7 @@ BUILD_RE2_WASM ?= 0 # for your distribution (look for libre2.a). See the Dockerfile for an example of how to build it. BUILD_STATIC ?= 0 -# List of plugins to build +# List of notification plugins to build PLUGINS ?= $(patsubst ./cmd/notification-%,%,$(wildcard ./cmd/notification-*)) #-------------------------------------- @@ -80,9 +80,17 @@ endif #expr_debug tag is required to enable the debug mode in expr GO_TAGS := netgo,osusergo,sqlite_omit_load_extension,expr_debug +# Allow building on ubuntu 24.10, see https://github.com/golang/go/issues/70023 +export CGO_LDFLAGS_ALLOW=-Wl,--(push|pop)-state.* + # this will be used by Go in the make target, some distributions require it export PKG_CONFIG_PATH:=/usr/local/lib/pkgconfig:$(PKG_CONFIG_PATH) +#-------------------------------------- +# +# Choose the re2 backend. +# + ifeq ($(call bool,$(BUILD_RE2_WASM)),0) ifeq ($(PKG_CONFIG),) $(error "pkg-config is not available. Please install pkg-config.") @@ -90,35 +98,28 @@ endif ifeq ($(RE2_CHECK),) RE2_FAIL := "libre2-dev is not installed, please install it or set BUILD_RE2_WASM=1 to use the WebAssembly version" +# if you prefer to build WASM instead of a critical error, comment out RE2_FAIL and uncomment RE2_MSG. +# RE2_MSG := Fallback to WebAssembly regexp library. To use the C++ version, make sure you have installed libre2-dev and pkg-config. else # += adds a space that we don't want GO_TAGS := $(GO_TAGS),re2_cgo LD_OPTS_VARS += -X '$(GO_MODULE_NAME)/pkg/cwversion.Libre2=C++' +RE2_MSG := Using C++ regexp library endif -endif - -# Build static to avoid the runtime dependency on libre2.so -ifeq ($(call bool,$(BUILD_STATIC)),1) -BUILD_TYPE = static -EXTLDFLAGS := -extldflags '-static' else -BUILD_TYPE = dynamic -EXTLDFLAGS := +RE2_MSG := Using WebAssembly regexp library endif -# Build with debug symbols, and disable optimizations + inlining, to use Delve -ifeq ($(call bool,$(DEBUG)),1) -STRIP_SYMBOLS := -DISABLE_OPTIMIZATION := -gcflags "-N -l" +ifeq ($(call bool,$(BUILD_RE2_WASM)),1) else -STRIP_SYMBOLS := -s -DISABLE_OPTIMIZATION := +ifneq (,$(RE2_CHECK)) +endif endif #-------------------------------------- - +# # Handle optional components and build profiles, to save space on the final binaries. - +# # Keep it safe for now until we decide how to expand on the idea. Either choose a profile or exclude components manually. # For example if we want to disable some component by default, or have opt-in components (INCLUDE?). @@ -131,6 +132,7 @@ COMPONENTS := \ datasource_cloudwatch \ datasource_docker \ datasource_file \ + datasource_http \ datasource_k8saudit \ datasource_kafka \ datasource_journalctl \ @@ -178,6 +180,23 @@ endif #-------------------------------------- +ifeq ($(call bool,$(BUILD_STATIC)),1) +BUILD_TYPE = static +EXTLDFLAGS := -extldflags '-static' +else +BUILD_TYPE = dynamic +EXTLDFLAGS := +endif + +# Build with debug symbols, and disable optimizations + inlining, to use Delve +ifeq ($(call bool,$(DEBUG)),1) +STRIP_SYMBOLS := +DISABLE_OPTIMIZATION := -gcflags "-N -l" +else +STRIP_SYMBOLS := -s +DISABLE_OPTIMIZATION := +endif + export LD_OPTS=-ldflags "$(STRIP_SYMBOLS) $(EXTLDFLAGS) $(LD_OPTS_VARS)" \ -trimpath -tags $(GO_TAGS) $(DISABLE_OPTIMIZATION) @@ -193,17 +212,13 @@ build: build-info crowdsec cscli plugins ## Build crowdsec, cscli and plugins .PHONY: build-info build-info: ## Print build information $(info Building $(BUILD_VERSION) ($(BUILD_TAG)) $(BUILD_TYPE) for $(GOOS)/$(GOARCH)) - $(info Excluded components: $(EXCLUDE_LIST)) + $(info Excluded components: $(if $(EXCLUDE_LIST),$(EXCLUDE_LIST),none)) ifneq (,$(RE2_FAIL)) $(error $(RE2_FAIL)) endif -ifneq (,$(RE2_CHECK)) - $(info Using C++ regexp library) -else - $(info Fallback to WebAssembly regexp library. To use the C++ version, make sure you have installed libre2-dev and pkg-config.) -endif + $(info $(RE2_MSG)) ifeq ($(call bool,$(DEBUG)),1) $(info Building with debug symbols and disabled optimizations) diff --git a/README.md b/README.md index a900f0ee514..1e57d4e91c4 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ The architecture is as follows : CrowdSec

-Once an unwanted behavior is detected, deal with it through a [bouncer](https://hub.crowdsec.net/browse/#bouncers). The aggressive IP, scenario triggered and timestamp are sent for curation, to avoid poisoning & false positives. (This can be disabled). If verified, this IP is then redistributed to all CrowdSec users running the same scenario. +Once an unwanted behavior is detected, deal with it through a [bouncer](https://app.crowdsec.net/hub/remediation-components). The aggressive IP, scenario triggered and timestamp are sent for curation, to avoid poisoning & false positives. (This can be disabled). If verified, this IP is then redistributed to all CrowdSec users running the same scenario. ## Outnumbering hackers all together diff --git a/azure-pipelines.yml b/azure-pipelines.yml index acbcabc20c5..bcf327bdf38 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -21,7 +21,7 @@ stages: - task: GoTool@0 displayName: "Install Go" inputs: - version: '1.23' + version: '1.23.3' - pwsh: | choco install -y make diff --git a/cmd/crowdsec-cli/cliallowlists/allowlists.go b/cmd/crowdsec-cli/cliallowlists/allowlists.go new file mode 100644 index 00000000000..9660a6ec817 --- /dev/null +++ b/cmd/crowdsec-cli/cliallowlists/allowlists.go @@ -0,0 +1,519 @@ +package cliallowlists + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/url" + "slices" + "strings" + "time" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/apiclient" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/fatih/color" + "github.com/go-openapi/strfmt" + "github.com/jedib0t/go-pretty/v6/table" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +type configGetter = func() *csconfig.Config + +type cliAllowLists struct { + db *database.Client + client *apiclient.ApiClient + cfg configGetter +} + +func New(cfg configGetter) *cliAllowLists { + return &cliAllowLists{ + cfg: cfg, + } +} + +// validAllowlists returns a list of valid allowlists name for command completion +func (cli *cliAllowLists) validAllowlists(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + var err error + + cfg := cli.cfg() + ctx := cmd.Context() + + // need to load config and db because PersistentPreRunE is not called for completions + + if err = require.LAPI(cfg); err != nil { + cobra.CompError("unable to load LAPI " + err.Error()) + return nil, cobra.ShellCompDirectiveNoFileComp + } + + cli.db, err = require.DBClient(ctx, cfg.DbConfig) + if err != nil { + cobra.CompError("unable to load dbclient " + err.Error()) + return nil, cobra.ShellCompDirectiveNoFileComp + } + + allowlists, err := cli.db.ListAllowLists(ctx, false) + if err != nil { + cobra.CompError("unable to list allowlists " + err.Error()) + return nil, cobra.ShellCompDirectiveNoFileComp + } + + ret := []string{} + + for _, allowlist := range allowlists { + if strings.Contains(allowlist.Name, toComplete) && !slices.Contains(args, allowlist.Name) { + ret = append(ret, allowlist.Name) + } + } + + return ret, cobra.ShellCompDirectiveNoFileComp +} + +func (cli *cliAllowLists) listHuman(out io.Writer, allowlists *models.GetAllowlistsResponse) { + t := cstable.NewLight(out, cli.cfg().Cscli.Color).Writer + t.AppendHeader(table.Row{"Name", "Description", "Creation Date", "Updated at", "Managed by Console", "Size"}) + + for _, allowlist := range *allowlists { + t.AppendRow(table.Row{allowlist.Name, allowlist.Description, allowlist.CreatedAt, allowlist.UpdatedAt, allowlist.ConsoleManaged, len(allowlist.Items)}) + } + + io.WriteString(out, t.Render()+"\n") +} + +func (cli *cliAllowLists) listContentHuman(out io.Writer, allowlist *models.GetAllowlistResponse) { + t := cstable.NewLight(out, cli.cfg().Cscli.Color).Writer + t.AppendHeader(table.Row{"Value", "Comment", "Expiration", "Created at"}) + + for _, content := range allowlist.Items { + expiration := "never" + if !time.Time(content.Expiration).IsZero() { + expiration = content.Expiration.String() + } + t.AppendRow(table.Row{content.Value, content.Description, expiration, allowlist.CreatedAt}) + } + + io.WriteString(out, t.Render()+"\n") +} + +func (cli *cliAllowLists) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "allowlists [action]", + Short: "Manage centralized allowlists", + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + } + + cmd.AddCommand(cli.newCreateCmd()) + cmd.AddCommand(cli.newListCmd()) + cmd.AddCommand(cli.newDeleteCmd()) + cmd.AddCommand(cli.newAddCmd()) + cmd.AddCommand(cli.newRemoveCmd()) + cmd.AddCommand(cli.newInspectCmd()) + return cmd +} + +func (cli *cliAllowLists) newCreateCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "create [allowlist_name]", + Example: "cscli allowlists create my_allowlist -d 'my allowlist description'", + Short: "Create a new allowlist", + Args: cobra.ExactArgs(1), + PersistentPreRunE: func(cmd *cobra.Command, _ []string) error { + var err error + cfg := cli.cfg() + if err = require.LAPI(cfg); err != nil { + return err + } + cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig) + if err != nil { + return err + } + return nil + }, + RunE: cli.create, + } + + flags := cmd.Flags() + + flags.StringP("description", "d", "", "description of the allowlist") + + cmd.MarkFlagRequired("description") + + return cmd +} + +func (cli *cliAllowLists) create(cmd *cobra.Command, args []string) error { + name := args[0] + description := cmd.Flag("description").Value.String() + + _, err := cli.db.CreateAllowList(cmd.Context(), name, description, false) + + if err != nil { + return err + } + + log.Infof("allowlist '%s' created successfully", name) + + return nil +} + +func (cli *cliAllowLists) newListCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Example: `cscli allowlists list`, + Short: "List all allowlists", + Args: cobra.NoArgs, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := cfg.LoadAPIClient(); err != nil { + return fmt.Errorf("loading api client: %w", err) + } + apiURL, err := url.Parse(cfg.API.Client.Credentials.URL) + if err != nil { + return fmt.Errorf("parsing api url: %w", err) + } + + cli.client, err = apiclient.NewClient(&apiclient.Config{ + MachineID: cfg.API.Client.Credentials.Login, + Password: strfmt.Password(cfg.API.Client.Credentials.Password), + URL: apiURL, + VersionPrefix: "v1", + }) + if err != nil { + return fmt.Errorf("creating api client: %w", err) + } + + return nil + }, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.list(cmd, color.Output) + }, + } + + return cmd +} + +func (cli *cliAllowLists) list(cmd *cobra.Command, out io.Writer) error { + allowlists, _, err := cli.client.Allowlists.List(cmd.Context(), apiclient.AllowlistListOpts{WithContent: true}) + if err != nil { + return err + } + + switch cli.cfg().Cscli.Output { + case "human": + cli.listHuman(out, allowlists) + case "json": + enc := json.NewEncoder(out) + enc.SetIndent("", " ") + + if err := enc.Encode(allowlists); err != nil { + return errors.New("failed to serialize") + } + + return nil + case "raw": + //return cli.listCSV(out, allowlists) + } + + return nil +} + +func (cli *cliAllowLists) newDeleteCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "delete [allowlist_name]", + Short: "Delete an allowlist", + Example: `cscli allowlists delete my_allowlist`, + Args: cobra.ExactArgs(1), + PersistentPreRunE: func(cmd *cobra.Command, _ []string) error { + var err error + cfg := cli.cfg() + + if err = require.LAPI(cfg); err != nil { + return err + } + cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig) + if err != nil { + return err + } + return nil + }, + RunE: cli.delete, + } + + return cmd +} + +func (cli *cliAllowLists) delete(cmd *cobra.Command, args []string) error { + name := args[0] + list, err := cli.db.GetAllowList(cmd.Context(), name, false) + + if err != nil { + return err + } + + if list == nil { + return fmt.Errorf("allowlist %s not found", name) + } + + if list.FromConsole { + return fmt.Errorf("allowlist %s is managed by console, cannot delete with cscli", name) + } + + err = cli.db.DeleteAllowList(cmd.Context(), name, false) + if err != nil { + return err + } + + log.Infof("allowlist '%s' deleted successfully", name) + + return nil +} + +func (cli *cliAllowLists) newAddCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "add [allowlist_name] --value [value] [-e expiration] [-d comment]", + Short: "Add content an allowlist", + Example: `cscli allowlists add my_allowlist --value 1.2.3.4 --value 2.3.4.5 -e 1h -d "my comment"`, + Args: cobra.ExactArgs(1), + PersistentPreRunE: func(cmd *cobra.Command, _ []string) error { + var err error + cfg := cli.cfg() + + if err = require.LAPI(cfg); err != nil { + return err + } + + cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig) + if err != nil { + return err + } + return nil + }, + RunE: cli.add, + } + + flags := cmd.Flags() + + flags.StringSliceP("value", "v", nil, "value to add to the allowlist") + flags.StringP("expiration", "e", "", "expiration duration") + flags.StringP("comment", "d", "", "comment for the value") + + cmd.MarkFlagRequired("value") + + return cmd +} + +func (cli *cliAllowLists) add(cmd *cobra.Command, args []string) error { + + var expiration time.Duration + + name := args[0] + values, err := cmd.Flags().GetStringSlice("value") + comment := cmd.Flag("comment").Value.String() + + if err != nil { + return err + } + + expirationStr := cmd.Flag("expiration").Value.String() + + if expirationStr != "" { + //FIXME: handle days (and maybe more ?) + expiration, err = time.ParseDuration(expirationStr) + + if err != nil { + return err + } + } + + allowlist, err := cli.db.GetAllowList(cmd.Context(), name, true) + + if err != nil { + return fmt.Errorf("unable to get allowlist: %w", err) + } + + if allowlist.FromConsole { + return fmt.Errorf("allowlist %s is managed by console, cannot update with cscli", name) + } + + toAdd := make([]*models.AllowlistItem, 0) + + for _, v := range values { + found := false + for _, item := range allowlist.Edges.AllowlistItems { + if item.Value == v { + found = true + log.Warnf("value %s already in allowlist", v) + break + } + } + if !found { + toAdd = append(toAdd, &models.AllowlistItem{Value: v, Description: comment, Expiration: strfmt.DateTime(time.Now().UTC().Add(expiration))}) + } + } + + if len(toAdd) == 0 { + log.Warn("no value to add to allowlist") + return nil + } + + log.Debugf("adding %d values to allowlist %s", len(toAdd), name) + + err = cli.db.AddToAllowlist(cmd.Context(), allowlist, toAdd) + + if err != nil { + return fmt.Errorf("unable to add values to allowlist: %w", err) + } + + return nil +} + +func (cli *cliAllowLists) newInspectCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "inspect [allowlist_name]", + Example: `cscli allowlists inspect my_allowlist`, + Short: "Inspect an allowlist", + Args: cobra.ExactArgs(1), + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := cfg.LoadAPIClient(); err != nil { + return fmt.Errorf("loading api client: %w", err) + } + apiURL, err := url.Parse(cfg.API.Client.Credentials.URL) + if err != nil { + return fmt.Errorf("parsing api url: %w", err) + } + + cli.client, err = apiclient.NewClient(&apiclient.Config{ + MachineID: cfg.API.Client.Credentials.Login, + Password: strfmt.Password(cfg.API.Client.Credentials.Password), + URL: apiURL, + VersionPrefix: "v1", + }) + if err != nil { + return fmt.Errorf("creating api client: %w", err) + } + + return nil + }, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.inspect(cmd, args, color.Output) + }, + } + + return cmd +} + +func (cli *cliAllowLists) inspect(cmd *cobra.Command, args []string, out io.Writer) error { + name := args[0] + allowlist, _, err := cli.client.Allowlists.Get(cmd.Context(), name, apiclient.AllowlistGetOpts{WithContent: true}) + + if err != nil { + return fmt.Errorf("unable to get allowlist: %w", err) + } + + switch cli.cfg().Cscli.Output { + case "human": + cli.listContentHuman(out, allowlist) + case "json": + enc := json.NewEncoder(out) + enc.SetIndent("", " ") + + if err := enc.Encode(allowlist); err != nil { + return errors.New("failed to serialize") + } + + return nil + case "raw": + //return cli.listCSV(out, allowlists) + } + + return nil +} + +func (cli *cliAllowLists) newRemoveCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "remove [allowlist_name] --value [value]", + Short: "remove content from an allowlist", + Example: `cscli allowlists remove my_allowlist --value 1.2.3.4 --value 2.3.4.5"`, + Args: cobra.ExactArgs(1), + PersistentPreRunE: func(cmd *cobra.Command, _ []string) error { + var err error + cfg := cli.cfg() + + if err = require.LAPI(cfg); err != nil { + return err + } + + cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig) + if err != nil { + return err + } + return nil + }, + RunE: cli.remove, + } + + flags := cmd.Flags() + + flags.StringSliceP("value", "v", nil, "value to remove from the allowlist") + + cmd.MarkFlagRequired("value") + + return cmd +} + +func (cli *cliAllowLists) remove(cmd *cobra.Command, args []string) error { + name := args[0] + values, err := cmd.Flags().GetStringSlice("value") + + if err != nil { + return err + } + + allowlist, err := cli.db.GetAllowList(cmd.Context(), name, true) + + if err != nil { + return fmt.Errorf("unable to get allowlist: %w", err) + } + + if allowlist.FromConsole { + return fmt.Errorf("allowlist %s is managed by console, cannot update with cscli", name) + } + + toRemove := make([]string, 0) + + for _, v := range values { + found := false + for _, item := range allowlist.Edges.AllowlistItems { + if item.Value == v { + found = true + break + } + } + if found { + toRemove = append(toRemove, v) + } + } + + if len(toRemove) == 0 { + log.Warn("no value to remove from allowlist") + } + + log.Debugf("removing %d values from allowlist %s", len(toRemove), name) + + nbDeleted, err := cli.db.RemoveFromAllowlist(cmd.Context(), allowlist, toRemove...) + + if err != nil { + return fmt.Errorf("unable to remove values from allowlist: %w", err) + } + + log.Infof("removed %d values from allowlist %s", nbDeleted, name) + + return nil +} diff --git a/cmd/crowdsec-cli/clibouncer/add.go b/cmd/crowdsec-cli/clibouncer/add.go index 8c40507a996..7cc74e45fba 100644 --- a/cmd/crowdsec-cli/clibouncer/add.go +++ b/cmd/crowdsec-cli/clibouncer/add.go @@ -24,7 +24,7 @@ func (cli *cliBouncers) add(ctx context.Context, bouncerName string, key string) } } - _, err = cli.db.CreateBouncer(ctx, bouncerName, "", middlewares.HashSHA512(key), types.ApiKeyAuthType) + _, err = cli.db.CreateBouncer(ctx, bouncerName, "", middlewares.HashSHA512(key), types.ApiKeyAuthType, false) if err != nil { return fmt.Errorf("unable to create bouncer: %w", err) } diff --git a/cmd/crowdsec-cli/clibouncer/bouncers.go b/cmd/crowdsec-cli/clibouncer/bouncers.go index 876b613be53..2b0a3556873 100644 --- a/cmd/crowdsec-cli/clibouncer/bouncers.go +++ b/cmd/crowdsec-cli/clibouncer/bouncers.go @@ -77,6 +77,7 @@ type bouncerInfo struct { AuthType string `json:"auth_type"` OS string `json:"os,omitempty"` Featureflags []string `json:"featureflags,omitempty"` + AutoCreated bool `json:"auto_created"` } func newBouncerInfo(b *ent.Bouncer) bouncerInfo { @@ -92,6 +93,7 @@ func newBouncerInfo(b *ent.Bouncer) bouncerInfo { AuthType: b.AuthType, OS: clientinfo.GetOSNameAndVersion(b), Featureflags: clientinfo.GetFeatureFlagList(b), + AutoCreated: b.AutoCreated, } } diff --git a/cmd/crowdsec-cli/clibouncer/delete.go b/cmd/crowdsec-cli/clibouncer/delete.go index 6e2f312d4af..33419f483b6 100644 --- a/cmd/crowdsec-cli/clibouncer/delete.go +++ b/cmd/crowdsec-cli/clibouncer/delete.go @@ -4,25 +4,73 @@ import ( "context" "errors" "fmt" + "strings" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" + "github.com/crowdsecurity/crowdsec/pkg/types" ) +func (cli *cliBouncers) findParentBouncer(bouncerName string, bouncers []*ent.Bouncer) (string, error) { + bouncerPrefix := strings.Split(bouncerName, "@")[0] + for _, bouncer := range bouncers { + if strings.HasPrefix(bouncer.Name, bouncerPrefix) && !bouncer.AutoCreated { + return bouncer.Name, nil + } + } + + return "", errors.New("no parent bouncer found") +} + func (cli *cliBouncers) delete(ctx context.Context, bouncers []string, ignoreMissing bool) error { - for _, bouncerID := range bouncers { - if err := cli.db.DeleteBouncer(ctx, bouncerID); err != nil { - var notFoundErr *database.BouncerNotFoundError + allBouncers, err := cli.db.ListBouncers(ctx) + if err != nil { + return fmt.Errorf("unable to list bouncers: %w", err) + } + for _, bouncerName := range bouncers { + bouncer, err := cli.db.SelectBouncerByName(ctx, bouncerName) + if err != nil { + var notFoundErr *ent.NotFoundError if ignoreMissing && errors.As(err, ¬FoundErr) { - return nil + continue } + return fmt.Errorf("unable to delete bouncer %s: %w", bouncerName, err) + } + + // For TLS bouncers, always delete them, they have no parents + if bouncer.AuthType == types.TlsAuthType { + if err := cli.db.DeleteBouncer(ctx, bouncerName); err != nil { + return fmt.Errorf("unable to delete bouncer %s: %w", bouncerName, err) + } + continue + } + + if bouncer.AutoCreated { + parentBouncer, err := cli.findParentBouncer(bouncerName, allBouncers) + if err != nil { + log.Errorf("bouncer '%s' is auto-created, but couldn't find a parent bouncer", err) + continue + } + log.Warnf("bouncer '%s' is auto-created and cannot be deleted, delete parent bouncer %s instead", bouncerName, parentBouncer) + continue + } + //Try to find all child bouncers and delete them + for _, childBouncer := range allBouncers { + if strings.HasPrefix(childBouncer.Name, bouncerName+"@") && childBouncer.AutoCreated { + if err := cli.db.DeleteBouncer(ctx, childBouncer.Name); err != nil { + return fmt.Errorf("unable to delete bouncer %s: %w", childBouncer.Name, err) + } + log.Infof("bouncer '%s' deleted successfully", childBouncer.Name) + } + } - return fmt.Errorf("unable to delete bouncer: %w", err) + if err := cli.db.DeleteBouncer(ctx, bouncerName); err != nil { + return fmt.Errorf("unable to delete bouncer %s: %w", bouncerName, err) } - log.Infof("bouncer '%s' deleted successfully", bouncerID) + log.Infof("bouncer '%s' deleted successfully", bouncerName) } return nil diff --git a/cmd/crowdsec-cli/clibouncer/inspect.go b/cmd/crowdsec-cli/clibouncer/inspect.go index 6dac386b888..b62344baa9b 100644 --- a/cmd/crowdsec-cli/clibouncer/inspect.go +++ b/cmd/crowdsec-cli/clibouncer/inspect.go @@ -40,6 +40,7 @@ func (cli *cliBouncers) inspectHuman(out io.Writer, bouncer *ent.Bouncer) { {"Last Pull", lastPull}, {"Auth type", bouncer.AuthType}, {"OS", clientinfo.GetOSNameAndVersion(bouncer)}, + {"Auto Created", bouncer.AutoCreated}, }) for _, ff := range clientinfo.GetFeatureFlagList(bouncer) { diff --git a/cmd/crowdsec-cli/clicapi/capi.go b/cmd/crowdsec-cli/clicapi/capi.go index cba66f11104..61d59836fdd 100644 --- a/cmd/crowdsec-cli/clicapi/capi.go +++ b/cmd/crowdsec-cli/clicapi/capi.go @@ -225,6 +225,27 @@ func (cli *cliCapi) Status(ctx context.Context, out io.Writer, hub *cwhub.Hub) e fmt.Fprint(out, "Your instance is enrolled in the console\n") } + switch *cfg.API.Server.OnlineClient.Sharing { + case true: + fmt.Fprint(out, "Sharing signals is enabled\n") + case false: + fmt.Fprint(out, "Sharing signals is disabled\n") + } + + switch *cfg.API.Server.OnlineClient.PullConfig.Community { + case true: + fmt.Fprint(out, "Pulling community blocklist is enabled\n") + case false: + fmt.Fprint(out, "Pulling community blocklist is disabled\n") + } + + switch *cfg.API.Server.OnlineClient.PullConfig.Blocklists { + case true: + fmt.Fprint(out, "Pulling blocklists from the console is enabled\n") + case false: + fmt.Fprint(out, "Pulling blocklists from the console is disabled\n") + } + return nil } diff --git a/cmd/crowdsec-cli/clinotifications/notifications.go b/cmd/crowdsec-cli/clinotifications/notifications.go index baf899c10cf..80ffebeaa23 100644 --- a/cmd/crowdsec-cli/clinotifications/notifications.go +++ b/cmd/crowdsec-cli/clinotifications/notifications.go @@ -260,7 +260,7 @@ func (cli *cliNotifications) notificationConfigFilter(cmd *cobra.Command, args [ return ret, cobra.ShellCompDirectiveNoFileComp } -func (cli cliNotifications) newTestCmd() *cobra.Command { +func (cli *cliNotifications) newTestCmd() *cobra.Command { var ( pluginBroker csplugin.PluginBroker pluginTomb tomb.Tomb diff --git a/cmd/crowdsec-cli/clipapi/papi.go b/cmd/crowdsec-cli/clipapi/papi.go index 461215c3a39..7ac2455d28f 100644 --- a/cmd/crowdsec-cli/clipapi/papi.go +++ b/cmd/crowdsec-cli/clipapi/papi.go @@ -136,7 +136,7 @@ func (cli *cliPapi) sync(ctx context.Context, out io.Writer, db *database.Client t.Go(papi.SyncDecisions) - err = papi.PullOnce(time.Time{}, true) + err = papi.PullOnce(ctx, time.Time{}, true) if err != nil { return fmt.Errorf("unable to sync decisions: %w", err) } diff --git a/cmd/crowdsec-cli/dashboard.go b/cmd/crowdsec-cli/dashboard.go index 53a7dff85a0..7ddac093dcd 100644 --- a/cmd/crowdsec-cli/dashboard.go +++ b/cmd/crowdsec-cli/dashboard.go @@ -243,7 +243,8 @@ func (cli *cliDashboard) newStopCmd() *cobra.Command { } func (cli *cliDashboard) newShowPasswordCmd() *cobra.Command { - cmd := &cobra.Command{Use: "show-password", + cmd := &cobra.Command{ + Use: "show-password", Short: "displays password of metabase.", Args: cobra.NoArgs, DisableAutoGenTag: true, @@ -457,7 +458,6 @@ func checkGroups(forceYes *bool) (*user.Group, error) { func (cli *cliDashboard) chownDatabase(gid string) error { cfg := cli.cfg() intID, err := strconv.Atoi(gid) - if err != nil { return fmt.Errorf("unable to convert group ID to int: %s", err) } diff --git a/cmd/crowdsec-cli/main.go b/cmd/crowdsec-cli/main.go index 1cca03b1d3d..bd12901d580 100644 --- a/cmd/crowdsec-cli/main.go +++ b/cmd/crowdsec-cli/main.go @@ -12,9 +12,11 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clialert" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cliallowlists" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clibouncer" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clicapi" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cliconsole" @@ -163,6 +165,8 @@ func (cli *cliRoot) initialize() error { } } + csConfig.DbConfig.LogLevel = ptr.Of(cli.wantedLogLevel()) + return nil } @@ -279,6 +283,7 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall cmd.AddCommand(cliitem.NewContext(cli.cfg).NewCommand()) cmd.AddCommand(cliitem.NewAppsecConfig(cli.cfg).NewCommand()) cmd.AddCommand(cliitem.NewAppsecRule(cli.cfg).NewCommand()) + cmd.AddCommand(cliallowlists.New(cli.cfg).NewCommand()) cli.addSetup(cmd) diff --git a/cmd/crowdsec-cli/setup.go b/cmd/crowdsec-cli/setup.go index 66c0d71e777..3581d69f052 100644 --- a/cmd/crowdsec-cli/setup.go +++ b/cmd/crowdsec-cli/setup.go @@ -1,4 +1,5 @@ //go:build !no_cscli_setup + package main import ( diff --git a/cmd/crowdsec/appsec.go b/cmd/crowdsec/appsec.go index cb02b137dcd..4320133b063 100644 --- a/cmd/crowdsec/appsec.go +++ b/cmd/crowdsec/appsec.go @@ -1,4 +1,4 @@ -// +build !no_datasource_appsec +//go:build !no_datasource_appsec package main diff --git a/cmd/crowdsec/crowdsec.go b/cmd/crowdsec/crowdsec.go index c44d71d2093..4df66d5a773 100644 --- a/cmd/crowdsec/crowdsec.go +++ b/cmd/crowdsec/crowdsec.go @@ -14,6 +14,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/acquisition" "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" "github.com/crowdsecurity/crowdsec/pkg/alertcontext" + "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" @@ -51,6 +52,15 @@ func initCrowdsec(cConfig *csconfig.Config, hub *cwhub.Hub) (*parser.Parsers, [] return nil, nil, err } + err = apiclient.InitLAPIClient( + context.TODO(), cConfig.API.Client.Credentials.URL, cConfig.API.Client.Credentials.PapiURL, + cConfig.API.Client.Credentials.Login, cConfig.API.Client.Credentials.Password, + hub.GetInstalledListForAPI()) + + if err != nil { + return nil, nil, fmt.Errorf("while initializing LAPIClient: %w", err) + } + datasources, err := LoadAcquisition(cConfig) if err != nil { return nil, nil, fmt.Errorf("while loading acquisition config: %w", err) @@ -116,7 +126,7 @@ func runCrowdsec(cConfig *csconfig.Config, parsers *parser.Parsers, hub *cwhub.H }) bucketWg.Wait() - apiClient, err := AuthenticatedLAPIClient(*cConfig.API.Client.Credentials, hub) + apiClient, err := apiclient.GetLAPIClient() if err != nil { return err } diff --git a/cmd/crowdsec/lapiclient.go b/cmd/crowdsec/lapiclient.go index eed517f9df9..6656ba6b4c2 100644 --- a/cmd/crowdsec/lapiclient.go +++ b/cmd/crowdsec/lapiclient.go @@ -14,7 +14,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/models" ) -func AuthenticatedLAPIClient(credentials csconfig.ApiCredentialsCfg, hub *cwhub.Hub) (*apiclient.ApiClient, error) { +func AuthenticatedLAPIClient(ctx context.Context, credentials csconfig.ApiCredentialsCfg, hub *cwhub.Hub) (*apiclient.ApiClient, error) { apiURL, err := url.Parse(credentials.URL) if err != nil { return nil, fmt.Errorf("parsing api url ('%s'): %w", credentials.URL, err) @@ -44,7 +44,7 @@ func AuthenticatedLAPIClient(credentials csconfig.ApiCredentialsCfg, hub *cwhub. return nil, fmt.Errorf("new client api: %w", err) } - authResp, _, err := client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ + authResp, _, err := client.Auth.AuthenticateWatcher(ctx, models.WatcherAuthRequest{ MachineID: &credentials.Login, Password: &password, Scenarios: itemsForAPI, diff --git a/cmd/crowdsec/main.go b/cmd/crowdsec/main.go index 6d8ca24c335..e414f59f3e2 100644 --- a/cmd/crowdsec/main.go +++ b/cmd/crowdsec/main.go @@ -148,14 +148,14 @@ func (l *labelsMap) String() string { return "labels" } -func (l labelsMap) Set(label string) error { +func (l *labelsMap) Set(label string) error { for _, pair := range strings.Split(label, ",") { split := strings.Split(pair, ":") if len(split) != 2 { return fmt.Errorf("invalid format for label '%s', must be key:value", pair) } - l[split[0]] = split[1] + (*l)[split[0]] = split[1] } return nil diff --git a/cmd/notification-email/main.go b/cmd/notification-email/main.go index 5fc02cdd1d7..b61644611b4 100644 --- a/cmd/notification-email/main.go +++ b/cmd/notification-email/main.go @@ -68,7 +68,7 @@ func (n *EmailPlugin) Configure(ctx context.Context, config *protobufs.Config) ( EncryptionType: "ssltls", AuthType: "login", SenderEmail: "crowdsec@crowdsec.local", - HeloHost: "localhost", + HeloHost: "localhost", } if err := yaml.Unmarshal(config.Config, &d); err != nil { diff --git a/go.mod b/go.mod index c889b62cb8c..f4bd9379a2d 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/crowdsecurity/crowdsec -go 1.23 +go 1.23.3 // Don't use the toolchain directive to avoid uncontrolled downloads during // a build, especially in sandboxed environments (freebsd, gentoo...). diff --git a/pkg/acquisition/acquisition.go b/pkg/acquisition/acquisition.go index 1ad385105d3..ef5a413b91f 100644 --- a/pkg/acquisition/acquisition.go +++ b/pkg/acquisition/acquisition.go @@ -337,6 +337,20 @@ func GetMetrics(sources []DataSource, aggregated bool) error { return nil } +// There's no need for an actual deep copy +// The event is almost empty, we are mostly interested in allocating new maps for Parsed/Meta/... +func copyEvent(evt types.Event, line string) types.Event { + evtCopy := types.MakeEvent(evt.ExpectMode == types.TIMEMACHINE, evt.Type, evt.Process) + evtCopy.Line = evt.Line + evtCopy.Line.Raw = line + evtCopy.Line.Labels = make(map[string]string) + for k, v := range evt.Line.Labels { + evtCopy.Line.Labels[k] = v + } + + return evtCopy +} + func transform(transformChan chan types.Event, output chan types.Event, AcquisTomb *tomb.Tomb, transformRuntime *vm.Program, logger *log.Entry) { defer trace.CatchPanic("crowdsec/acquis") logger.Infof("transformer started") @@ -363,8 +377,7 @@ func transform(transformChan chan types.Event, output chan types.Event, AcquisTo switch v := out.(type) { case string: logger.Tracef("transform expression returned %s", v) - evt.Line.Raw = v - output <- evt + output <- copyEvent(evt, v) case []interface{}: logger.Tracef("transform expression returned %v", v) //nolint:asasalint // We actually want to log the slice content @@ -373,19 +386,16 @@ func transform(transformChan chan types.Event, output chan types.Event, AcquisTo if !ok { logger.Errorf("transform expression returned []interface{}, but cannot assert an element to string") output <- evt - continue } - evt.Line.Raw = l - output <- evt + output <- copyEvent(evt, l) } case []string: logger.Tracef("transform expression returned %v", v) for _, line := range v { - evt.Line.Raw = line - output <- evt + output <- copyEvent(evt, line) } default: logger.Errorf("transform expression returned an invalid type %T, sending event as-is", out) diff --git a/pkg/acquisition/configuration/configuration.go b/pkg/acquisition/configuration/configuration.go index 3e27da1b9e6..a9d570d2788 100644 --- a/pkg/acquisition/configuration/configuration.go +++ b/pkg/acquisition/configuration/configuration.go @@ -13,12 +13,14 @@ type DataSourceCommonCfg struct { UseTimeMachine bool `yaml:"use_time_machine,omitempty"` UniqueId string `yaml:"unique_id,omitempty"` TransformExpr string `yaml:"transform,omitempty"` - Config map[string]interface{} `yaml:",inline"` //to keep the datasource-specific configuration directives + Config map[string]interface{} `yaml:",inline"` // to keep the datasource-specific configuration directives } -var TAIL_MODE = "tail" -var CAT_MODE = "cat" -var SERVER_MODE = "server" // No difference with tail, just a bit more verbose +var ( + TAIL_MODE = "tail" + CAT_MODE = "cat" + SERVER_MODE = "server" // No difference with tail, just a bit more verbose +) const ( METRICS_NONE = iota diff --git a/pkg/acquisition/http.go b/pkg/acquisition/http.go new file mode 100644 index 00000000000..59745772b62 --- /dev/null +++ b/pkg/acquisition/http.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_http + +package acquisition + +import ( + httpacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/http" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("http", func() DataSource { return &httpacquisition.HTTPSource{} }) +} diff --git a/pkg/acquisition/modules/appsec/appsec.go b/pkg/acquisition/modules/appsec/appsec.go index a6dcffe89a2..ed6c85ea95e 100644 --- a/pkg/acquisition/modules/appsec/appsec.go +++ b/pkg/acquisition/modules/appsec/appsec.go @@ -20,6 +20,7 @@ import ( "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/appsec" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/types" @@ -31,6 +32,8 @@ const ( ) var DefaultAuthCacheDuration = (1 * time.Minute) +var negativeAllowlistCacheDuration = (5 * time.Minute) +var positiveAllowlistCacheDuration = (5 * time.Minute) // configuration structure of the acquis for the application security engine type AppsecSourceConfig struct { @@ -41,6 +44,7 @@ type AppsecSourceConfig struct { Path string `yaml:"path"` Routines int `yaml:"routines"` AppsecConfig string `yaml:"appsec_config"` + AppsecConfigs []string `yaml:"appsec_configs"` AppsecConfigPath string `yaml:"appsec_config_path"` AuthCacheDuration *time.Duration `yaml:"auth_cache_duration"` configuration.DataSourceCommonCfg `yaml:",inline"` @@ -48,18 +52,20 @@ type AppsecSourceConfig struct { // runtime structure of AppsecSourceConfig type AppsecSource struct { - metricsLevel int - config AppsecSourceConfig - logger *log.Entry - mux *http.ServeMux - server *http.Server - outChan chan types.Event - InChan chan appsec.ParsedRequest - AppsecRuntime *appsec.AppsecRuntimeConfig - AppsecConfigs map[string]appsec.AppsecConfig - lapiURL string - AuthCache AuthCache - AppsecRunners []AppsecRunner // one for each go-routine + metricsLevel int + config AppsecSourceConfig + logger *log.Entry + mux *http.ServeMux + server *http.Server + outChan chan types.Event + InChan chan appsec.ParsedRequest + AppsecRuntime *appsec.AppsecRuntimeConfig + AppsecConfigs map[string]appsec.AppsecConfig + lapiURL string + AuthCache AuthCache + AppsecRunners []AppsecRunner // one for each go-routine + allowlistCache allowlistCache + apiClient *apiclient.ApiClient } // Struct to handle cache of authentication @@ -68,6 +74,17 @@ type AuthCache struct { mu sync.RWMutex } +// FIXME: auth and allowlist should probably be merged to a common structure +type allowlistCache struct { + mu sync.RWMutex + allowlist map[string]allowlistCacheEntry +} + +type allowlistCacheEntry struct { + allowlisted bool + expiration time.Time +} + func NewAuthCache() AuthCache { return AuthCache{ APIKeys: make(map[string]time.Time, 0), @@ -85,9 +102,34 @@ func (ac *AuthCache) Get(apiKey string) (time.Time, bool) { ac.mu.RLock() expiration, exists := ac.APIKeys[apiKey] ac.mu.RUnlock() + return expiration, exists } +func NewAllowlistCache() allowlistCache { + return allowlistCache{ + allowlist: make(map[string]allowlistCacheEntry, 0), + mu: sync.RWMutex{}, + } +} + +func (ac *allowlistCache) Set(value string, allowlisted bool, expiration time.Time) { + ac.mu.Lock() + ac.allowlist[value] = allowlistCacheEntry{ + allowlisted: allowlisted, + expiration: expiration, + } + ac.mu.Unlock() +} + +func (ac *allowlistCache) Get(value string) (bool, time.Time, bool) { + ac.mu.RLock() + entry, exists := ac.allowlist[value] + ac.mu.RUnlock() + + return entry.allowlisted, entry.expiration, exists +} + // @tko + @sbl : we might want to get rid of that or improve it type BodyResponse struct { Action string `json:"action"` @@ -120,14 +162,19 @@ func (w *AppsecSource) UnmarshalConfig(yamlConfig []byte) error { w.config.Routines = 1 } - if w.config.AppsecConfig == "" && w.config.AppsecConfigPath == "" { + if w.config.AppsecConfig == "" && w.config.AppsecConfigPath == "" && len(w.config.AppsecConfigs) == 0 { return errors.New("appsec_config or appsec_config_path must be set") } + if (w.config.AppsecConfig != "" || w.config.AppsecConfigPath != "") && len(w.config.AppsecConfigs) != 0 { + return errors.New("appsec_config and appsec_config_path are mutually exclusive with appsec_configs") + } + if w.config.Name == "" { if w.config.ListenSocket != "" && w.config.ListenAddr == "" { w.config.Name = w.config.ListenSocket } + if w.config.ListenSocket == "" { w.config.Name = fmt.Sprintf("%s%s", w.config.ListenAddr, w.config.Path) } @@ -153,6 +200,7 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLe if err != nil { return fmt.Errorf("unable to parse appsec configuration: %w", err) } + w.logger = logger w.metricsLevel = MetricsLevel w.logger.Tracef("Appsec configuration: %+v", w.config) @@ -172,6 +220,9 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLe w.InChan = make(chan appsec.ParsedRequest) appsecCfg := appsec.AppsecConfig{Logger: w.logger.WithField("component", "appsec_config")} + //we keep the datasource name + appsecCfg.Name = w.config.Name + // let's load the associated appsec_config: if w.config.AppsecConfigPath != "" { err := appsecCfg.LoadByPath(w.config.AppsecConfigPath) @@ -183,10 +234,20 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLe if err != nil { return fmt.Errorf("unable to load appsec_config: %w", err) } + } else if len(w.config.AppsecConfigs) > 0 { + for _, appsecConfig := range w.config.AppsecConfigs { + err := appsecCfg.Load(appsecConfig) + if err != nil { + return fmt.Errorf("unable to load appsec_config: %w", err) + } + } } else { return errors.New("no appsec_config provided") } + // Now we can set up the logger + appsecCfg.SetUpLogger() + w.AppsecRuntime, err = appsecCfg.Build() if err != nil { return fmt.Errorf("unable to build appsec_config: %w", err) @@ -211,10 +272,12 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLe AppsecRuntime: &wrt, Labels: w.config.Labels, } + err := runner.Init(appsecCfg.GetDataDir()) if err != nil { return fmt.Errorf("unable to initialize runner: %w", err) } + w.AppsecRunners[nbRoutine] = runner } @@ -222,6 +285,13 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLe // We donĀ“t use the wrapper provided by coraza because we want to fully control what happens when a rule match to send the information in crowdsec w.mux.HandleFunc(w.config.Path, w.appsecHandler) + + w.apiClient, err = apiclient.GetLAPIClient() + if err != nil { + return fmt.Errorf("unable to get authenticated LAPI client: %w", err) + } + w.allowlistCache = NewAllowlistCache() + return nil } @@ -243,10 +313,12 @@ func (w *AppsecSource) OneShotAcquisition(_ context.Context, _ chan types.Event, func (w *AppsecSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { w.outChan = out + t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/appsec/live") w.logger.Infof("%d appsec runner to start", len(w.AppsecRunners)) + for _, runner := range w.AppsecRunners { runner.outChan = out t.Go(func() error { @@ -254,6 +326,7 @@ func (w *AppsecSource) StreamingAcquisition(ctx context.Context, out chan types. return runner.Run(t) }) } + t.Go(func() error { if w.config.ListenSocket != "" { w.logger.Infof("creating unix socket %s", w.config.ListenSocket) @@ -268,10 +341,11 @@ func (w *AppsecSource) StreamingAcquisition(ctx context.Context, out chan types. } else { err = w.server.Serve(listener) } - if err != nil && err != http.ErrServerClosed { + if err != nil && !errors.Is(err, http.ErrServerClosed) { return fmt.Errorf("appsec server failed: %w", err) } } + return nil }) t.Go(func() error { @@ -288,6 +362,7 @@ func (w *AppsecSource) StreamingAcquisition(ctx context.Context, out chan types. return fmt.Errorf("appsec server failed: %w", err) } } + return nil }) <-t.Dying() @@ -297,6 +372,7 @@ func (w *AppsecSource) StreamingAcquisition(ctx context.Context, out chan types. w.server.Shutdown(ctx) return nil }) + return nil } @@ -334,6 +410,29 @@ func (w *AppsecSource) IsAuth(apiKey string) bool { return resp.StatusCode == http.StatusOK } +func (w *AppsecSource) isAllowlisted(ctx context.Context, value string) bool { + var err error + + allowlisted, expiration, exists := w.allowlistCache.Get(value) + if exists && !time.Now().After(expiration) { + return allowlisted + } + + allowlisted, _, err = w.apiClient.Allowlists.CheckIfAllowlisted(ctx, value) + if err != nil { + w.logger.Errorf("unable to check if %s is allowlisted: %s", value, err) + return false + } + + if allowlisted { + w.allowlistCache.Set(value, allowlisted, time.Now().Add(positiveAllowlistCacheDuration)) + } else { + w.allowlistCache.Set(value, allowlisted, time.Now().Add(negativeAllowlistCacheDuration)) + } + + return allowlisted +} + // should this be in the runner ? func (w *AppsecSource) appsecHandler(rw http.ResponseWriter, r *http.Request) { w.logger.Debugf("Received request from '%s' on %s", r.RemoteAddr, r.URL.Path) @@ -359,6 +458,25 @@ func (w *AppsecSource) appsecHandler(rw http.ResponseWriter, r *http.Request) { w.AuthCache.Set(apiKey, time.Now().Add(*w.config.AuthCacheDuration)) } + // check if the client IP is allowlisted + if w.isAllowlisted(r.Context(), clientIP) { + w.logger.Infof("%s is allowlisted by LAPI, not processing", clientIP) + statusCode, appsecResponse := w.AppsecRuntime.GenerateResponse(appsec.AppsecTempResponse{ + InBandInterrupt: false, + OutOfBandInterrupt: false, + Action: appsec.AllowRemediation, + }, w.logger) + body, err := json.Marshal(appsecResponse) + if err != nil { + w.logger.Errorf("unable to serialize response: %s", err) + rw.WriteHeader(http.StatusInternalServerError) + return + } + rw.WriteHeader(statusCode) + rw.Write(body) + return + } + // parse the request only once parsedRequest, err := appsec.NewParsedRequestFromRequest(r, w.logger) if err != nil { @@ -391,6 +509,7 @@ func (w *AppsecSource) appsecHandler(rw http.ResponseWriter, r *http.Request) { logger.Debugf("Response: %+v", appsecResponse) rw.WriteHeader(statusCode) + body, err := json.Marshal(appsecResponse) if err != nil { logger.Errorf("unable to serialize response: %s", err) diff --git a/pkg/acquisition/modules/appsec/appsec_hooks_test.go b/pkg/acquisition/modules/appsec/appsec_hooks_test.go index c549d2ef1d1..d87384a0189 100644 --- a/pkg/acquisition/modules/appsec/appsec_hooks_test.go +++ b/pkg/acquisition/modules/appsec/appsec_hooks_test.go @@ -341,7 +341,6 @@ func TestAppsecOnMatchHooks(t *testing.T) { } func TestAppsecPreEvalHooks(t *testing.T) { - tests := []appsecRuleTest{ { name: "Basic pre_eval hook to disable inband rule", @@ -403,7 +402,6 @@ func TestAppsecPreEvalHooks(t *testing.T) { require.Len(t, responses, 1) require.True(t, responses[0].InBandInterrupt) - }, }, { @@ -670,7 +668,6 @@ func TestAppsecPreEvalHooks(t *testing.T) { } func TestAppsecRemediationConfigHooks(t *testing.T) { - tests := []appsecRuleTest{ { name: "Basic matching rule", @@ -759,6 +756,7 @@ func TestAppsecRemediationConfigHooks(t *testing.T) { }) } } + func TestOnMatchRemediationHooks(t *testing.T) { tests := []appsecRuleTest{ { diff --git a/pkg/acquisition/modules/appsec/appsec_rules_test.go b/pkg/acquisition/modules/appsec/appsec_rules_test.go index 909f16357ed..00093c5a5ad 100644 --- a/pkg/acquisition/modules/appsec/appsec_rules_test.go +++ b/pkg/acquisition/modules/appsec/appsec_rules_test.go @@ -28,7 +28,8 @@ func TestAppsecRuleMatches(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, @@ -59,7 +60,8 @@ func TestAppsecRuleMatches(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Args: url.Values{"foo": []string{"tutu"}}, @@ -84,7 +86,8 @@ func TestAppsecRuleMatches(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, @@ -110,7 +113,8 @@ func TestAppsecRuleMatches(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, @@ -136,7 +140,8 @@ func TestAppsecRuleMatches(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, @@ -165,7 +170,8 @@ func TestAppsecRuleMatches(t *testing.T) { {Filter: "IsInBand == true", Apply: []string{"SetRemediation('captcha')"}}, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Args: url.Values{"foo": []string{"bla"}}, @@ -192,7 +198,8 @@ func TestAppsecRuleMatches(t *testing.T) { {Filter: "IsInBand == true", Apply: []string{"SetReturnCode(418)"}}, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Args: url.Values{"foo": []string{"bla"}}, @@ -219,7 +226,8 @@ func TestAppsecRuleMatches(t *testing.T) { {Filter: "IsInBand == true", Apply: []string{"SetRemediationByName('rule42', 'captcha')"}}, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Args: url.Values{"foo": []string{"bla"}}, @@ -243,7 +251,8 @@ func TestAppsecRuleMatches(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Headers: http.Header{"Cookie": []string{"foo=toto"}}, @@ -273,7 +282,8 @@ func TestAppsecRuleMatches(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Headers: http.Header{"Cookie": []string{"foo=toto; bar=tutu"}}, @@ -303,7 +313,8 @@ func TestAppsecRuleMatches(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Headers: http.Header{"Cookie": []string{"bar=tutu; tututata=toto"}}, @@ -333,7 +344,8 @@ func TestAppsecRuleMatches(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Headers: http.Header{"Content-Type": []string{"multipart/form-data; boundary=boundary"}}, @@ -354,6 +366,32 @@ toto require.Len(t, events[1].Appsec.MatchedRules, 1) require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + require.Len(t, responses, 1) + require.True(t, responses[0].InBandInterrupt) + }, + }, + { + name: "Basic matching IP address", + expected_load_ok: true, + inband_native_rules: []string{ + "SecRule REMOTE_ADDR \"@ipMatch 1.2.3.4\" \"id:1,phase:1,log,deny,msg: 'block ip'\"", + }, + input_request: appsec.ParsedRequest{ + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", + Method: "GET", + URI: "/urllll", + Headers: http.Header{"Content-Type": []string{"multipart/form-data; boundary=boundary"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + + require.Equal(t, types.LOG, events[1].Type) + require.True(t, events[1].Appsec.HasInBandMatches) + require.Len(t, events[1].Appsec.MatchedRules, 1) + require.Equal(t, "block ip", events[1].Appsec.MatchedRules[0]["msg"]) + require.Len(t, responses, 1) require.True(t, responses[0].InBandInterrupt) }, @@ -381,7 +419,8 @@ func TestAppsecRuleTransforms(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/toto", }, @@ -404,7 +443,8 @@ func TestAppsecRuleTransforms(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/TOTO", }, @@ -427,7 +467,8 @@ func TestAppsecRuleTransforms(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/toto", }, @@ -451,7 +492,8 @@ func TestAppsecRuleTransforms(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/?foo=dG90bw", }, @@ -475,7 +517,8 @@ func TestAppsecRuleTransforms(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/?foo=dG90bw===", }, @@ -499,7 +542,8 @@ func TestAppsecRuleTransforms(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/?foo=toto", }, @@ -523,7 +567,8 @@ func TestAppsecRuleTransforms(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/?foo=%42%42%2F%41", }, @@ -547,7 +592,8 @@ func TestAppsecRuleTransforms(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/?foo=%20%20%42%42%2F%41%20%20", }, @@ -585,7 +631,8 @@ func TestAppsecRuleZones(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/foobar?something=toto&foobar=smth", }, @@ -612,7 +659,8 @@ func TestAppsecRuleZones(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/foobar?something=toto&foobar=smth", }, @@ -639,7 +687,8 @@ func TestAppsecRuleZones(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/", Body: []byte("smth=toto&foobar=other"), @@ -668,7 +717,8 @@ func TestAppsecRuleZones(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/", Body: []byte("smth=toto&foobar=other"), @@ -697,7 +747,8 @@ func TestAppsecRuleZones(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/", Headers: http.Header{"foobar": []string{"toto"}}, @@ -725,7 +776,8 @@ func TestAppsecRuleZones(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/", Headers: http.Header{"foobar": []string{"toto"}}, @@ -748,7 +800,8 @@ func TestAppsecRuleZones(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/", }, @@ -770,7 +823,8 @@ func TestAppsecRuleZones(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/", Proto: "HTTP/3.1", @@ -793,7 +847,8 @@ func TestAppsecRuleZones(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/foobar", }, @@ -815,7 +870,8 @@ func TestAppsecRuleZones(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/foobar?a=b", }, @@ -837,7 +893,8 @@ func TestAppsecRuleZones(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/", Body: []byte("foobar=42421"), diff --git a/pkg/acquisition/modules/appsec/appsec_runner.go b/pkg/acquisition/modules/appsec/appsec_runner.go index de34b62d704..d4535d3f9a2 100644 --- a/pkg/acquisition/modules/appsec/appsec_runner.go +++ b/pkg/acquisition/modules/appsec/appsec_runner.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "slices" + "strings" "time" "github.com/prometheus/client_golang/prometheus" @@ -31,23 +32,38 @@ type AppsecRunner struct { logger *log.Entry } +func (r *AppsecRunner) MergeDedupRules(collections []appsec.AppsecCollection, logger *log.Entry) string { + var rulesArr []string + dedupRules := make(map[string]struct{}) + + for _, collection := range collections { + for _, rule := range collection.Rules { + if _, ok := dedupRules[rule]; !ok { + rulesArr = append(rulesArr, rule) + dedupRules[rule] = struct{}{} + } else { + logger.Debugf("Discarding duplicate rule : %s", rule) + } + } + } + if len(rulesArr) != len(dedupRules) { + logger.Warningf("%d rules were discarded as they were duplicates", len(rulesArr)-len(dedupRules)) + } + + return strings.Join(rulesArr, "\n") +} + func (r *AppsecRunner) Init(datadir string) error { var err error fs := os.DirFS(datadir) - inBandRules := "" - outOfBandRules := "" - - for _, collection := range r.AppsecRuntime.InBandRules { - inBandRules += collection.String() - } - - for _, collection := range r.AppsecRuntime.OutOfBandRules { - outOfBandRules += collection.String() - } inBandLogger := r.logger.Dup().WithField("band", "inband") outBandLogger := r.logger.Dup().WithField("band", "outband") + //While loading rules, we dedup rules based on their content, while keeping the order + inBandRules := r.MergeDedupRules(r.AppsecRuntime.InBandRules, inBandLogger) + outOfBandRules := r.MergeDedupRules(r.AppsecRuntime.OutOfBandRules, outBandLogger) + //setting up inband engine inbandCfg := coraza.NewWAFConfig().WithDirectives(inBandRules).WithRootFS(fs).WithDebugLogger(appsec.NewCrzLogger(inBandLogger)) if !r.AppsecRuntime.Config.InbandOptions.DisableBodyInspection { @@ -74,6 +90,9 @@ func (r *AppsecRunner) Init(datadir string) error { outbandCfg = outbandCfg.WithRequestBodyInMemoryLimit(*r.AppsecRuntime.Config.OutOfBandOptions.RequestBodyInMemoryLimit) } r.AppsecOutbandEngine, err = coraza.NewWAF(outbandCfg) + if err != nil { + return fmt.Errorf("unable to initialize outband engine : %w", err) + } if r.AppsecRuntime.DisabledInBandRulesTags != nil { for _, tag := range r.AppsecRuntime.DisabledInBandRulesTags { @@ -102,10 +121,6 @@ func (r *AppsecRunner) Init(datadir string) error { r.logger.Tracef("Loaded inband rules: %+v", r.AppsecInbandEngine.GetRuleGroup().GetRules()) r.logger.Tracef("Loaded outband rules: %+v", r.AppsecOutbandEngine.GetRuleGroup().GetRules()) - if err != nil { - return fmt.Errorf("unable to initialize outband engine : %w", err) - } - return nil } @@ -135,7 +150,7 @@ func (r *AppsecRunner) processRequest(tx appsec.ExtendedTransaction, request *ap //FIXME: should we abort here ? } - request.Tx.ProcessConnection(request.RemoteAddr, 0, "", 0) + request.Tx.ProcessConnection(request.ClientIP, 0, "", 0) for k, v := range request.Args { for _, vv := range v { @@ -249,7 +264,7 @@ func (r *AppsecRunner) handleInBandInterrupt(request *appsec.ParsedRequest) { // Should the in band match trigger an overflow ? if r.AppsecRuntime.Response.SendAlert { - appsecOvlfw, err := AppsecEventGeneration(evt) + appsecOvlfw, err := AppsecEventGeneration(evt, request.HTTPRequest) if err != nil { r.logger.Errorf("unable to generate appsec event : %s", err) return @@ -293,7 +308,7 @@ func (r *AppsecRunner) handleOutBandInterrupt(request *appsec.ParsedRequest) { // Should the match trigger an overflow ? if r.AppsecRuntime.Response.SendAlert { - appsecOvlfw, err := AppsecEventGeneration(evt) + appsecOvlfw, err := AppsecEventGeneration(evt, request.HTTPRequest) if err != nil { r.logger.Errorf("unable to generate appsec event : %s", err) return @@ -363,7 +378,6 @@ func (r *AppsecRunner) handleRequest(request *appsec.ParsedRequest) { // time spent to process inband AND out of band rules globalParsingElapsed := time.Since(startGlobalParsing) AppsecGlobalParsingHistogram.With(prometheus.Labels{"source": request.RemoteAddrNormalized, "appsec_engine": request.AppsecEngine}).Observe(globalParsingElapsed.Seconds()) - } func (r *AppsecRunner) Run(t *tomb.Tomb) error { diff --git a/pkg/acquisition/modules/appsec/appsec_runner_test.go b/pkg/acquisition/modules/appsec/appsec_runner_test.go new file mode 100644 index 00000000000..d07fb153186 --- /dev/null +++ b/pkg/acquisition/modules/appsec/appsec_runner_test.go @@ -0,0 +1,153 @@ +package appsecacquisition + +import ( + "testing" + + "github.com/crowdsecurity/crowdsec/pkg/appsec/appsec_rule" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" +) + +func TestAppsecRuleLoad(t *testing.T) { + log.SetLevel(log.TraceLevel) + tests := []appsecRuleTest{ + { + name: "simple rule load", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + }, + afterload_asserts: func(runner AppsecRunner) { + require.Len(t, runner.AppsecInbandEngine.GetRuleGroup().GetRules(), 1) + }, + }, + { + name: "simple native rule load", + expected_load_ok: true, + inband_native_rules: []string{ + `Secrule REQUEST_HEADERS:Content-Type "@rx ^application/x-www-form-urlencoded" "id:100,phase:1,pass,nolog,noauditlog,ctl:requestBodyProcessor=URLENCODED"`, + }, + afterload_asserts: func(runner AppsecRunner) { + require.Len(t, runner.AppsecInbandEngine.GetRuleGroup().GetRules(), 1) + }, + }, + { + name: "simple native rule load (2)", + expected_load_ok: true, + inband_native_rules: []string{ + `Secrule REQUEST_HEADERS:Content-Type "@rx ^application/x-www-form-urlencoded" "id:100,phase:1,pass,nolog,noauditlog,ctl:requestBodyProcessor=URLENCODED"`, + `Secrule REQUEST_HEADERS:Content-Type "@rx ^multipart/form-data" "id:101,phase:1,pass,nolog,noauditlog,ctl:requestBodyProcessor=MULTIPART"`, + }, + afterload_asserts: func(runner AppsecRunner) { + require.Len(t, runner.AppsecInbandEngine.GetRuleGroup().GetRules(), 2) + }, + }, + { + name: "simple native rule load + dedup", + expected_load_ok: true, + inband_native_rules: []string{ + `Secrule REQUEST_HEADERS:Content-Type "@rx ^application/x-www-form-urlencoded" "id:100,phase:1,pass,nolog,noauditlog,ctl:requestBodyProcessor=URLENCODED"`, + `Secrule REQUEST_HEADERS:Content-Type "@rx ^multipart/form-data" "id:101,phase:1,pass,nolog,noauditlog,ctl:requestBodyProcessor=MULTIPART"`, + `Secrule REQUEST_HEADERS:Content-Type "@rx ^application/x-www-form-urlencoded" "id:100,phase:1,pass,nolog,noauditlog,ctl:requestBodyProcessor=URLENCODED"`, + }, + afterload_asserts: func(runner AppsecRunner) { + require.Len(t, runner.AppsecInbandEngine.GetRuleGroup().GetRules(), 2) + }, + }, + { + name: "multi simple rule load", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + { + Name: "rule2", + Zones: []string{"ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + }, + afterload_asserts: func(runner AppsecRunner) { + require.Len(t, runner.AppsecInbandEngine.GetRuleGroup().GetRules(), 2) + }, + }, + { + name: "multi simple rule load", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + { + Name: "rule2", + Zones: []string{"ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + }, + afterload_asserts: func(runner AppsecRunner) { + require.Len(t, runner.AppsecInbandEngine.GetRuleGroup().GetRules(), 2) + }, + }, + { + name: "imbricated rule load", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + + Or: []appsec_rule.CustomRule{ + { + //Name: "rule1", + Zones: []string{"ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + { + //Name: "rule1", + Zones: []string{"ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "tutu"}, + }, + { + //Name: "rule1", + Zones: []string{"ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "tata"}, + }, { + //Name: "rule1", + Zones: []string{"ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "titi"}, + }, + }, + }, + }, + afterload_asserts: func(runner AppsecRunner) { + require.Len(t, runner.AppsecInbandEngine.GetRuleGroup().GetRules(), 4) + }, + }, + { + name: "invalid inband rule", + expected_load_ok: false, + inband_native_rules: []string{ + "this_is_not_a_rule", + }, + }, + { + name: "invalid outofband rule", + expected_load_ok: false, + outofband_native_rules: []string{ + "this_is_not_a_rule", + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} diff --git a/pkg/acquisition/modules/appsec/appsec_test.go b/pkg/acquisition/modules/appsec/appsec_test.go index d2079b43726..c0af1002f49 100644 --- a/pkg/acquisition/modules/appsec/appsec_test.go +++ b/pkg/acquisition/modules/appsec/appsec_test.go @@ -18,6 +18,8 @@ type appsecRuleTest struct { expected_load_ok bool inband_rules []appsec_rule.CustomRule outofband_rules []appsec_rule.CustomRule + inband_native_rules []string + outofband_native_rules []string on_load []appsec.Hook pre_eval []appsec.Hook post_eval []appsec.Hook @@ -28,6 +30,7 @@ type appsecRuleTest struct { DefaultRemediation string DefaultPassAction string input_request appsec.ParsedRequest + afterload_asserts func(runner AppsecRunner) output_asserts func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) } @@ -53,6 +56,8 @@ func loadAppSecEngine(test appsecRuleTest, t *testing.T) { inbandRules = append(inbandRules, strRule) } + inbandRules = append(inbandRules, test.inband_native_rules...) + outofbandRules = append(outofbandRules, test.outofband_native_rules...) for ridx, rule := range test.outofband_rules { strRule, _, err := rule.Convert(appsec_rule.ModsecurityRuleType, rule.Name) if err != nil { @@ -61,7 +66,8 @@ func loadAppSecEngine(test appsecRuleTest, t *testing.T) { outofbandRules = append(outofbandRules, strRule) } - appsecCfg := appsec.AppsecConfig{Logger: logger, + appsecCfg := appsec.AppsecConfig{ + Logger: logger, OnLoad: test.on_load, PreEval: test.pre_eval, PostEval: test.post_eval, @@ -70,7 +76,8 @@ func loadAppSecEngine(test appsecRuleTest, t *testing.T) { UserBlockedHTTPCode: test.UserBlockedHTTPCode, UserPassedHTTPCode: test.UserPassedHTTPCode, DefaultRemediation: test.DefaultRemediation, - DefaultPassAction: test.DefaultPassAction} + DefaultPassAction: test.DefaultPassAction, + } AppsecRuntime, err := appsecCfg.Build() if err != nil { t.Fatalf("unable to build appsec runtime : %s", err) @@ -91,8 +98,21 @@ func loadAppSecEngine(test appsecRuleTest, t *testing.T) { } err = runner.Init("/tmp/") if err != nil { + if !test.expected_load_ok { + return + } t.Fatalf("unable to initialize runner : %s", err) } + if !test.expected_load_ok { + t.Fatalf("expected load to fail but it didn't") + } + + if test.afterload_asserts != nil { + //afterload asserts are just to evaluate the state of the runner after the rules have been loaded + //if it's present, don't try to process requests + test.afterload_asserts(runner) + return + } input := test.input_request input.ResponseChannel = make(chan appsec.AppsecTempResponse) diff --git a/pkg/acquisition/modules/appsec/bodyprocessors/raw.go b/pkg/acquisition/modules/appsec/bodyprocessors/raw.go index e2e23eb57ae..aa467ecf048 100644 --- a/pkg/acquisition/modules/appsec/bodyprocessors/raw.go +++ b/pkg/acquisition/modules/appsec/bodyprocessors/raw.go @@ -9,8 +9,7 @@ import ( "github.com/crowdsecurity/coraza/v3/experimental/plugins/plugintypes" ) -type rawBodyProcessor struct { -} +type rawBodyProcessor struct{} type setterInterface interface { Set(string) @@ -33,9 +32,7 @@ func (*rawBodyProcessor) ProcessResponse(reader io.Reader, v plugintypes.Transac return nil } -var ( - _ plugintypes.BodyProcessor = &rawBodyProcessor{} -) +var _ plugintypes.BodyProcessor = &rawBodyProcessor{} //nolint:gochecknoinits //Coraza recommends to use init() for registering plugins func init() { diff --git a/pkg/acquisition/modules/appsec/utils.go b/pkg/acquisition/modules/appsec/utils.go index 4fb1a979d14..8995b305680 100644 --- a/pkg/acquisition/modules/appsec/utils.go +++ b/pkg/acquisition/modules/appsec/utils.go @@ -1,10 +1,10 @@ package appsecacquisition import ( + "errors" "fmt" "net" - "slices" - "strconv" + "net/http" "time" "github.com/oschwald/geoip2-golang" @@ -22,29 +22,44 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -var appsecMetaKeys = []string{ - "id", - "name", - "method", - "uri", - "matched_zones", - "msg", -} +func AppsecEventGenerationGeoIPEnrich(src *models.Source) error { -func appendMeta(meta models.Meta, key string, value string) models.Meta { - if value == "" { - return meta + if src == nil || src.Scope == nil || *src.Scope != types.Ip { + return errors.New("source is nil or not an IP") } - meta = append(meta, &models.MetaItems0{ - Key: key, - Value: value, - }) + //GeoIP enrich + asndata, err := exprhelpers.GeoIPASNEnrich(src.IP) + + if err != nil { + return err + } else if asndata != nil { + record := asndata.(*geoip2.ASN) + src.AsName = record.AutonomousSystemOrganization + src.AsNumber = fmt.Sprintf("%d", record.AutonomousSystemNumber) + } - return meta + cityData, err := exprhelpers.GeoIPEnrich(src.IP) + if err != nil { + return err + } else if cityData != nil { + record := cityData.(*geoip2.City) + src.Cn = record.Country.IsoCode + src.Latitude = float32(record.Location.Latitude) + src.Longitude = float32(record.Location.Longitude) + } + + rangeData, err := exprhelpers.GeoIPRangeEnrich(src.IP) + if err != nil { + return err + } else if rangeData != nil { + record := rangeData.(*net.IPNet) + src.Range = record.String() + } + return nil } -func AppsecEventGeneration(inEvt types.Event) (*types.Event, error) { +func AppsecEventGeneration(inEvt types.Event, request *http.Request) (*types.Event, error) { // if the request didnd't trigger inband rules, we don't want to generate an event to LAPI/CAPI if !inEvt.Appsec.HasInBandMatches { return nil, nil @@ -60,34 +75,12 @@ func AppsecEventGeneration(inEvt types.Event) (*types.Event, error) { Scope: ptr.Of(types.Ip), } - asndata, err := exprhelpers.GeoIPASNEnrich(sourceIP) - - if err != nil { - log.Errorf("Unable to enrich ip '%s' for ASN: %s", sourceIP, err) - } else if asndata != nil { - record := asndata.(*geoip2.ASN) - source.AsName = record.AutonomousSystemOrganization - source.AsNumber = fmt.Sprintf("%d", record.AutonomousSystemNumber) - } - - cityData, err := exprhelpers.GeoIPEnrich(sourceIP) - if err != nil { - log.Errorf("Unable to enrich ip '%s' for geo data: %s", sourceIP, err) - } else if cityData != nil { - record := cityData.(*geoip2.City) - source.Cn = record.Country.IsoCode - source.Latitude = float32(record.Location.Latitude) - source.Longitude = float32(record.Location.Longitude) - } - - rangeData, err := exprhelpers.GeoIPRangeEnrich(sourceIP) - if err != nil { - log.Errorf("Unable to enrich ip '%s' for range: %s", sourceIP, err) - } else if rangeData != nil { - record := rangeData.(*net.IPNet) - source.Range = record.String() + // Enrich source with GeoIP data + if err := AppsecEventGenerationGeoIPEnrich(&source); err != nil { + log.Errorf("unable to enrich source with GeoIP data : %s", err) } + // Build overflow evt.Overflow.Sources = make(map[string]models.Source) evt.Overflow.Sources[sourceIP] = source @@ -95,83 +88,11 @@ func AppsecEventGeneration(inEvt types.Event) (*types.Event, error) { alert.Capacity = ptr.Of(int32(1)) alert.Events = make([]*models.Event, len(evt.Appsec.GetRuleIDs())) - now := ptr.Of(time.Now().UTC().Format(time.RFC3339)) - - tmpAppsecContext := make(map[string][]string) - - for _, matched_rule := range inEvt.Appsec.MatchedRules { - evtRule := models.Event{} - - evtRule.Timestamp = now - - evtRule.Meta = make(models.Meta, 0) - - for _, key := range appsecMetaKeys { - if tmpAppsecContext[key] == nil { - tmpAppsecContext[key] = make([]string, 0) - } - - switch value := matched_rule[key].(type) { - case string: - evtRule.Meta = appendMeta(evtRule.Meta, key, value) - - if value != "" && !slices.Contains(tmpAppsecContext[key], value) { - tmpAppsecContext[key] = append(tmpAppsecContext[key], value) - } - case int: - val := strconv.Itoa(value) - evtRule.Meta = appendMeta(evtRule.Meta, key, val) - - if val != "" && !slices.Contains(tmpAppsecContext[key], val) { - tmpAppsecContext[key] = append(tmpAppsecContext[key], val) - } - case []string: - for _, v := range value { - evtRule.Meta = appendMeta(evtRule.Meta, key, v) - - if v != "" && !slices.Contains(tmpAppsecContext[key], v) { - tmpAppsecContext[key] = append(tmpAppsecContext[key], v) - } - } - case []int: - for _, v := range value { - val := strconv.Itoa(v) - evtRule.Meta = appendMeta(evtRule.Meta, key, val) - - if val != "" && !slices.Contains(tmpAppsecContext[key], val) { - tmpAppsecContext[key] = append(tmpAppsecContext[key], val) - } - } - default: - val := fmt.Sprintf("%v", value) - evtRule.Meta = appendMeta(evtRule.Meta, key, val) - - if val != "" && !slices.Contains(tmpAppsecContext[key], val) { - tmpAppsecContext[key] = append(tmpAppsecContext[key], val) - } - } - } - - alert.Events = append(alert.Events, &evtRule) - } - - metas := make([]*models.MetaItems0, 0) - - for key, values := range tmpAppsecContext { - if len(values) == 0 { - continue - } - - valueStr, err := alertcontext.TruncateContext(values, alertcontext.MaxContextValueLen) - if err != nil { - log.Warning(err.Error()) + metas, errors := alertcontext.AppsecEventToContext(inEvt.Appsec, request) + if len(errors) > 0 { + for _, err := range errors { + log.Errorf("failed to generate appsec context: %s", err) } - - meta := models.MetaItems0{ - Key: key, - Value: valueStr, - } - metas = append(metas, &meta) } alert.Meta = metas @@ -195,10 +116,7 @@ func AppsecEventGeneration(inEvt types.Event) (*types.Event, error) { } func EventFromRequest(r *appsec.ParsedRequest, labels map[string]string) (types.Event, error) { - evt := types.Event{} - // we might want to change this based on in-band vs out-of-band ? - evt.Type = types.LOG - evt.ExpectMode = types.LIVE + evt := types.MakeEvent(false, types.LOG, true) // def needs fixing evt.Stage = "s00-raw" evt.Parsed = map[string]string{ diff --git a/pkg/acquisition/modules/cloudwatch/cloudwatch.go b/pkg/acquisition/modules/cloudwatch/cloudwatch.go index 2df70b3312b..ba267c9050b 100644 --- a/pkg/acquisition/modules/cloudwatch/cloudwatch.go +++ b/pkg/acquisition/modules/cloudwatch/cloudwatch.go @@ -710,7 +710,7 @@ func (cw *CloudwatchSource) CatLogStream(ctx context.Context, cfg *LogStreamTail func cwLogToEvent(log *cloudwatchlogs.OutputLogEvent, cfg *LogStreamTailConfig) (types.Event, error) { l := types.Line{} - evt := types.Event{} + evt := types.MakeEvent(cfg.ExpectMode == types.TIMEMACHINE, types.LOG, true) if log.Message == nil { return evt, errors.New("nil message") } @@ -726,9 +726,6 @@ func cwLogToEvent(log *cloudwatchlogs.OutputLogEvent, cfg *LogStreamTailConfig) l.Process = true l.Module = "cloudwatch" evt.Line = l - evt.Process = true - evt.Type = types.LOG - evt.ExpectMode = cfg.ExpectMode cfg.logger.Debugf("returned event labels : %+v", evt.Line.Labels) return evt, nil } diff --git a/pkg/acquisition/modules/docker/docker.go b/pkg/acquisition/modules/docker/docker.go index 2f79d4dcee6..b27255ec13f 100644 --- a/pkg/acquisition/modules/docker/docker.go +++ b/pkg/acquisition/modules/docker/docker.go @@ -334,7 +334,10 @@ func (d *DockerSource) OneShotAcquisition(ctx context.Context, out chan types.Ev if d.metricsLevel != configuration.METRICS_NONE { linesRead.With(prometheus.Labels{"source": containerConfig.Name}).Inc() } - evt := types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} + evt := types.MakeEvent(true, types.LOG, true) + evt.Line = l + evt.Process = true + evt.Type = types.LOG out <- evt d.logger.Debugf("Sent line to parsing: %+v", evt.Line.Raw) } @@ -579,12 +582,8 @@ func (d *DockerSource) TailDocker(ctx context.Context, container *ContainerConfi l.Src = container.Name l.Process = true l.Module = d.GetName() - var evt types.Event - if !d.Config.UseTimeMachine { - evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE} - } else { - evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} - } + evt := types.MakeEvent(d.Config.UseTimeMachine, types.LOG, true) + evt.Line = l linesRead.With(prometheus.Labels{"source": container.Name}).Inc() outChan <- evt d.logger.Debugf("Sent line to parsing: %+v", evt.Line.Raw) diff --git a/pkg/acquisition/modules/file/file.go b/pkg/acquisition/modules/file/file.go index f752d04aada..9f439b0c82e 100644 --- a/pkg/acquisition/modules/file/file.go +++ b/pkg/acquisition/modules/file/file.go @@ -621,11 +621,9 @@ func (f *FileSource) tailFile(out chan types.Event, t *tomb.Tomb, tail *tail.Tai // we're tailing, it must be real time logs logger.Debugf("pushing %+v", l) - expectMode := types.LIVE - if f.config.UseTimeMachine { - expectMode = types.TIMEMACHINE - } - out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: expectMode} + evt := types.MakeEvent(f.config.UseTimeMachine, types.LOG, true) + evt.Line = l + out <- evt } } } @@ -684,7 +682,7 @@ func (f *FileSource) readFile(filename string, out chan types.Event, t *tomb.Tom linesRead.With(prometheus.Labels{"source": filename}).Inc() // we're reading logs at once, it must be time-machine buckets - out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} + out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE, Unmarshaled: make(map[string]interface{})} } } diff --git a/pkg/acquisition/modules/http/http.go b/pkg/acquisition/modules/http/http.go new file mode 100644 index 00000000000..6bb8228f32c --- /dev/null +++ b/pkg/acquisition/modules/http/http.go @@ -0,0 +1,414 @@ +package httpacquisition + +import ( + "compress/gzip" + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "os" + "time" + + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + + "gopkg.in/tomb.v2" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/go-cs-lib/trace" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +var dataSourceName = "http" + +var linesRead = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "cs_httpsource_hits_total", + Help: "Total lines that were read from http source", + }, + []string{"path", "src"}) + +type HttpConfiguration struct { + //IPFilter []string `yaml:"ip_filter"` + //ChunkSize *int64 `yaml:"chunk_size"` + ListenAddr string `yaml:"listen_addr"` + Path string `yaml:"path"` + AuthType string `yaml:"auth_type"` + BasicAuth *BasicAuthConfig `yaml:"basic_auth"` + Headers *map[string]string `yaml:"headers"` + TLS *TLSConfig `yaml:"tls"` + CustomStatusCode *int `yaml:"custom_status_code"` + CustomHeaders *map[string]string `yaml:"custom_headers"` + MaxBodySize *int64 `yaml:"max_body_size"` + Timeout *time.Duration `yaml:"timeout"` + configuration.DataSourceCommonCfg `yaml:",inline"` +} + +type BasicAuthConfig struct { + Username string `yaml:"username"` + Password string `yaml:"password"` +} + +type TLSConfig struct { + InsecureSkipVerify bool `yaml:"insecure_skip_verify"` + ServerCert string `yaml:"server_cert"` + ServerKey string `yaml:"server_key"` + CaCert string `yaml:"ca_cert"` +} + +type HTTPSource struct { + metricsLevel int + Config HttpConfiguration + logger *log.Entry + Server *http.Server +} + +func (h *HTTPSource) GetUuid() string { + return h.Config.UniqueId +} + +func (h *HTTPSource) UnmarshalConfig(yamlConfig []byte) error { + h.Config = HttpConfiguration{} + err := yaml.Unmarshal(yamlConfig, &h.Config) + if err != nil { + return fmt.Errorf("cannot parse %s datasource configuration: %w", dataSourceName, err) + } + + if h.Config.Mode == "" { + h.Config.Mode = configuration.TAIL_MODE + } + + return nil +} + +func (hc *HttpConfiguration) Validate() error { + if hc.ListenAddr == "" { + return errors.New("listen_addr is required") + } + + if hc.Path == "" { + hc.Path = "/" + } + if hc.Path[0] != '/' { + return errors.New("path must start with /") + } + + switch hc.AuthType { + case "basic_auth": + baseErr := "basic_auth is selected, but" + if hc.BasicAuth == nil { + return errors.New(baseErr + " basic_auth is not provided") + } + if hc.BasicAuth.Username == "" { + return errors.New(baseErr + " username is not provided") + } + if hc.BasicAuth.Password == "" { + return errors.New(baseErr + " password is not provided") + } + case "headers": + if hc.Headers == nil { + return errors.New("headers is selected, but headers is not provided") + } + case "mtls": + if hc.TLS == nil || hc.TLS.CaCert == "" { + return errors.New("mtls is selected, but ca_cert is not provided") + } + default: + return errors.New("invalid auth_type: must be one of basic_auth, headers, mtls") + } + + if hc.TLS != nil { + if hc.TLS.ServerCert == "" { + return errors.New("server_cert is required") + } + if hc.TLS.ServerKey == "" { + return errors.New("server_key is required") + } + } + + if hc.MaxBodySize != nil && *hc.MaxBodySize <= 0 { + return errors.New("max_body_size must be positive") + } + + /* + if hc.ChunkSize != nil && *hc.ChunkSize <= 0 { + return errors.New("chunk_size must be positive") + } + */ + + if hc.CustomStatusCode != nil { + statusText := http.StatusText(*hc.CustomStatusCode) + if statusText == "" { + return errors.New("invalid HTTP status code") + } + } + + return nil +} + +func (h *HTTPSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLevel int) error { + h.logger = logger + h.metricsLevel = MetricsLevel + err := h.UnmarshalConfig(yamlConfig) + if err != nil { + return err + } + + if err := h.Config.Validate(); err != nil { + return fmt.Errorf("invalid configuration: %w", err) + } + + return nil +} + +func (h *HTTPSource) ConfigureByDSN(string, map[string]string, *log.Entry, string) error { + return fmt.Errorf("%s datasource does not support command-line acquisition", dataSourceName) +} + +func (h *HTTPSource) GetMode() string { + return h.Config.Mode +} + +func (h *HTTPSource) GetName() string { + return dataSourceName +} + +func (h *HTTPSource) OneShotAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { + return fmt.Errorf("%s datasource does not support one-shot acquisition", dataSourceName) +} + +func (h *HTTPSource) CanRun() error { + return nil +} + +func (h *HTTPSource) GetMetrics() []prometheus.Collector { + return []prometheus.Collector{linesRead} +} + +func (h *HTTPSource) GetAggregMetrics() []prometheus.Collector { + return []prometheus.Collector{linesRead} +} + +func (h *HTTPSource) Dump() interface{} { + return h +} + +func (hc *HttpConfiguration) NewTLSConfig() (*tls.Config, error) { + tlsConfig := tls.Config{ + InsecureSkipVerify: hc.TLS.InsecureSkipVerify, + } + + if hc.TLS.ServerCert != "" && hc.TLS.ServerKey != "" { + cert, err := tls.LoadX509KeyPair(hc.TLS.ServerCert, hc.TLS.ServerKey) + if err != nil { + return nil, fmt.Errorf("failed to load server cert/key: %w", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + } + + if hc.AuthType == "mtls" && hc.TLS.CaCert != "" { + caCert, err := os.ReadFile(hc.TLS.CaCert) + if err != nil { + return nil, fmt.Errorf("failed to read ca cert: %w", err) + } + + caCertPool, err := x509.SystemCertPool() + if err != nil { + return nil, fmt.Errorf("failed to load system cert pool: %w", err) + } + + if caCertPool == nil { + caCertPool = x509.NewCertPool() + } + caCertPool.AppendCertsFromPEM(caCert) + tlsConfig.ClientCAs = caCertPool + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + } + + return &tlsConfig, nil +} + +func authorizeRequest(r *http.Request, hc *HttpConfiguration) error { + if hc.AuthType == "basic_auth" { + username, password, ok := r.BasicAuth() + if !ok { + return errors.New("missing basic auth") + } + if username != hc.BasicAuth.Username || password != hc.BasicAuth.Password { + return errors.New("invalid basic auth") + } + } + if hc.AuthType == "headers" { + for key, value := range *hc.Headers { + if r.Header.Get(key) != value { + return errors.New("invalid headers") + } + } + } + return nil +} + +func (h *HTTPSource) processRequest(w http.ResponseWriter, r *http.Request, hc *HttpConfiguration, out chan types.Event) error { + if hc.MaxBodySize != nil && r.ContentLength > *hc.MaxBodySize { + w.WriteHeader(http.StatusRequestEntityTooLarge) + return fmt.Errorf("body size exceeds max body size: %d > %d", r.ContentLength, *hc.MaxBodySize) + } + + srcHost, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return err + } + + defer r.Body.Close() + + reader := r.Body + + if r.Header.Get("Content-Encoding") == "gzip" { + reader, err = gzip.NewReader(r.Body) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return fmt.Errorf("failed to create gzip reader: %w", err) + } + defer reader.Close() + } + + decoder := json.NewDecoder(reader) + for { + var message json.RawMessage + + if err := decoder.Decode(&message); err != nil { + if err == io.EOF { + break + } + w.WriteHeader(http.StatusBadRequest) + return fmt.Errorf("failed to decode: %w", err) + } + + line := types.Line{ + Raw: string(message), + Src: srcHost, + Time: time.Now().UTC(), + Labels: hc.Labels, + Process: true, + Module: h.GetName(), + } + + if h.metricsLevel == configuration.METRICS_AGGREGATE { + line.Src = hc.Path + } + + evt := types.MakeEvent(h.Config.UseTimeMachine, types.LOG, true) + evt.Line = line + + if h.metricsLevel == configuration.METRICS_AGGREGATE { + linesRead.With(prometheus.Labels{"path": hc.Path, "src": ""}).Inc() + } else if h.metricsLevel == configuration.METRICS_FULL { + linesRead.With(prometheus.Labels{"path": hc.Path, "src": srcHost}).Inc() + } + + h.logger.Tracef("line to send: %+v", line) + out <- evt + } + + return nil +} + +func (h *HTTPSource) RunServer(out chan types.Event, t *tomb.Tomb) error { + mux := http.NewServeMux() + mux.HandleFunc(h.Config.Path, func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + h.logger.Errorf("method not allowed: %s", r.Method) + http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) + return + } + if err := authorizeRequest(r, &h.Config); err != nil { + h.logger.Errorf("failed to authorize request from '%s': %s", r.RemoteAddr, err) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + err := h.processRequest(w, r, &h.Config, out) + if err != nil { + h.logger.Errorf("failed to process request from '%s': %s", r.RemoteAddr, err) + return + } + + if h.Config.CustomHeaders != nil { + for key, value := range *h.Config.CustomHeaders { + w.Header().Set(key, value) + } + } + if h.Config.CustomStatusCode != nil { + w.WriteHeader(*h.Config.CustomStatusCode) + } else { + w.WriteHeader(http.StatusOK) + } + + w.Write([]byte("OK")) + }) + + h.Server = &http.Server{ + Addr: h.Config.ListenAddr, + Handler: mux, + } + + if h.Config.Timeout != nil { + h.Server.ReadTimeout = *h.Config.Timeout + } + + if h.Config.TLS != nil { + tlsConfig, err := h.Config.NewTLSConfig() + if err != nil { + return fmt.Errorf("failed to create tls config: %w", err) + } + h.logger.Tracef("tls config: %+v", tlsConfig) + h.Server.TLSConfig = tlsConfig + } + + t.Go(func() error { + defer trace.CatchPanic("crowdsec/acquis/http/server") + if h.Config.TLS != nil { + h.logger.Infof("start https server on %s", h.Config.ListenAddr) + err := h.Server.ListenAndServeTLS(h.Config.TLS.ServerCert, h.Config.TLS.ServerKey) + if err != nil && err != http.ErrServerClosed { + return fmt.Errorf("https server failed: %w", err) + } + } else { + h.logger.Infof("start http server on %s", h.Config.ListenAddr) + err := h.Server.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + return fmt.Errorf("http server failed: %w", err) + } + } + return nil + }) + + //nolint //fp + for { + select { + case <-t.Dying(): + h.logger.Infof("%s datasource stopping", dataSourceName) + if err := h.Server.Close(); err != nil { + return fmt.Errorf("while closing %s server: %w", dataSourceName, err) + } + return nil + } + } +} + +func (h *HTTPSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { + h.logger.Debugf("start http server on %s", h.Config.ListenAddr) + + t.Go(func() error { + defer trace.CatchPanic("crowdsec/acquis/http/live") + return h.RunServer(out, t) + }) + + return nil +} diff --git a/pkg/acquisition/modules/http/http_test.go b/pkg/acquisition/modules/http/http_test.go new file mode 100644 index 00000000000..4d99134419f --- /dev/null +++ b/pkg/acquisition/modules/http/http_test.go @@ -0,0 +1,784 @@ +package httpacquisition + +import ( + "compress/gzip" + "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io" + "net/http" + "os" + "strings" + "testing" + "time" + + "github.com/crowdsecurity/crowdsec/pkg/types" + "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/tomb.v2" +) + +const ( + testHTTPServerAddr = "http://127.0.0.1:8080" + testHTTPServerAddrTLS = "https://127.0.0.1:8080" +) + +func TestConfigure(t *testing.T) { + tests := []struct { + config string + expectedErr string + }{ + { + config: ` +foobar: bla`, + expectedErr: "invalid configuration: listen_addr is required", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: wrongpath`, + expectedErr: "invalid configuration: path must start with /", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: basic_auth`, + expectedErr: "invalid configuration: basic_auth is selected, but basic_auth is not provided", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers`, + expectedErr: "invalid configuration: headers is selected, but headers is not provided", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: basic_auth +basic_auth: + username: 132`, + expectedErr: "invalid configuration: basic_auth is selected, but password is not provided", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: basic_auth +basic_auth: + password: 132`, + expectedErr: "invalid configuration: basic_auth is selected, but username is not provided", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers:`, + expectedErr: "invalid configuration: headers is selected, but headers is not provided", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: toto`, + expectedErr: "invalid configuration: invalid auth_type: must be one of basic_auth, headers, mtls", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: value +tls: + server_key: key`, + expectedErr: "invalid configuration: server_cert is required", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: value +tls: + server_cert: cert`, + expectedErr: "invalid configuration: server_key is required", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: mtls +tls: + server_cert: cert + server_key: key`, + expectedErr: "invalid configuration: mtls is selected, but ca_cert is not provided", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: value +max_body_size: 0`, + expectedErr: "invalid configuration: max_body_size must be positive", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: value +timeout: toto`, + expectedErr: "cannot parse http datasource configuration: yaml: unmarshal errors:\n line 8: cannot unmarshal !!str `toto` into time.Duration", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: value +custom_status_code: 999`, + expectedErr: "invalid configuration: invalid HTTP status code", + }, + } + + subLogger := log.WithFields(log.Fields{ + "type": "http", + }) + + for _, test := range tests { + h := HTTPSource{} + err := h.Configure([]byte(test.config), subLogger, 0) + cstest.AssertErrorContains(t, err, test.expectedErr) + } +} + +func TestGetUuid(t *testing.T) { + h := HTTPSource{} + h.Config.UniqueId = "test" + assert.Equal(t, "test", h.GetUuid()) +} + +func TestUnmarshalConfig(t *testing.T) { + h := HTTPSource{} + err := h.UnmarshalConfig([]byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: 15 + auth_type: headers`)) + cstest.AssertErrorMessage(t, err, "cannot parse http datasource configuration: yaml: line 4: found a tab character that violates indentation") +} + +func TestConfigureByDSN(t *testing.T) { + h := HTTPSource{} + err := h.ConfigureByDSN("http://localhost:8080/test", map[string]string{}, log.WithFields(log.Fields{ + "type": "http", + }), "test") + cstest.AssertErrorMessage( + t, + err, + "http datasource does not support command-line acquisition", + ) +} + +func TestGetMode(t *testing.T) { + h := HTTPSource{} + h.Config.Mode = "test" + assert.Equal(t, "test", h.GetMode()) +} + +func TestGetName(t *testing.T) { + h := HTTPSource{} + assert.Equal(t, "http", h.GetName()) +} + +func SetupAndRunHTTPSource(t *testing.T, h *HTTPSource, config []byte, metricLevel int) (chan types.Event, *tomb.Tomb) { + ctx := context.Background() + subLogger := log.WithFields(log.Fields{ + "type": "http", + }) + err := h.Configure(config, subLogger, metricLevel) + require.NoError(t, err) + tomb := tomb.Tomb{} + out := make(chan types.Event) + err = h.StreamingAcquisition(ctx, out, &tomb) + require.NoError(t, err) + + for _, metric := range h.GetMetrics() { + prometheus.Register(metric) + } + + return out, &tomb +} + +func TestStreamingAcquisitionWrongHTTPMethod(t *testing.T) { + h := &HTTPSource{} + _, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: basic_auth +basic_auth: + username: test + password: test`), 0) + + time.Sleep(1 * time.Second) + + res, err := http.Get(fmt.Sprintf("%s/test", testHTTPServerAddr)) + require.NoError(t, err) + assert.Equal(t, http.StatusMethodNotAllowed, res.StatusCode) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +func TestStreamingAcquisitionUnknownPath(t *testing.T) { + h := &HTTPSource{} + _, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: basic_auth +basic_auth: + username: test + password: test`), 0) + + time.Sleep(1 * time.Second) + + res, err := http.Get(fmt.Sprintf("%s/unknown", testHTTPServerAddr)) + require.NoError(t, err) + assert.Equal(t, http.StatusNotFound, res.StatusCode) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +func TestStreamingAcquisitionBasicAuth(t *testing.T) { + h := &HTTPSource{} + _, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: basic_auth +basic_auth: + username: test + password: test`), 0) + + time.Sleep(1 * time.Second) + + client := &http.Client{} + + resp, err := http.Post(fmt.Sprintf("%s/test", testHTTPServerAddr), "application/json", strings.NewReader("test")) + require.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("test")) + require.NoError(t, err) + req.SetBasicAuth("test", "WrongPassword") + + resp, err = client.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +func TestStreamingAcquisitionBadHeaders(t *testing.T) { + h := &HTTPSource{} + _, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: test`), 0) + + time.Sleep(1 * time.Second) + + client := &http.Client{} + + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("test")) + require.NoError(t, err) + + req.Header.Add("Key", "wrong") + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +func TestStreamingAcquisitionMaxBodySize(t *testing.T) { + h := &HTTPSource{} + _, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: test +max_body_size: 5`), 0) + + time.Sleep(1 * time.Second) + + client := &http.Client{} + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("testtest")) + require.NoError(t, err) + + req.Header.Add("Key", "test") + resp, err := client.Do(req) + require.NoError(t, err) + + assert.Equal(t, http.StatusRequestEntityTooLarge, resp.StatusCode) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +func TestStreamingAcquisitionSuccess(t *testing.T) { + h := &HTTPSource{} + out, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: test`), 2) + + time.Sleep(1 * time.Second) + rawEvt := `{"test": "test"}` + + errChan := make(chan error) + go assertEvents(out, []string{rawEvt}, errChan) + + client := &http.Client{} + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(rawEvt)) + require.NoError(t, err) + + req.Header.Add("Key", "test") + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + err = <-errChan + require.NoError(t, err) + + assertMetrics(t, h.GetMetrics(), 1) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +func TestStreamingAcquisitionCustomStatusCodeAndCustomHeaders(t *testing.T) { + h := &HTTPSource{} + out, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: test +custom_status_code: 201 +custom_headers: + success: true`), 2) + + time.Sleep(1 * time.Second) + + rawEvt := `{"test": "test"}` + errChan := make(chan error) + go assertEvents(out, []string{rawEvt}, errChan) + + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(rawEvt)) + require.NoError(t, err) + + req.Header.Add("Key", "test") + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + + assert.Equal(t, http.StatusCreated, resp.StatusCode) + assert.Equal(t, "true", resp.Header.Get("Success")) + + err = <-errChan + require.NoError(t, err) + + assertMetrics(t, h.GetMetrics(), 1) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +type slowReader struct { + delay time.Duration + body []byte + index int +} + +func (sr *slowReader) Read(p []byte) (int, error) { + if sr.index >= len(sr.body) { + return 0, io.EOF + } + time.Sleep(sr.delay) // Simulate a delay in reading + n := copy(p, sr.body[sr.index:]) + sr.index += n + return n, nil +} + +func assertEvents(out chan types.Event, expected []string, errChan chan error) { + readLines := []types.Event{} + + for i := 0; i < len(expected); i++ { + select { + case event := <-out: + readLines = append(readLines, event) + case <-time.After(2 * time.Second): + errChan <- errors.New("timeout waiting for event") + return + } + } + + if len(readLines) != len(expected) { + errChan <- fmt.Errorf("expected %d lines, got %d", len(expected), len(readLines)) + return + } + + for i, evt := range readLines { + if evt.Line.Raw != expected[i] { + errChan <- fmt.Errorf(`expected %s, got '%+v'`, expected, evt.Line.Raw) + return + } + if evt.Line.Src != "127.0.0.1" { + errChan <- fmt.Errorf("expected '127.0.0.1', got '%s'", evt.Line.Src) + return + } + if evt.Line.Module != "http" { + errChan <- fmt.Errorf("expected 'http', got '%s'", evt.Line.Module) + return + } + } + errChan <- nil +} + +func TestStreamingAcquisitionTimeout(t *testing.T) { + h := &HTTPSource{} + _, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: test +timeout: 1s`), 0) + + time.Sleep(1 * time.Second) + + slow := &slowReader{ + delay: 2 * time.Second, + body: []byte(`{"test": "delayed_payload"}`), + } + + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), slow) + require.NoError(t, err) + + req.Header.Add("Key", "test") + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +func TestStreamingAcquisitionTLSHTTPRequest(t *testing.T) { + h := &HTTPSource{} + _, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +auth_type: mtls +path: /test +tls: + server_cert: testdata/server.crt + server_key: testdata/server.key + ca_cert: testdata/ca.crt`), 0) + + time.Sleep(1 * time.Second) + + resp, err := http.Post(fmt.Sprintf("%s/test", testHTTPServerAddr), "application/json", strings.NewReader("test")) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +func TestStreamingAcquisitionTLSWithHeadersAuthSuccess(t *testing.T) { + h := &HTTPSource{} + out, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: test +tls: + server_cert: testdata/server.crt + server_key: testdata/server.key +`), 0) + + time.Sleep(1 * time.Second) + + caCert, err := os.ReadFile("testdata/server.crt") + require.NoError(t, err) + + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + + tlsConfig := &tls.Config{ + RootCAs: caCertPool, + } + + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsConfig, + }, + } + + rawEvt := `{"test": "test"}` + errChan := make(chan error) + go assertEvents(out, []string{rawEvt}, errChan) + + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddrTLS), strings.NewReader(rawEvt)) + require.NoError(t, err) + + req.Header.Add("Key", "test") + resp, err := client.Do(req) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + err = <-errChan + require.NoError(t, err) + + assertMetrics(t, h.GetMetrics(), 0) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +func TestStreamingAcquisitionMTLS(t *testing.T) { + h := &HTTPSource{} + out, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: mtls +tls: + server_cert: testdata/server.crt + server_key: testdata/server.key + ca_cert: testdata/ca.crt`), 0) + + time.Sleep(1 * time.Second) + + // init client cert + cert, err := tls.LoadX509KeyPair("testdata/client.crt", "testdata/client.key") + require.NoError(t, err) + + caCert, err := os.ReadFile("testdata/ca.crt") + require.NoError(t, err) + + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: caCertPool, + } + + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsConfig, + }, + } + + rawEvt := `{"test": "test"}` + errChan := make(chan error) + go assertEvents(out, []string{rawEvt}, errChan) + + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddrTLS), strings.NewReader(rawEvt)) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + err = <-errChan + require.NoError(t, err) + + assertMetrics(t, h.GetMetrics(), 0) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +func TestStreamingAcquisitionGzipData(t *testing.T) { + h := &HTTPSource{} + out, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: test`), 2) + + time.Sleep(1 * time.Second) + + rawEvt := `{"test": "test"}` + errChan := make(chan error) + go assertEvents(out, []string{rawEvt, rawEvt}, errChan) + + var b strings.Builder + gz := gzip.NewWriter(&b) + + _, err := gz.Write([]byte(rawEvt)) + require.NoError(t, err) + + _, err = gz.Write([]byte(rawEvt)) + require.NoError(t, err) + + err = gz.Close() + require.NoError(t, err) + + // send gzipped compressed data + client := &http.Client{} + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(b.String())) + require.NoError(t, err) + + req.Header.Add("Key", "test") + req.Header.Add("Content-Encoding", "gzip") + req.Header.Add("Content-Type", "application/json") + + resp, err := client.Do(req) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + err = <-errChan + require.NoError(t, err) + + assertMetrics(t, h.GetMetrics(), 2) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +func TestStreamingAcquisitionNDJson(t *testing.T) { + h := &HTTPSource{} + out, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: test`), 2) + + time.Sleep(1 * time.Second) + rawEvt := `{"test": "test"}` + + errChan := make(chan error) + go assertEvents(out, []string{rawEvt, rawEvt}, errChan) + + client := &http.Client{} + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(fmt.Sprintf("%s\n%s\n", rawEvt, rawEvt))) + + require.NoError(t, err) + + req.Header.Add("Key", "test") + req.Header.Add("Content-Type", "application/x-ndjson") + + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + err = <-errChan + require.NoError(t, err) + + assertMetrics(t, h.GetMetrics(), 2) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +func assertMetrics(t *testing.T, metrics []prometheus.Collector, expected int) { + promMetrics, err := prometheus.DefaultGatherer.Gather() + require.NoError(t, err) + + isExist := false + for _, metricFamily := range promMetrics { + if metricFamily.GetName() == "cs_httpsource_hits_total" { + isExist = true + assert.Len(t, metricFamily.GetMetric(), 1) + for _, metric := range metricFamily.GetMetric() { + assert.InDelta(t, float64(expected), metric.GetCounter().GetValue(), 0.000001) + labels := metric.GetLabel() + assert.Len(t, labels, 2) + assert.Equal(t, "path", labels[0].GetName()) + assert.Equal(t, "/test", labels[0].GetValue()) + assert.Equal(t, "src", labels[1].GetName()) + assert.Equal(t, "127.0.0.1", labels[1].GetValue()) + } + } + } + if !isExist && expected > 0 { + t.Fatalf("expected metric cs_httpsource_hits_total not found") + } + + for _, metric := range metrics { + metric.(*prometheus.CounterVec).Reset() + } +} diff --git a/pkg/acquisition/modules/http/testdata/ca.crt b/pkg/acquisition/modules/http/testdata/ca.crt new file mode 100644 index 00000000000..ac81b9db8a6 --- /dev/null +++ b/pkg/acquisition/modules/http/testdata/ca.crt @@ -0,0 +1,23 @@ +-----BEGIN CERTIFICATE----- +MIIDvzCCAqegAwIBAgIUHQfsFpWkCy7gAmDa3A6O+y5CvAswDQYJKoZIhvcNAQEL +BQAwbzELMAkGA1UEBhMCRlIxFjAUBgNVBAgTDUlsZS1kZS1GcmFuY2UxDjAMBgNV +BAcTBVBhcmlzMREwDwYDVQQKEwhDcm93ZHNlYzERMA8GA1UECxMIQ3Jvd2RzZWMx +EjAQBgNVBAMTCWxvY2FsaG9zdDAeFw0yNDEwMjMxMDAxMDBaFw0yOTEwMjIxMDAx +MDBaMG8xCzAJBgNVBAYTAkZSMRYwFAYDVQQIEw1JbGUtZGUtRnJhbmNlMQ4wDAYD +VQQHEwVQYXJpczERMA8GA1UEChMIQ3Jvd2RzZWMxETAPBgNVBAsTCENyb3dkc2Vj +MRIwEAYDVQQDEwlsb2NhbGhvc3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK +AoIBAQCZSR2/A24bpVHSiEeSlelfdA32uhk9wHkauwy2qxos/G/UmKG/dgWrHzRh +LawlFVHtVn4u7Hjqz2y2EsH3bX42jC5NMVARgXIOBr1dE6F5/bPqA6SoVgkDm9wh +ZBigyAMxYsR4+3ahuf0pQflBShKrLZ1UYoe6tQXob7l3x5vThEhNkBawBkLfWpj7 +7Imm1tGyEZdxCMkT400KRtSmJRrnpiOCUosnacwgp7MCbKWOIOng07Eh16cVUiuI +BthWU/LycIuac2xaD9PFpeK/MpwASRRPXZgPUhiZuaa7vttD0phCdDaS46Oln5/7 +tFRZH0alBZmkpVZJCWAP4ujIA3vLAgMBAAGjUzBRMA4GA1UdDwEB/wQEAwIBBjAP +BgNVHRMBAf8EBTADAQH/MB0GA1UdDgQWBBTwpg+WN1nZJs4gj5hfoa+fMSZjGTAP +BgNVHREECDAGhwR/AAABMA0GCSqGSIb3DQEBCwUAA4IBAQAZuOWT8zHcwbWvC6Jm +/ccgB/U7SbeIYFJrCZd9mTyqsgnkFNH8yJ5F4dXXtPXr+SO/uWWa3G5hams3qVFf +zWzzPDQdyhUhfh5fjUHR2RsSGBmCxcapYHpVvAP5aY1/ujYrXMvAJV0hfDO2tGHb +rveuJxhe8ymQ1Yb2u9NcmI1HG9IVt3Airz4gAIUJWbFvRigky0bukfddOkfiUiaF +DMPJQO6HAj8d8ctSHHVZWzhAInZ1pDg6HIHYF44m1tT27pSQoi0ZFocskDi/fC2f +EIF0nu5fRLUS6BZEfpnDi9U0lbJ/kUrgT5IFHMFqXdRpDqcnXpJZhYtp5l6GoqjM +gT33 +-----END CERTIFICATE----- diff --git a/pkg/acquisition/modules/http/testdata/client.crt b/pkg/acquisition/modules/http/testdata/client.crt new file mode 100644 index 00000000000..55efdddad09 --- /dev/null +++ b/pkg/acquisition/modules/http/testdata/client.crt @@ -0,0 +1,24 @@ +-----BEGIN CERTIFICATE----- +MIID7jCCAtagAwIBAgIUJMTPh3oPJLPgsnb9T85ieb4EuOQwDQYJKoZIhvcNAQEL +BQAwbzELMAkGA1UEBhMCRlIxFjAUBgNVBAgTDUlsZS1kZS1GcmFuY2UxDjAMBgNV +BAcTBVBhcmlzMREwDwYDVQQKEwhDcm93ZHNlYzERMA8GA1UECxMIQ3Jvd2RzZWMx +EjAQBgNVBAMTCWxvY2FsaG9zdDAeFw0yNDEwMjMxMDQ2MDBaFw0yNTEwMjMxMDQ2 +MDBaMHIxCzAJBgNVBAYTAkZSMRYwFAYDVQQIEw1JbGUtZGUtRnJhbmNlMQ4wDAYD +VQQHEwVQYXJpczERMA8GA1UEChMIQ3Jvd2RzZWMxFzAVBgNVBAsTDlRlc3Rpbmcg +Y2xpZW50MQ8wDQYDVQQDEwZjbGllbnQwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw +ggEKAoIBAQDAUOdpRieRrrH6krUjgcjLgJg6TzoWAb/iv6rfcioX1L9bj9fZSkwu +GqKzXX/PceIXElzQgiGJZErbJtnTzhGS80QgtAB8BwWQIT2zgoGcYJf7pPFvmcMM +qMGFwK0dMC+LHPk+ePtFz8dskI2XJ8jgBdtuZcnDblMuVGtjYT6n0rszvRdo118+ +mlGCLPzOfsO1JdOqLWAR88yZfqCFt1TrwmzpRT1crJQeM6i7muw4aO0L7uSek9QM +6APHz0QexSq7/zHOtRjA4jnJbDzZJHRlwOdlsNU9cmTz6uWIQXlg+2ovD55YurNy ++jYfmfDYpimhoeGf54zaETp1fTuTJYpxAgMBAAGjfzB9MA4GA1UdDwEB/wQEAwIF +oDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIwADAd +BgNVHQ4EFgQUmH0/7RuKnoW7sEK4Cr8eVNGbb8swHwYDVR0jBBgwFoAU8KYPljdZ +2SbOII+YX6GvnzEmYxkwDQYJKoZIhvcNAQELBQADggEBAHVn9Zuoyxu9iTFoyJ50 +e/XKcmt2uK2M1x+ap2Av7Wb/Omikx/R2YPq7994BfiUCAezY2YtreZzkE6Io1wNM +qApijEJnlqEmOXiYJqlF89QrCcsAsz6lfaqitYBZSL3o4KT+7/uUDVxgNEjEksRz +9qy6DFBLvyhxbOM2zDEV+MVfemBWSvNiojHqXzDBkZnBHHclJLuIKsXDZDGhKbNd +hsoGU00RLevvcUpUJ3a68ekgwiYFJifm0uyfmao9lmiB3i+8ZW3Q4rbwHtD+U7U2 +3n+U5PkhiUAveuMfrvUMzsTolZiop9ZLtcALDUFaqyr4tjfVOf5+CGjiluio7oE1 +UYg= +-----END CERTIFICATE----- diff --git a/pkg/acquisition/modules/http/testdata/client.key b/pkg/acquisition/modules/http/testdata/client.key new file mode 100644 index 00000000000..f8ef2efbd58 --- /dev/null +++ b/pkg/acquisition/modules/http/testdata/client.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAwFDnaUYnka6x+pK1I4HIy4CYOk86FgG/4r+q33IqF9S/W4/X +2UpMLhqis11/z3HiFxJc0IIhiWRK2ybZ084RkvNEILQAfAcFkCE9s4KBnGCX+6Tx +b5nDDKjBhcCtHTAvixz5Pnj7Rc/HbJCNlyfI4AXbbmXJw25TLlRrY2E+p9K7M70X +aNdfPppRgiz8zn7DtSXTqi1gEfPMmX6ghbdU68Js6UU9XKyUHjOou5rsOGjtC+7k +npPUDOgDx89EHsUqu/8xzrUYwOI5yWw82SR0ZcDnZbDVPXJk8+rliEF5YPtqLw+e +WLqzcvo2H5nw2KYpoaHhn+eM2hE6dX07kyWKcQIDAQABAoIBAQChriKuza0MfBri +9x3UCRN/is/wDZVe1P+2KL8F9ZvPxytNVeP4qM7c38WzF8MQ6sRR8z0WiqCZOjj4 +f3QX7iG2MlAvUkUqAFk778ZIuUov5sE/bU8RLOrfJKz1vqOLa2w8/xHH5LwS1/jn +m6t9zZPCSwpMiMSUSZci1xQlS6b6POZMjeqLPqv9cP8PJNv9UNrHFcQnQi1iwKJH +MJ7CQI3R8FSeGad3P7tB9YDaBm7hHmd/TevuFkymcKNT44XBSgddPDfgKui6sHTY +QQWgWI9VGVO350ZBLRLkrk8wboY4vc15qbBzYFG66WiR/tNdLt3rDYxpcXaDvcQy +e47mYNVxAoGBAMFsUmPDssqzmOkmZxHDM+VmgHYPXjDqQdE299FtuClobUW4iU4g +By7o84aCIBQz2sp9f1KM+10lr+Bqw3s7QBbR5M67PA8Zm45DL9t70NR/NZHGzFRD +BR/NMbwzCqNtY2UGDhYQLGhW8heAwsYwir8ZqmOfKTd9aY1pu/S8m9AlAoGBAP6I +483EIN8R5y+beGcGynYeIrH5Gc+W2FxWIW9jh/G7vRbhMlW4z0GxV3uEAYmOlBH2 +AqUkV6+uzU0P4B/m3vCYqLycBVDwifJazDj9nskVL5kGMxia62iwDMXs5nqNS4WJ +ZM5Gl2xIiwmgWnYnujM3eKF2wbm439wj4na80SldAoGANdIqatA9o+GtntKsw2iJ +vD91Z2SHVR0aC1k8Q+4/3GXOYiQjMLYAybDQcpEq0/RJ4SZik1nfZ9/gvJV4p4Wp +I7Br9opq/9ikTEWtv2kIhtiO02151ciAWIUEXdXmE+uQSMASk1kUwkPPQXL2v6cq +NFqz6tyS33nqMQtG3abNxHECgYA4AEA2nmcpDRRTSh50dG8JC9pQU+EU5jhWIHEc +w8Y+LjMNHKDpcU7QQkdgGowICsGTLhAo61ULhycORGboPfBg+QVu8djNlQ6Urttt +0ocj8LBXN6D4UeVnVAyLY3LWFc4+5Bq0s51PKqrEhG5Cvrzd1d+JjspSpVVDZvXF +cAeI1QKBgC/cMN3+2Sc+2biu46DnkdYpdF/N0VGMOgzz+unSVD4RA2mEJ9UdwGga +feshtrtcroHtEmc+WDYgTTnAq1MbsVFQYIwZ5fL/GJ1R8ccaWiPuX2HrKALKG4Y3 +CMFpDUWhRgtaBsmuOpUq3FeS5cyPNMHk6axL1KyFoJk9AgfhqhTp +-----END RSA PRIVATE KEY----- diff --git a/pkg/acquisition/modules/http/testdata/server.crt b/pkg/acquisition/modules/http/testdata/server.crt new file mode 100644 index 00000000000..7a02c606c9d --- /dev/null +++ b/pkg/acquisition/modules/http/testdata/server.crt @@ -0,0 +1,23 @@ +-----BEGIN CERTIFICATE----- +MIID5jCCAs6gAwIBAgIUU3F6URi0oTe9ontkf7JqXOo89QYwDQYJKoZIhvcNAQEL +BQAwbzELMAkGA1UEBhMCRlIxFjAUBgNVBAgTDUlsZS1kZS1GcmFuY2UxDjAMBgNV +BAcTBVBhcmlzMREwDwYDVQQKEwhDcm93ZHNlYzERMA8GA1UECxMIQ3Jvd2RzZWMx +EjAQBgNVBAMTCWxvY2FsaG9zdDAeFw0yNDEwMjMxMDAzMDBaFw0yNTEwMjMxMDAz +MDBaMG8xCzAJBgNVBAYTAkZSMRYwFAYDVQQIEw1JbGUtZGUtRnJhbmNlMQ4wDAYD +VQQHEwVQYXJpczERMA8GA1UEChMIQ3Jvd2RzZWMxETAPBgNVBAsTCENyb3dkc2Vj +MRIwEAYDVQQDEwlsb2NhbGhvc3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK +AoIBAQC/lnUubjBGe5x0LgIE5GeG52LRzj99iLWuvey4qbSwFZ07ECgv+JttVwDm +AjEeakj2ZR46WHvHAR9eBNkRCORyWX0iKVIzm09PXYi80KtwGLaA8YMEio9/08Cc ++LS0TuP0yiOcw+btrhmvvauDzcQhA6u55q8anCZiF2BlHfX9Sh6QKewA3NhOkzbU +VTxqrOqfcRsGNub7dheqfP5bfrPkF6Y6l/0Fhyx0NMsu1zaQ0hCls2hkTf0Y3XGt +IQNWoN22seexR3qRmPf0j3jBa0qOmGgd6kAd+YpsjDblgCNUIJZiVj51fVb0sGRx +ShkfKGU6t0eznTWPCqswujO/sn+pAgMBAAGjejB4MA4GA1UdDwEB/wQEAwIFoDAd +BgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIwADAdBgNV +HQ4EFgQUOiIF+7Wzx1J8Ki3DiBfx+E6zlSUwGgYDVR0RBBMwEYIJbG9jYWxob3N0 +hwR/AAABMA0GCSqGSIb3DQEBCwUAA4IBAQA0dzlhBr/0wXPyj/iWxMOXxZ1FNJ9f +lxBMhLAgX0WrT2ys+284J7Hcn0lJeqelluYpmeKn9vmCAEj3MmUmHzZyf//lhuUJ +0DlYWIHUsGaJHJ7A+1hQqrcXHhkcRy5WGIM9VoddKbBbg2b6qzTSvxn8EnuD7H4h +28wLyGLCzsSXoVcAB8u+svYt29TPuy6xmMAokyIShV8FsE77fjVTgtCuxmx1PKv3 +zd6+uEae7bbZ+GJH1zKF0vokejQvmByt+YuIXlNbMseaMUeDdpy+6qlRvbbN1dyp +rkQXfWvidMfSue5nH/akAn83v/CdKxG6tfW83d9Rud3naabUkywALDng +-----END CERTIFICATE----- diff --git a/pkg/acquisition/modules/http/testdata/server.key b/pkg/acquisition/modules/http/testdata/server.key new file mode 100644 index 00000000000..4d0ee53b4c2 --- /dev/null +++ b/pkg/acquisition/modules/http/testdata/server.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAv5Z1Lm4wRnucdC4CBORnhudi0c4/fYi1rr3suKm0sBWdOxAo +L/ibbVcA5gIxHmpI9mUeOlh7xwEfXgTZEQjkcll9IilSM5tPT12IvNCrcBi2gPGD +BIqPf9PAnPi0tE7j9MojnMPm7a4Zr72rg83EIQOrueavGpwmYhdgZR31/UoekCns +ANzYTpM21FU8aqzqn3EbBjbm+3YXqnz+W36z5BemOpf9BYcsdDTLLtc2kNIQpbNo +ZE39GN1xrSEDVqDdtrHnsUd6kZj39I94wWtKjphoHepAHfmKbIw25YAjVCCWYlY+ +dX1W9LBkcUoZHyhlOrdHs501jwqrMLozv7J/qQIDAQABAoIBAF1Vd/rJlV0Q5RQ4 +QaWOe9zdpmedeZK3YgMh5UvE6RCLRxC5+0n7bASlSPvEf5dYofjfJA26g3pcUqKj +6/d/hIMsk2hsBu67L7TzVSTe51XxxB8nCPPSaLwWNZSDGM1qTWU4gIbjbQHHOh5C +YWcRfAW1WxhyiEWHYq+QwdYg9XCRrSg1UzvVvW1Yt2wDGcSZP5whbXipfw3BITDs +XU7ODYNkU1sjIzQZwzVGxOf9qKdhZFZ26Vhoz8OJNMLyJxY7EspuwR7HbDGt11Pb +CxOt/BV44LwdVYeqh57oIKtckQW33W/6EeaWr7GfMzyH5WSrsOJoK5IJVrZaPTcS +QiMYLA0CgYEA9vMVsGshBl3TeRGaU3XLHqooXD4kszbdnjfPrwGlfCO/iybhDqo5 +WFypM/bYcIWzbTez/ihufHEHPSCUbFEcN4B+oczGcuxTcZjFyvJYvq2ycxPUiDIi +JnVUcVxgh1Yn39+CsQ/b6meP7MumTD2P3I87CeQGlWTO5Ys9mdw0BjcCgYEAxpv1 +64l5UoFJGr4yElNKDIKnhEFbJZsLGKiiuVXcS1QVHW5Az5ar9fPxuepyHpz416l3 +ppncuhJiUIP+jbu5e0s0LsN46mLS3wkHLgYJj06CNT3uOSLSg1iFl7DusdbyiaA7 +wEJ/aotS1NZ4XaeryAWHwYJ6Kag3nz6NV3ZYuR8CgYEAxAFCuMj+6F+2RsTa+d1n +v8oMyNImLPyiQD9KHzyuTW7OTDMqtIoVg/Xf8re9KOpl9I0e1t7eevT3auQeCi8C +t2bMm7290V+UB3jbnO5n08hn+ADIUuV/x4ie4m8QyrpuYbm0sLbGtTFHwgoNzzuZ +oNUqZfpP42mk8fpnhWSLAlcCgYEAgpY7XRI4HkJ5ocbav2faMV2a7X/XgWNvKViA +HeJRhYoUlBRRMuz7xi0OjFKVlIFbsNlxna5fDk1WLWCMd/6tl168Qd8u2tX9lr6l +5OH9WSeiv4Un5JN73PbQaAvi9jXBpTIg92oBwzk2TlFyNQoxDcRtHZQ/5LIBWIhV +gOOEtLsCgYEA1wbGc4XlH+/nXVsvx7gmfK8pZG8XA4/ToeIEURwPYrxtQZLB4iZs +aqWGgIwiB4F4UkuKZIjMrgInU9y0fG6EL96Qty7Yjh7dGy1vJTZl6C+QU6o4sEwl +r5Id5BNLEaqISWQ0LvzfwdfABYlvFfBdaGbzUzLEitD79eyhxuNEOBw= +-----END RSA PRIVATE KEY----- diff --git a/pkg/acquisition/modules/journalctl/journalctl.go b/pkg/acquisition/modules/journalctl/journalctl.go index e7a35d5a3ba..27f20b9f446 100644 --- a/pkg/acquisition/modules/journalctl/journalctl.go +++ b/pkg/acquisition/modules/journalctl/journalctl.go @@ -136,12 +136,9 @@ func (j *JournalCtlSource) runJournalCtl(ctx context.Context, out chan types.Eve if j.metricsLevel != configuration.METRICS_NONE { linesRead.With(prometheus.Labels{"source": j.src}).Inc() } - var evt types.Event - if !j.config.UseTimeMachine { - evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE} - } else { - evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} - } + + evt := types.MakeEvent(j.config.UseTimeMachine, types.LOG, true) + evt.Line = l out <- evt case stderrLine := <-stderrChan: logger.Warnf("Got stderr message : %s", stderrLine) diff --git a/pkg/acquisition/modules/kafka/kafka.go b/pkg/acquisition/modules/kafka/kafka.go index a9a5e13e958..77fc44e310d 100644 --- a/pkg/acquisition/modules/kafka/kafka.go +++ b/pkg/acquisition/modules/kafka/kafka.go @@ -173,13 +173,8 @@ func (k *KafkaSource) ReadMessage(ctx context.Context, out chan types.Event) err if k.metricsLevel != configuration.METRICS_NONE { linesRead.With(prometheus.Labels{"topic": k.Config.Topic}).Inc() } - var evt types.Event - - if !k.Config.UseTimeMachine { - evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE} - } else { - evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} - } + evt := types.MakeEvent(k.Config.UseTimeMachine, types.LOG, true) + evt.Line = l out <- evt } } diff --git a/pkg/acquisition/modules/kinesis/kinesis.go b/pkg/acquisition/modules/kinesis/kinesis.go index 3cfc224aa25..3744e43f38d 100644 --- a/pkg/acquisition/modules/kinesis/kinesis.go +++ b/pkg/acquisition/modules/kinesis/kinesis.go @@ -322,12 +322,8 @@ func (k *KinesisSource) ParseAndPushRecords(records []*kinesis.Record, out chan } else { l.Src = k.Config.StreamName } - var evt types.Event - if !k.Config.UseTimeMachine { - evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE} - } else { - evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} - } + evt := types.MakeEvent(k.Config.UseTimeMachine, types.LOG, true) + evt.Line = l out <- evt } } diff --git a/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go b/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go index 30fc5c467ea..aaa83a3bbb2 100644 --- a/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go +++ b/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go @@ -66,6 +66,7 @@ func (ka *KubernetesAuditSource) GetAggregMetrics() []prometheus.Collector { func (ka *KubernetesAuditSource) UnmarshalConfig(yamlConfig []byte) error { k8sConfig := KubernetesAuditConfiguration{} + err := yaml.UnmarshalStrict(yamlConfig, &k8sConfig) if err != nil { return fmt.Errorf("cannot parse k8s-audit configuration: %w", err) @@ -92,6 +93,7 @@ func (ka *KubernetesAuditSource) UnmarshalConfig(yamlConfig []byte) error { if ka.config.Mode == "" { ka.config.Mode = configuration.TAIL_MODE } + return nil } @@ -116,6 +118,7 @@ func (ka *KubernetesAuditSource) Configure(config []byte, logger *log.Entry, Met } ka.mux.HandleFunc(ka.config.WebhookPath, ka.webhookHandler) + return nil } @@ -137,6 +140,7 @@ func (ka *KubernetesAuditSource) OneShotAcquisition(_ context.Context, _ chan ty func (ka *KubernetesAuditSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { ka.outChan = out + t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/k8s-audit/live") ka.logger.Infof("Starting k8s-audit server on %s:%d%s", ka.config.ListenAddr, ka.config.ListenPort, ka.config.WebhookPath) @@ -145,13 +149,16 @@ func (ka *KubernetesAuditSource) StreamingAcquisition(ctx context.Context, out c if err != nil && err != http.ErrServerClosed { return fmt.Errorf("k8s-audit server failed: %w", err) } + return nil }) <-t.Dying() ka.logger.Infof("Stopping k8s-audit server on %s:%d%s", ka.config.ListenAddr, ka.config.ListenPort, ka.config.WebhookPath) ka.server.Shutdown(ctx) + return nil }) + return nil } @@ -167,51 +174,58 @@ func (ka *KubernetesAuditSource) webhookHandler(w http.ResponseWriter, r *http.R if ka.metricsLevel != configuration.METRICS_NONE { requestCount.WithLabelValues(ka.addr).Inc() } + if r.Method != http.MethodPost { w.WriteHeader(http.StatusMethodNotAllowed) return } + ka.logger.Tracef("webhookHandler called") + var auditEvents audit.EventList jsonBody, err := io.ReadAll(r.Body) if err != nil { ka.logger.Errorf("Error reading request body: %v", err) w.WriteHeader(http.StatusInternalServerError) + return } + ka.logger.Tracef("webhookHandler receveid: %s", string(jsonBody)) + err = json.Unmarshal(jsonBody, &auditEvents) if err != nil { ka.logger.Errorf("Error decoding audit events: %s", err) w.WriteHeader(http.StatusInternalServerError) + return } remoteIP := strings.Split(r.RemoteAddr, ":")[0] - for _, auditEvent := range auditEvents.Items { + + for idx := range auditEvents.Items { if ka.metricsLevel != configuration.METRICS_NONE { eventCount.WithLabelValues(ka.addr).Inc() } - bytesEvent, err := json.Marshal(auditEvent) + + bytesEvent, err := json.Marshal(auditEvents.Items[idx]) if err != nil { ka.logger.Errorf("Error serializing audit event: %s", err) continue } + ka.logger.Tracef("Got audit event: %s", string(bytesEvent)) l := types.Line{ Raw: string(bytesEvent), Labels: ka.config.Labels, - Time: auditEvent.StageTimestamp.Time, + Time: auditEvents.Items[idx].StageTimestamp.Time, Src: remoteIP, Process: true, Module: ka.GetName(), } - ka.outChan <- types.Event{ - Line: l, - Process: true, - Type: types.LOG, - ExpectMode: types.LIVE, - } + evt := types.MakeEvent(ka.config.UseTimeMachine, types.LOG, true) + evt.Line = l + ka.outChan <- evt } } diff --git a/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go b/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go index 846e833abea..5996518e191 100644 --- a/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go +++ b/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go @@ -119,7 +119,7 @@ func (lc *LokiClient) queryRange(ctx context.Context, uri string, c chan *LokiQu case <-lc.t.Dying(): return lc.t.Err() case <-ticker.C: - resp, err := lc.Get(uri) + resp, err := lc.Get(ctx, uri) if err != nil { if ok := lc.shouldRetry(); !ok { return fmt.Errorf("error querying range: %w", err) @@ -205,6 +205,7 @@ func (lc *LokiClient) getURLFor(endpoint string, params map[string]string) strin func (lc *LokiClient) Ready(ctx context.Context) error { tick := time.NewTicker(500 * time.Millisecond) url := lc.getURLFor("ready", nil) + lc.Logger.Debugf("Using url: %s for ready check", url) for { select { case <-ctx.Done(): @@ -215,7 +216,7 @@ func (lc *LokiClient) Ready(ctx context.Context) error { return lc.t.Err() case <-tick.C: lc.Logger.Debug("Checking if Loki is ready") - resp, err := lc.Get(url) + resp, err := lc.Get(ctx, url) if err != nil { lc.Logger.Warnf("Error checking if Loki is ready: %s", err) continue @@ -300,8 +301,8 @@ func (lc *LokiClient) QueryRange(ctx context.Context, infinite bool) chan *LokiQ } // Create a wrapper for http.Get to be able to set headers and auth -func (lc *LokiClient) Get(url string) (*http.Response, error) { - request, err := http.NewRequest(http.MethodGet, url, nil) +func (lc *LokiClient) Get(ctx context.Context, url string) (*http.Response, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, err } diff --git a/pkg/acquisition/modules/loki/loki.go b/pkg/acquisition/modules/loki/loki.go index e39c76af22c..c57e6a67c94 100644 --- a/pkg/acquisition/modules/loki/loki.go +++ b/pkg/acquisition/modules/loki/loki.go @@ -53,6 +53,7 @@ type LokiConfiguration struct { WaitForReady time.Duration `yaml:"wait_for_ready"` // Retry interval, default is 10 seconds Auth LokiAuthConfiguration `yaml:"auth"` MaxFailureDuration time.Duration `yaml:"max_failure_duration"` // Max duration of failure before stopping the source + NoReadyCheck bool `yaml:"no_ready_check"` // Bypass /ready check before starting configuration.DataSourceCommonCfg `yaml:",inline"` } @@ -229,6 +230,14 @@ func (l *LokiSource) ConfigureByDSN(dsn string, labels map[string]string, logger l.logger.Logger.SetLevel(level) } + if noReadyCheck := params.Get("no_ready_check"); noReadyCheck != "" { + noReadyCheck, err := strconv.ParseBool(noReadyCheck) + if err != nil { + return fmt.Errorf("invalid no_ready_check in dsn: %w", err) + } + l.Config.NoReadyCheck = noReadyCheck + } + l.Config.URL = fmt.Sprintf("%s://%s", scheme, u.Host) if u.User != nil { l.Config.Auth.Username = u.User.Username() @@ -264,26 +273,28 @@ func (l *LokiSource) GetName() string { func (l *LokiSource) OneShotAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { l.logger.Debug("Loki one shot acquisition") l.Client.SetTomb(t) - readyCtx, cancel := context.WithTimeout(ctx, l.Config.WaitForReady) - defer cancel() - err := l.Client.Ready(readyCtx) - if err != nil { - return fmt.Errorf("loki is not ready: %w", err) + + if !l.Config.NoReadyCheck { + readyCtx, readyCancel := context.WithTimeout(ctx, l.Config.WaitForReady) + defer readyCancel() + err := l.Client.Ready(readyCtx) + if err != nil { + return fmt.Errorf("loki is not ready: %w", err) + } } - ctx, cancel = context.WithCancel(ctx) - c := l.Client.QueryRange(ctx, false) + lokiCtx, cancel := context.WithCancel(ctx) + defer cancel() + c := l.Client.QueryRange(lokiCtx, false) for { select { case <-t.Dying(): l.logger.Debug("Loki one shot acquisition stopped") - cancel() return nil case resp, ok := <-c: if !ok { l.logger.Info("Loki acquisition done, chan closed") - cancel() return nil } for _, stream := range resp.Data.Result { @@ -307,41 +318,33 @@ func (l *LokiSource) readOneEntry(entry lokiclient.Entry, labels map[string]stri if l.metricsLevel != configuration.METRICS_NONE { linesRead.With(prometheus.Labels{"source": l.Config.URL}).Inc() } - expectMode := types.LIVE - if l.Config.UseTimeMachine { - expectMode = types.TIMEMACHINE - } - out <- types.Event{ - Line: ll, - Process: true, - Type: types.LOG, - ExpectMode: expectMode, - } + evt := types.MakeEvent(l.Config.UseTimeMachine, types.LOG, true) + evt.Line = ll + out <- evt } func (l *LokiSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { l.Client.SetTomb(t) - readyCtx, cancel := context.WithTimeout(ctx, l.Config.WaitForReady) - defer cancel() - err := l.Client.Ready(readyCtx) - if err != nil { - return fmt.Errorf("loki is not ready: %w", err) + + if !l.Config.NoReadyCheck { + readyCtx, readyCancel := context.WithTimeout(ctx, l.Config.WaitForReady) + defer readyCancel() + err := l.Client.Ready(readyCtx) + if err != nil { + return fmt.Errorf("loki is not ready: %w", err) + } } ll := l.logger.WithField("websocket_url", l.lokiWebsocket) t.Go(func() error { ctx, cancel := context.WithCancel(ctx) defer cancel() respChan := l.Client.QueryRange(ctx, true) - if err != nil { - ll.Errorf("could not start loki tail: %s", err) - return fmt.Errorf("while starting loki tail: %w", err) - } for { select { case resp, ok := <-respChan: if !ok { ll.Warnf("loki channel closed") - return err + return errors.New("loki channel closed") } for _, stream := range resp.Data.Result { for _, entry := range stream.Entries { diff --git a/pkg/acquisition/modules/loki/loki_test.go b/pkg/acquisition/modules/loki/loki_test.go index cacdda32d80..643aefad715 100644 --- a/pkg/acquisition/modules/loki/loki_test.go +++ b/pkg/acquisition/modules/loki/loki_test.go @@ -34,6 +34,7 @@ func TestConfiguration(t *testing.T) { password string waitForReady time.Duration delayFor time.Duration + noReadyCheck bool testName string }{ { @@ -99,6 +100,19 @@ query: > mode: tail source: loki url: http://localhost:3100/ +no_ready_check: true +query: > + {server="demo"} +`, + expectedErr: "", + testName: "Correct config with no_ready_check", + noReadyCheck: true, + }, + { + config: ` +mode: tail +source: loki +url: http://localhost:3100/ auth: username: foo password: bar @@ -148,6 +162,8 @@ query: > t.Fatalf("Wrong DelayFor %v != %v", lokiSource.Config.DelayFor, test.delayFor) } } + + assert.Equal(t, test.noReadyCheck, lokiSource.Config.NoReadyCheck) }) } } @@ -164,6 +180,7 @@ func TestConfigureDSN(t *testing.T) { scheme string waitForReady time.Duration delayFor time.Duration + noReadyCheck bool }{ { name: "Wrong scheme", @@ -202,10 +219,11 @@ func TestConfigureDSN(t *testing.T) { }, { name: "Correct DSN", - dsn: `loki://localhost:3100/?query={server="demo"}&wait_for_ready=5s&delay_for=1s`, + dsn: `loki://localhost:3100/?query={server="demo"}&wait_for_ready=5s&delay_for=1s&no_ready_check=true`, expectedErr: "", waitForReady: 5 * time.Second, delayFor: 1 * time.Second, + noReadyCheck: true, }, { name: "SSL DSN", @@ -256,6 +274,9 @@ func TestConfigureDSN(t *testing.T) { t.Fatalf("Wrong DelayFor %v != %v", lokiSource.Config.DelayFor, test.delayFor) } } + + assert.Equal(t, test.noReadyCheck, lokiSource.Config.NoReadyCheck) + } } diff --git a/pkg/acquisition/modules/s3/s3.go b/pkg/acquisition/modules/s3/s3.go index acd78ceba8f..cdc84a8a3ca 100644 --- a/pkg/acquisition/modules/s3/s3.go +++ b/pkg/acquisition/modules/s3/s3.go @@ -443,12 +443,8 @@ func (s *S3Source) readFile(bucket string, key string) error { } else if s.MetricsLevel == configuration.METRICS_AGGREGATE { l.Src = bucket } - var evt types.Event - if !s.Config.UseTimeMachine { - evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE} - } else { - evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} - } + evt := types.MakeEvent(s.Config.UseTimeMachine, types.LOG, true) + evt.Line = l s.out <- evt } } diff --git a/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse.go b/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse.go index 66d842ed519..04c7053ef27 100644 --- a/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse.go +++ b/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse.go @@ -48,7 +48,6 @@ func WithStrictHostname() RFC3164Option { } func (r *RFC3164) parsePRI() error { - pri := 0 if r.buf[r.position] != '<' { diff --git a/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse.go b/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse.go index 639e91e1224..c9aa89f7256 100644 --- a/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse.go +++ b/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse.go @@ -48,7 +48,6 @@ func WithStrictHostname() RFC5424Option { } func (r *RFC5424) parsePRI() error { - pri := 0 if r.buf[r.position] != '<' { @@ -94,7 +93,6 @@ func (r *RFC5424) parseVersion() error { } func (r *RFC5424) parseTimestamp() error { - timestamp := []byte{} if r.buf[r.position] == NIL_VALUE { @@ -121,7 +119,6 @@ func (r *RFC5424) parseTimestamp() error { } date, err := time.Parse(VALID_TIMESTAMP, string(timestamp)) - if err != nil { return errors.New("timestamp is not valid") } diff --git a/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse_test.go b/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse_test.go index 0938e947fe7..d3a68c196db 100644 --- a/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse_test.go +++ b/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse_test.go @@ -94,7 +94,8 @@ func TestParse(t *testing.T) { }{ { "valid msg", - `<13>1 2021-05-18T11:58:40.828081+02:42 mantis sshd 49340 - [timeQuality isSynced="0" tzKnown="1"] blabla`, expected{ + `<13>1 2021-05-18T11:58:40.828081+02:42 mantis sshd 49340 - [timeQuality isSynced="0" tzKnown="1"] blabla`, + expected{ Timestamp: time.Date(2021, 5, 18, 11, 58, 40, 828081000, time.FixedZone("+0242", 9720)), Hostname: "mantis", Tag: "sshd", @@ -102,11 +103,14 @@ func TestParse(t *testing.T) { MsgID: "", Message: "blabla", PRI: 13, - }, "", []RFC5424Option{}, + }, + "", + []RFC5424Option{}, }, { "valid msg with msgid", - `<13>1 2021-05-18T11:58:40.828081+02:42 mantis foobar 49340 123123 [timeQuality isSynced="0" tzKnown="1"] blabla`, expected{ + `<13>1 2021-05-18T11:58:40.828081+02:42 mantis foobar 49340 123123 [timeQuality isSynced="0" tzKnown="1"] blabla`, + expected{ Timestamp: time.Date(2021, 5, 18, 11, 58, 40, 828081000, time.FixedZone("+0242", 9720)), Hostname: "mantis", Tag: "foobar", @@ -114,11 +118,14 @@ func TestParse(t *testing.T) { MsgID: "123123", Message: "blabla", PRI: 13, - }, "", []RFC5424Option{}, + }, + "", + []RFC5424Option{}, }, { "valid msg with repeating SD", - `<13>1 2021-05-18T11:58:40.828081+02:42 mantis foobar 49340 123123 [timeQuality isSynced="0" tzKnown="1"][foo="bar][a] blabla`, expected{ + `<13>1 2021-05-18T11:58:40.828081+02:42 mantis foobar 49340 123123 [timeQuality isSynced="0" tzKnown="1"][foo="bar][a] blabla`, + expected{ Timestamp: time.Date(2021, 5, 18, 11, 58, 40, 828081000, time.FixedZone("+0242", 9720)), Hostname: "mantis", Tag: "foobar", @@ -126,36 +133,53 @@ func TestParse(t *testing.T) { MsgID: "123123", Message: "blabla", PRI: 13, - }, "", []RFC5424Option{}, + }, + "", + []RFC5424Option{}, }, { "invalid SD", - `<13>1 2021-05-18T11:58:40.828081+02:00 mantis foobar 49340 123123 [timeQuality asd`, expected{}, "structured data must end with ']'", []RFC5424Option{}, + `<13>1 2021-05-18T11:58:40.828081+02:00 mantis foobar 49340 123123 [timeQuality asd`, + expected{}, + "structured data must end with ']'", + []RFC5424Option{}, }, { "invalid version", - `<13>42 2021-05-18T11:58:40.828081+02:00 mantis foobar 49340 123123 [timeQuality isSynced="0" tzKnown="1"] blabla`, expected{}, "version must be 1", []RFC5424Option{}, + `<13>42 2021-05-18T11:58:40.828081+02:00 mantis foobar 49340 123123 [timeQuality isSynced="0" tzKnown="1"] blabla`, + expected{}, + "version must be 1", + []RFC5424Option{}, }, { "invalid message", - `<13>1`, expected{}, "version must be followed by a space", []RFC5424Option{}, + `<13>1`, + expected{}, + "version must be followed by a space", + []RFC5424Option{}, }, { "valid msg with empty fields", - `<13>1 - foo - - - - blabla`, expected{ + `<13>1 - foo - - - - blabla`, + expected{ Timestamp: time.Now().UTC(), Hostname: "foo", PRI: 13, Message: "blabla", - }, "", []RFC5424Option{}, + }, + "", + []RFC5424Option{}, }, { "valid msg with empty fields", - `<13>1 - - - - - - blabla`, expected{ + `<13>1 - - - - - - blabla`, + expected{ Timestamp: time.Now().UTC(), PRI: 13, Message: "blabla", - }, "", []RFC5424Option{}, + }, + "", + []RFC5424Option{}, }, { "valid msg with escaped SD", @@ -167,7 +191,9 @@ func TestParse(t *testing.T) { Hostname: "testhostname", MsgID: `sn="msgid"`, Message: `testmessage`, - }, "", []RFC5424Option{}, + }, + "", + []RFC5424Option{}, }, { "valid complex msg", @@ -179,7 +205,9 @@ func TestParse(t *testing.T) { PRI: 13, MsgID: `sn="msgid"`, Message: `source: sn="www.foobar.com" | message: 1.1.1.1 - - [24/May/2022:10:57:37 +0200] "GET /dist/precache-manifest.58b57debe6bc4f96698da0dc314461e9.js HTTP/2.0" 304 0 "https://www.foobar.com/sw.js" "Mozilla/5.0 (Linux; Android 9; ANE-LX1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/101.0.4951.61 Mobile Safari/537.36" "-" "www.foobar.com" sn="www.foobar.com" rt=0.000 ua="-" us="-" ut="-" ul="-" cs=HIT { request: /dist/precache-manifest.58b57debe6bc4f96698da0dc314461e9.js | src_ip_geo_country: DE | MONTH: May | COMMONAPACHELOG: 1.1.1.1 - - [24/May/2022:10:57:37 +0200] "GET /dist/precache-manifest.58b57debe6bc4f96698da0dc314461e9.js HTTP/2.0" 304 0 | auth: - | HOUR: 10 | gl2_remote_ip: 172.31.32.142 | ident: - | gl2_remote_port: 43375 | BASE10NUM: [2.0, 304, 0] | pid: -1 | program: nginx | gl2_source_input: 623ed3440183476d61cff974 | INT: +0200 | is_private_ip: false | YEAR: 2022 | src_ip_geo_city: Achern | clientip: 1.1.1.1 | USERNAME:`, - }, "", []RFC5424Option{}, + }, + "", + []RFC5424Option{}, }, { "partial message", diff --git a/pkg/acquisition/modules/syslog/internal/server/syslogserver.go b/pkg/acquisition/modules/syslog/internal/server/syslogserver.go index 7118c295b54..83f5e5a57e5 100644 --- a/pkg/acquisition/modules/syslog/internal/server/syslogserver.go +++ b/pkg/acquisition/modules/syslog/internal/server/syslogserver.go @@ -25,7 +25,6 @@ type SyslogMessage struct { } func (s *SyslogServer) Listen(listenAddr string, port int) error { - s.listenAddr = listenAddr s.port = port udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", s.listenAddr, s.port)) diff --git a/pkg/acquisition/modules/syslog/syslog.go b/pkg/acquisition/modules/syslog/syslog.go index 33a2f1542db..fb6a04600c1 100644 --- a/pkg/acquisition/modules/syslog/syslog.go +++ b/pkg/acquisition/modules/syslog/syslog.go @@ -235,11 +235,9 @@ func (s *SyslogSource) handleSyslogMsg(out chan types.Event, t *tomb.Tomb, c cha l.Time = ts l.Src = syslogLine.Client l.Process = true - if !s.config.UseTimeMachine { - out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE} - } else { - out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} - } + evt := types.MakeEvent(s.config.UseTimeMachine, types.LOG, true) + evt.Line = l + out <- evt } } } diff --git a/pkg/acquisition/modules/wineventlog/wineventlog_windows.go b/pkg/acquisition/modules/wineventlog/wineventlog_windows.go index 887be8b7dd3..8283bcc21a2 100644 --- a/pkg/acquisition/modules/wineventlog/wineventlog_windows.go +++ b/pkg/acquisition/modules/wineventlog/wineventlog_windows.go @@ -94,7 +94,7 @@ func (w *WinEventLogSource) getXMLEvents(config *winlog.SubscribeConfig, publish 2000, // Timeout in milliseconds to wait. 0, // Reserved. Must be zero. &returned) // The number of handles in the array that are set by the API. - if err == windows.ERROR_NO_MORE_ITEMS { + if errors.Is(err, windows.ERROR_NO_MORE_ITEMS) { return nil, err } else if err != nil { return nil, fmt.Errorf("wevtapi.EvtNext failed: %v", err) @@ -188,7 +188,7 @@ func (w *WinEventLogSource) getEvents(out chan types.Event, t *tomb.Tomb) error } if status == syscall.WAIT_OBJECT_0 { renderedEvents, err := w.getXMLEvents(w.evtConfig, publisherCache, subscription, 500) - if err == windows.ERROR_NO_MORE_ITEMS { + if errors.Is(err, windows.ERROR_NO_MORE_ITEMS) { windows.ResetEvent(w.evtConfig.SignalEvent) } else if err != nil { w.logger.Errorf("getXMLEvents failed: %v", err) @@ -206,9 +206,9 @@ func (w *WinEventLogSource) getEvents(out chan types.Event, t *tomb.Tomb) error l.Src = w.name l.Process = true if !w.config.UseTimeMachine { - out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE} + out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE, Unmarshaled: make(map[string]interface{})} } else { - out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} + out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE, Unmarshaled: make(map[string]interface{})} } } } @@ -411,7 +411,7 @@ OUTER_LOOP: return nil default: evts, err := w.getXMLEvents(w.evtConfig, publisherCache, handle, 500) - if err == windows.ERROR_NO_MORE_ITEMS { + if errors.Is(err, windows.ERROR_NO_MORE_ITEMS) { log.Info("No more items") break OUTER_LOOP } else if err != nil { @@ -430,7 +430,9 @@ OUTER_LOOP: l.Time = time.Now() l.Src = w.name l.Process = true - out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} + csevt := types.MakeEvent(w.config.UseTimeMachine, types.LOG, true) + csevt.Line = l + out <- csevt } } } diff --git a/pkg/alertcontext/alertcontext.go b/pkg/alertcontext/alertcontext.go index 16ebc6d0ac2..1b7d1e20018 100644 --- a/pkg/alertcontext/alertcontext.go +++ b/pkg/alertcontext/alertcontext.go @@ -3,6 +3,7 @@ package alertcontext import ( "encoding/json" "fmt" + "net/http" "slices" "strconv" @@ -30,7 +31,10 @@ type Context struct { func ValidateContextExpr(key string, expressions []string) error { for _, expression := range expressions { - _, err := expr.Compile(expression, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) + _, err := expr.Compile(expression, exprhelpers.GetExprOptions(map[string]interface{}{ + "evt": &types.Event{}, + "match": &types.MatchedRule{}, + "req": &http.Request{}})...) if err != nil { return fmt.Errorf("compilation of '%s' failed: %w", expression, err) } @@ -72,7 +76,10 @@ func NewAlertContext(contextToSend map[string][]string, valueLength int) error { } for _, value := range values { - valueCompiled, err := expr.Compile(value, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) + valueCompiled, err := expr.Compile(value, exprhelpers.GetExprOptions(map[string]interface{}{ + "evt": &types.Event{}, + "match": &types.MatchedRule{}, + "req": &http.Request{}})...) if err != nil { return fmt.Errorf("compilation of '%s' context value failed: %w", value, err) } @@ -85,6 +92,32 @@ func NewAlertContext(contextToSend map[string][]string, valueLength int) error { return nil } +// Truncate the context map to fit in the context value length +func TruncateContextMap(contextMap map[string][]string, contextValueLen int) ([]*models.MetaItems0, []error) { + metas := make([]*models.MetaItems0, 0) + errors := make([]error, 0) + + for key, values := range contextMap { + if len(values) == 0 { + continue + } + + valueStr, err := TruncateContext(values, alertContext.ContextValueLen) + if err != nil { + errors = append(errors, fmt.Errorf("error truncating content for %s: %w", key, err)) + continue + } + + meta := models.MetaItems0{ + Key: key, + Value: valueStr, + } + metas = append(metas, &meta) + } + return metas, errors +} + +// Truncate an individual []string to fit in the context value length func TruncateContext(values []string, contextValueLen int) (string, error) { valueByte, err := json.Marshal(values) if err != nil { @@ -116,61 +149,102 @@ func TruncateContext(values []string, contextValueLen int) (string, error) { return ret, nil } -func EventToContext(events []types.Event) (models.Meta, []error) { +func EvalAlertContextRules(evt types.Event, match *types.MatchedRule, request *http.Request, tmpContext map[string][]string) []error { + var errors []error - metas := make([]*models.MetaItems0, 0) - tmpContext := make(map[string][]string) + //if we're evaluating context for appsec event, match and request will be present. + //otherwise, only evt will be. + if match == nil { + match = types.NewMatchedRule() + } + if request == nil { + request = &http.Request{} + } - for _, evt := range events { - for key, values := range alertContext.ContextToSendCompiled { - if _, ok := tmpContext[key]; !ok { - tmpContext[key] = make([]string, 0) - } + for key, values := range alertContext.ContextToSendCompiled { - for _, value := range values { - var val string + if _, ok := tmpContext[key]; !ok { + tmpContext[key] = make([]string, 0) + } - output, err := expr.Run(value, map[string]interface{}{"evt": evt}) - if err != nil { - errors = append(errors, fmt.Errorf("failed to get value for %s: %w", key, err)) - continue - } + for _, value := range values { + var val string - switch out := output.(type) { - case string: - val = out - case int: - val = strconv.Itoa(out) - default: - errors = append(errors, fmt.Errorf("unexpected return type for %s: %T", key, output)) - continue + output, err := expr.Run(value, map[string]interface{}{"match": match, "evt": evt, "req": request}) + if err != nil { + errors = append(errors, fmt.Errorf("failed to get value for %s: %w", key, err)) + continue + } + switch out := output.(type) { + case string: + val = out + if val != "" && !slices.Contains(tmpContext[key], val) { + tmpContext[key] = append(tmpContext[key], val) } - + case []string: + for _, v := range out { + if v != "" && !slices.Contains(tmpContext[key], v) { + tmpContext[key] = append(tmpContext[key], v) + } + } + case int: + val = strconv.Itoa(out) + if val != "" && !slices.Contains(tmpContext[key], val) { + tmpContext[key] = append(tmpContext[key], val) + } + case []int: + for _, v := range out { + val = strconv.Itoa(v) + if val != "" && !slices.Contains(tmpContext[key], val) { + tmpContext[key] = append(tmpContext[key], val) + } + } + default: + val := fmt.Sprintf("%v", output) if val != "" && !slices.Contains(tmpContext[key], val) { tmpContext[key] = append(tmpContext[key], val) } } } } + return errors +} - for key, values := range tmpContext { - if len(values) == 0 { - continue - } +// Iterate over the individual appsec matched rules to create the needed alert context. +func AppsecEventToContext(event types.AppsecEvent, request *http.Request) (models.Meta, []error) { + var errors []error - valueStr, err := TruncateContext(values, alertContext.ContextValueLen) - if err != nil { - log.Warning(err.Error()) - } + tmpContext := make(map[string][]string) - meta := models.MetaItems0{ - Key: key, - Value: valueStr, - } - metas = append(metas, &meta) + evt := types.MakeEvent(false, types.LOG, false) + for _, matched_rule := range event.MatchedRules { + tmpErrors := EvalAlertContextRules(evt, &matched_rule, request, tmpContext) + errors = append(errors, tmpErrors...) } + metas, truncErrors := TruncateContextMap(tmpContext, alertContext.ContextValueLen) + errors = append(errors, truncErrors...) + + ret := models.Meta(metas) + + return ret, errors +} + +// Iterate over the individual events to create the needed alert context. +func EventToContext(events []types.Event) (models.Meta, []error) { + var errors []error + + tmpContext := make(map[string][]string) + + for _, evt := range events { + tmpErrors := EvalAlertContextRules(evt, nil, nil, tmpContext) + errors = append(errors, tmpErrors...) + } + + metas, truncErrors := TruncateContextMap(tmpContext, alertContext.ContextValueLen) + errors = append(errors, truncErrors...) + ret := models.Meta(metas) return ret, errors diff --git a/pkg/alertcontext/alertcontext_test.go b/pkg/alertcontext/alertcontext_test.go index c111d1bbcfb..284ff451bc2 100644 --- a/pkg/alertcontext/alertcontext_test.go +++ b/pkg/alertcontext/alertcontext_test.go @@ -2,6 +2,7 @@ package alertcontext import ( "fmt" + "net/http" "testing" "github.com/stretchr/testify/assert" @@ -9,6 +10,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" + "github.com/crowdsecurity/go-cs-lib/ptr" ) func TestNewAlertContext(t *testing.T) { @@ -200,3 +202,162 @@ func TestEventToContext(t *testing.T) { assert.ElementsMatch(t, test.expectedResult, metas) } } + +func TestValidateContextExpr(t *testing.T) { + tests := []struct { + name string + key string + exprs []string + expectedErr *string + }{ + { + name: "basic config", + key: "source_ip", + exprs: []string{ + "evt.Parsed.source_ip", + }, + expectedErr: nil, + }, + { + name: "basic config with non existent field", + key: "source_ip", + exprs: []string{ + "evt.invalid.source_ip", + }, + expectedErr: ptr.Of("compilation of 'evt.invalid.source_ip' failed: type types.Event has no field invalid"), + }, + } + for _, test := range tests { + fmt.Printf("Running test '%s'\n", test.name) + err := ValidateContextExpr(test.key, test.exprs) + if test.expectedErr == nil { + require.NoError(t, err) + } else { + require.ErrorContains(t, err, *test.expectedErr) + } + } +} + +func TestAppsecEventToContext(t *testing.T) { + tests := []struct { + name string + contextToSend map[string][]string + match types.AppsecEvent + req *http.Request + expectedResult models.Meta + expectedErrLen int + }{ + { + name: "basic test on match", + contextToSend: map[string][]string{ + "id": {"match.id"}, + }, + match: types.AppsecEvent{ + MatchedRules: types.MatchedRules{ + { + "id": "test", + }, + }, + }, + req: &http.Request{}, + expectedResult: []*models.MetaItems0{ + { + Key: "id", + Value: "[\"test\"]", + }, + }, + expectedErrLen: 0, + }, + { + name: "basic test on req", + contextToSend: map[string][]string{ + "ua": {"req.UserAgent()"}, + }, + match: types.AppsecEvent{ + MatchedRules: types.MatchedRules{ + { + "id": "test", + }, + }, + }, + req: &http.Request{ + Header: map[string][]string{ + "User-Agent": {"test"}, + }, + }, + expectedResult: []*models.MetaItems0{ + { + Key: "ua", + Value: "[\"test\"]", + }, + }, + expectedErrLen: 0, + }, + { + name: "test on req -> []string", + contextToSend: map[string][]string{ + "foobarxx": {"req.Header.Values('Foobar')"}, + }, + match: types.AppsecEvent{ + MatchedRules: types.MatchedRules{ + { + "id": "test", + }, + }, + }, + req: &http.Request{ + Header: map[string][]string{ + "User-Agent": {"test"}, + "Foobar": {"test1", "test2"}, + }, + }, + expectedResult: []*models.MetaItems0{ + { + Key: "foobarxx", + Value: "[\"test1\",\"test2\"]", + }, + }, + expectedErrLen: 0, + }, + { + name: "test on type int", + contextToSend: map[string][]string{ + "foobarxx": {"len(req.Header.Values('Foobar'))"}, + }, + match: types.AppsecEvent{ + MatchedRules: types.MatchedRules{ + { + "id": "test", + }, + }, + }, + req: &http.Request{ + Header: map[string][]string{ + "User-Agent": {"test"}, + "Foobar": {"test1", "test2"}, + }, + }, + expectedResult: []*models.MetaItems0{ + { + Key: "foobarxx", + Value: "[\"2\"]", + }, + }, + expectedErrLen: 0, + }, + } + + for _, test := range tests { + //reset cache + alertContext = Context{} + //compile + if err := NewAlertContext(test.contextToSend, 100); err != nil { + t.Fatalf("failed to compile %s: %s", test.name, err) + } + //run + + metas, errors := AppsecEventToContext(test.match, test.req) + assert.Len(t, errors, test.expectedErrLen) + assert.ElementsMatch(t, test.expectedResult, metas) + } +} diff --git a/pkg/apiclient/allowlists_service.go b/pkg/apiclient/allowlists_service.go new file mode 100644 index 00000000000..382f8fcfe5b --- /dev/null +++ b/pkg/apiclient/allowlists_service.go @@ -0,0 +1,91 @@ +package apiclient + +import ( + "context" + "fmt" + "net/http" + + "github.com/crowdsecurity/crowdsec/pkg/models" + qs "github.com/google/go-querystring/query" + log "github.com/sirupsen/logrus" +) + +type AllowlistsService service + +type AllowlistListOpts struct { + WithContent bool `url:"with_content,omitempty"` +} + +func (s *AllowlistsService) List(ctx context.Context, opts AllowlistListOpts) (*models.GetAllowlistsResponse, *Response, error) { + u := s.client.URLPrefix + "/allowlists" + + params, err := qs.Values(opts) + if err != nil { + return nil, nil, fmt.Errorf("building query: %w", err) + } + + u += "?" + params.Encode() + + req, err := s.client.NewRequest(http.MethodGet, u, nil) + if err != nil { + return nil, nil, err + } + + allowlists := &models.GetAllowlistsResponse{} + + resp, err := s.client.Do(ctx, req, allowlists) + if err != nil { + return nil, resp, err + } + + return allowlists, resp, nil +} + +type AllowlistGetOpts struct { + WithContent bool `url:"with_content,omitempty"` +} + +func (s *AllowlistsService) Get(ctx context.Context, name string, opts AllowlistGetOpts) (*models.GetAllowlistResponse, *Response, error) { + u := s.client.URLPrefix + "/allowlists/" + name + + params, err := qs.Values(opts) + if err != nil { + return nil, nil, fmt.Errorf("building query: %w", err) + } + + u += "?" + params.Encode() + + log.Debugf("GET %s", u) + + req, err := s.client.NewRequest(http.MethodGet, u, nil) + if err != nil { + return nil, nil, err + } + + allowlist := &models.GetAllowlistResponse{} + + resp, err := s.client.Do(ctx, req, allowlist) + if err != nil { + return nil, resp, err + } + + return allowlist, resp, nil +} + +func (s *AllowlistsService) CheckIfAllowlisted(ctx context.Context, value string) (bool, *Response, error) { + u := s.client.URLPrefix + "/allowlists/check/" + value + + req, err := s.client.NewRequest(http.MethodHead, u, nil) + if err != nil { + return false, nil, err + } + + var discardBody interface{} + + resp, err := s.client.Do(ctx, req, discardBody) + if err != nil { + return false, resp, err + } + + return resp.Response.StatusCode == http.StatusOK, resp, nil +} diff --git a/pkg/apiclient/auth_jwt.go b/pkg/apiclient/auth_jwt.go index 193486ff065..c43e9fc291c 100644 --- a/pkg/apiclient/auth_jwt.go +++ b/pkg/apiclient/auth_jwt.go @@ -62,7 +62,6 @@ func (t *JWTTransport) refreshJwtToken() error { enc := json.NewEncoder(buf) enc.SetEscapeHTML(false) err = enc.Encode(auth) - if err != nil { return fmt.Errorf("could not encode jwt auth body: %w", err) } @@ -169,7 +168,6 @@ func (t *JWTTransport) prepareRequest(req *http.Request) (*http.Request, error) // RoundTrip implements the RoundTripper interface. func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { - var resp *http.Response attemptsCount := make(map[int]int) @@ -229,7 +227,6 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { } } return resp, nil - } func (t *JWTTransport) Client() *http.Client { diff --git a/pkg/apiclient/client.go b/pkg/apiclient/client.go index 47d97a28344..b6632594a5b 100644 --- a/pkg/apiclient/client.go +++ b/pkg/apiclient/client.go @@ -4,12 +4,15 @@ import ( "context" "crypto/tls" "crypto/x509" + "errors" "fmt" "net" "net/http" "net/url" "strings" + "time" + "github.com/go-openapi/strfmt" "github.com/golang-jwt/jwt/v4" "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" @@ -20,6 +23,7 @@ var ( InsecureSkipVerify = false Cert *tls.Certificate CaCertPool *x509.CertPool + lapiClient *ApiClient ) type ApiClient struct { @@ -36,6 +40,7 @@ type ApiClient struct { Decisions *DecisionsService DecisionDelete *DecisionDeleteService Alerts *AlertsService + Allowlists *AllowlistsService Auth *AuthService Metrics *MetricsService Signal *SignalService @@ -66,6 +71,68 @@ type service struct { client *ApiClient } +func InitLAPIClient(ctx context.Context, apiUrl string, papiUrl string, login string, password string, scenarios []string) error { + + if lapiClient != nil { + return errors.New("client already initialized") + } + + apiURL, err := url.Parse(apiUrl) + if err != nil { + return fmt.Errorf("parsing api url ('%s'): %w", apiURL, err) + } + + papiURL, err := url.Parse(papiUrl) + if err != nil { + return fmt.Errorf("parsing polling api url ('%s'): %w", papiURL, err) + } + + pwd := strfmt.Password(password) + + client, err := NewClient(&Config{ + MachineID: login, + Password: pwd, + Scenarios: scenarios, + URL: apiURL, + PapiURL: papiURL, + VersionPrefix: "v1", + UpdateScenario: func(_ context.Context) ([]string, error) { + return scenarios, nil + }, + }) + if err != nil { + return fmt.Errorf("new client api: %w", err) + } + + authResp, _, err := client.Auth.AuthenticateWatcher(ctx, models.WatcherAuthRequest{ + MachineID: &login, + Password: &pwd, + Scenarios: scenarios, + }) + if err != nil { + return fmt.Errorf("authenticate watcher (%s): %w", login, err) + } + + var expiration time.Time + if err := expiration.UnmarshalText([]byte(authResp.Expire)); err != nil { + return fmt.Errorf("unable to parse jwt expiration: %w", err) + } + + client.GetClient().Transport.(*JWTTransport).Token = authResp.Token + client.GetClient().Transport.(*JWTTransport).Expiration = expiration + + lapiClient = client + + return nil +} + +func GetLAPIClient() (*ApiClient, error) { + if lapiClient == nil { + return nil, errors.New("client not initialized") + } + return lapiClient, nil +} + func NewClient(config *Config) (*ApiClient, error) { userAgent := config.UserAgent if userAgent == "" { @@ -115,6 +182,7 @@ func NewClient(config *Config) (*ApiClient, error) { c.common.client = c c.Decisions = (*DecisionsService)(&c.common) c.Alerts = (*AlertsService)(&c.common) + c.Allowlists = (*AllowlistsService)(&c.common) c.Auth = (*AuthService)(&c.common) c.Metrics = (*MetricsService)(&c.common) c.Signal = (*SignalService)(&c.common) @@ -157,6 +225,7 @@ func NewDefaultClient(URL *url.URL, prefix string, userAgent string, client *htt c.common.client = c c.Decisions = (*DecisionsService)(&c.common) c.Alerts = (*AlertsService)(&c.common) + c.Allowlists = (*AllowlistsService)(&c.common) c.Auth = (*AuthService)(&c.common) c.Metrics = (*MetricsService)(&c.common) c.Signal = (*SignalService)(&c.common) diff --git a/pkg/apiclient/decisions_service.go b/pkg/apiclient/decisions_service.go index 98f26cad9ae..fea2f39072d 100644 --- a/pkg/apiclient/decisions_service.go +++ b/pkg/apiclient/decisions_service.go @@ -31,6 +31,8 @@ type DecisionsListOpts struct { type DecisionsStreamOpts struct { Startup bool `url:"startup,omitempty"` + CommunityPull bool `url:"community_pull"` + AdditionalPull bool `url:"additional_pull"` Scopes string `url:"scopes,omitempty"` ScenariosContaining string `url:"scenarios_containing,omitempty"` ScenariosNotContaining string `url:"scenarios_not_containing,omitempty"` @@ -43,6 +45,17 @@ func (o *DecisionsStreamOpts) addQueryParamsToURL(url string) (string, error) { return "", err } + //Those 2 are a bit different + //They default to true, and we only want to include them if they are false + + if params.Get("community_pull") == "true" { + params.Del("community_pull") + } + + if params.Get("additional_pull") == "true" { + params.Del("additional_pull") + } + return fmt.Sprintf("%s?%s", url, params.Encode()), nil } diff --git a/pkg/apiclient/decisions_service_test.go b/pkg/apiclient/decisions_service_test.go index 54c44f43eda..942d14689ff 100644 --- a/pkg/apiclient/decisions_service_test.go +++ b/pkg/apiclient/decisions_service_test.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "net/url" + "strings" "testing" log "github.com/sirupsen/logrus" @@ -87,7 +88,7 @@ func TestDecisionsStream(t *testing.T) { testMethod(t, r, http.MethodGet) if r.Method == http.MethodGet { - if r.URL.RawQuery == "startup=true" { + if strings.Contains(r.URL.RawQuery, "startup=true") { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"deleted":null,"new":[{"duration":"3h59m55.756182786s","id":4,"origin":"cscli","scenario":"manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'","scope":"Ip","type":"ban","value":"1.2.3.4"}]}`)) } else { @@ -160,7 +161,7 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) { testMethod(t, r, http.MethodGet) if r.Method == http.MethodGet { - if r.URL.RawQuery == "startup=true" { + if strings.Contains(r.URL.RawQuery, "startup=true") { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"deleted":[{"scope":"ip","decisions":["1.2.3.5"]}],"new":[{"scope":"ip", "scenario": "manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'", "decisions":[{"duration":"3h59m55.756182786s","value":"1.2.3.4"}]}]}`)) } else { @@ -429,6 +430,8 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { Scopes string ScenariosContaining string ScenariosNotContaining string + CommunityPull bool + AdditionalPull bool } tests := []struct { @@ -440,11 +443,17 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { { name: "no filter", expected: baseURLString + "?", + fields: fields{ + CommunityPull: true, + AdditionalPull: true, + }, }, { name: "startup=true", fields: fields{ - Startup: true, + Startup: true, + CommunityPull: true, + AdditionalPull: true, }, expected: baseURLString + "?startup=true", }, @@ -455,9 +464,19 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { Scopes: "ip,range", ScenariosContaining: "ssh", ScenariosNotContaining: "bf", + CommunityPull: true, + AdditionalPull: true, }, expected: baseURLString + "?scenarios_containing=ssh&scenarios_not_containing=bf&scopes=ip%2Crange&startup=true", }, + { + name: "pull options", + fields: fields{ + CommunityPull: false, + AdditionalPull: false, + }, + expected: baseURLString + "?additional_pull=false&community_pull=false", + }, } for _, tt := range tests { @@ -467,6 +486,8 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { Scopes: tt.fields.Scopes, ScenariosContaining: tt.fields.ScenariosContaining, ScenariosNotContaining: tt.fields.ScenariosNotContaining, + CommunityPull: tt.fields.CommunityPull, + AdditionalPull: tt.fields.AdditionalPull, } got, err := o.addQueryParamsToURL(baseURLString) diff --git a/pkg/apiserver/alerts_test.go b/pkg/apiserver/alerts_test.go index 4cc215c344f..9b79b0dd311 100644 --- a/pkg/apiserver/alerts_test.go +++ b/pkg/apiserver/alerts_test.go @@ -16,27 +16,35 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csplugin" + "github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/models" ) +const ( + passwordAuthType = "password" + apiKeyAuthType = "apikey" +) + type LAPI struct { router *gin.Engine loginResp models.WatcherAuthResponse bouncerKey string DBConfig *csconfig.DatabaseCfg + DBClient *database.Client } func SetupLAPITest(t *testing.T, ctx context.Context) LAPI { t.Helper() router, loginResp, config := InitMachineTest(t, ctx) - APIKey := CreateTestBouncer(t, ctx, config.API.Server.DbConfig) + APIKey, dbClient := CreateTestBouncer(t, ctx, config.API.Server.DbConfig) return LAPI{ router: router, loginResp: loginResp, bouncerKey: APIKey, DBConfig: config.API.Server.DbConfig, + DBClient: dbClient, } } @@ -51,14 +59,17 @@ func (l *LAPI) RecordResponse(t *testing.T, ctx context.Context, verb string, ur require.NoError(t, err) switch authType { - case "apikey": + case apiKeyAuthType: req.Header.Add("X-Api-Key", l.bouncerKey) - case "password": + case passwordAuthType: AddAuthHeaders(req, l.loginResp) default: t.Fatal("auth type not supported") } + // Port is required for gin to properly parse the client IP + req.RemoteAddr = "127.0.0.1:1234" + l.router.ServeHTTP(w, req) return w @@ -118,14 +129,14 @@ func TestCreateAlert(t *testing.T) { w := lapi.RecordResponse(t, ctx, http.MethodPost, "/v1/alerts", strings.NewReader("test"), "password") assert.Equal(t, 400, w.Code) - assert.Equal(t, `{"message":"invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String()) + assert.JSONEq(t, `{"message":"invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String()) // Create Alert with invalid input alertContent := GetAlertReaderFromFile(t, "./tests/invalidAlert_sample.json") w = lapi.RecordResponse(t, ctx, http.MethodPost, "/v1/alerts", alertContent, "password") assert.Equal(t, 500, w.Code) - assert.Equal(t, + assert.JSONEq(t, `{"message":"validation failure list:\n0.scenario in body is required\n0.scenario_hash in body is required\n0.scenario_version in body is required\n0.simulated in body is required\n0.source in body is required"}`, w.Body.String()) @@ -173,7 +184,7 @@ func TestAlertListFilters(t *testing.T) { w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?test=test", alertContent, "password") assert.Equal(t, 500, w.Code) - assert.Equal(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String()) + assert.JSONEq(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String()) // get without filters @@ -239,7 +250,7 @@ func TestAlertListFilters(t *testing.T) { w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?ip=gruueq", emptyBody, "password") assert.Equal(t, 500, w.Code) - assert.Equal(t, `{"message":"unable to convert 'gruueq' to int: invalid address: invalid ip address / range"}`, w.Body.String()) + assert.JSONEq(t, `{"message":"unable to convert 'gruueq' to int: invalid address: invalid ip address / range"}`, w.Body.String()) // test range (ok) @@ -258,7 +269,7 @@ func TestAlertListFilters(t *testing.T) { w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?range=ratata", emptyBody, "password") assert.Equal(t, 500, w.Code) - assert.Equal(t, `{"message":"unable to convert 'ratata' to int: invalid address: invalid ip address / range"}`, w.Body.String()) + assert.JSONEq(t, `{"message":"unable to convert 'ratata' to int: invalid address: invalid ip address / range"}`, w.Body.String()) // test since (ok) @@ -329,7 +340,7 @@ func TestAlertListFilters(t *testing.T) { w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password") assert.Equal(t, 500, w.Code) - assert.Equal(t, `{"message":"'ratatqata' is not a boolean: strconv.ParseBool: parsing \"ratatqata\": invalid syntax: unable to parse type"}`, w.Body.String()) + assert.JSONEq(t, `{"message":"'ratatqata' is not a boolean: strconv.ParseBool: parsing \"ratatqata\": invalid syntax: unable to parse type"}`, w.Body.String()) } func TestAlertBulkInsert(t *testing.T) { @@ -351,7 +362,7 @@ func TestListAlert(t *testing.T) { w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?test=test", emptyBody, "password") assert.Equal(t, 500, w.Code) - assert.Equal(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String()) + assert.JSONEq(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String()) // List Alert @@ -394,7 +405,7 @@ func TestDeleteAlert(t *testing.T) { req.RemoteAddr = "127.0.0.2:4242" lapi.router.ServeHTTP(w, req) assert.Equal(t, 403, w.Code) - assert.Equal(t, `{"message":"access forbidden from this IP (127.0.0.2)"}`, w.Body.String()) + assert.JSONEq(t, `{"message":"access forbidden from this IP (127.0.0.2)"}`, w.Body.String()) // Delete Alert w = httptest.NewRecorder() @@ -403,7 +414,7 @@ func TestDeleteAlert(t *testing.T) { req.RemoteAddr = "127.0.0.1:4242" lapi.router.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) - assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) + assert.JSONEq(t, `{"nbDeleted":"1"}`, w.Body.String()) } func TestDeleteAlertByID(t *testing.T) { @@ -418,7 +429,7 @@ func TestDeleteAlertByID(t *testing.T) { req.RemoteAddr = "127.0.0.2:4242" lapi.router.ServeHTTP(w, req) assert.Equal(t, 403, w.Code) - assert.Equal(t, `{"message":"access forbidden from this IP (127.0.0.2)"}`, w.Body.String()) + assert.JSONEq(t, `{"message":"access forbidden from this IP (127.0.0.2)"}`, w.Body.String()) // Delete Alert w = httptest.NewRecorder() @@ -427,7 +438,7 @@ func TestDeleteAlertByID(t *testing.T) { req.RemoteAddr = "127.0.0.1:4242" lapi.router.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) - assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) + assert.JSONEq(t, `{"nbDeleted":"1"}`, w.Body.String()) } func TestDeleteAlertTrustedIPS(t *testing.T) { @@ -472,7 +483,7 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) - assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) + assert.JSONEq(t, `{"nbDeleted":"1"}`, w.Body.String()) } lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") diff --git a/pkg/apiserver/allowlists_test.go b/pkg/apiserver/allowlists_test.go new file mode 100644 index 00000000000..3575a21897d --- /dev/null +++ b/pkg/apiserver/allowlists_test.go @@ -0,0 +1,119 @@ +package apiserver + +import ( + "context" + "encoding/json" + "net/http" + "testing" + "time" + + "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/go-openapi/strfmt" + "github.com/stretchr/testify/require" +) + +func TestAllowlistList(t *testing.T) { + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) + + _, err := lapi.DBClient.CreateAllowList(ctx, "test", "test", false) + + require.NoError(t, err) + + w := lapi.RecordResponse(t, ctx, http.MethodGet, "/v1/allowlists", emptyBody, passwordAuthType) + + require.Equal(t, http.StatusOK, w.Code) + + allowlists := models.GetAllowlistsResponse{} + + err = json.Unmarshal(w.Body.Bytes(), &allowlists) + require.NoError(t, err) + + require.Len(t, allowlists, 1) + require.Equal(t, "test", allowlists[0].Name) +} + +func TestGetAllowlist(t *testing.T) { + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) + + l, err := lapi.DBClient.CreateAllowList(ctx, "test", "test", false) + + require.NoError(t, err) + + lapi.DBClient.AddToAllowlist(ctx, l, []*models.AllowlistItem{ + { + Value: "1.2.3.4", + }, + { + Value: "2.3.4.5", + Expiration: strfmt.DateTime(time.Now().Add(-time.Hour)), // expired + }, + }) + + w := lapi.RecordResponse(t, ctx, http.MethodGet, "/v1/allowlists/test?with_content=true", emptyBody, passwordAuthType) + + require.Equal(t, http.StatusOK, w.Code) + + allowlist := models.GetAllowlistResponse{} + + err = json.Unmarshal(w.Body.Bytes(), &allowlist) + require.NoError(t, err) + + require.Equal(t, "test", allowlist.Name) + require.Len(t, allowlist.Items, 1) + require.Equal(t, allowlist.Items[0].Value, "1.2.3.4") +} + +func TestCheckInAllowlist(t *testing.T) { + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) + + l, err := lapi.DBClient.CreateAllowList(ctx, "test", "test", false) + + require.NoError(t, err) + + lapi.DBClient.AddToAllowlist(ctx, l, []*models.AllowlistItem{ + { + Value: "1.2.3.4", + }, + { + Value: "2.3.4.5", + Expiration: strfmt.DateTime(time.Now().Add(-time.Hour)), // expired + }, + }) + + // GET request, should return 200 and status in body + w := lapi.RecordResponse(t, ctx, http.MethodGet, "/v1/allowlists/check/1.2.3.4", emptyBody, passwordAuthType) + + require.Equal(t, http.StatusOK, w.Code) + + resp := models.CheckAllowlistResponse{} + + err = json.Unmarshal(w.Body.Bytes(), &resp) + require.NoError(t, err) + + require.True(t, resp.Allowlisted) + + // GET request, should return 200 and status in body + w = lapi.RecordResponse(t, ctx, http.MethodGet, "/v1/allowlists/check/2.3.4.5", emptyBody, passwordAuthType) + + require.Equal(t, http.StatusOK, w.Code) + + resp = models.CheckAllowlistResponse{} + + err = json.Unmarshal(w.Body.Bytes(), &resp) + + require.NoError(t, err) + require.False(t, resp.Allowlisted) + + // HEAD request, should return 200 + w = lapi.RecordResponse(t, ctx, http.MethodHead, "/v1/allowlists/check/1.2.3.4", emptyBody, passwordAuthType) + + require.Equal(t, http.StatusOK, w.Code) + + // HEAD request, should return 204 + w = lapi.RecordResponse(t, ctx, http.MethodHead, "/v1/allowlists/check/2.3.4.5", emptyBody, passwordAuthType) + + require.Equal(t, http.StatusNoContent, w.Code) +} diff --git a/pkg/apiserver/api_key_test.go b/pkg/apiserver/api_key_test.go index e6ed68a6e0d..89e37cd3852 100644 --- a/pkg/apiserver/api_key_test.go +++ b/pkg/apiserver/api_key_test.go @@ -14,34 +14,80 @@ func TestAPIKey(t *testing.T) { ctx := context.Background() router, config := NewAPITest(t, ctx) - APIKey := CreateTestBouncer(t, ctx, config.API.Server.DbConfig) + APIKey, _ := CreateTestBouncer(t, ctx, config.API.Server.DbConfig) // Login with empty token w := httptest.NewRecorder() req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) req.Header.Add("User-Agent", UserAgent) + req.RemoteAddr = "127.0.0.1:1234" router.ServeHTTP(w, req) - assert.Equal(t, 403, w.Code) - assert.Equal(t, `{"message":"access forbidden"}`, w.Body.String()) + assert.Equal(t, http.StatusForbidden, w.Code) + assert.JSONEq(t, `{"message":"access forbidden"}`, w.Body.String()) // Login with invalid token w = httptest.NewRecorder() req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Api-Key", "a1b2c3d4e5f6") + req.RemoteAddr = "127.0.0.1:1234" router.ServeHTTP(w, req) - assert.Equal(t, 403, w.Code) - assert.Equal(t, `{"message":"access forbidden"}`, w.Body.String()) + assert.Equal(t, http.StatusForbidden, w.Code) + assert.JSONEq(t, `{"message":"access forbidden"}`, w.Body.String()) // Login with valid token w = httptest.NewRecorder() req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Api-Key", APIKey) + req.RemoteAddr = "127.0.0.1:1234" router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "null", w.Body.String()) + + // Login with valid token from another IP + w = httptest.NewRecorder() + req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) + req.Header.Add("User-Agent", UserAgent) + req.Header.Add("X-Api-Key", APIKey) + req.RemoteAddr = "4.3.2.1:1234" + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "null", w.Body.String()) + + // Make the requests multiple times to make sure we only create one + w = httptest.NewRecorder() + req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) + req.Header.Add("User-Agent", UserAgent) + req.Header.Add("X-Api-Key", APIKey) + req.RemoteAddr = "4.3.2.1:1234" + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "null", w.Body.String()) + + // Use the original bouncer again + w = httptest.NewRecorder() + req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) + req.Header.Add("User-Agent", UserAgent) + req.Header.Add("X-Api-Key", APIKey) + req.RemoteAddr = "127.0.0.1:1234" + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "null", w.Body.String()) + + // Check if our second bouncer was properly created + bouncers := GetBouncers(t, config.API.Server.DbConfig) + + assert.Len(t, bouncers, 2) + assert.Equal(t, "test@4.3.2.1", bouncers[1].Name) + assert.Equal(t, bouncers[0].APIKey, bouncers[1].APIKey) + assert.Equal(t, bouncers[0].AuthType, bouncers[1].AuthType) + assert.False(t, bouncers[0].AutoCreated) + assert.True(t, bouncers[1].AutoCreated) } diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index fff0ebcacbf..d420ef5ac3c 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -1,7 +1,9 @@ package apiserver import ( + "bufio" "context" + "encoding/json" "errors" "fmt" "math/rand" @@ -69,6 +71,10 @@ type apic struct { consoleConfig *csconfig.ConsoleConfig isPulling chan bool whitelists *csconfig.CapiWhitelist + + pullBlocklists bool + pullCommunity bool + shareSignals bool } // randomDuration returns a duration value between d-delta and d+delta @@ -198,6 +204,9 @@ func NewAPIC(ctx context.Context, config *csconfig.OnlineApiClientCfg, dbClient usageMetricsIntervalFirst: randomDuration(usageMetricsInterval, usageMetricsIntervalDelta), isPulling: make(chan bool, 1), whitelists: apicWhitelist, + pullBlocklists: *config.PullConfig.Blocklists, + pullCommunity: *config.PullConfig.Community, + shareSignals: *config.Sharing, } password := strfmt.Password(config.Credentials.Password) @@ -295,7 +304,7 @@ func (a *apic) Push(ctx context.Context) error { var signals []*models.AddSignalsRequestItem for _, alert := range alerts { - if ok := shouldShareAlert(alert, a.consoleConfig); ok { + if ok := shouldShareAlert(alert, a.consoleConfig, a.shareSignals); ok { signals = append(signals, alertToSignal(alert, getScenarioTrustOfAlert(alert), *a.consoleConfig.ShareContext)) } } @@ -324,7 +333,12 @@ func getScenarioTrustOfAlert(alert *models.Alert) string { return scenarioTrust } -func shouldShareAlert(alert *models.Alert, consoleConfig *csconfig.ConsoleConfig) bool { +func shouldShareAlert(alert *models.Alert, consoleConfig *csconfig.ConsoleConfig, shareSignals bool) bool { + if !shareSignals { + log.Debugf("sharing signals is disabled") + return false + } + if *alert.Simulated { log.Debugf("simulation enabled for alert (id:%d), will not be sent to CAPI", alert.ID) return false @@ -625,7 +639,9 @@ func (a *apic) PullTop(ctx context.Context, forcePull bool) error { log.Infof("Starting community-blocklist update") - data, _, err := a.apiClient.Decisions.GetStreamV3(ctx, apiclient.DecisionsStreamOpts{Startup: a.startup}) + log.Debugf("Community pull: %t | Blocklist pull: %t", a.pullCommunity, a.pullBlocklists) + + data, _, err := a.apiClient.Decisions.GetStreamV3(ctx, apiclient.DecisionsStreamOpts{Startup: a.startup, CommunityPull: a.pullCommunity, AdditionalPull: a.pullBlocklists}) if err != nil { return fmt.Errorf("get stream: %w", err) } @@ -650,28 +666,40 @@ func (a *apic) PullTop(ctx context.Context, forcePull bool) error { log.Printf("capi/community-blocklist : %d explicit deletions", nbDeleted) - if len(data.New) == 0 { - log.Infof("capi/community-blocklist : received 0 new entries (expected if you just installed crowdsec)") - return nil - } - - // create one alert for community blocklist using the first decision - decisions := a.apiClient.Decisions.GetDecisionsFromGroups(data.New) - // apply APIC specific whitelists - decisions = a.ApplyApicWhitelists(decisions) + if len(data.New) > 0 { + // create one alert for community blocklist using the first decision + decisions := a.apiClient.Decisions.GetDecisionsFromGroups(data.New) + // apply APIC specific whitelists + decisions = a.ApplyApicWhitelists(decisions) - alert := createAlertForDecision(decisions[0]) - alertsFromCapi := []*models.Alert{alert} - alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, addCounters) + alert := createAlertForDecision(decisions[0]) + alertsFromCapi := []*models.Alert{alert} + alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, addCounters) - err = a.SaveAlerts(ctx, alertsFromCapi, addCounters, deleteCounters) - if err != nil { - return fmt.Errorf("while saving alerts: %w", err) + err = a.SaveAlerts(ctx, alertsFromCapi, addCounters, deleteCounters) + if err != nil { + return fmt.Errorf("while saving alerts: %w", err) + } + } else { + if a.pullCommunity { + log.Info("capi/community-blocklist : received 0 new entries (expected if you just installed crowdsec)") + } else { + log.Debug("capi/community-blocklist : community blocklist pull is disabled") + } } - // update blocklists - if err := a.UpdateBlocklists(ctx, data.Links, addCounters, forcePull); err != nil { - return fmt.Errorf("while updating blocklists: %w", err) + // update allowlists/blocklists + if data.Links != nil { + if len(data.Links.Blocklists) > 0 { + if err := a.UpdateBlocklists(ctx, data.Links.Blocklists, addCounters, forcePull); err != nil { + return fmt.Errorf("while updating blocklists: %w", err) + } + } + if len(data.Links.Allowlists) > 0 { + if err := a.UpdateAllowlists(ctx, data.Links.Allowlists, forcePull); err != nil { + return fmt.Errorf("while updating allowlists: %w", err) + } + } } return nil @@ -680,15 +708,94 @@ func (a *apic) PullTop(ctx context.Context, forcePull bool) error { // we receive a link to a blocklist, we pull the content of the blocklist and we create one alert func (a *apic) PullBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink, forcePull bool) error { addCounters, _ := makeAddAndDeleteCounters() - if err := a.UpdateBlocklists(ctx, &modelscapi.GetDecisionsStreamResponseLinks{ - Blocklists: []*modelscapi.BlocklistLink{blocklist}, - }, addCounters, forcePull); err != nil { + if err := a.UpdateBlocklists(ctx, []*modelscapi.BlocklistLink{blocklist}, addCounters, forcePull); err != nil { return fmt.Errorf("while pulling blocklist: %w", err) } return nil } +func (a *apic) PullAllowlist(ctx context.Context, allowlist *modelscapi.AllowlistLink, forcePull bool) error { + if err := a.UpdateAllowlists(ctx, []*modelscapi.AllowlistLink{allowlist}, forcePull); err != nil { + return fmt.Errorf("while pulling allowlist: %w", err) + } + + return nil +} + +func (a *apic) UpdateAllowlists(ctx context.Context, allowlistsLinks []*modelscapi.AllowlistLink, forcePull bool) error { + if len(allowlistsLinks) == 0 { + return nil + } + + defaultClient, err := apiclient.NewDefaultClient(a.apiClient.BaseURL, "", "", nil) + if err != nil { + return fmt.Errorf("while creating default client: %w", err) + } + + for _, link := range allowlistsLinks { + if link.URL == nil { + log.Warningf("allowlist has no URL") + continue + } + if link.Name == nil { + log.Warningf("allowlist has no name") + continue + } + + description := "" + if link.Description != nil { + description = *link.Description + } + + resp, err := defaultClient.GetClient().Get(*link.URL) + if err != nil { + log.Errorf("while pulling allowlist: %s", err) + continue + } + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + items := make([]*models.AllowlistItem, 0) + for scanner.Scan() { + item := scanner.Text() + j := &models.AllowlistItem{} + if err := json.Unmarshal([]byte(item), j); err != nil { + log.Errorf("while unmarshalling allowlist item: %s", err) + continue + } + items = append(items, j) + } + + list, err := a.dbClient.GetAllowList(ctx, *link.Name, false) + + if err != nil { + if !ent.IsNotFound(err) { + log.Errorf("while getting allowlist %s: %s", *link.Name, err) + continue + } + } + + if list == nil { + list, err = a.dbClient.CreateAllowList(ctx, *link.Name, description, true) + if err != nil { + log.Errorf("while creating allowlist %s: %s", *link.Name, err) + continue + } + } + + err = a.dbClient.ReplaceAllowlist(ctx, list, items, true) + if err != nil { + log.Errorf("while replacing allowlist %s: %s", *link.Name, err) + continue + } + + log.Infof("Allowlist %s updated", *link.Name) + } + + return nil +} + // if decisions is whitelisted: return representation of the whitelist ip or cidr // if not whitelisted: empty string func (a *apic) whitelistedBy(decision *models.Decision) string { @@ -862,14 +969,11 @@ func (a *apic) updateBlocklist(ctx context.Context, client *apiclient.ApiClient, return nil } -func (a *apic) UpdateBlocklists(ctx context.Context, links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error { - if links == nil { +func (a *apic) UpdateBlocklists(ctx context.Context, blocklists []*modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error { + if len(blocklists) == 0 { return nil } - if links.Blocklists == nil { - return nil - } // we must use a different http client than apiClient's because the transport of apiClient is jwtTransport or here we have signed apis that are incompatibles // we can use the same baseUrl as the urls are absolute and the parse will take care of it defaultClient, err := apiclient.NewDefaultClient(a.apiClient.BaseURL, "", "", nil) @@ -877,7 +981,7 @@ func (a *apic) UpdateBlocklists(ctx context.Context, links *modelscapi.GetDecisi return fmt.Errorf("while creating default client: %w", err) } - for _, blocklist := range links.Blocklists { + for _, blocklist := range blocklists { if err := a.updateBlocklist(ctx, defaultClient, blocklist, addCounters, forcePull); err != nil { return err } diff --git a/pkg/apiserver/apic_metrics.go b/pkg/apiserver/apic_metrics.go index aa8db3f1c85..fe0dfd55821 100644 --- a/pkg/apiserver/apic_metrics.go +++ b/pkg/apiserver/apic_metrics.go @@ -368,10 +368,14 @@ func (a *apic) SendUsageMetrics(ctx context.Context) { if err != nil { log.Errorf("unable to send usage metrics: %s", err) - if resp == nil || resp.Response.StatusCode >= http.StatusBadRequest && resp.Response.StatusCode != http.StatusUnprocessableEntity { + if resp == nil || resp.Response == nil { + // Most likely a transient network error, it will be retried later + continue + } + + if resp.Response.StatusCode >= http.StatusBadRequest && resp.Response.StatusCode != http.StatusUnprocessableEntity { // In case of 422, mark the metrics as sent anyway, the API did not like what we sent, // and it's unlikely we'll be able to fix it - // also if resp is nil, we should'nt mark the metrics as sent could be network issue continue } } diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 99fee6e32bf..a8fbb40c4fa 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -69,7 +69,10 @@ func getAPIC(t *testing.T, ctx context.Context) *apic { ShareCustomScenarios: ptr.Of(false), ShareContext: ptr.Of(false), }, - isPulling: make(chan bool, 1), + isPulling: make(chan bool, 1), + shareSignals: true, + pullBlocklists: true, + pullCommunity: true, } } @@ -200,6 +203,11 @@ func TestNewAPIC(t *testing.T) { Login: "foo", Password: "bar", }, + Sharing: ptr.Of(true), + PullConfig: csconfig.CapiPullConfig{ + Community: ptr.Of(true), + Blocklists: ptr.Of(true), + }, } } @@ -1193,6 +1201,7 @@ func TestShouldShareAlert(t *testing.T) { tests := []struct { name string consoleConfig *csconfig.ConsoleConfig + shareSignals bool alert *models.Alert expectedRet bool expectedTrust string @@ -1203,6 +1212,7 @@ func TestShouldShareAlert(t *testing.T) { ShareCustomScenarios: ptr.Of(true), }, alert: &models.Alert{Simulated: ptr.Of(false)}, + shareSignals: true, expectedRet: true, expectedTrust: "custom", }, @@ -1212,6 +1222,7 @@ func TestShouldShareAlert(t *testing.T) { ShareCustomScenarios: ptr.Of(false), }, alert: &models.Alert{Simulated: ptr.Of(false)}, + shareSignals: true, expectedRet: false, expectedTrust: "custom", }, @@ -1220,6 +1231,7 @@ func TestShouldShareAlert(t *testing.T) { consoleConfig: &csconfig.ConsoleConfig{ ShareManualDecisions: ptr.Of(true), }, + shareSignals: true, alert: &models.Alert{ Simulated: ptr.Of(false), Decisions: []*models.Decision{{Origin: ptr.Of(types.CscliOrigin)}}, @@ -1232,6 +1244,7 @@ func TestShouldShareAlert(t *testing.T) { consoleConfig: &csconfig.ConsoleConfig{ ShareManualDecisions: ptr.Of(false), }, + shareSignals: true, alert: &models.Alert{ Simulated: ptr.Of(false), Decisions: []*models.Decision{{Origin: ptr.Of(types.CscliOrigin)}}, @@ -1244,6 +1257,7 @@ func TestShouldShareAlert(t *testing.T) { consoleConfig: &csconfig.ConsoleConfig{ ShareTaintedScenarios: ptr.Of(true), }, + shareSignals: true, alert: &models.Alert{ Simulated: ptr.Of(false), ScenarioHash: ptr.Of("whateverHash"), @@ -1256,6 +1270,7 @@ func TestShouldShareAlert(t *testing.T) { consoleConfig: &csconfig.ConsoleConfig{ ShareTaintedScenarios: ptr.Of(false), }, + shareSignals: true, alert: &models.Alert{ Simulated: ptr.Of(false), ScenarioHash: ptr.Of("whateverHash"), @@ -1263,11 +1278,24 @@ func TestShouldShareAlert(t *testing.T) { expectedRet: false, expectedTrust: "tainted", }, + { + name: "manual alert should not be shared if global sharing is disabled", + consoleConfig: &csconfig.ConsoleConfig{ + ShareManualDecisions: ptr.Of(true), + }, + shareSignals: false, + alert: &models.Alert{ + Simulated: ptr.Of(false), + ScenarioHash: ptr.Of("whateverHash"), + }, + expectedRet: false, + expectedTrust: "manual", + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - ret := shouldShareAlert(tc.alert, tc.consoleConfig) + ret := shouldShareAlert(tc.alert, tc.consoleConfig, tc.shareSignals) assert.Equal(t, tc.expectedRet, ret) }) } diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 35f9beaf635..05f9150b037 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -46,20 +46,11 @@ type APIServer struct { consoleConfig *csconfig.ConsoleConfig } -func recoverFromPanic(c *gin.Context) { - err := recover() - if err == nil { - return - } - - // Check for a broken connection, as it is not really a - // condition that warrants a panic stack trace. - brokenPipe := false - +func isBrokenConnection(err any) bool { if ne, ok := err.(*net.OpError); ok { if se, ok := ne.Err.(*os.SyscallError); ok { if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") { - brokenPipe = true + return true } } } @@ -79,11 +70,22 @@ func recoverFromPanic(c *gin.Context) { errors.Is(strErr, errClosedBody) || errors.Is(strErr, errHandlerComplete) || errors.Is(strErr, errStreamClosed) { - brokenPipe = true + return true } } - if brokenPipe { + return false +} + +func recoverFromPanic(c *gin.Context) { + err := recover() + if err == nil { + return + } + + // Check for a broken connection, as it is not really a + // condition that warrants a panic stack trace. + if isBrokenConnection(err) { log.Warningf("client %s disconnected: %s", c.ClientIP(), err) c.Abort() } else { diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index cdf99462c35..8b1d0d9b401 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -24,6 +24,7 @@ import ( middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -62,6 +63,7 @@ func LoadTestConfig(t *testing.T) csconfig.Config { } apiServerConfig := csconfig.LocalApiServerCfg{ ListenURI: "http://127.0.0.1:8080", + LogLevel: ptr.Of(log.DebugLevel), DbConfig: &dbconfig, ProfilesPath: "./tests/profiles.yaml", ConsoleConfig: &csconfig.ConsoleConfig{ @@ -206,6 +208,18 @@ func GetMachineIP(t *testing.T, machineID string, config *csconfig.DatabaseCfg) return "" } +func GetBouncers(t *testing.T, config *csconfig.DatabaseCfg) []*ent.Bouncer { + ctx := context.Background() + + dbClient, err := database.NewClient(ctx, config) + require.NoError(t, err) + + bouncers, err := dbClient.ListBouncers(ctx) + require.NoError(t, err) + + return bouncers +} + func GetAlertReaderFromFile(t *testing.T, path string) *strings.Reader { alertContentBytes, err := os.ReadFile(path) require.NoError(t, err) @@ -283,17 +297,17 @@ func CreateTestMachine(t *testing.T, ctx context.Context, router *gin.Engine, to return body } -func CreateTestBouncer(t *testing.T, ctx context.Context, config *csconfig.DatabaseCfg) string { +func CreateTestBouncer(t *testing.T, ctx context.Context, config *csconfig.DatabaseCfg) (string, *database.Client) { dbClient, err := database.NewClient(ctx, config) require.NoError(t, err) apiKey, err := middlewares.GenerateAPIKey(keyLength) require.NoError(t, err) - _, err = dbClient.CreateBouncer(ctx, "test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType) + _, err = dbClient.CreateBouncer(ctx, "test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType, false) require.NoError(t, err) - return apiKey + return apiKey, dbClient } func TestWithWrongDBConfig(t *testing.T) { diff --git a/pkg/apiserver/controllers/controller.go b/pkg/apiserver/controllers/controller.go index 719bb231006..58f757b6657 100644 --- a/pkg/apiserver/controllers/controller.go +++ b/pkg/apiserver/controllers/controller.go @@ -123,6 +123,11 @@ func (c *Controller) NewV1() error { jwtAuth.DELETE("/decisions", c.HandlerV1.DeleteDecisions) jwtAuth.DELETE("/decisions/:decision_id", c.HandlerV1.DeleteDecisionById) jwtAuth.GET("/heartbeat", c.HandlerV1.HeartBeat) + jwtAuth.GET("/allowlists", c.HandlerV1.GetAllowlists) + jwtAuth.GET("/allowlists/:allowlist_name", c.HandlerV1.GetAllowlist) + jwtAuth.GET("/allowlists/check/:ip_or_range", c.HandlerV1.CheckInAllowlist) + jwtAuth.HEAD("/allowlists/check/:ip_or_range", c.HandlerV1.CheckInAllowlist) + } apiKeyAuth := groupV1.Group("") diff --git a/pkg/apiserver/controllers/v1/alerts.go b/pkg/apiserver/controllers/v1/alerts.go index d1f93228512..a1582cb05b2 100644 --- a/pkg/apiserver/controllers/v1/alerts.go +++ b/pkg/apiserver/controllers/v1/alerts.go @@ -141,6 +141,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { } stopFlush := false + alertsToSave := make([]*models.Alert, 0) for _, alert := range input { // normalize scope for alert.Source and decisions @@ -154,6 +155,19 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { } } + if alert.Source.Scope != nil && (*alert.Source.Scope == types.Ip || *alert.Source.Scope == types.Range) && // Allowlist only works for IP/range + alert.Source.Value != nil && // Is this possible ? + len(alert.Decisions) == 0 { // If there's no decisions, means it's coming from crowdsec (not cscli), so we can apply allowlist + isAllowlisted, err := c.DBClient.IsAllowlisted(ctx, *alert.Source.Value) + if err == nil && isAllowlisted { + log.Infof("alert source %s is allowlisted, skipping", *alert.Source.Value) + continue + } else if err != nil { + //FIXME: Do we still want to process the alert normally if we can't check the allowlist ? + log.Errorf("error while checking allowlist: %s", err) + } + } + alert.MachineID = machineID // generate uuid here for alert alert.UUID = uuid.NewString() @@ -189,6 +203,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { stopFlush = true } + alertsToSave = append(alertsToSave, alert) continue } @@ -234,13 +249,15 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { break } } + + alertsToSave = append(alertsToSave, alert) } if stopFlush { c.DBClient.CanFlush = false } - alerts, err := c.DBClient.CreateAlert(ctx, machineID, input) + alerts, err := c.DBClient.CreateAlert(ctx, machineID, alertsToSave) c.DBClient.CanFlush = true if err != nil { @@ -250,7 +267,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { if c.AlertsAddChan != nil { select { - case c.AlertsAddChan <- input: + case c.AlertsAddChan <- alertsToSave: log.Debug("alert sent to CAPI channel") default: log.Warning("Cannot send alert to Central API channel") diff --git a/pkg/apiserver/controllers/v1/allowlist.go b/pkg/apiserver/controllers/v1/allowlist.go new file mode 100644 index 00000000000..64a409e4ecd --- /dev/null +++ b/pkg/apiserver/controllers/v1/allowlist.go @@ -0,0 +1,126 @@ +package v1 + +import ( + "net/http" + "time" + + "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/gin-gonic/gin" + "github.com/go-openapi/strfmt" +) + +func (c *Controller) CheckInAllowlist(gctx *gin.Context) { + value := gctx.Param("ip_or_range") + + if value == "" { + gctx.JSON(http.StatusBadRequest, gin.H{"message": "value is required"}) + return + } + + allowlisted, err := c.DBClient.IsAllowlisted(gctx.Request.Context(), value) + + if err != nil { + c.HandleDBErrors(gctx, err) + return + } + + if gctx.Request.Method == http.MethodHead { + if allowlisted { + gctx.Status(http.StatusOK) + } else { + gctx.Status(http.StatusNoContent) + } + return + } + + resp := models.CheckAllowlistResponse{ + Allowlisted: allowlisted, + } + + gctx.JSON(http.StatusOK, resp) +} + +func (c *Controller) GetAllowlists(gctx *gin.Context) { + params := gctx.Request.URL.Query() + + withContent := params.Get("with_content") == "true" + + allowlists, err := c.DBClient.ListAllowLists(gctx.Request.Context(), withContent) + + if err != nil { + c.HandleDBErrors(gctx, err) + return + } + + resp := models.GetAllowlistsResponse{} + + for _, allowlist := range allowlists { + items := make([]*models.AllowlistItem, 0) + if withContent { + for _, item := range allowlist.Edges.AllowlistItems { + if !item.ExpiresAt.IsZero() && item.ExpiresAt.Before(time.Now()) { + continue + } + items = append(items, &models.AllowlistItem{ + CreatedAt: strfmt.DateTime(item.CreatedAt), + Description: item.Comment, + Expiration: strfmt.DateTime(item.ExpiresAt), + Value: item.Value, + }) + } + } + resp = append(resp, &models.GetAllowlistResponse{ + AllowlistID: allowlist.AllowlistID, + Name: allowlist.Name, + Description: allowlist.Description, + CreatedAt: strfmt.DateTime(allowlist.CreatedAt), + UpdatedAt: strfmt.DateTime(allowlist.UpdatedAt), + ConsoleManaged: allowlist.FromConsole, + Items: items, + }) + } + + gctx.JSON(http.StatusOK, resp) +} + +func (c *Controller) GetAllowlist(gctx *gin.Context) { + allowlist := gctx.Param("allowlist_name") + + params := gctx.Request.URL.Query() + withContent := params.Get("with_content") == "true" + + allowlistModel, err := c.DBClient.GetAllowList(gctx.Request.Context(), allowlist, withContent) + + if err != nil { + c.HandleDBErrors(gctx, err) + return + } + + items := make([]*models.AllowlistItem, 0) + + if withContent { + for _, item := range allowlistModel.Edges.AllowlistItems { + if !item.ExpiresAt.IsZero() && item.ExpiresAt.Before(time.Now()) { + continue + } + items = append(items, &models.AllowlistItem{ + CreatedAt: strfmt.DateTime(item.CreatedAt), + Description: item.Comment, + Expiration: strfmt.DateTime(item.ExpiresAt), + Value: item.Value, + }) + } + } + + resp := models.GetAllowlistResponse{ + AllowlistID: allowlistModel.AllowlistID, + Name: allowlistModel.Name, + Description: allowlistModel.Description, + CreatedAt: strfmt.DateTime(allowlistModel.CreatedAt), + UpdatedAt: strfmt.DateTime(allowlistModel.UpdatedAt), + ConsoleManaged: allowlistModel.FromConsole, + Items: items, + } + + gctx.JSON(http.StatusOK, resp) +} diff --git a/pkg/apiserver/decisions_test.go b/pkg/apiserver/decisions_test.go index a0af6956443..cb5d2e1c4f1 100644 --- a/pkg/apiserver/decisions_test.go +++ b/pkg/apiserver/decisions_test.go @@ -22,19 +22,19 @@ func TestDeleteDecisionRange(t *testing.T) { // delete by ip wrong w := lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?range=1.2.3.0/24", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) - assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) + assert.JSONEq(t, `{"nbDeleted":"0"}`, w.Body.String()) // delete by range w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) - assert.Equal(t, `{"nbDeleted":"2"}`, w.Body.String()) + assert.JSONEq(t, `{"nbDeleted":"2"}`, w.Body.String()) // delete by range : ensure it was already deleted w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?range=91.121.79.0/24", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) - assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) + assert.JSONEq(t, `{"nbDeleted":"0"}`, w.Body.String()) } func TestDeleteDecisionFilter(t *testing.T) { @@ -48,19 +48,19 @@ func TestDeleteDecisionFilter(t *testing.T) { w := lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?ip=1.2.3.4", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) - assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) + assert.JSONEq(t, `{"nbDeleted":"0"}`, w.Body.String()) // delete by ip good w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?ip=91.121.79.179", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) - assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) + assert.JSONEq(t, `{"nbDeleted":"1"}`, w.Body.String()) // delete by scope/value w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?scopes=Ip&value=91.121.79.178", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) - assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) + assert.JSONEq(t, `{"nbDeleted":"1"}`, w.Body.String()) } func TestDeleteDecisionFilterByScenario(t *testing.T) { @@ -74,13 +74,13 @@ func TestDeleteDecisionFilterByScenario(t *testing.T) { w := lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bff", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) - assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) + assert.JSONEq(t, `{"nbDeleted":"0"}`, w.Body.String()) // delete by scenario good w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bf", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) - assert.Equal(t, `{"nbDeleted":"2"}`, w.Body.String()) + assert.JSONEq(t, `{"nbDeleted":"2"}`, w.Body.String()) } func TestGetDecisionFilters(t *testing.T) { diff --git a/pkg/apiserver/jwt_test.go b/pkg/apiserver/jwt_test.go index f6f51763975..72ae0302ae4 100644 --- a/pkg/apiserver/jwt_test.go +++ b/pkg/apiserver/jwt_test.go @@ -23,7 +23,7 @@ func TestLogin(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - assert.Equal(t, `{"code":401,"message":"machine test not validated"}`, w.Body.String()) + assert.JSONEq(t, `{"code":401,"message":"machine test not validated"}`, w.Body.String()) // Login with machine not exist w = httptest.NewRecorder() @@ -32,7 +32,7 @@ func TestLogin(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - assert.Equal(t, `{"code":401,"message":"ent: machine not found"}`, w.Body.String()) + assert.JSONEq(t, `{"code":401,"message":"ent: machine not found"}`, w.Body.String()) // Login with invalid body w = httptest.NewRecorder() @@ -41,7 +41,7 @@ func TestLogin(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - assert.Equal(t, `{"code":401,"message":"missing: invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String()) + assert.JSONEq(t, `{"code":401,"message":"missing: invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String()) // Login with invalid format w = httptest.NewRecorder() @@ -50,7 +50,7 @@ func TestLogin(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - assert.Equal(t, `{"code":401,"message":"validation failure list:\npassword in body is required"}`, w.Body.String()) + assert.JSONEq(t, `{"code":401,"message":"validation failure list:\npassword in body is required"}`, w.Body.String()) // Validate machine ValidateMachine(t, ctx, "test", config.API.Server.DbConfig) @@ -62,7 +62,7 @@ func TestLogin(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - assert.Equal(t, `{"code":401,"message":"incorrect Username or Password"}`, w.Body.String()) + assert.JSONEq(t, `{"code":401,"message":"incorrect Username or Password"}`, w.Body.String()) // Login with valid machine w = httptest.NewRecorder() diff --git a/pkg/apiserver/machines_test.go b/pkg/apiserver/machines_test.go index 969f75707d6..57b96f54ddd 100644 --- a/pkg/apiserver/machines_test.go +++ b/pkg/apiserver/machines_test.go @@ -25,7 +25,7 @@ func TestCreateMachine(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, http.StatusBadRequest, w.Code) - assert.Equal(t, `{"message":"invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String()) + assert.JSONEq(t, `{"message":"invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String()) // Create machine with invalid input w = httptest.NewRecorder() @@ -34,7 +34,7 @@ func TestCreateMachine(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, http.StatusUnprocessableEntity, w.Code) - assert.Equal(t, `{"message":"validation failure list:\nmachine_id in body is required\npassword in body is required"}`, w.Body.String()) + assert.JSONEq(t, `{"message":"validation failure list:\nmachine_id in body is required\npassword in body is required"}`, w.Body.String()) // Create machine b, err := json.Marshal(MachineTest) @@ -144,7 +144,7 @@ func TestCreateMachineAlreadyExist(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, http.StatusForbidden, w.Code) - assert.Equal(t, `{"message":"user 'test': user already exist"}`, w.Body.String()) + assert.JSONEq(t, `{"message":"user 'test': user already exist"}`, w.Body.String()) } func TestAutoRegistration(t *testing.T) { diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go index d438c9b15a4..df2f68930d6 100644 --- a/pkg/apiserver/middlewares/v1/api_key.go +++ b/pkg/apiserver/middlewares/v1/api_key.go @@ -89,7 +89,7 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { logger.Infof("Creating bouncer %s", bouncerName) - bouncer, err = a.DbClient.CreateBouncer(ctx, bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType) + bouncer, err = a.DbClient.CreateBouncer(ctx, bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType, true) if err != nil { logger.Errorf("while creating bouncer db entry: %s", err) return nil @@ -114,18 +114,68 @@ func (a *APIKey) authPlain(c *gin.Context, logger *log.Entry) *ent.Bouncer { return nil } + clientIP := c.ClientIP() + ctx := c.Request.Context() hashStr := HashSHA512(val[0]) - bouncer, err := a.DbClient.SelectBouncer(ctx, hashStr) + // Appsec case, we only care if the key is valid + // No content is returned, no last_pull update or anything + if c.Request.Method == http.MethodHead { + bouncer, err := a.DbClient.SelectBouncers(ctx, hashStr, types.ApiKeyAuthType) + if err != nil { + logger.Errorf("while fetching bouncer info: %s", err) + return nil + } + return bouncer[0] + } + + // most common case, check if this specific bouncer exists + bouncer, err := a.DbClient.SelectBouncerWithIP(ctx, hashStr, clientIP) + if err != nil && !ent.IsNotFound(err) { + logger.Errorf("while fetching bouncer info: %s", err) + return nil + } + + // We found the bouncer with key and IP, we can use it + if bouncer != nil { + if bouncer.AuthType != types.ApiKeyAuthType { + logger.Errorf("bouncer isn't allowed to auth by API key") + return nil + } + return bouncer + } + + // We didn't find the bouncer with key and IP, let's try to find it with the key only + bouncers, err := a.DbClient.SelectBouncers(ctx, hashStr, types.ApiKeyAuthType) if err != nil { logger.Errorf("while fetching bouncer info: %s", err) return nil } - if bouncer.AuthType != types.ApiKeyAuthType { - logger.Errorf("bouncer %s attempted to login using an API key but it is configured to auth with %s", bouncer.Name, bouncer.AuthType) + if len(bouncers) == 0 { + logger.Debugf("no bouncer found with this key") + return nil + } + + logger.Debugf("found %d bouncers with this key", len(bouncers)) + + // We only have one bouncer with this key and no IP + // This is the first request made by this bouncer, keep this one + if len(bouncers) == 1 && bouncers[0].IPAddress == "" { + return bouncers[0] + } + + // Bouncers are ordered by ID, first one *should* be the manually created one + // Can probably get a bit weird if the user deletes the manually created one + bouncerName := fmt.Sprintf("%s@%s", bouncers[0].Name, clientIP) + + logger.Infof("Creating bouncer %s", bouncerName) + + bouncer, err = a.DbClient.CreateBouncer(ctx, bouncerName, clientIP, hashStr, types.ApiKeyAuthType, true) + if err != nil { + logger.Errorf("while creating bouncer db entry: %s", err) return nil } @@ -156,27 +206,20 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { return } - logger = logger.WithField("name", bouncer.Name) - - if bouncer.IPAddress == "" { - if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil { - logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - - return - } + // Appsec request, return immediately if we found something + if c.Request.Method == http.MethodHead { + c.Set(BouncerContextKey, bouncer) + return } - // Don't update IP on HEAD request, as it's used by the appsec to check the validity of the API key provided - if bouncer.IPAddress != clientIP && bouncer.IPAddress != "" && c.Request.Method != http.MethodHead { - log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, clientIP, bouncer.IPAddress) + logger = logger.WithField("name", bouncer.Name) + // 1st time we see this bouncer, we update its IP + if bouncer.IPAddress == "" { if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil { logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() - return } } diff --git a/pkg/apiserver/papi.go b/pkg/apiserver/papi.go index 7dd6b346aa9..83ba13843b9 100644 --- a/pkg/apiserver/papi.go +++ b/pkg/apiserver/papi.go @@ -205,8 +205,8 @@ func reverse(s []longpollclient.Event) []longpollclient.Event { return a } -func (p *Papi) PullOnce(since time.Time, sync bool) error { - events, err := p.Client.PullOnce(since) +func (p *Papi) PullOnce(ctx context.Context, since time.Time, sync bool) error { + events, err := p.Client.PullOnce(ctx, since) if err != nil { return err } @@ -261,7 +261,7 @@ func (p *Papi) Pull(ctx context.Context) error { p.Logger.Infof("Starting PAPI pull (since:%s)", lastTimestamp) - for event := range p.Client.Start(lastTimestamp) { + for event := range p.Client.Start(ctx, lastTimestamp) { logger := p.Logger.WithField("request-id", event.RequestId) // update last timestamp in database newTime := time.Now().UTC() diff --git a/pkg/apiserver/papi_cmd.go b/pkg/apiserver/papi_cmd.go index 78f5dc9b0fe..6cef9cb387a 100644 --- a/pkg/apiserver/papi_cmd.go +++ b/pkg/apiserver/papi_cmd.go @@ -6,11 +6,13 @@ import ( "fmt" "time" + "github.com/go-openapi/strfmt" log "github.com/sirupsen/logrus" "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/crowdsec/pkg/apiclient" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/modelscapi" "github.com/crowdsecurity/crowdsec/pkg/types" @@ -21,25 +23,18 @@ type deleteDecisions struct { Decisions []string `json:"decisions"` } -type blocklistLink struct { - // blocklist name - Name string `json:"name"` - // blocklist url - Url string `json:"url"` - // blocklist remediation - Remediation string `json:"remediation"` - // blocklist scope - Scope string `json:"scope,omitempty"` - // blocklist duration - Duration string `json:"duration,omitempty"` +type forcePull struct { + Blocklist *modelscapi.BlocklistLink `json:"blocklist,omitempty"` + Allowlist *modelscapi.AllowlistLink `json:"allowlist,omitempty"` } -type forcePull struct { - Blocklist *blocklistLink `json:"blocklist,omitempty"` +type blocklistUnsubscribe struct { + Name string `json:"name"` } -type listUnsubscribe struct { +type allowlistUnsubscribe struct { Name string `json:"name"` + Id string `json:"id"` } func DecisionCmd(message *Message, p *Papi, sync bool) error { @@ -174,10 +169,10 @@ func AlertCmd(message *Message, p *Papi, sync bool) error { func ManagementCmd(message *Message, p *Papi, sync bool) error { ctx := context.TODO() - if sync { + /*if sync { p.Logger.Infof("Ignoring management command from PAPI in sync mode") return nil - } + }*/ switch message.Header.OperationCmd { case "blocklist_unsubscribe": @@ -186,7 +181,7 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error { return err } - unsubscribeMsg := listUnsubscribe{} + unsubscribeMsg := blocklistUnsubscribe{} if err := json.Unmarshal(data, &unsubscribeMsg); err != nil { return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err) } @@ -224,27 +219,79 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error { ctx := context.TODO() - if forcePullMsg.Blocklist == nil { + if forcePullMsg.Blocklist == nil && forcePullMsg.Allowlist == nil { p.Logger.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists") err = p.apic.PullTop(ctx, true) if err != nil { return fmt.Errorf("failed to force pull operation: %w", err) } - } else { - p.Logger.Infof("Received force_pull command from PAPI, pulling blocklist %s", forcePullMsg.Blocklist.Name) + } else if forcePullMsg.Blocklist != nil { + err = forcePullMsg.Blocklist.Validate(strfmt.Default) + + if err != nil { + return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err) + } + + p.Logger.Infof("Received blocklist force_pull command from PAPI, pulling blocklist %s", *forcePullMsg.Blocklist.Name) err = p.apic.PullBlocklist(ctx, &modelscapi.BlocklistLink{ - Name: &forcePullMsg.Blocklist.Name, - URL: &forcePullMsg.Blocklist.Url, - Remediation: &forcePullMsg.Blocklist.Remediation, - Scope: &forcePullMsg.Blocklist.Scope, - Duration: &forcePullMsg.Blocklist.Duration, + Name: forcePullMsg.Blocklist.Name, + URL: forcePullMsg.Blocklist.URL, + Remediation: forcePullMsg.Blocklist.Remediation, + Scope: forcePullMsg.Blocklist.Scope, + Duration: forcePullMsg.Blocklist.Duration, + }, true) + if err != nil { + return fmt.Errorf("failed to force pull operation: %w", err) + } + } else if forcePullMsg.Allowlist != nil { + err = forcePullMsg.Allowlist.Validate(strfmt.Default) + + if err != nil { + return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err) + } + + p.Logger.Infof("Received allowlist force_pull command from PAPI, pulling allowlist %s", *forcePullMsg.Allowlist.Name) + + err = p.apic.PullAllowlist(ctx, &modelscapi.AllowlistLink{ + Name: forcePullMsg.Allowlist.Name, + URL: forcePullMsg.Allowlist.URL, + ID: forcePullMsg.Allowlist.ID, + CreatedAt: forcePullMsg.Allowlist.CreatedAt, + UpdatedAt: forcePullMsg.Allowlist.UpdatedAt, }, true) if err != nil { return fmt.Errorf("failed to force pull operation: %w", err) } } + case "allowlist_unsubscribe": + data, err := json.Marshal(message.Data) + if err != nil { + return err + } + + unsubscribeMsg := allowlistUnsubscribe{} + + if err := json.Unmarshal(data, &unsubscribeMsg); err != nil { + return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err) + } + + if unsubscribeMsg.Name == "" { + return fmt.Errorf("message for '%s' contains bad data format: missing allowlist name", message.Header.OperationType) + } + + p.Logger.Infof("Received allowlist_unsubscribe command from PAPI, unsubscribing from allowlist %s", unsubscribeMsg.Name) + + err = p.DBClient.DeleteAllowList(ctx, unsubscribeMsg.Name, true) + + if err != nil { + if !ent.IsNotFound(err) { + return err + } + p.Logger.Warningf("Allowlist %s not found", unsubscribeMsg.Name) + } + return nil default: return fmt.Errorf("unknown command '%s' for operation type '%s'", message.Header.OperationCmd, message.Header.OperationType) } diff --git a/pkg/appsec/appsec.go b/pkg/appsec/appsec.go index 30784b23db0..5f01f76d993 100644 --- a/pkg/appsec/appsec.go +++ b/pkg/appsec/appsec.go @@ -1,7 +1,6 @@ package appsec import ( - "errors" "fmt" "net/http" "os" @@ -150,6 +149,17 @@ func (w *AppsecRuntimeConfig) ClearResponse() { w.Response.SendAlert = true } +func (wc *AppsecConfig) SetUpLogger() { + if wc.LogLevel == nil { + lvl := wc.Logger.Logger.GetLevel() + wc.LogLevel = &lvl + } + + /* wc.Name is actually the datasource name.*/ + wc.Logger = wc.Logger.Dup().WithField("name", wc.Name) + wc.Logger.Logger.SetLevel(*wc.LogLevel) +} + func (wc *AppsecConfig) LoadByPath(file string) error { wc.Logger.Debugf("loading config %s", file) @@ -157,20 +167,65 @@ func (wc *AppsecConfig) LoadByPath(file string) error { if err != nil { return fmt.Errorf("unable to read file %s : %s", file, err) } - err = yaml.UnmarshalStrict(yamlFile, wc) + + //as LoadByPath can be called several time, we append rules/hooks, but override other options + var tmp AppsecConfig + + err = yaml.UnmarshalStrict(yamlFile, &tmp) if err != nil { return fmt.Errorf("unable to parse yaml file %s : %s", file, err) } - if wc.Name == "" { - return errors.New("name cannot be empty") + if wc.Name == "" && tmp.Name != "" { + wc.Name = tmp.Name } - if wc.LogLevel == nil { - lvl := wc.Logger.Logger.GetLevel() - wc.LogLevel = &lvl + + //We can append rules/hooks + if tmp.OutOfBandRules != nil { + wc.OutOfBandRules = append(wc.OutOfBandRules, tmp.OutOfBandRules...) } - wc.Logger = wc.Logger.Dup().WithField("name", wc.Name) - wc.Logger.Logger.SetLevel(*wc.LogLevel) + if tmp.InBandRules != nil { + wc.InBandRules = append(wc.InBandRules, tmp.InBandRules...) + } + if tmp.OnLoad != nil { + wc.OnLoad = append(wc.OnLoad, tmp.OnLoad...) + } + if tmp.PreEval != nil { + wc.PreEval = append(wc.PreEval, tmp.PreEval...) + } + if tmp.PostEval != nil { + wc.PostEval = append(wc.PostEval, tmp.PostEval...) + } + if tmp.OnMatch != nil { + wc.OnMatch = append(wc.OnMatch, tmp.OnMatch...) + } + if tmp.VariablesTracking != nil { + wc.VariablesTracking = append(wc.VariablesTracking, tmp.VariablesTracking...) + } + + //override other options + wc.LogLevel = tmp.LogLevel + + wc.DefaultRemediation = tmp.DefaultRemediation + wc.DefaultPassAction = tmp.DefaultPassAction + wc.BouncerBlockedHTTPCode = tmp.BouncerBlockedHTTPCode + wc.BouncerPassedHTTPCode = tmp.BouncerPassedHTTPCode + wc.UserBlockedHTTPCode = tmp.UserBlockedHTTPCode + wc.UserPassedHTTPCode = tmp.UserPassedHTTPCode + + if tmp.InbandOptions.DisableBodyInspection { + wc.InbandOptions.DisableBodyInspection = true + } + if tmp.InbandOptions.RequestBodyInMemoryLimit != nil { + wc.InbandOptions.RequestBodyInMemoryLimit = tmp.InbandOptions.RequestBodyInMemoryLimit + } + if tmp.OutOfBandOptions.DisableBodyInspection { + wc.OutOfBandOptions.DisableBodyInspection = true + } + if tmp.OutOfBandOptions.RequestBodyInMemoryLimit != nil { + wc.OutOfBandOptions.RequestBodyInMemoryLimit = tmp.OutOfBandOptions.RequestBodyInMemoryLimit + } + return nil } diff --git a/pkg/appsec/appsec_rule/appsec_rule.go b/pkg/appsec/appsec_rule/appsec_rule.go index 136d8b11cb7..9d47c0eed5c 100644 --- a/pkg/appsec/appsec_rule/appsec_rule.go +++ b/pkg/appsec/appsec_rule/appsec_rule.go @@ -47,7 +47,6 @@ type CustomRule struct { } func (v *CustomRule) Convert(ruleType string, appsecRuleName string) (string, []uint32, error) { - if v.Zones == nil && v.And == nil && v.Or == nil { return "", nil, errors.New("no zones defined") } diff --git a/pkg/appsec/appsec_rule/modsec_rule_test.go b/pkg/appsec/appsec_rule/modsec_rule_test.go index ffb8a15ff1f..74e9b85426e 100644 --- a/pkg/appsec/appsec_rule/modsec_rule_test.go +++ b/pkg/appsec/appsec_rule/modsec_rule_test.go @@ -88,7 +88,6 @@ func TestVPatchRuleString(t *testing.T) { rule: CustomRule{ And: []CustomRule{ { - Zones: []string{"ARGS"}, Variables: []string{"foo"}, Match: Match{Type: "regex", Value: "[^a-zA-Z]"}, @@ -161,7 +160,6 @@ SecRule ARGS_GET:foo "@rx [^a-zA-Z]" "id:1519945803,phase:2,deny,log,msg:'OR AND for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { actual, _, err := tt.rule.Convert(ModsecurityRuleType, tt.name) - if err != nil { t.Errorf("Error converting rule: %s", err) } diff --git a/pkg/appsec/coraza_logger.go b/pkg/appsec/coraza_logger.go index d2c1612cbd7..93e31be5876 100644 --- a/pkg/appsec/coraza_logger.go +++ b/pkg/appsec/coraza_logger.go @@ -124,7 +124,7 @@ func (e *crzLogEvent) Stringer(key string, val fmt.Stringer) dbg.Event { return e } -func (e crzLogEvent) IsEnabled() bool { +func (e *crzLogEvent) IsEnabled() bool { return !e.muted } diff --git a/pkg/appsec/request_test.go b/pkg/appsec/request_test.go index f8333e4e5f9..8b457e24dab 100644 --- a/pkg/appsec/request_test.go +++ b/pkg/appsec/request_test.go @@ -3,7 +3,6 @@ package appsec import "testing" func TestBodyDumper(t *testing.T) { - tests := []struct { name string req *ParsedRequest @@ -159,7 +158,6 @@ func TestBodyDumper(t *testing.T) { } for idx, test := range tests { - t.Run(test.name, func(t *testing.T) { orig_dr := test.req.DumpRequest() result := test.filter(orig_dr).GetFilteredRequest() @@ -177,5 +175,4 @@ func TestBodyDumper(t *testing.T) { } }) } - } diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go index a4e0bd0127a..4da9fd5bf7b 100644 --- a/pkg/cache/cache_test.go +++ b/pkg/cache/cache_test.go @@ -5,26 +5,27 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCreateSetGet(t *testing.T) { err := CacheInit(CacheCfg{Name: "test", Size: 100, TTL: 1 * time.Second}) - assert.Empty(t, err) + require.NoError(t, err) //set & get err = SetKey("test", "testkey0", "testvalue1", nil) - assert.Empty(t, err) + require.NoError(t, err) ret, err := GetKey("test", "testkey0") assert.Equal(t, "testvalue1", ret) - assert.Empty(t, err) + require.NoError(t, err) //re-set err = SetKey("test", "testkey0", "testvalue2", nil) - assert.Empty(t, err) + require.NoError(t, err) assert.Equal(t, "testvalue1", ret) - assert.Empty(t, err) + require.NoError(t, err) //expire time.Sleep(1500 * time.Millisecond) ret, err = GetKey("test", "testkey0") assert.Equal(t, "", ret) - assert.Empty(t, err) + require.NoError(t, err) } diff --git a/pkg/csconfig/api.go b/pkg/csconfig/api.go index 3014b729a9e..5f2f8f9248b 100644 --- a/pkg/csconfig/api.go +++ b/pkg/csconfig/api.go @@ -38,10 +38,17 @@ type ApiCredentialsCfg struct { CertPath string `yaml:"cert_path,omitempty"` } -/*global api config (for lapi->oapi)*/ +type CapiPullConfig struct { + Community *bool `yaml:"community,omitempty"` + Blocklists *bool `yaml:"blocklists,omitempty"` +} + +/*global api config (for lapi->capi)*/ type OnlineApiClientCfg struct { CredentialsFilePath string `yaml:"credentials_path,omitempty"` // credz will be edited by software, store in diff file Credentials *ApiCredentialsCfg `yaml:"-"` + PullConfig CapiPullConfig `yaml:"pull,omitempty"` + Sharing *bool `yaml:"sharing,omitempty"` } /*local api config (for crowdsec/cscli->lapi)*/ @@ -344,6 +351,21 @@ func (c *Config) LoadAPIServer(inCli bool) error { log.Printf("push and pull to Central API disabled") } + //Set default values for CAPI push/pull + if c.API.Server.OnlineClient != nil { + if c.API.Server.OnlineClient.PullConfig.Community == nil { + c.API.Server.OnlineClient.PullConfig.Community = ptr.Of(true) + } + + if c.API.Server.OnlineClient.PullConfig.Blocklists == nil { + c.API.Server.OnlineClient.PullConfig.Blocklists = ptr.Of(true) + } + + if c.API.Server.OnlineClient.Sharing == nil { + c.API.Server.OnlineClient.Sharing = ptr.Of(true) + } + } + if err := c.LoadDBConfig(inCli); err != nil { return err } diff --git a/pkg/csconfig/api_test.go b/pkg/csconfig/api_test.go index dff3c3afc8c..17802ba31dd 100644 --- a/pkg/csconfig/api_test.go +++ b/pkg/csconfig/api_test.go @@ -212,6 +212,11 @@ func TestLoadAPIServer(t *testing.T) { Login: "test", Password: "testpassword", }, + Sharing: ptr.Of(true), + PullConfig: CapiPullConfig{ + Community: ptr.Of(true), + Blocklists: ptr.Of(true), + }, }, Profiles: tmpLAPI.Profiles, ProfilesPath: "./testdata/profiles.yaml", diff --git a/pkg/csplugin/broker.go b/pkg/csplugin/broker.go index e996fa9b68c..f53c831e186 100644 --- a/pkg/csplugin/broker.go +++ b/pkg/csplugin/broker.go @@ -91,7 +91,6 @@ func (pb *PluginBroker) Init(ctx context.Context, pluginCfg *csconfig.PluginCfg, pb.watcher = PluginWatcher{} pb.watcher.Init(pb.pluginConfigByName, pb.alertsByPluginName) return nil - } func (pb *PluginBroker) Kill() { @@ -166,6 +165,7 @@ func (pb *PluginBroker) addProfileAlert(profileAlert ProfileAlert) { pb.watcher.Inserts <- pluginName } } + func (pb *PluginBroker) profilesContainPlugin(pluginName string) bool { for _, profileCfg := range pb.profileConfigs { for _, name := range profileCfg.Notifications { @@ -176,6 +176,7 @@ func (pb *PluginBroker) profilesContainPlugin(pluginName string) bool { } return false } + func (pb *PluginBroker) loadConfig(path string) error { files, err := listFilesAtPath(path) if err != nil { @@ -277,7 +278,6 @@ func (pb *PluginBroker) loadPlugins(ctx context.Context, path string) error { } func (pb *PluginBroker) loadNotificationPlugin(name string, binaryPath string) (protobufs.NotifierServer, error) { - handshake, err := getHandshake() if err != nil { return nil, err diff --git a/pkg/cticlient/types.go b/pkg/cticlient/types.go index 2ad0a6eb34e..954d24641b4 100644 --- a/pkg/cticlient/types.go +++ b/pkg/cticlient/types.go @@ -210,7 +210,6 @@ func (c *SmokeItem) GetFalsePositives() []string { } func (c *SmokeItem) IsFalsePositive() bool { - if c.Classifications.FalsePositives != nil { if len(c.Classifications.FalsePositives) > 0 { return true @@ -284,7 +283,6 @@ func (c *FireItem) GetFalsePositives() []string { } func (c *FireItem) IsFalsePositive() bool { - if c.Classifications.FalsePositives != nil { if len(c.Classifications.FalsePositives) > 0 { return true diff --git a/pkg/cwversion/component/component.go b/pkg/cwversion/component/component.go index 4036b63cf00..7ed596525e0 100644 --- a/pkg/cwversion/component/component.go +++ b/pkg/cwversion/component/component.go @@ -7,20 +7,21 @@ package component // Built is a map of all the known components, and whether they are built-in or not. // This is populated as soon as possible by the respective init() functions -var Built = map[string]bool { - "datasource_appsec": false, - "datasource_cloudwatch": false, - "datasource_docker": false, - "datasource_file": false, - "datasource_journalctl": false, - "datasource_k8s-audit": false, - "datasource_kafka": false, - "datasource_kinesis": false, - "datasource_loki": false, - "datasource_s3": false, - "datasource_syslog": false, - "datasource_wineventlog":false, - "cscli_setup": false, +var Built = map[string]bool{ + "datasource_appsec": false, + "datasource_cloudwatch": false, + "datasource_docker": false, + "datasource_file": false, + "datasource_journalctl": false, + "datasource_k8s-audit": false, + "datasource_kafka": false, + "datasource_kinesis": false, + "datasource_loki": false, + "datasource_s3": false, + "datasource_syslog": false, + "datasource_wineventlog": false, + "datasource_http": false, + "cscli_setup": false, } func Register(name string) { diff --git a/pkg/database/allowlists.go b/pkg/database/allowlists.go new file mode 100644 index 00000000000..3dc6ea7306f --- /dev/null +++ b/pkg/database/allowlists.go @@ -0,0 +1,255 @@ +package database + +import ( + "context" + "fmt" + "time" + + "github.com/crowdsecurity/crowdsec/pkg/database/ent" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem" + "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +func (c *Client) CreateAllowList(ctx context.Context, name string, description string, fromConsole bool) (*ent.AllowList, error) { + allowlist, err := c.Ent.AllowList.Create(). + SetName(name). + SetFromConsole(fromConsole). + SetDescription(description). + Save(ctx) + if err != nil { + return nil, fmt.Errorf("unable to create allowlist: %w", err) + } + + return allowlist, nil +} + +func (c *Client) DeleteAllowList(ctx context.Context, name string, fromConsole bool) error { + + nbDeleted, err := c.Ent.AllowListItem.Delete().Where(allowlistitem.HasAllowlistWith(allowlist.NameEQ(name), allowlist.FromConsoleEQ(fromConsole))).Exec(ctx) + if err != nil { + return fmt.Errorf("unable to delete allowlist items: %w", err) + } + + c.Log.Debugf("deleted %d items from allowlist %s", nbDeleted, name) + + nbDeleted, err = c.Ent.AllowList. + Delete(). + Where(allowlist.NameEQ(name), allowlist.FromConsoleEQ(fromConsole)). + Exec(ctx) + if err != nil { + return fmt.Errorf("unable to delete allowlist: %w", err) + } + + if nbDeleted == 0 { + return fmt.Errorf("allowlist %s not found", name) + } + + return nil +} + +func (c *Client) ListAllowLists(ctx context.Context, withContent bool) ([]*ent.AllowList, error) { + q := c.Ent.AllowList.Query() + if withContent { + q = q.WithAllowlistItems() + } + result, err := q.All(ctx) + if err != nil { + return nil, fmt.Errorf("unable to list allowlists: %w", err) + } + + return result, nil +} + +func (c *Client) GetAllowList(ctx context.Context, name string, withContent bool) (*ent.AllowList, error) { + q := c.Ent.AllowList.Query().Where(allowlist.NameEQ(name)) + if withContent { + q = q.WithAllowlistItems() + } + result, err := q.First(ctx) + if err != nil { + return nil, err + } + + return result, nil +} + +func (c *Client) AddToAllowlist(ctx context.Context, list *ent.AllowList, items []*models.AllowlistItem) error { + successCount := 0 + + c.Log.Debugf("adding %d values to allowlist %s", len(items), list.Name) + c.Log.Tracef("values: %+v", items) + + for _, item := range items { + //FIXME: wrap this in a transaction + c.Log.Debugf("adding value %s to allowlist %s", item.Value, list.Name) + sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(item.Value) + if err != nil { + c.Log.Errorf("unable to parse value %s: %s", item.Value, err) + continue + } + query := c.Ent.AllowListItem.Create(). + SetValue(item.Value). + SetIPSize(int64(sz)). + SetStartIP(start_ip). + SetStartSuffix(start_sfx). + SetEndIP(end_ip). + SetEndSuffix(end_sfx). + SetComment(item.Description) + + if !time.Time(item.Expiration).IsZero() { + query = query.SetExpiresAt(time.Time(item.Expiration)) + } + + content, err := query.Save(ctx) + if err != nil { + c.Log.Errorf("unable to add value to allowlist: %s", err) + } + + c.Log.Debugf("Updating allowlist %s with value %s (exp: %s)", list.Name, item.Value, item.Expiration) + + //We don't have a clean way to handle name conflict from the console, so use id + err = c.Ent.AllowList.Update().AddAllowlistItems(content).Where(allowlist.IDEQ(list.ID)).Exec(ctx) + if err != nil { + c.Log.Errorf("unable to add value to allowlist: %s", err) + continue + } + successCount++ + } + + c.Log.Infof("added %d values to allowlist %s", successCount, list.Name) + + return nil +} + +func (c *Client) RemoveFromAllowlist(ctx context.Context, list *ent.AllowList, values ...string) (int, error) { + c.Log.Debugf("removing %d values from allowlist %s", len(values), list.Name) + c.Log.Tracef("values: %v", values) + + nbDeleted, err := c.Ent.AllowListItem.Delete().Where( + allowlistitem.HasAllowlistWith(allowlist.IDEQ(list.ID)), + allowlistitem.ValueIn(values...), + ).Exec(ctx) + + if err != nil { + return 0, fmt.Errorf("unable to remove values from allowlist: %w", err) + } + + return nbDeleted, nil +} + +func (c *Client) ReplaceAllowlist(ctx context.Context, list *ent.AllowList, items []*models.AllowlistItem, fromConsole bool) error { + c.Log.Debugf("replacing values in allowlist %s", list.Name) + c.Log.Tracef("items: %+v", items) + + _, err := c.Ent.AllowListItem.Delete().Where(allowlistitem.HasAllowlistWith(allowlist.IDEQ(list.ID))).Exec(ctx) + + if err != nil { + return fmt.Errorf("unable to delete allowlist contents: %w", err) + } + + err = c.AddToAllowlist(ctx, list, items) + + if err != nil { + return fmt.Errorf("unable to add values to allowlist: %w", err) + } + + if !list.FromConsole && fromConsole { + c.Log.Infof("marking allowlist %s as managed from console", list.Name) + err = c.Ent.AllowList.Update().SetFromConsole(fromConsole).Where(allowlist.IDEQ(list.ID)).Exec(ctx) + if err != nil { + return fmt.Errorf("unable to update allowlist: %w", err) + } + } + + return nil +} + +func (c *Client) IsAllowlisted(ctx context.Context, value string) (bool, error) { + /* + Few cases: + - value is an IP/range directly is in allowlist + - value is an IP/range in a range in allowlist + - value is a range and an IP/range belonging to it is in allowlist + */ + + sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(value) + if err != nil { + return false, fmt.Errorf("unable to parse value %s: %w", value, err) + } + + c.Log.Debugf("checking if %s is allowlisted", value) + + now := time.Now().UTC() + query := c.Ent.AllowListItem.Query().Where( + allowlistitem.Or( + allowlistitem.ExpiresAtGTE(now), + allowlistitem.ExpiresAtIsNil(), + ), + allowlistitem.IPSizeEQ(int64(sz)), + ) + + if sz == 4 { + query = query.Where( + allowlistitem.Or( + // Value contained inside a range or exact match + allowlistitem.And( + allowlistitem.StartIPLTE(start_ip), + allowlistitem.EndIPGTE(end_ip), + ), + // Value contains another allowlisted value + allowlistitem.And( + allowlistitem.StartIPGTE(start_ip), + allowlistitem.EndIPLTE(end_ip), + ), + )) + } + + if sz == 16 { + query = query.Where( + // Value contained inside a range or exact match + allowlistitem.Or( + allowlistitem.And( + allowlistitem.Or( + allowlistitem.StartIPLT(start_ip), + allowlistitem.And( + allowlistitem.StartIPEQ(start_ip), + allowlistitem.StartSuffixLTE(start_sfx), + )), + allowlistitem.Or( + allowlistitem.EndIPGT(end_ip), + allowlistitem.And( + allowlistitem.EndIPEQ(end_ip), + allowlistitem.EndSuffixGTE(end_sfx), + ), + ), + ), + // Value contains another allowlisted value + allowlistitem.And( + allowlistitem.Or( + allowlistitem.StartIPGT(start_ip), + allowlistitem.And( + allowlistitem.StartIPEQ(start_ip), + allowlistitem.StartSuffixGTE(start_sfx), + )), + allowlistitem.Or( + allowlistitem.EndIPLT(end_ip), + allowlistitem.And( + allowlistitem.EndIPEQ(end_ip), + allowlistitem.EndSuffixLTE(end_sfx), + ), + ), + ), + ), + ) + } + + allowed, err := query.Exist(ctx) + + if err != nil { + return false, fmt.Errorf("unable to check if value is allowlisted: %w", err) + } + + return allowed, nil +} diff --git a/pkg/database/allowlists_test.go b/pkg/database/allowlists_test.go new file mode 100644 index 00000000000..10f1b3aa9ae --- /dev/null +++ b/pkg/database/allowlists_test.go @@ -0,0 +1,99 @@ +package database + +import ( + "context" + "os" + "testing" + "time" + + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/go-openapi/strfmt" + "github.com/stretchr/testify/require" +) + +func getDBClient(t *testing.T, ctx context.Context) *Client { + t.Helper() + + dbPath, err := os.CreateTemp("", "*sqlite") + require.NoError(t, err) + dbClient, err := NewClient(ctx, &csconfig.DatabaseCfg{ + Type: "sqlite", + DbName: "crowdsec", + DbPath: dbPath.Name(), + }) + require.NoError(t, err) + + return dbClient +} + +func TestCheckAllowlist(t *testing.T) { + + ctx := context.Background() + dbClient := getDBClient(t, ctx) + + allowlist, err := dbClient.CreateAllowList(ctx, "test", "test", false) + + require.NoError(t, err) + + err = dbClient.AddToAllowlist(ctx, allowlist, []*models.AllowlistItem{ + { + CreatedAt: strfmt.DateTime(time.Now()), + Value: "1.2.3.4", + }, + { + CreatedAt: strfmt.DateTime(time.Now()), + Value: "8.0.0.0/8", + }, + { + CreatedAt: strfmt.DateTime(time.Now()), + Value: "2001:db8::/32", + }, + { + CreatedAt: strfmt.DateTime(time.Now()), + Value: "2.3.4.5", + Expiration: strfmt.DateTime(time.Now().Add(-time.Hour)), // expired item + }, + { + CreatedAt: strfmt.DateTime(time.Now()), + Value: "8a95:c186:9f96:4c75:0dad:49c6:ff62:94b8", + }, + }) + + require.NoError(t, err) + + // Exatch match + allowlisted, err := dbClient.IsAllowlisted(ctx, "1.2.3.4") + require.NoError(t, err) + require.True(t, allowlisted) + + // CIDR match + allowlisted, err = dbClient.IsAllowlisted(ctx, "8.8.8.8") + require.NoError(t, err) + require.True(t, allowlisted) + + // IPv6 match + allowlisted, err = dbClient.IsAllowlisted(ctx, "2001:db8::1") + require.NoError(t, err) + require.True(t, allowlisted) + + // Expired item + allowlisted, err = dbClient.IsAllowlisted(ctx, "2.3.4.5") + require.NoError(t, err) + require.False(t, allowlisted) + + // Decision on a range that contains an allowlisted value + allowlisted, err = dbClient.IsAllowlisted(ctx, "1.2.3.0/24") + require.NoError(t, err) + require.True(t, allowlisted) + + // No match + allowlisted, err = dbClient.IsAllowlisted(ctx, "42.42.42.42") + require.NoError(t, err) + require.False(t, allowlisted) + + // IPv6 range that contains an allowlisted value + allowlisted, err = dbClient.IsAllowlisted(ctx, "8a95:c186:9f96:4c75::/64") + require.NoError(t, err) + require.True(t, allowlisted) +} diff --git a/pkg/database/bouncers.go b/pkg/database/bouncers.go index 04ef830ae72..f9e62bc6522 100644 --- a/pkg/database/bouncers.go +++ b/pkg/database/bouncers.go @@ -41,8 +41,19 @@ func (c *Client) BouncerUpdateBaseMetrics(ctx context.Context, bouncerName strin return nil } -func (c *Client) SelectBouncer(ctx context.Context, apiKeyHash string) (*ent.Bouncer, error) { - result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash)).First(ctx) +func (c *Client) SelectBouncers(ctx context.Context, apiKeyHash string, authType string) ([]*ent.Bouncer, error) { + //Order by ID so manually created bouncer will be first in the list to use as the base name + //when automatically creating a new entry if API keys are shared + result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash), bouncer.AuthTypeEQ(authType)).Order(ent.Asc(bouncer.FieldID)).All(ctx) + if err != nil { + return nil, err + } + + return result, nil +} + +func (c *Client) SelectBouncerWithIP(ctx context.Context, apiKeyHash string, clientIP string) (*ent.Bouncer, error) { + result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash), bouncer.IPAddressEQ(clientIP)).First(ctx) if err != nil { return nil, err } @@ -68,13 +79,15 @@ func (c *Client) ListBouncers(ctx context.Context) ([]*ent.Bouncer, error) { return result, nil } -func (c *Client) CreateBouncer(ctx context.Context, name string, ipAddr string, apiKey string, authType string) (*ent.Bouncer, error) { +func (c *Client) CreateBouncer(ctx context.Context, name string, ipAddr string, apiKey string, authType string, autoCreated bool) (*ent.Bouncer, error) { bouncer, err := c.Ent.Bouncer. Create(). SetName(name). SetAPIKey(apiKey). SetRevoked(false). SetAuthType(authType). + SetIPAddress(ipAddr). + SetAutoCreated(autoCreated). Save(ctx) if err != nil { if ent.IsConstraintError(err) { diff --git a/pkg/database/ent/allowlist.go b/pkg/database/ent/allowlist.go new file mode 100644 index 00000000000..99b36687a7b --- /dev/null +++ b/pkg/database/ent/allowlist.go @@ -0,0 +1,189 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist" +) + +// AllowList is the model entity for the AllowList schema. +type AllowList struct { + config `json:"-"` + // ID of the ent. + ID int `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // Name holds the value of the "name" field. + Name string `json:"name,omitempty"` + // FromConsole holds the value of the "from_console" field. + FromConsole bool `json:"from_console,omitempty"` + // Description holds the value of the "description" field. + Description string `json:"description,omitempty"` + // AllowlistID holds the value of the "allowlist_id" field. + AllowlistID string `json:"allowlist_id,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the AllowListQuery when eager-loading is set. + Edges AllowListEdges `json:"edges"` + selectValues sql.SelectValues +} + +// AllowListEdges holds the relations/edges for other nodes in the graph. +type AllowListEdges struct { + // AllowlistItems holds the value of the allowlist_items edge. + AllowlistItems []*AllowListItem `json:"allowlist_items,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [1]bool +} + +// AllowlistItemsOrErr returns the AllowlistItems value or an error if the edge +// was not loaded in eager-loading. +func (e AllowListEdges) AllowlistItemsOrErr() ([]*AllowListItem, error) { + if e.loadedTypes[0] { + return e.AllowlistItems, nil + } + return nil, &NotLoadedError{edge: "allowlist_items"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*AllowList) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case allowlist.FieldFromConsole: + values[i] = new(sql.NullBool) + case allowlist.FieldID: + values[i] = new(sql.NullInt64) + case allowlist.FieldName, allowlist.FieldDescription, allowlist.FieldAllowlistID: + values[i] = new(sql.NullString) + case allowlist.FieldCreatedAt, allowlist.FieldUpdatedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the AllowList fields. +func (al *AllowList) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case allowlist.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + al.ID = int(value.Int64) + case allowlist.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + al.CreatedAt = value.Time + } + case allowlist.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + al.UpdatedAt = value.Time + } + case allowlist.FieldName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[i]) + } else if value.Valid { + al.Name = value.String + } + case allowlist.FieldFromConsole: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field from_console", values[i]) + } else if value.Valid { + al.FromConsole = value.Bool + } + case allowlist.FieldDescription: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field description", values[i]) + } else if value.Valid { + al.Description = value.String + } + case allowlist.FieldAllowlistID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field allowlist_id", values[i]) + } else if value.Valid { + al.AllowlistID = value.String + } + default: + al.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the AllowList. +// This includes values selected through modifiers, order, etc. +func (al *AllowList) Value(name string) (ent.Value, error) { + return al.selectValues.Get(name) +} + +// QueryAllowlistItems queries the "allowlist_items" edge of the AllowList entity. +func (al *AllowList) QueryAllowlistItems() *AllowListItemQuery { + return NewAllowListClient(al.config).QueryAllowlistItems(al) +} + +// Update returns a builder for updating this AllowList. +// Note that you need to call AllowList.Unwrap() before calling this method if this AllowList +// was returned from a transaction, and the transaction was committed or rolled back. +func (al *AllowList) Update() *AllowListUpdateOne { + return NewAllowListClient(al.config).UpdateOne(al) +} + +// Unwrap unwraps the AllowList entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (al *AllowList) Unwrap() *AllowList { + _tx, ok := al.config.driver.(*txDriver) + if !ok { + panic("ent: AllowList is not a transactional entity") + } + al.config.driver = _tx.drv + return al +} + +// String implements the fmt.Stringer. +func (al *AllowList) String() string { + var builder strings.Builder + builder.WriteString("AllowList(") + builder.WriteString(fmt.Sprintf("id=%v, ", al.ID)) + builder.WriteString("created_at=") + builder.WriteString(al.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(al.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("name=") + builder.WriteString(al.Name) + builder.WriteString(", ") + builder.WriteString("from_console=") + builder.WriteString(fmt.Sprintf("%v", al.FromConsole)) + builder.WriteString(", ") + builder.WriteString("description=") + builder.WriteString(al.Description) + builder.WriteString(", ") + builder.WriteString("allowlist_id=") + builder.WriteString(al.AllowlistID) + builder.WriteByte(')') + return builder.String() +} + +// AllowLists is a parsable slice of AllowList. +type AllowLists []*AllowList diff --git a/pkg/database/ent/allowlist/allowlist.go b/pkg/database/ent/allowlist/allowlist.go new file mode 100644 index 00000000000..36cac5c1b21 --- /dev/null +++ b/pkg/database/ent/allowlist/allowlist.go @@ -0,0 +1,133 @@ +// Code generated by ent, DO NOT EDIT. + +package allowlist + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the allowlist type in the database. + Label = "allow_list" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldName holds the string denoting the name field in the database. + FieldName = "name" + // FieldFromConsole holds the string denoting the from_console field in the database. + FieldFromConsole = "from_console" + // FieldDescription holds the string denoting the description field in the database. + FieldDescription = "description" + // FieldAllowlistID holds the string denoting the allowlist_id field in the database. + FieldAllowlistID = "allowlist_id" + // EdgeAllowlistItems holds the string denoting the allowlist_items edge name in mutations. + EdgeAllowlistItems = "allowlist_items" + // Table holds the table name of the allowlist in the database. + Table = "allow_lists" + // AllowlistItemsTable is the table that holds the allowlist_items relation/edge. The primary key declared below. + AllowlistItemsTable = "allow_list_allowlist_items" + // AllowlistItemsInverseTable is the table name for the AllowListItem entity. + // It exists in this package in order to avoid circular dependency with the "allowlistitem" package. + AllowlistItemsInverseTable = "allow_list_items" +) + +// Columns holds all SQL columns for allowlist fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldName, + FieldFromConsole, + FieldDescription, + FieldAllowlistID, +} + +var ( + // AllowlistItemsPrimaryKey and AllowlistItemsColumn2 are the table columns denoting the + // primary key for the allowlist_items relation (M2M). + AllowlistItemsPrimaryKey = []string{"allow_list_id", "allow_list_item_id"} +) + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time +) + +// OrderOption defines the ordering options for the AllowList queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByFromConsole orders the results by the from_console field. +func ByFromConsole(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldFromConsole, opts...).ToFunc() +} + +// ByDescription orders the results by the description field. +func ByDescription(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDescription, opts...).ToFunc() +} + +// ByAllowlistID orders the results by the allowlist_id field. +func ByAllowlistID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAllowlistID, opts...).ToFunc() +} + +// ByAllowlistItemsCount orders the results by allowlist_items count. +func ByAllowlistItemsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newAllowlistItemsStep(), opts...) + } +} + +// ByAllowlistItems orders the results by allowlist_items terms. +func ByAllowlistItems(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAllowlistItemsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newAllowlistItemsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AllowlistItemsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2M, false, AllowlistItemsTable, AllowlistItemsPrimaryKey...), + ) +} diff --git a/pkg/database/ent/allowlist/where.go b/pkg/database/ent/allowlist/where.go new file mode 100644 index 00000000000..d8b43be2cf9 --- /dev/null +++ b/pkg/database/ent/allowlist/where.go @@ -0,0 +1,429 @@ +// Code generated by ent, DO NOT EDIT. + +package allowlist + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int) predicate.AllowList { + return predicate.AllowList(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.AllowList { + return predicate.AllowList(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.AllowList { + return predicate.AllowList(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.AllowList { + return predicate.AllowList(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.AllowList { + return predicate.AllowList(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.AllowList { + return predicate.AllowList(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.AllowList { + return predicate.AllowList(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.AllowList { + return predicate.AllowList(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.AllowList { + return predicate.AllowList(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.AllowList { + return predicate.AllowList(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.AllowList { + return predicate.AllowList(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// Name applies equality check predicate on the "name" field. It's identical to NameEQ. +func Name(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldEQ(FieldName, v)) +} + +// FromConsole applies equality check predicate on the "from_console" field. It's identical to FromConsoleEQ. +func FromConsole(v bool) predicate.AllowList { + return predicate.AllowList(sql.FieldEQ(FieldFromConsole, v)) +} + +// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ. +func Description(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldEQ(FieldDescription, v)) +} + +// AllowlistID applies equality check predicate on the "allowlist_id" field. It's identical to AllowlistIDEQ. +func AllowlistID(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldEQ(FieldAllowlistID, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.AllowList { + return predicate.AllowList(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.AllowList { + return predicate.AllowList(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.AllowList { + return predicate.AllowList(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.AllowList { + return predicate.AllowList(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.AllowList { + return predicate.AllowList(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.AllowList { + return predicate.AllowList(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.AllowList { + return predicate.AllowList(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.AllowList { + return predicate.AllowList(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.AllowList { + return predicate.AllowList(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.AllowList { + return predicate.AllowList(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.AllowList { + return predicate.AllowList(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.AllowList { + return predicate.AllowList(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.AllowList { + return predicate.AllowList(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.AllowList { + return predicate.AllowList(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.AllowList { + return predicate.AllowList(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.AllowList { + return predicate.AllowList(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// NameEQ applies the EQ predicate on the "name" field. +func NameEQ(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldEQ(FieldName, v)) +} + +// NameNEQ applies the NEQ predicate on the "name" field. +func NameNEQ(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldNEQ(FieldName, v)) +} + +// NameIn applies the In predicate on the "name" field. +func NameIn(vs ...string) predicate.AllowList { + return predicate.AllowList(sql.FieldIn(FieldName, vs...)) +} + +// NameNotIn applies the NotIn predicate on the "name" field. +func NameNotIn(vs ...string) predicate.AllowList { + return predicate.AllowList(sql.FieldNotIn(FieldName, vs...)) +} + +// NameGT applies the GT predicate on the "name" field. +func NameGT(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldGT(FieldName, v)) +} + +// NameGTE applies the GTE predicate on the "name" field. +func NameGTE(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldGTE(FieldName, v)) +} + +// NameLT applies the LT predicate on the "name" field. +func NameLT(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldLT(FieldName, v)) +} + +// NameLTE applies the LTE predicate on the "name" field. +func NameLTE(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldLTE(FieldName, v)) +} + +// NameContains applies the Contains predicate on the "name" field. +func NameContains(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldContains(FieldName, v)) +} + +// NameHasPrefix applies the HasPrefix predicate on the "name" field. +func NameHasPrefix(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldHasPrefix(FieldName, v)) +} + +// NameHasSuffix applies the HasSuffix predicate on the "name" field. +func NameHasSuffix(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldHasSuffix(FieldName, v)) +} + +// NameEqualFold applies the EqualFold predicate on the "name" field. +func NameEqualFold(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldEqualFold(FieldName, v)) +} + +// NameContainsFold applies the ContainsFold predicate on the "name" field. +func NameContainsFold(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldContainsFold(FieldName, v)) +} + +// FromConsoleEQ applies the EQ predicate on the "from_console" field. +func FromConsoleEQ(v bool) predicate.AllowList { + return predicate.AllowList(sql.FieldEQ(FieldFromConsole, v)) +} + +// FromConsoleNEQ applies the NEQ predicate on the "from_console" field. +func FromConsoleNEQ(v bool) predicate.AllowList { + return predicate.AllowList(sql.FieldNEQ(FieldFromConsole, v)) +} + +// DescriptionEQ applies the EQ predicate on the "description" field. +func DescriptionEQ(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldEQ(FieldDescription, v)) +} + +// DescriptionNEQ applies the NEQ predicate on the "description" field. +func DescriptionNEQ(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldNEQ(FieldDescription, v)) +} + +// DescriptionIn applies the In predicate on the "description" field. +func DescriptionIn(vs ...string) predicate.AllowList { + return predicate.AllowList(sql.FieldIn(FieldDescription, vs...)) +} + +// DescriptionNotIn applies the NotIn predicate on the "description" field. +func DescriptionNotIn(vs ...string) predicate.AllowList { + return predicate.AllowList(sql.FieldNotIn(FieldDescription, vs...)) +} + +// DescriptionGT applies the GT predicate on the "description" field. +func DescriptionGT(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldGT(FieldDescription, v)) +} + +// DescriptionGTE applies the GTE predicate on the "description" field. +func DescriptionGTE(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldGTE(FieldDescription, v)) +} + +// DescriptionLT applies the LT predicate on the "description" field. +func DescriptionLT(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldLT(FieldDescription, v)) +} + +// DescriptionLTE applies the LTE predicate on the "description" field. +func DescriptionLTE(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldLTE(FieldDescription, v)) +} + +// DescriptionContains applies the Contains predicate on the "description" field. +func DescriptionContains(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldContains(FieldDescription, v)) +} + +// DescriptionHasPrefix applies the HasPrefix predicate on the "description" field. +func DescriptionHasPrefix(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldHasPrefix(FieldDescription, v)) +} + +// DescriptionHasSuffix applies the HasSuffix predicate on the "description" field. +func DescriptionHasSuffix(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldHasSuffix(FieldDescription, v)) +} + +// DescriptionIsNil applies the IsNil predicate on the "description" field. +func DescriptionIsNil() predicate.AllowList { + return predicate.AllowList(sql.FieldIsNull(FieldDescription)) +} + +// DescriptionNotNil applies the NotNil predicate on the "description" field. +func DescriptionNotNil() predicate.AllowList { + return predicate.AllowList(sql.FieldNotNull(FieldDescription)) +} + +// DescriptionEqualFold applies the EqualFold predicate on the "description" field. +func DescriptionEqualFold(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldEqualFold(FieldDescription, v)) +} + +// DescriptionContainsFold applies the ContainsFold predicate on the "description" field. +func DescriptionContainsFold(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldContainsFold(FieldDescription, v)) +} + +// AllowlistIDEQ applies the EQ predicate on the "allowlist_id" field. +func AllowlistIDEQ(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldEQ(FieldAllowlistID, v)) +} + +// AllowlistIDNEQ applies the NEQ predicate on the "allowlist_id" field. +func AllowlistIDNEQ(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldNEQ(FieldAllowlistID, v)) +} + +// AllowlistIDIn applies the In predicate on the "allowlist_id" field. +func AllowlistIDIn(vs ...string) predicate.AllowList { + return predicate.AllowList(sql.FieldIn(FieldAllowlistID, vs...)) +} + +// AllowlistIDNotIn applies the NotIn predicate on the "allowlist_id" field. +func AllowlistIDNotIn(vs ...string) predicate.AllowList { + return predicate.AllowList(sql.FieldNotIn(FieldAllowlistID, vs...)) +} + +// AllowlistIDGT applies the GT predicate on the "allowlist_id" field. +func AllowlistIDGT(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldGT(FieldAllowlistID, v)) +} + +// AllowlistIDGTE applies the GTE predicate on the "allowlist_id" field. +func AllowlistIDGTE(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldGTE(FieldAllowlistID, v)) +} + +// AllowlistIDLT applies the LT predicate on the "allowlist_id" field. +func AllowlistIDLT(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldLT(FieldAllowlistID, v)) +} + +// AllowlistIDLTE applies the LTE predicate on the "allowlist_id" field. +func AllowlistIDLTE(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldLTE(FieldAllowlistID, v)) +} + +// AllowlistIDContains applies the Contains predicate on the "allowlist_id" field. +func AllowlistIDContains(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldContains(FieldAllowlistID, v)) +} + +// AllowlistIDHasPrefix applies the HasPrefix predicate on the "allowlist_id" field. +func AllowlistIDHasPrefix(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldHasPrefix(FieldAllowlistID, v)) +} + +// AllowlistIDHasSuffix applies the HasSuffix predicate on the "allowlist_id" field. +func AllowlistIDHasSuffix(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldHasSuffix(FieldAllowlistID, v)) +} + +// AllowlistIDIsNil applies the IsNil predicate on the "allowlist_id" field. +func AllowlistIDIsNil() predicate.AllowList { + return predicate.AllowList(sql.FieldIsNull(FieldAllowlistID)) +} + +// AllowlistIDNotNil applies the NotNil predicate on the "allowlist_id" field. +func AllowlistIDNotNil() predicate.AllowList { + return predicate.AllowList(sql.FieldNotNull(FieldAllowlistID)) +} + +// AllowlistIDEqualFold applies the EqualFold predicate on the "allowlist_id" field. +func AllowlistIDEqualFold(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldEqualFold(FieldAllowlistID, v)) +} + +// AllowlistIDContainsFold applies the ContainsFold predicate on the "allowlist_id" field. +func AllowlistIDContainsFold(v string) predicate.AllowList { + return predicate.AllowList(sql.FieldContainsFold(FieldAllowlistID, v)) +} + +// HasAllowlistItems applies the HasEdge predicate on the "allowlist_items" edge. +func HasAllowlistItems() predicate.AllowList { + return predicate.AllowList(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2M, false, AllowlistItemsTable, AllowlistItemsPrimaryKey...), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAllowlistItemsWith applies the HasEdge predicate on the "allowlist_items" edge with a given conditions (other predicates). +func HasAllowlistItemsWith(preds ...predicate.AllowListItem) predicate.AllowList { + return predicate.AllowList(func(s *sql.Selector) { + step := newAllowlistItemsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.AllowList) predicate.AllowList { + return predicate.AllowList(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.AllowList) predicate.AllowList { + return predicate.AllowList(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.AllowList) predicate.AllowList { + return predicate.AllowList(sql.NotPredicates(p)) +} diff --git a/pkg/database/ent/allowlist_create.go b/pkg/database/ent/allowlist_create.go new file mode 100644 index 00000000000..ec9d29b6ae5 --- /dev/null +++ b/pkg/database/ent/allowlist_create.go @@ -0,0 +1,321 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem" +) + +// AllowListCreate is the builder for creating a AllowList entity. +type AllowListCreate struct { + config + mutation *AllowListMutation + hooks []Hook +} + +// SetCreatedAt sets the "created_at" field. +func (alc *AllowListCreate) SetCreatedAt(t time.Time) *AllowListCreate { + alc.mutation.SetCreatedAt(t) + return alc +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (alc *AllowListCreate) SetNillableCreatedAt(t *time.Time) *AllowListCreate { + if t != nil { + alc.SetCreatedAt(*t) + } + return alc +} + +// SetUpdatedAt sets the "updated_at" field. +func (alc *AllowListCreate) SetUpdatedAt(t time.Time) *AllowListCreate { + alc.mutation.SetUpdatedAt(t) + return alc +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (alc *AllowListCreate) SetNillableUpdatedAt(t *time.Time) *AllowListCreate { + if t != nil { + alc.SetUpdatedAt(*t) + } + return alc +} + +// SetName sets the "name" field. +func (alc *AllowListCreate) SetName(s string) *AllowListCreate { + alc.mutation.SetName(s) + return alc +} + +// SetFromConsole sets the "from_console" field. +func (alc *AllowListCreate) SetFromConsole(b bool) *AllowListCreate { + alc.mutation.SetFromConsole(b) + return alc +} + +// SetDescription sets the "description" field. +func (alc *AllowListCreate) SetDescription(s string) *AllowListCreate { + alc.mutation.SetDescription(s) + return alc +} + +// SetNillableDescription sets the "description" field if the given value is not nil. +func (alc *AllowListCreate) SetNillableDescription(s *string) *AllowListCreate { + if s != nil { + alc.SetDescription(*s) + } + return alc +} + +// SetAllowlistID sets the "allowlist_id" field. +func (alc *AllowListCreate) SetAllowlistID(s string) *AllowListCreate { + alc.mutation.SetAllowlistID(s) + return alc +} + +// SetNillableAllowlistID sets the "allowlist_id" field if the given value is not nil. +func (alc *AllowListCreate) SetNillableAllowlistID(s *string) *AllowListCreate { + if s != nil { + alc.SetAllowlistID(*s) + } + return alc +} + +// AddAllowlistItemIDs adds the "allowlist_items" edge to the AllowListItem entity by IDs. +func (alc *AllowListCreate) AddAllowlistItemIDs(ids ...int) *AllowListCreate { + alc.mutation.AddAllowlistItemIDs(ids...) + return alc +} + +// AddAllowlistItems adds the "allowlist_items" edges to the AllowListItem entity. +func (alc *AllowListCreate) AddAllowlistItems(a ...*AllowListItem) *AllowListCreate { + ids := make([]int, len(a)) + for i := range a { + ids[i] = a[i].ID + } + return alc.AddAllowlistItemIDs(ids...) +} + +// Mutation returns the AllowListMutation object of the builder. +func (alc *AllowListCreate) Mutation() *AllowListMutation { + return alc.mutation +} + +// Save creates the AllowList in the database. +func (alc *AllowListCreate) Save(ctx context.Context) (*AllowList, error) { + alc.defaults() + return withHooks(ctx, alc.sqlSave, alc.mutation, alc.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (alc *AllowListCreate) SaveX(ctx context.Context) *AllowList { + v, err := alc.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (alc *AllowListCreate) Exec(ctx context.Context) error { + _, err := alc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (alc *AllowListCreate) ExecX(ctx context.Context) { + if err := alc.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (alc *AllowListCreate) defaults() { + if _, ok := alc.mutation.CreatedAt(); !ok { + v := allowlist.DefaultCreatedAt() + alc.mutation.SetCreatedAt(v) + } + if _, ok := alc.mutation.UpdatedAt(); !ok { + v := allowlist.DefaultUpdatedAt() + alc.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (alc *AllowListCreate) check() error { + if _, ok := alc.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AllowList.created_at"`)} + } + if _, ok := alc.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AllowList.updated_at"`)} + } + if _, ok := alc.mutation.Name(); !ok { + return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "AllowList.name"`)} + } + if _, ok := alc.mutation.FromConsole(); !ok { + return &ValidationError{Name: "from_console", err: errors.New(`ent: missing required field "AllowList.from_console"`)} + } + return nil +} + +func (alc *AllowListCreate) sqlSave(ctx context.Context) (*AllowList, error) { + if err := alc.check(); err != nil { + return nil, err + } + _node, _spec := alc.createSpec() + if err := sqlgraph.CreateNode(ctx, alc.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int(id) + alc.mutation.id = &_node.ID + alc.mutation.done = true + return _node, nil +} + +func (alc *AllowListCreate) createSpec() (*AllowList, *sqlgraph.CreateSpec) { + var ( + _node = &AllowList{config: alc.config} + _spec = sqlgraph.NewCreateSpec(allowlist.Table, sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt)) + ) + if value, ok := alc.mutation.CreatedAt(); ok { + _spec.SetField(allowlist.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := alc.mutation.UpdatedAt(); ok { + _spec.SetField(allowlist.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := alc.mutation.Name(); ok { + _spec.SetField(allowlist.FieldName, field.TypeString, value) + _node.Name = value + } + if value, ok := alc.mutation.FromConsole(); ok { + _spec.SetField(allowlist.FieldFromConsole, field.TypeBool, value) + _node.FromConsole = value + } + if value, ok := alc.mutation.Description(); ok { + _spec.SetField(allowlist.FieldDescription, field.TypeString, value) + _node.Description = value + } + if value, ok := alc.mutation.AllowlistID(); ok { + _spec.SetField(allowlist.FieldAllowlistID, field.TypeString, value) + _node.AllowlistID = value + } + if nodes := alc.mutation.AllowlistItemsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: allowlist.AllowlistItemsTable, + Columns: allowlist.AllowlistItemsPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// AllowListCreateBulk is the builder for creating many AllowList entities in bulk. +type AllowListCreateBulk struct { + config + err error + builders []*AllowListCreate +} + +// Save creates the AllowList entities in the database. +func (alcb *AllowListCreateBulk) Save(ctx context.Context) ([]*AllowList, error) { + if alcb.err != nil { + return nil, alcb.err + } + specs := make([]*sqlgraph.CreateSpec, len(alcb.builders)) + nodes := make([]*AllowList, len(alcb.builders)) + mutators := make([]Mutator, len(alcb.builders)) + for i := range alcb.builders { + func(i int, root context.Context) { + builder := alcb.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*AllowListMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, alcb.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, alcb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, alcb.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (alcb *AllowListCreateBulk) SaveX(ctx context.Context) []*AllowList { + v, err := alcb.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (alcb *AllowListCreateBulk) Exec(ctx context.Context) error { + _, err := alcb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (alcb *AllowListCreateBulk) ExecX(ctx context.Context) { + if err := alcb.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/pkg/database/ent/allowlist_delete.go b/pkg/database/ent/allowlist_delete.go new file mode 100644 index 00000000000..dcfaa214f6f --- /dev/null +++ b/pkg/database/ent/allowlist_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// AllowListDelete is the builder for deleting a AllowList entity. +type AllowListDelete struct { + config + hooks []Hook + mutation *AllowListMutation +} + +// Where appends a list predicates to the AllowListDelete builder. +func (ald *AllowListDelete) Where(ps ...predicate.AllowList) *AllowListDelete { + ald.mutation.Where(ps...) + return ald +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (ald *AllowListDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, ald.sqlExec, ald.mutation, ald.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (ald *AllowListDelete) ExecX(ctx context.Context) int { + n, err := ald.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (ald *AllowListDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(allowlist.Table, sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt)) + if ps := ald.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, ald.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + ald.mutation.done = true + return affected, err +} + +// AllowListDeleteOne is the builder for deleting a single AllowList entity. +type AllowListDeleteOne struct { + ald *AllowListDelete +} + +// Where appends a list predicates to the AllowListDelete builder. +func (aldo *AllowListDeleteOne) Where(ps ...predicate.AllowList) *AllowListDeleteOne { + aldo.ald.mutation.Where(ps...) + return aldo +} + +// Exec executes the deletion query. +func (aldo *AllowListDeleteOne) Exec(ctx context.Context) error { + n, err := aldo.ald.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{allowlist.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (aldo *AllowListDeleteOne) ExecX(ctx context.Context) { + if err := aldo.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/pkg/database/ent/allowlist_query.go b/pkg/database/ent/allowlist_query.go new file mode 100644 index 00000000000..511c3c051f7 --- /dev/null +++ b/pkg/database/ent/allowlist_query.go @@ -0,0 +1,636 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "database/sql/driver" + "fmt" + "math" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// AllowListQuery is the builder for querying AllowList entities. +type AllowListQuery struct { + config + ctx *QueryContext + order []allowlist.OrderOption + inters []Interceptor + predicates []predicate.AllowList + withAllowlistItems *AllowListItemQuery + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the AllowListQuery builder. +func (alq *AllowListQuery) Where(ps ...predicate.AllowList) *AllowListQuery { + alq.predicates = append(alq.predicates, ps...) + return alq +} + +// Limit the number of records to be returned by this query. +func (alq *AllowListQuery) Limit(limit int) *AllowListQuery { + alq.ctx.Limit = &limit + return alq +} + +// Offset to start from. +func (alq *AllowListQuery) Offset(offset int) *AllowListQuery { + alq.ctx.Offset = &offset + return alq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (alq *AllowListQuery) Unique(unique bool) *AllowListQuery { + alq.ctx.Unique = &unique + return alq +} + +// Order specifies how the records should be ordered. +func (alq *AllowListQuery) Order(o ...allowlist.OrderOption) *AllowListQuery { + alq.order = append(alq.order, o...) + return alq +} + +// QueryAllowlistItems chains the current query on the "allowlist_items" edge. +func (alq *AllowListQuery) QueryAllowlistItems() *AllowListItemQuery { + query := (&AllowListItemClient{config: alq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := alq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := alq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(allowlist.Table, allowlist.FieldID, selector), + sqlgraph.To(allowlistitem.Table, allowlistitem.FieldID), + sqlgraph.Edge(sqlgraph.M2M, false, allowlist.AllowlistItemsTable, allowlist.AllowlistItemsPrimaryKey...), + ) + fromU = sqlgraph.SetNeighbors(alq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first AllowList entity from the query. +// Returns a *NotFoundError when no AllowList was found. +func (alq *AllowListQuery) First(ctx context.Context) (*AllowList, error) { + nodes, err := alq.Limit(1).All(setContextOp(ctx, alq.ctx, "First")) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{allowlist.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (alq *AllowListQuery) FirstX(ctx context.Context) *AllowList { + node, err := alq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first AllowList ID from the query. +// Returns a *NotFoundError when no AllowList ID was found. +func (alq *AllowListQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = alq.Limit(1).IDs(setContextOp(ctx, alq.ctx, "FirstID")); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{allowlist.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (alq *AllowListQuery) FirstIDX(ctx context.Context) int { + id, err := alq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single AllowList entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one AllowList entity is found. +// Returns a *NotFoundError when no AllowList entities are found. +func (alq *AllowListQuery) Only(ctx context.Context) (*AllowList, error) { + nodes, err := alq.Limit(2).All(setContextOp(ctx, alq.ctx, "Only")) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{allowlist.Label} + default: + return nil, &NotSingularError{allowlist.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (alq *AllowListQuery) OnlyX(ctx context.Context) *AllowList { + node, err := alq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only AllowList ID in the query. +// Returns a *NotSingularError when more than one AllowList ID is found. +// Returns a *NotFoundError when no entities are found. +func (alq *AllowListQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = alq.Limit(2).IDs(setContextOp(ctx, alq.ctx, "OnlyID")); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{allowlist.Label} + default: + err = &NotSingularError{allowlist.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (alq *AllowListQuery) OnlyIDX(ctx context.Context) int { + id, err := alq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of AllowLists. +func (alq *AllowListQuery) All(ctx context.Context) ([]*AllowList, error) { + ctx = setContextOp(ctx, alq.ctx, "All") + if err := alq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*AllowList, *AllowListQuery]() + return withInterceptors[[]*AllowList](ctx, alq, qr, alq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (alq *AllowListQuery) AllX(ctx context.Context) []*AllowList { + nodes, err := alq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of AllowList IDs. +func (alq *AllowListQuery) IDs(ctx context.Context) (ids []int, err error) { + if alq.ctx.Unique == nil && alq.path != nil { + alq.Unique(true) + } + ctx = setContextOp(ctx, alq.ctx, "IDs") + if err = alq.Select(allowlist.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (alq *AllowListQuery) IDsX(ctx context.Context) []int { + ids, err := alq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (alq *AllowListQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, alq.ctx, "Count") + if err := alq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, alq, querierCount[*AllowListQuery](), alq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (alq *AllowListQuery) CountX(ctx context.Context) int { + count, err := alq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (alq *AllowListQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, alq.ctx, "Exist") + switch _, err := alq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (alq *AllowListQuery) ExistX(ctx context.Context) bool { + exist, err := alq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the AllowListQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (alq *AllowListQuery) Clone() *AllowListQuery { + if alq == nil { + return nil + } + return &AllowListQuery{ + config: alq.config, + ctx: alq.ctx.Clone(), + order: append([]allowlist.OrderOption{}, alq.order...), + inters: append([]Interceptor{}, alq.inters...), + predicates: append([]predicate.AllowList{}, alq.predicates...), + withAllowlistItems: alq.withAllowlistItems.Clone(), + // clone intermediate query. + sql: alq.sql.Clone(), + path: alq.path, + } +} + +// WithAllowlistItems tells the query-builder to eager-load the nodes that are connected to +// the "allowlist_items" edge. The optional arguments are used to configure the query builder of the edge. +func (alq *AllowListQuery) WithAllowlistItems(opts ...func(*AllowListItemQuery)) *AllowListQuery { + query := (&AllowListItemClient{config: alq.config}).Query() + for _, opt := range opts { + opt(query) + } + alq.withAllowlistItems = query + return alq +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.AllowList.Query(). +// GroupBy(allowlist.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (alq *AllowListQuery) GroupBy(field string, fields ...string) *AllowListGroupBy { + alq.ctx.Fields = append([]string{field}, fields...) + grbuild := &AllowListGroupBy{build: alq} + grbuild.flds = &alq.ctx.Fields + grbuild.label = allowlist.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.AllowList.Query(). +// Select(allowlist.FieldCreatedAt). +// Scan(ctx, &v) +func (alq *AllowListQuery) Select(fields ...string) *AllowListSelect { + alq.ctx.Fields = append(alq.ctx.Fields, fields...) + sbuild := &AllowListSelect{AllowListQuery: alq} + sbuild.label = allowlist.Label + sbuild.flds, sbuild.scan = &alq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a AllowListSelect configured with the given aggregations. +func (alq *AllowListQuery) Aggregate(fns ...AggregateFunc) *AllowListSelect { + return alq.Select().Aggregate(fns...) +} + +func (alq *AllowListQuery) prepareQuery(ctx context.Context) error { + for _, inter := range alq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, alq); err != nil { + return err + } + } + } + for _, f := range alq.ctx.Fields { + if !allowlist.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if alq.path != nil { + prev, err := alq.path(ctx) + if err != nil { + return err + } + alq.sql = prev + } + return nil +} + +func (alq *AllowListQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AllowList, error) { + var ( + nodes = []*AllowList{} + _spec = alq.querySpec() + loadedTypes = [1]bool{ + alq.withAllowlistItems != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*AllowList).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &AllowList{config: alq.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, alq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := alq.withAllowlistItems; query != nil { + if err := alq.loadAllowlistItems(ctx, query, nodes, + func(n *AllowList) { n.Edges.AllowlistItems = []*AllowListItem{} }, + func(n *AllowList, e *AllowListItem) { n.Edges.AllowlistItems = append(n.Edges.AllowlistItems, e) }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (alq *AllowListQuery) loadAllowlistItems(ctx context.Context, query *AllowListItemQuery, nodes []*AllowList, init func(*AllowList), assign func(*AllowList, *AllowListItem)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*AllowList) + nids := make(map[int]map[*AllowList]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(allowlist.AllowlistItemsTable) + s.Join(joinT).On(s.C(allowlistitem.FieldID), joinT.C(allowlist.AllowlistItemsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(allowlist.AllowlistItemsPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(allowlist.AllowlistItemsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + if err := query.prepareQuery(ctx); err != nil { + return err + } + qr := QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + return query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]any, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err + } + return append([]any{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []any) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*AllowList]struct{}{byID[outValue]: {}} + return assign(columns[1:], values[1:]) + } + nids[inValue][byID[outValue]] = struct{}{} + return nil + } + }) + }) + neighbors, err := withInterceptors[[]*AllowListItem](ctx, query, qr, query.inters) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "allowlist_items" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil +} + +func (alq *AllowListQuery) sqlCount(ctx context.Context) (int, error) { + _spec := alq.querySpec() + _spec.Node.Columns = alq.ctx.Fields + if len(alq.ctx.Fields) > 0 { + _spec.Unique = alq.ctx.Unique != nil && *alq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, alq.driver, _spec) +} + +func (alq *AllowListQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(allowlist.Table, allowlist.Columns, sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt)) + _spec.From = alq.sql + if unique := alq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if alq.path != nil { + _spec.Unique = true + } + if fields := alq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, allowlist.FieldID) + for i := range fields { + if fields[i] != allowlist.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := alq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := alq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := alq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := alq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (alq *AllowListQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(alq.driver.Dialect()) + t1 := builder.Table(allowlist.Table) + columns := alq.ctx.Fields + if len(columns) == 0 { + columns = allowlist.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if alq.sql != nil { + selector = alq.sql + selector.Select(selector.Columns(columns...)...) + } + if alq.ctx.Unique != nil && *alq.ctx.Unique { + selector.Distinct() + } + for _, p := range alq.predicates { + p(selector) + } + for _, p := range alq.order { + p(selector) + } + if offset := alq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := alq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// AllowListGroupBy is the group-by builder for AllowList entities. +type AllowListGroupBy struct { + selector + build *AllowListQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (algb *AllowListGroupBy) Aggregate(fns ...AggregateFunc) *AllowListGroupBy { + algb.fns = append(algb.fns, fns...) + return algb +} + +// Scan applies the selector query and scans the result into the given value. +func (algb *AllowListGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, algb.build.ctx, "GroupBy") + if err := algb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AllowListQuery, *AllowListGroupBy](ctx, algb.build, algb, algb.build.inters, v) +} + +func (algb *AllowListGroupBy) sqlScan(ctx context.Context, root *AllowListQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(algb.fns)) + for _, fn := range algb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*algb.flds)+len(algb.fns)) + for _, f := range *algb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*algb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := algb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// AllowListSelect is the builder for selecting fields of AllowList entities. +type AllowListSelect struct { + *AllowListQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (als *AllowListSelect) Aggregate(fns ...AggregateFunc) *AllowListSelect { + als.fns = append(als.fns, fns...) + return als +} + +// Scan applies the selector query and scans the result into the given value. +func (als *AllowListSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, als.ctx, "Select") + if err := als.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AllowListQuery, *AllowListSelect](ctx, als.AllowListQuery, als, als.inters, v) +} + +func (als *AllowListSelect) sqlScan(ctx context.Context, root *AllowListQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(als.fns)) + for _, fn := range als.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*als.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := als.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/pkg/database/ent/allowlist_update.go b/pkg/database/ent/allowlist_update.go new file mode 100644 index 00000000000..b2a6bc65c9d --- /dev/null +++ b/pkg/database/ent/allowlist_update.go @@ -0,0 +1,421 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// AllowListUpdate is the builder for updating AllowList entities. +type AllowListUpdate struct { + config + hooks []Hook + mutation *AllowListMutation +} + +// Where appends a list predicates to the AllowListUpdate builder. +func (alu *AllowListUpdate) Where(ps ...predicate.AllowList) *AllowListUpdate { + alu.mutation.Where(ps...) + return alu +} + +// SetUpdatedAt sets the "updated_at" field. +func (alu *AllowListUpdate) SetUpdatedAt(t time.Time) *AllowListUpdate { + alu.mutation.SetUpdatedAt(t) + return alu +} + +// SetFromConsole sets the "from_console" field. +func (alu *AllowListUpdate) SetFromConsole(b bool) *AllowListUpdate { + alu.mutation.SetFromConsole(b) + return alu +} + +// SetNillableFromConsole sets the "from_console" field if the given value is not nil. +func (alu *AllowListUpdate) SetNillableFromConsole(b *bool) *AllowListUpdate { + if b != nil { + alu.SetFromConsole(*b) + } + return alu +} + +// AddAllowlistItemIDs adds the "allowlist_items" edge to the AllowListItem entity by IDs. +func (alu *AllowListUpdate) AddAllowlistItemIDs(ids ...int) *AllowListUpdate { + alu.mutation.AddAllowlistItemIDs(ids...) + return alu +} + +// AddAllowlistItems adds the "allowlist_items" edges to the AllowListItem entity. +func (alu *AllowListUpdate) AddAllowlistItems(a ...*AllowListItem) *AllowListUpdate { + ids := make([]int, len(a)) + for i := range a { + ids[i] = a[i].ID + } + return alu.AddAllowlistItemIDs(ids...) +} + +// Mutation returns the AllowListMutation object of the builder. +func (alu *AllowListUpdate) Mutation() *AllowListMutation { + return alu.mutation +} + +// ClearAllowlistItems clears all "allowlist_items" edges to the AllowListItem entity. +func (alu *AllowListUpdate) ClearAllowlistItems() *AllowListUpdate { + alu.mutation.ClearAllowlistItems() + return alu +} + +// RemoveAllowlistItemIDs removes the "allowlist_items" edge to AllowListItem entities by IDs. +func (alu *AllowListUpdate) RemoveAllowlistItemIDs(ids ...int) *AllowListUpdate { + alu.mutation.RemoveAllowlistItemIDs(ids...) + return alu +} + +// RemoveAllowlistItems removes "allowlist_items" edges to AllowListItem entities. +func (alu *AllowListUpdate) RemoveAllowlistItems(a ...*AllowListItem) *AllowListUpdate { + ids := make([]int, len(a)) + for i := range a { + ids[i] = a[i].ID + } + return alu.RemoveAllowlistItemIDs(ids...) +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (alu *AllowListUpdate) Save(ctx context.Context) (int, error) { + alu.defaults() + return withHooks(ctx, alu.sqlSave, alu.mutation, alu.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (alu *AllowListUpdate) SaveX(ctx context.Context) int { + affected, err := alu.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (alu *AllowListUpdate) Exec(ctx context.Context) error { + _, err := alu.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (alu *AllowListUpdate) ExecX(ctx context.Context) { + if err := alu.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (alu *AllowListUpdate) defaults() { + if _, ok := alu.mutation.UpdatedAt(); !ok { + v := allowlist.UpdateDefaultUpdatedAt() + alu.mutation.SetUpdatedAt(v) + } +} + +func (alu *AllowListUpdate) sqlSave(ctx context.Context) (n int, err error) { + _spec := sqlgraph.NewUpdateSpec(allowlist.Table, allowlist.Columns, sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt)) + if ps := alu.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := alu.mutation.UpdatedAt(); ok { + _spec.SetField(allowlist.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := alu.mutation.FromConsole(); ok { + _spec.SetField(allowlist.FieldFromConsole, field.TypeBool, value) + } + if alu.mutation.DescriptionCleared() { + _spec.ClearField(allowlist.FieldDescription, field.TypeString) + } + if alu.mutation.AllowlistIDCleared() { + _spec.ClearField(allowlist.FieldAllowlistID, field.TypeString) + } + if alu.mutation.AllowlistItemsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: allowlist.AllowlistItemsTable, + Columns: allowlist.AllowlistItemsPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := alu.mutation.RemovedAllowlistItemsIDs(); len(nodes) > 0 && !alu.mutation.AllowlistItemsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: allowlist.AllowlistItemsTable, + Columns: allowlist.AllowlistItemsPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := alu.mutation.AllowlistItemsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: allowlist.AllowlistItemsTable, + Columns: allowlist.AllowlistItemsPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if n, err = sqlgraph.UpdateNodes(ctx, alu.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{allowlist.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + alu.mutation.done = true + return n, nil +} + +// AllowListUpdateOne is the builder for updating a single AllowList entity. +type AllowListUpdateOne struct { + config + fields []string + hooks []Hook + mutation *AllowListMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (aluo *AllowListUpdateOne) SetUpdatedAt(t time.Time) *AllowListUpdateOne { + aluo.mutation.SetUpdatedAt(t) + return aluo +} + +// SetFromConsole sets the "from_console" field. +func (aluo *AllowListUpdateOne) SetFromConsole(b bool) *AllowListUpdateOne { + aluo.mutation.SetFromConsole(b) + return aluo +} + +// SetNillableFromConsole sets the "from_console" field if the given value is not nil. +func (aluo *AllowListUpdateOne) SetNillableFromConsole(b *bool) *AllowListUpdateOne { + if b != nil { + aluo.SetFromConsole(*b) + } + return aluo +} + +// AddAllowlistItemIDs adds the "allowlist_items" edge to the AllowListItem entity by IDs. +func (aluo *AllowListUpdateOne) AddAllowlistItemIDs(ids ...int) *AllowListUpdateOne { + aluo.mutation.AddAllowlistItemIDs(ids...) + return aluo +} + +// AddAllowlistItems adds the "allowlist_items" edges to the AllowListItem entity. +func (aluo *AllowListUpdateOne) AddAllowlistItems(a ...*AllowListItem) *AllowListUpdateOne { + ids := make([]int, len(a)) + for i := range a { + ids[i] = a[i].ID + } + return aluo.AddAllowlistItemIDs(ids...) +} + +// Mutation returns the AllowListMutation object of the builder. +func (aluo *AllowListUpdateOne) Mutation() *AllowListMutation { + return aluo.mutation +} + +// ClearAllowlistItems clears all "allowlist_items" edges to the AllowListItem entity. +func (aluo *AllowListUpdateOne) ClearAllowlistItems() *AllowListUpdateOne { + aluo.mutation.ClearAllowlistItems() + return aluo +} + +// RemoveAllowlistItemIDs removes the "allowlist_items" edge to AllowListItem entities by IDs. +func (aluo *AllowListUpdateOne) RemoveAllowlistItemIDs(ids ...int) *AllowListUpdateOne { + aluo.mutation.RemoveAllowlistItemIDs(ids...) + return aluo +} + +// RemoveAllowlistItems removes "allowlist_items" edges to AllowListItem entities. +func (aluo *AllowListUpdateOne) RemoveAllowlistItems(a ...*AllowListItem) *AllowListUpdateOne { + ids := make([]int, len(a)) + for i := range a { + ids[i] = a[i].ID + } + return aluo.RemoveAllowlistItemIDs(ids...) +} + +// Where appends a list predicates to the AllowListUpdate builder. +func (aluo *AllowListUpdateOne) Where(ps ...predicate.AllowList) *AllowListUpdateOne { + aluo.mutation.Where(ps...) + return aluo +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (aluo *AllowListUpdateOne) Select(field string, fields ...string) *AllowListUpdateOne { + aluo.fields = append([]string{field}, fields...) + return aluo +} + +// Save executes the query and returns the updated AllowList entity. +func (aluo *AllowListUpdateOne) Save(ctx context.Context) (*AllowList, error) { + aluo.defaults() + return withHooks(ctx, aluo.sqlSave, aluo.mutation, aluo.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (aluo *AllowListUpdateOne) SaveX(ctx context.Context) *AllowList { + node, err := aluo.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (aluo *AllowListUpdateOne) Exec(ctx context.Context) error { + _, err := aluo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (aluo *AllowListUpdateOne) ExecX(ctx context.Context) { + if err := aluo.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (aluo *AllowListUpdateOne) defaults() { + if _, ok := aluo.mutation.UpdatedAt(); !ok { + v := allowlist.UpdateDefaultUpdatedAt() + aluo.mutation.SetUpdatedAt(v) + } +} + +func (aluo *AllowListUpdateOne) sqlSave(ctx context.Context) (_node *AllowList, err error) { + _spec := sqlgraph.NewUpdateSpec(allowlist.Table, allowlist.Columns, sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt)) + id, ok := aluo.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AllowList.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := aluo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, allowlist.FieldID) + for _, f := range fields { + if !allowlist.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != allowlist.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := aluo.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := aluo.mutation.UpdatedAt(); ok { + _spec.SetField(allowlist.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := aluo.mutation.FromConsole(); ok { + _spec.SetField(allowlist.FieldFromConsole, field.TypeBool, value) + } + if aluo.mutation.DescriptionCleared() { + _spec.ClearField(allowlist.FieldDescription, field.TypeString) + } + if aluo.mutation.AllowlistIDCleared() { + _spec.ClearField(allowlist.FieldAllowlistID, field.TypeString) + } + if aluo.mutation.AllowlistItemsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: allowlist.AllowlistItemsTable, + Columns: allowlist.AllowlistItemsPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := aluo.mutation.RemovedAllowlistItemsIDs(); len(nodes) > 0 && !aluo.mutation.AllowlistItemsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: allowlist.AllowlistItemsTable, + Columns: allowlist.AllowlistItemsPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := aluo.mutation.AllowlistItemsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: allowlist.AllowlistItemsTable, + Columns: allowlist.AllowlistItemsPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &AllowList{config: aluo.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, aluo.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{allowlist.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + aluo.mutation.done = true + return _node, nil +} diff --git a/pkg/database/ent/allowlistitem.go b/pkg/database/ent/allowlistitem.go new file mode 100644 index 00000000000..2c0d997b1d7 --- /dev/null +++ b/pkg/database/ent/allowlistitem.go @@ -0,0 +1,231 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem" +) + +// AllowListItem is the model entity for the AllowListItem schema. +type AllowListItem struct { + config `json:"-"` + // ID of the ent. + ID int `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // ExpiresAt holds the value of the "expires_at" field. + ExpiresAt time.Time `json:"expires_at,omitempty"` + // Comment holds the value of the "comment" field. + Comment string `json:"comment,omitempty"` + // Value holds the value of the "value" field. + Value string `json:"value,omitempty"` + // StartIP holds the value of the "start_ip" field. + StartIP int64 `json:"start_ip,omitempty"` + // EndIP holds the value of the "end_ip" field. + EndIP int64 `json:"end_ip,omitempty"` + // StartSuffix holds the value of the "start_suffix" field. + StartSuffix int64 `json:"start_suffix,omitempty"` + // EndSuffix holds the value of the "end_suffix" field. + EndSuffix int64 `json:"end_suffix,omitempty"` + // IPSize holds the value of the "ip_size" field. + IPSize int64 `json:"ip_size,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the AllowListItemQuery when eager-loading is set. + Edges AllowListItemEdges `json:"edges"` + selectValues sql.SelectValues +} + +// AllowListItemEdges holds the relations/edges for other nodes in the graph. +type AllowListItemEdges struct { + // Allowlist holds the value of the allowlist edge. + Allowlist []*AllowList `json:"allowlist,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [1]bool +} + +// AllowlistOrErr returns the Allowlist value or an error if the edge +// was not loaded in eager-loading. +func (e AllowListItemEdges) AllowlistOrErr() ([]*AllowList, error) { + if e.loadedTypes[0] { + return e.Allowlist, nil + } + return nil, &NotLoadedError{edge: "allowlist"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*AllowListItem) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case allowlistitem.FieldID, allowlistitem.FieldStartIP, allowlistitem.FieldEndIP, allowlistitem.FieldStartSuffix, allowlistitem.FieldEndSuffix, allowlistitem.FieldIPSize: + values[i] = new(sql.NullInt64) + case allowlistitem.FieldComment, allowlistitem.FieldValue: + values[i] = new(sql.NullString) + case allowlistitem.FieldCreatedAt, allowlistitem.FieldUpdatedAt, allowlistitem.FieldExpiresAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the AllowListItem fields. +func (ali *AllowListItem) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case allowlistitem.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + ali.ID = int(value.Int64) + case allowlistitem.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + ali.CreatedAt = value.Time + } + case allowlistitem.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + ali.UpdatedAt = value.Time + } + case allowlistitem.FieldExpiresAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field expires_at", values[i]) + } else if value.Valid { + ali.ExpiresAt = value.Time + } + case allowlistitem.FieldComment: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field comment", values[i]) + } else if value.Valid { + ali.Comment = value.String + } + case allowlistitem.FieldValue: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field value", values[i]) + } else if value.Valid { + ali.Value = value.String + } + case allowlistitem.FieldStartIP: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field start_ip", values[i]) + } else if value.Valid { + ali.StartIP = value.Int64 + } + case allowlistitem.FieldEndIP: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field end_ip", values[i]) + } else if value.Valid { + ali.EndIP = value.Int64 + } + case allowlistitem.FieldStartSuffix: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field start_suffix", values[i]) + } else if value.Valid { + ali.StartSuffix = value.Int64 + } + case allowlistitem.FieldEndSuffix: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field end_suffix", values[i]) + } else if value.Valid { + ali.EndSuffix = value.Int64 + } + case allowlistitem.FieldIPSize: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field ip_size", values[i]) + } else if value.Valid { + ali.IPSize = value.Int64 + } + default: + ali.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// GetValue returns the ent.Value that was dynamically selected and assigned to the AllowListItem. +// This includes values selected through modifiers, order, etc. +func (ali *AllowListItem) GetValue(name string) (ent.Value, error) { + return ali.selectValues.Get(name) +} + +// QueryAllowlist queries the "allowlist" edge of the AllowListItem entity. +func (ali *AllowListItem) QueryAllowlist() *AllowListQuery { + return NewAllowListItemClient(ali.config).QueryAllowlist(ali) +} + +// Update returns a builder for updating this AllowListItem. +// Note that you need to call AllowListItem.Unwrap() before calling this method if this AllowListItem +// was returned from a transaction, and the transaction was committed or rolled back. +func (ali *AllowListItem) Update() *AllowListItemUpdateOne { + return NewAllowListItemClient(ali.config).UpdateOne(ali) +} + +// Unwrap unwraps the AllowListItem entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (ali *AllowListItem) Unwrap() *AllowListItem { + _tx, ok := ali.config.driver.(*txDriver) + if !ok { + panic("ent: AllowListItem is not a transactional entity") + } + ali.config.driver = _tx.drv + return ali +} + +// String implements the fmt.Stringer. +func (ali *AllowListItem) String() string { + var builder strings.Builder + builder.WriteString("AllowListItem(") + builder.WriteString(fmt.Sprintf("id=%v, ", ali.ID)) + builder.WriteString("created_at=") + builder.WriteString(ali.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(ali.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("expires_at=") + builder.WriteString(ali.ExpiresAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("comment=") + builder.WriteString(ali.Comment) + builder.WriteString(", ") + builder.WriteString("value=") + builder.WriteString(ali.Value) + builder.WriteString(", ") + builder.WriteString("start_ip=") + builder.WriteString(fmt.Sprintf("%v", ali.StartIP)) + builder.WriteString(", ") + builder.WriteString("end_ip=") + builder.WriteString(fmt.Sprintf("%v", ali.EndIP)) + builder.WriteString(", ") + builder.WriteString("start_suffix=") + builder.WriteString(fmt.Sprintf("%v", ali.StartSuffix)) + builder.WriteString(", ") + builder.WriteString("end_suffix=") + builder.WriteString(fmt.Sprintf("%v", ali.EndSuffix)) + builder.WriteString(", ") + builder.WriteString("ip_size=") + builder.WriteString(fmt.Sprintf("%v", ali.IPSize)) + builder.WriteByte(')') + return builder.String() +} + +// AllowListItems is a parsable slice of AllowListItem. +type AllowListItems []*AllowListItem diff --git a/pkg/database/ent/allowlistitem/allowlistitem.go b/pkg/database/ent/allowlistitem/allowlistitem.go new file mode 100644 index 00000000000..5474763eac3 --- /dev/null +++ b/pkg/database/ent/allowlistitem/allowlistitem.go @@ -0,0 +1,165 @@ +// Code generated by ent, DO NOT EDIT. + +package allowlistitem + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the allowlistitem type in the database. + Label = "allow_list_item" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldExpiresAt holds the string denoting the expires_at field in the database. + FieldExpiresAt = "expires_at" + // FieldComment holds the string denoting the comment field in the database. + FieldComment = "comment" + // FieldValue holds the string denoting the value field in the database. + FieldValue = "value" + // FieldStartIP holds the string denoting the start_ip field in the database. + FieldStartIP = "start_ip" + // FieldEndIP holds the string denoting the end_ip field in the database. + FieldEndIP = "end_ip" + // FieldStartSuffix holds the string denoting the start_suffix field in the database. + FieldStartSuffix = "start_suffix" + // FieldEndSuffix holds the string denoting the end_suffix field in the database. + FieldEndSuffix = "end_suffix" + // FieldIPSize holds the string denoting the ip_size field in the database. + FieldIPSize = "ip_size" + // EdgeAllowlist holds the string denoting the allowlist edge name in mutations. + EdgeAllowlist = "allowlist" + // Table holds the table name of the allowlistitem in the database. + Table = "allow_list_items" + // AllowlistTable is the table that holds the allowlist relation/edge. The primary key declared below. + AllowlistTable = "allow_list_allowlist_items" + // AllowlistInverseTable is the table name for the AllowList entity. + // It exists in this package in order to avoid circular dependency with the "allowlist" package. + AllowlistInverseTable = "allow_lists" +) + +// Columns holds all SQL columns for allowlistitem fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldExpiresAt, + FieldComment, + FieldValue, + FieldStartIP, + FieldEndIP, + FieldStartSuffix, + FieldEndSuffix, + FieldIPSize, +} + +var ( + // AllowlistPrimaryKey and AllowlistColumn2 are the table columns denoting the + // primary key for the allowlist relation (M2M). + AllowlistPrimaryKey = []string{"allow_list_id", "allow_list_item_id"} +) + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time +) + +// OrderOption defines the ordering options for the AllowListItem queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByExpiresAt orders the results by the expires_at field. +func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldExpiresAt, opts...).ToFunc() +} + +// ByComment orders the results by the comment field. +func ByComment(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldComment, opts...).ToFunc() +} + +// ByValue orders the results by the value field. +func ByValue(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldValue, opts...).ToFunc() +} + +// ByStartIP orders the results by the start_ip field. +func ByStartIP(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStartIP, opts...).ToFunc() +} + +// ByEndIP orders the results by the end_ip field. +func ByEndIP(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEndIP, opts...).ToFunc() +} + +// ByStartSuffix orders the results by the start_suffix field. +func ByStartSuffix(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStartSuffix, opts...).ToFunc() +} + +// ByEndSuffix orders the results by the end_suffix field. +func ByEndSuffix(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEndSuffix, opts...).ToFunc() +} + +// ByIPSize orders the results by the ip_size field. +func ByIPSize(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIPSize, opts...).ToFunc() +} + +// ByAllowlistCount orders the results by allowlist count. +func ByAllowlistCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newAllowlistStep(), opts...) + } +} + +// ByAllowlist orders the results by allowlist terms. +func ByAllowlist(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAllowlistStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newAllowlistStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AllowlistInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2M, true, AllowlistTable, AllowlistPrimaryKey...), + ) +} diff --git a/pkg/database/ent/allowlistitem/where.go b/pkg/database/ent/allowlistitem/where.go new file mode 100644 index 00000000000..32a10d77c22 --- /dev/null +++ b/pkg/database/ent/allowlistitem/where.go @@ -0,0 +1,664 @@ +// Code generated by ent, DO NOT EDIT. + +package allowlistitem + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ. +func ExpiresAt(v time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEQ(FieldExpiresAt, v)) +} + +// Comment applies equality check predicate on the "comment" field. It's identical to CommentEQ. +func Comment(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEQ(FieldComment, v)) +} + +// Value applies equality check predicate on the "value" field. It's identical to ValueEQ. +func Value(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEQ(FieldValue, v)) +} + +// StartIP applies equality check predicate on the "start_ip" field. It's identical to StartIPEQ. +func StartIP(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEQ(FieldStartIP, v)) +} + +// EndIP applies equality check predicate on the "end_ip" field. It's identical to EndIPEQ. +func EndIP(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEQ(FieldEndIP, v)) +} + +// StartSuffix applies equality check predicate on the "start_suffix" field. It's identical to StartSuffixEQ. +func StartSuffix(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEQ(FieldStartSuffix, v)) +} + +// EndSuffix applies equality check predicate on the "end_suffix" field. It's identical to EndSuffixEQ. +func EndSuffix(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEQ(FieldEndSuffix, v)) +} + +// IPSize applies equality check predicate on the "ip_size" field. It's identical to IPSizeEQ. +func IPSize(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEQ(FieldIPSize, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// ExpiresAtEQ applies the EQ predicate on the "expires_at" field. +func ExpiresAtEQ(v time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEQ(FieldExpiresAt, v)) +} + +// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field. +func ExpiresAtNEQ(v time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNEQ(FieldExpiresAt, v)) +} + +// ExpiresAtIn applies the In predicate on the "expires_at" field. +func ExpiresAtIn(vs ...time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field. +func ExpiresAtNotIn(vs ...time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNotIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtGT applies the GT predicate on the "expires_at" field. +func ExpiresAtGT(v time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldGT(FieldExpiresAt, v)) +} + +// ExpiresAtGTE applies the GTE predicate on the "expires_at" field. +func ExpiresAtGTE(v time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldGTE(FieldExpiresAt, v)) +} + +// ExpiresAtLT applies the LT predicate on the "expires_at" field. +func ExpiresAtLT(v time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldLT(FieldExpiresAt, v)) +} + +// ExpiresAtLTE applies the LTE predicate on the "expires_at" field. +func ExpiresAtLTE(v time.Time) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldLTE(FieldExpiresAt, v)) +} + +// ExpiresAtIsNil applies the IsNil predicate on the "expires_at" field. +func ExpiresAtIsNil() predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldIsNull(FieldExpiresAt)) +} + +// ExpiresAtNotNil applies the NotNil predicate on the "expires_at" field. +func ExpiresAtNotNil() predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNotNull(FieldExpiresAt)) +} + +// CommentEQ applies the EQ predicate on the "comment" field. +func CommentEQ(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEQ(FieldComment, v)) +} + +// CommentNEQ applies the NEQ predicate on the "comment" field. +func CommentNEQ(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNEQ(FieldComment, v)) +} + +// CommentIn applies the In predicate on the "comment" field. +func CommentIn(vs ...string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldIn(FieldComment, vs...)) +} + +// CommentNotIn applies the NotIn predicate on the "comment" field. +func CommentNotIn(vs ...string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNotIn(FieldComment, vs...)) +} + +// CommentGT applies the GT predicate on the "comment" field. +func CommentGT(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldGT(FieldComment, v)) +} + +// CommentGTE applies the GTE predicate on the "comment" field. +func CommentGTE(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldGTE(FieldComment, v)) +} + +// CommentLT applies the LT predicate on the "comment" field. +func CommentLT(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldLT(FieldComment, v)) +} + +// CommentLTE applies the LTE predicate on the "comment" field. +func CommentLTE(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldLTE(FieldComment, v)) +} + +// CommentContains applies the Contains predicate on the "comment" field. +func CommentContains(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldContains(FieldComment, v)) +} + +// CommentHasPrefix applies the HasPrefix predicate on the "comment" field. +func CommentHasPrefix(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldHasPrefix(FieldComment, v)) +} + +// CommentHasSuffix applies the HasSuffix predicate on the "comment" field. +func CommentHasSuffix(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldHasSuffix(FieldComment, v)) +} + +// CommentIsNil applies the IsNil predicate on the "comment" field. +func CommentIsNil() predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldIsNull(FieldComment)) +} + +// CommentNotNil applies the NotNil predicate on the "comment" field. +func CommentNotNil() predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNotNull(FieldComment)) +} + +// CommentEqualFold applies the EqualFold predicate on the "comment" field. +func CommentEqualFold(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEqualFold(FieldComment, v)) +} + +// CommentContainsFold applies the ContainsFold predicate on the "comment" field. +func CommentContainsFold(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldContainsFold(FieldComment, v)) +} + +// ValueEQ applies the EQ predicate on the "value" field. +func ValueEQ(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEQ(FieldValue, v)) +} + +// ValueNEQ applies the NEQ predicate on the "value" field. +func ValueNEQ(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNEQ(FieldValue, v)) +} + +// ValueIn applies the In predicate on the "value" field. +func ValueIn(vs ...string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldIn(FieldValue, vs...)) +} + +// ValueNotIn applies the NotIn predicate on the "value" field. +func ValueNotIn(vs ...string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNotIn(FieldValue, vs...)) +} + +// ValueGT applies the GT predicate on the "value" field. +func ValueGT(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldGT(FieldValue, v)) +} + +// ValueGTE applies the GTE predicate on the "value" field. +func ValueGTE(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldGTE(FieldValue, v)) +} + +// ValueLT applies the LT predicate on the "value" field. +func ValueLT(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldLT(FieldValue, v)) +} + +// ValueLTE applies the LTE predicate on the "value" field. +func ValueLTE(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldLTE(FieldValue, v)) +} + +// ValueContains applies the Contains predicate on the "value" field. +func ValueContains(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldContains(FieldValue, v)) +} + +// ValueHasPrefix applies the HasPrefix predicate on the "value" field. +func ValueHasPrefix(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldHasPrefix(FieldValue, v)) +} + +// ValueHasSuffix applies the HasSuffix predicate on the "value" field. +func ValueHasSuffix(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldHasSuffix(FieldValue, v)) +} + +// ValueEqualFold applies the EqualFold predicate on the "value" field. +func ValueEqualFold(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEqualFold(FieldValue, v)) +} + +// ValueContainsFold applies the ContainsFold predicate on the "value" field. +func ValueContainsFold(v string) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldContainsFold(FieldValue, v)) +} + +// StartIPEQ applies the EQ predicate on the "start_ip" field. +func StartIPEQ(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEQ(FieldStartIP, v)) +} + +// StartIPNEQ applies the NEQ predicate on the "start_ip" field. +func StartIPNEQ(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNEQ(FieldStartIP, v)) +} + +// StartIPIn applies the In predicate on the "start_ip" field. +func StartIPIn(vs ...int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldIn(FieldStartIP, vs...)) +} + +// StartIPNotIn applies the NotIn predicate on the "start_ip" field. +func StartIPNotIn(vs ...int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNotIn(FieldStartIP, vs...)) +} + +// StartIPGT applies the GT predicate on the "start_ip" field. +func StartIPGT(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldGT(FieldStartIP, v)) +} + +// StartIPGTE applies the GTE predicate on the "start_ip" field. +func StartIPGTE(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldGTE(FieldStartIP, v)) +} + +// StartIPLT applies the LT predicate on the "start_ip" field. +func StartIPLT(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldLT(FieldStartIP, v)) +} + +// StartIPLTE applies the LTE predicate on the "start_ip" field. +func StartIPLTE(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldLTE(FieldStartIP, v)) +} + +// StartIPIsNil applies the IsNil predicate on the "start_ip" field. +func StartIPIsNil() predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldIsNull(FieldStartIP)) +} + +// StartIPNotNil applies the NotNil predicate on the "start_ip" field. +func StartIPNotNil() predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNotNull(FieldStartIP)) +} + +// EndIPEQ applies the EQ predicate on the "end_ip" field. +func EndIPEQ(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEQ(FieldEndIP, v)) +} + +// EndIPNEQ applies the NEQ predicate on the "end_ip" field. +func EndIPNEQ(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNEQ(FieldEndIP, v)) +} + +// EndIPIn applies the In predicate on the "end_ip" field. +func EndIPIn(vs ...int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldIn(FieldEndIP, vs...)) +} + +// EndIPNotIn applies the NotIn predicate on the "end_ip" field. +func EndIPNotIn(vs ...int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNotIn(FieldEndIP, vs...)) +} + +// EndIPGT applies the GT predicate on the "end_ip" field. +func EndIPGT(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldGT(FieldEndIP, v)) +} + +// EndIPGTE applies the GTE predicate on the "end_ip" field. +func EndIPGTE(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldGTE(FieldEndIP, v)) +} + +// EndIPLT applies the LT predicate on the "end_ip" field. +func EndIPLT(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldLT(FieldEndIP, v)) +} + +// EndIPLTE applies the LTE predicate on the "end_ip" field. +func EndIPLTE(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldLTE(FieldEndIP, v)) +} + +// EndIPIsNil applies the IsNil predicate on the "end_ip" field. +func EndIPIsNil() predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldIsNull(FieldEndIP)) +} + +// EndIPNotNil applies the NotNil predicate on the "end_ip" field. +func EndIPNotNil() predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNotNull(FieldEndIP)) +} + +// StartSuffixEQ applies the EQ predicate on the "start_suffix" field. +func StartSuffixEQ(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEQ(FieldStartSuffix, v)) +} + +// StartSuffixNEQ applies the NEQ predicate on the "start_suffix" field. +func StartSuffixNEQ(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNEQ(FieldStartSuffix, v)) +} + +// StartSuffixIn applies the In predicate on the "start_suffix" field. +func StartSuffixIn(vs ...int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldIn(FieldStartSuffix, vs...)) +} + +// StartSuffixNotIn applies the NotIn predicate on the "start_suffix" field. +func StartSuffixNotIn(vs ...int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNotIn(FieldStartSuffix, vs...)) +} + +// StartSuffixGT applies the GT predicate on the "start_suffix" field. +func StartSuffixGT(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldGT(FieldStartSuffix, v)) +} + +// StartSuffixGTE applies the GTE predicate on the "start_suffix" field. +func StartSuffixGTE(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldGTE(FieldStartSuffix, v)) +} + +// StartSuffixLT applies the LT predicate on the "start_suffix" field. +func StartSuffixLT(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldLT(FieldStartSuffix, v)) +} + +// StartSuffixLTE applies the LTE predicate on the "start_suffix" field. +func StartSuffixLTE(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldLTE(FieldStartSuffix, v)) +} + +// StartSuffixIsNil applies the IsNil predicate on the "start_suffix" field. +func StartSuffixIsNil() predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldIsNull(FieldStartSuffix)) +} + +// StartSuffixNotNil applies the NotNil predicate on the "start_suffix" field. +func StartSuffixNotNil() predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNotNull(FieldStartSuffix)) +} + +// EndSuffixEQ applies the EQ predicate on the "end_suffix" field. +func EndSuffixEQ(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEQ(FieldEndSuffix, v)) +} + +// EndSuffixNEQ applies the NEQ predicate on the "end_suffix" field. +func EndSuffixNEQ(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNEQ(FieldEndSuffix, v)) +} + +// EndSuffixIn applies the In predicate on the "end_suffix" field. +func EndSuffixIn(vs ...int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldIn(FieldEndSuffix, vs...)) +} + +// EndSuffixNotIn applies the NotIn predicate on the "end_suffix" field. +func EndSuffixNotIn(vs ...int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNotIn(FieldEndSuffix, vs...)) +} + +// EndSuffixGT applies the GT predicate on the "end_suffix" field. +func EndSuffixGT(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldGT(FieldEndSuffix, v)) +} + +// EndSuffixGTE applies the GTE predicate on the "end_suffix" field. +func EndSuffixGTE(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldGTE(FieldEndSuffix, v)) +} + +// EndSuffixLT applies the LT predicate on the "end_suffix" field. +func EndSuffixLT(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldLT(FieldEndSuffix, v)) +} + +// EndSuffixLTE applies the LTE predicate on the "end_suffix" field. +func EndSuffixLTE(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldLTE(FieldEndSuffix, v)) +} + +// EndSuffixIsNil applies the IsNil predicate on the "end_suffix" field. +func EndSuffixIsNil() predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldIsNull(FieldEndSuffix)) +} + +// EndSuffixNotNil applies the NotNil predicate on the "end_suffix" field. +func EndSuffixNotNil() predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNotNull(FieldEndSuffix)) +} + +// IPSizeEQ applies the EQ predicate on the "ip_size" field. +func IPSizeEQ(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldEQ(FieldIPSize, v)) +} + +// IPSizeNEQ applies the NEQ predicate on the "ip_size" field. +func IPSizeNEQ(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNEQ(FieldIPSize, v)) +} + +// IPSizeIn applies the In predicate on the "ip_size" field. +func IPSizeIn(vs ...int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldIn(FieldIPSize, vs...)) +} + +// IPSizeNotIn applies the NotIn predicate on the "ip_size" field. +func IPSizeNotIn(vs ...int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNotIn(FieldIPSize, vs...)) +} + +// IPSizeGT applies the GT predicate on the "ip_size" field. +func IPSizeGT(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldGT(FieldIPSize, v)) +} + +// IPSizeGTE applies the GTE predicate on the "ip_size" field. +func IPSizeGTE(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldGTE(FieldIPSize, v)) +} + +// IPSizeLT applies the LT predicate on the "ip_size" field. +func IPSizeLT(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldLT(FieldIPSize, v)) +} + +// IPSizeLTE applies the LTE predicate on the "ip_size" field. +func IPSizeLTE(v int64) predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldLTE(FieldIPSize, v)) +} + +// IPSizeIsNil applies the IsNil predicate on the "ip_size" field. +func IPSizeIsNil() predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldIsNull(FieldIPSize)) +} + +// IPSizeNotNil applies the NotNil predicate on the "ip_size" field. +func IPSizeNotNil() predicate.AllowListItem { + return predicate.AllowListItem(sql.FieldNotNull(FieldIPSize)) +} + +// HasAllowlist applies the HasEdge predicate on the "allowlist" edge. +func HasAllowlist() predicate.AllowListItem { + return predicate.AllowListItem(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2M, true, AllowlistTable, AllowlistPrimaryKey...), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAllowlistWith applies the HasEdge predicate on the "allowlist" edge with a given conditions (other predicates). +func HasAllowlistWith(preds ...predicate.AllowList) predicate.AllowListItem { + return predicate.AllowListItem(func(s *sql.Selector) { + step := newAllowlistStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.AllowListItem) predicate.AllowListItem { + return predicate.AllowListItem(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.AllowListItem) predicate.AllowListItem { + return predicate.AllowListItem(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.AllowListItem) predicate.AllowListItem { + return predicate.AllowListItem(sql.NotPredicates(p)) +} diff --git a/pkg/database/ent/allowlistitem_create.go b/pkg/database/ent/allowlistitem_create.go new file mode 100644 index 00000000000..502cec11db7 --- /dev/null +++ b/pkg/database/ent/allowlistitem_create.go @@ -0,0 +1,398 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem" +) + +// AllowListItemCreate is the builder for creating a AllowListItem entity. +type AllowListItemCreate struct { + config + mutation *AllowListItemMutation + hooks []Hook +} + +// SetCreatedAt sets the "created_at" field. +func (alic *AllowListItemCreate) SetCreatedAt(t time.Time) *AllowListItemCreate { + alic.mutation.SetCreatedAt(t) + return alic +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (alic *AllowListItemCreate) SetNillableCreatedAt(t *time.Time) *AllowListItemCreate { + if t != nil { + alic.SetCreatedAt(*t) + } + return alic +} + +// SetUpdatedAt sets the "updated_at" field. +func (alic *AllowListItemCreate) SetUpdatedAt(t time.Time) *AllowListItemCreate { + alic.mutation.SetUpdatedAt(t) + return alic +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (alic *AllowListItemCreate) SetNillableUpdatedAt(t *time.Time) *AllowListItemCreate { + if t != nil { + alic.SetUpdatedAt(*t) + } + return alic +} + +// SetExpiresAt sets the "expires_at" field. +func (alic *AllowListItemCreate) SetExpiresAt(t time.Time) *AllowListItemCreate { + alic.mutation.SetExpiresAt(t) + return alic +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (alic *AllowListItemCreate) SetNillableExpiresAt(t *time.Time) *AllowListItemCreate { + if t != nil { + alic.SetExpiresAt(*t) + } + return alic +} + +// SetComment sets the "comment" field. +func (alic *AllowListItemCreate) SetComment(s string) *AllowListItemCreate { + alic.mutation.SetComment(s) + return alic +} + +// SetNillableComment sets the "comment" field if the given value is not nil. +func (alic *AllowListItemCreate) SetNillableComment(s *string) *AllowListItemCreate { + if s != nil { + alic.SetComment(*s) + } + return alic +} + +// SetValue sets the "value" field. +func (alic *AllowListItemCreate) SetValue(s string) *AllowListItemCreate { + alic.mutation.SetValue(s) + return alic +} + +// SetStartIP sets the "start_ip" field. +func (alic *AllowListItemCreate) SetStartIP(i int64) *AllowListItemCreate { + alic.mutation.SetStartIP(i) + return alic +} + +// SetNillableStartIP sets the "start_ip" field if the given value is not nil. +func (alic *AllowListItemCreate) SetNillableStartIP(i *int64) *AllowListItemCreate { + if i != nil { + alic.SetStartIP(*i) + } + return alic +} + +// SetEndIP sets the "end_ip" field. +func (alic *AllowListItemCreate) SetEndIP(i int64) *AllowListItemCreate { + alic.mutation.SetEndIP(i) + return alic +} + +// SetNillableEndIP sets the "end_ip" field if the given value is not nil. +func (alic *AllowListItemCreate) SetNillableEndIP(i *int64) *AllowListItemCreate { + if i != nil { + alic.SetEndIP(*i) + } + return alic +} + +// SetStartSuffix sets the "start_suffix" field. +func (alic *AllowListItemCreate) SetStartSuffix(i int64) *AllowListItemCreate { + alic.mutation.SetStartSuffix(i) + return alic +} + +// SetNillableStartSuffix sets the "start_suffix" field if the given value is not nil. +func (alic *AllowListItemCreate) SetNillableStartSuffix(i *int64) *AllowListItemCreate { + if i != nil { + alic.SetStartSuffix(*i) + } + return alic +} + +// SetEndSuffix sets the "end_suffix" field. +func (alic *AllowListItemCreate) SetEndSuffix(i int64) *AllowListItemCreate { + alic.mutation.SetEndSuffix(i) + return alic +} + +// SetNillableEndSuffix sets the "end_suffix" field if the given value is not nil. +func (alic *AllowListItemCreate) SetNillableEndSuffix(i *int64) *AllowListItemCreate { + if i != nil { + alic.SetEndSuffix(*i) + } + return alic +} + +// SetIPSize sets the "ip_size" field. +func (alic *AllowListItemCreate) SetIPSize(i int64) *AllowListItemCreate { + alic.mutation.SetIPSize(i) + return alic +} + +// SetNillableIPSize sets the "ip_size" field if the given value is not nil. +func (alic *AllowListItemCreate) SetNillableIPSize(i *int64) *AllowListItemCreate { + if i != nil { + alic.SetIPSize(*i) + } + return alic +} + +// AddAllowlistIDs adds the "allowlist" edge to the AllowList entity by IDs. +func (alic *AllowListItemCreate) AddAllowlistIDs(ids ...int) *AllowListItemCreate { + alic.mutation.AddAllowlistIDs(ids...) + return alic +} + +// AddAllowlist adds the "allowlist" edges to the AllowList entity. +func (alic *AllowListItemCreate) AddAllowlist(a ...*AllowList) *AllowListItemCreate { + ids := make([]int, len(a)) + for i := range a { + ids[i] = a[i].ID + } + return alic.AddAllowlistIDs(ids...) +} + +// Mutation returns the AllowListItemMutation object of the builder. +func (alic *AllowListItemCreate) Mutation() *AllowListItemMutation { + return alic.mutation +} + +// Save creates the AllowListItem in the database. +func (alic *AllowListItemCreate) Save(ctx context.Context) (*AllowListItem, error) { + alic.defaults() + return withHooks(ctx, alic.sqlSave, alic.mutation, alic.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (alic *AllowListItemCreate) SaveX(ctx context.Context) *AllowListItem { + v, err := alic.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (alic *AllowListItemCreate) Exec(ctx context.Context) error { + _, err := alic.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (alic *AllowListItemCreate) ExecX(ctx context.Context) { + if err := alic.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (alic *AllowListItemCreate) defaults() { + if _, ok := alic.mutation.CreatedAt(); !ok { + v := allowlistitem.DefaultCreatedAt() + alic.mutation.SetCreatedAt(v) + } + if _, ok := alic.mutation.UpdatedAt(); !ok { + v := allowlistitem.DefaultUpdatedAt() + alic.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (alic *AllowListItemCreate) check() error { + if _, ok := alic.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AllowListItem.created_at"`)} + } + if _, ok := alic.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AllowListItem.updated_at"`)} + } + if _, ok := alic.mutation.Value(); !ok { + return &ValidationError{Name: "value", err: errors.New(`ent: missing required field "AllowListItem.value"`)} + } + return nil +} + +func (alic *AllowListItemCreate) sqlSave(ctx context.Context) (*AllowListItem, error) { + if err := alic.check(); err != nil { + return nil, err + } + _node, _spec := alic.createSpec() + if err := sqlgraph.CreateNode(ctx, alic.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int(id) + alic.mutation.id = &_node.ID + alic.mutation.done = true + return _node, nil +} + +func (alic *AllowListItemCreate) createSpec() (*AllowListItem, *sqlgraph.CreateSpec) { + var ( + _node = &AllowListItem{config: alic.config} + _spec = sqlgraph.NewCreateSpec(allowlistitem.Table, sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt)) + ) + if value, ok := alic.mutation.CreatedAt(); ok { + _spec.SetField(allowlistitem.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := alic.mutation.UpdatedAt(); ok { + _spec.SetField(allowlistitem.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := alic.mutation.ExpiresAt(); ok { + _spec.SetField(allowlistitem.FieldExpiresAt, field.TypeTime, value) + _node.ExpiresAt = value + } + if value, ok := alic.mutation.Comment(); ok { + _spec.SetField(allowlistitem.FieldComment, field.TypeString, value) + _node.Comment = value + } + if value, ok := alic.mutation.Value(); ok { + _spec.SetField(allowlistitem.FieldValue, field.TypeString, value) + _node.Value = value + } + if value, ok := alic.mutation.StartIP(); ok { + _spec.SetField(allowlistitem.FieldStartIP, field.TypeInt64, value) + _node.StartIP = value + } + if value, ok := alic.mutation.EndIP(); ok { + _spec.SetField(allowlistitem.FieldEndIP, field.TypeInt64, value) + _node.EndIP = value + } + if value, ok := alic.mutation.StartSuffix(); ok { + _spec.SetField(allowlistitem.FieldStartSuffix, field.TypeInt64, value) + _node.StartSuffix = value + } + if value, ok := alic.mutation.EndSuffix(); ok { + _spec.SetField(allowlistitem.FieldEndSuffix, field.TypeInt64, value) + _node.EndSuffix = value + } + if value, ok := alic.mutation.IPSize(); ok { + _spec.SetField(allowlistitem.FieldIPSize, field.TypeInt64, value) + _node.IPSize = value + } + if nodes := alic.mutation.AllowlistIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: true, + Table: allowlistitem.AllowlistTable, + Columns: allowlistitem.AllowlistPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// AllowListItemCreateBulk is the builder for creating many AllowListItem entities in bulk. +type AllowListItemCreateBulk struct { + config + err error + builders []*AllowListItemCreate +} + +// Save creates the AllowListItem entities in the database. +func (alicb *AllowListItemCreateBulk) Save(ctx context.Context) ([]*AllowListItem, error) { + if alicb.err != nil { + return nil, alicb.err + } + specs := make([]*sqlgraph.CreateSpec, len(alicb.builders)) + nodes := make([]*AllowListItem, len(alicb.builders)) + mutators := make([]Mutator, len(alicb.builders)) + for i := range alicb.builders { + func(i int, root context.Context) { + builder := alicb.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*AllowListItemMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, alicb.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, alicb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, alicb.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (alicb *AllowListItemCreateBulk) SaveX(ctx context.Context) []*AllowListItem { + v, err := alicb.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (alicb *AllowListItemCreateBulk) Exec(ctx context.Context) error { + _, err := alicb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (alicb *AllowListItemCreateBulk) ExecX(ctx context.Context) { + if err := alicb.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/pkg/database/ent/allowlistitem_delete.go b/pkg/database/ent/allowlistitem_delete.go new file mode 100644 index 00000000000..87b340012df --- /dev/null +++ b/pkg/database/ent/allowlistitem_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// AllowListItemDelete is the builder for deleting a AllowListItem entity. +type AllowListItemDelete struct { + config + hooks []Hook + mutation *AllowListItemMutation +} + +// Where appends a list predicates to the AllowListItemDelete builder. +func (alid *AllowListItemDelete) Where(ps ...predicate.AllowListItem) *AllowListItemDelete { + alid.mutation.Where(ps...) + return alid +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (alid *AllowListItemDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, alid.sqlExec, alid.mutation, alid.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (alid *AllowListItemDelete) ExecX(ctx context.Context) int { + n, err := alid.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (alid *AllowListItemDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(allowlistitem.Table, sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt)) + if ps := alid.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, alid.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + alid.mutation.done = true + return affected, err +} + +// AllowListItemDeleteOne is the builder for deleting a single AllowListItem entity. +type AllowListItemDeleteOne struct { + alid *AllowListItemDelete +} + +// Where appends a list predicates to the AllowListItemDelete builder. +func (alido *AllowListItemDeleteOne) Where(ps ...predicate.AllowListItem) *AllowListItemDeleteOne { + alido.alid.mutation.Where(ps...) + return alido +} + +// Exec executes the deletion query. +func (alido *AllowListItemDeleteOne) Exec(ctx context.Context) error { + n, err := alido.alid.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{allowlistitem.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (alido *AllowListItemDeleteOne) ExecX(ctx context.Context) { + if err := alido.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/pkg/database/ent/allowlistitem_query.go b/pkg/database/ent/allowlistitem_query.go new file mode 100644 index 00000000000..628b680a27c --- /dev/null +++ b/pkg/database/ent/allowlistitem_query.go @@ -0,0 +1,636 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "database/sql/driver" + "fmt" + "math" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// AllowListItemQuery is the builder for querying AllowListItem entities. +type AllowListItemQuery struct { + config + ctx *QueryContext + order []allowlistitem.OrderOption + inters []Interceptor + predicates []predicate.AllowListItem + withAllowlist *AllowListQuery + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the AllowListItemQuery builder. +func (aliq *AllowListItemQuery) Where(ps ...predicate.AllowListItem) *AllowListItemQuery { + aliq.predicates = append(aliq.predicates, ps...) + return aliq +} + +// Limit the number of records to be returned by this query. +func (aliq *AllowListItemQuery) Limit(limit int) *AllowListItemQuery { + aliq.ctx.Limit = &limit + return aliq +} + +// Offset to start from. +func (aliq *AllowListItemQuery) Offset(offset int) *AllowListItemQuery { + aliq.ctx.Offset = &offset + return aliq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (aliq *AllowListItemQuery) Unique(unique bool) *AllowListItemQuery { + aliq.ctx.Unique = &unique + return aliq +} + +// Order specifies how the records should be ordered. +func (aliq *AllowListItemQuery) Order(o ...allowlistitem.OrderOption) *AllowListItemQuery { + aliq.order = append(aliq.order, o...) + return aliq +} + +// QueryAllowlist chains the current query on the "allowlist" edge. +func (aliq *AllowListItemQuery) QueryAllowlist() *AllowListQuery { + query := (&AllowListClient{config: aliq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := aliq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := aliq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(allowlistitem.Table, allowlistitem.FieldID, selector), + sqlgraph.To(allowlist.Table, allowlist.FieldID), + sqlgraph.Edge(sqlgraph.M2M, true, allowlistitem.AllowlistTable, allowlistitem.AllowlistPrimaryKey...), + ) + fromU = sqlgraph.SetNeighbors(aliq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first AllowListItem entity from the query. +// Returns a *NotFoundError when no AllowListItem was found. +func (aliq *AllowListItemQuery) First(ctx context.Context) (*AllowListItem, error) { + nodes, err := aliq.Limit(1).All(setContextOp(ctx, aliq.ctx, "First")) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{allowlistitem.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (aliq *AllowListItemQuery) FirstX(ctx context.Context) *AllowListItem { + node, err := aliq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first AllowListItem ID from the query. +// Returns a *NotFoundError when no AllowListItem ID was found. +func (aliq *AllowListItemQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = aliq.Limit(1).IDs(setContextOp(ctx, aliq.ctx, "FirstID")); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{allowlistitem.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (aliq *AllowListItemQuery) FirstIDX(ctx context.Context) int { + id, err := aliq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single AllowListItem entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one AllowListItem entity is found. +// Returns a *NotFoundError when no AllowListItem entities are found. +func (aliq *AllowListItemQuery) Only(ctx context.Context) (*AllowListItem, error) { + nodes, err := aliq.Limit(2).All(setContextOp(ctx, aliq.ctx, "Only")) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{allowlistitem.Label} + default: + return nil, &NotSingularError{allowlistitem.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (aliq *AllowListItemQuery) OnlyX(ctx context.Context) *AllowListItem { + node, err := aliq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only AllowListItem ID in the query. +// Returns a *NotSingularError when more than one AllowListItem ID is found. +// Returns a *NotFoundError when no entities are found. +func (aliq *AllowListItemQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = aliq.Limit(2).IDs(setContextOp(ctx, aliq.ctx, "OnlyID")); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{allowlistitem.Label} + default: + err = &NotSingularError{allowlistitem.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (aliq *AllowListItemQuery) OnlyIDX(ctx context.Context) int { + id, err := aliq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of AllowListItems. +func (aliq *AllowListItemQuery) All(ctx context.Context) ([]*AllowListItem, error) { + ctx = setContextOp(ctx, aliq.ctx, "All") + if err := aliq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*AllowListItem, *AllowListItemQuery]() + return withInterceptors[[]*AllowListItem](ctx, aliq, qr, aliq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (aliq *AllowListItemQuery) AllX(ctx context.Context) []*AllowListItem { + nodes, err := aliq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of AllowListItem IDs. +func (aliq *AllowListItemQuery) IDs(ctx context.Context) (ids []int, err error) { + if aliq.ctx.Unique == nil && aliq.path != nil { + aliq.Unique(true) + } + ctx = setContextOp(ctx, aliq.ctx, "IDs") + if err = aliq.Select(allowlistitem.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (aliq *AllowListItemQuery) IDsX(ctx context.Context) []int { + ids, err := aliq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (aliq *AllowListItemQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, aliq.ctx, "Count") + if err := aliq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, aliq, querierCount[*AllowListItemQuery](), aliq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (aliq *AllowListItemQuery) CountX(ctx context.Context) int { + count, err := aliq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (aliq *AllowListItemQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, aliq.ctx, "Exist") + switch _, err := aliq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (aliq *AllowListItemQuery) ExistX(ctx context.Context) bool { + exist, err := aliq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the AllowListItemQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (aliq *AllowListItemQuery) Clone() *AllowListItemQuery { + if aliq == nil { + return nil + } + return &AllowListItemQuery{ + config: aliq.config, + ctx: aliq.ctx.Clone(), + order: append([]allowlistitem.OrderOption{}, aliq.order...), + inters: append([]Interceptor{}, aliq.inters...), + predicates: append([]predicate.AllowListItem{}, aliq.predicates...), + withAllowlist: aliq.withAllowlist.Clone(), + // clone intermediate query. + sql: aliq.sql.Clone(), + path: aliq.path, + } +} + +// WithAllowlist tells the query-builder to eager-load the nodes that are connected to +// the "allowlist" edge. The optional arguments are used to configure the query builder of the edge. +func (aliq *AllowListItemQuery) WithAllowlist(opts ...func(*AllowListQuery)) *AllowListItemQuery { + query := (&AllowListClient{config: aliq.config}).Query() + for _, opt := range opts { + opt(query) + } + aliq.withAllowlist = query + return aliq +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.AllowListItem.Query(). +// GroupBy(allowlistitem.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (aliq *AllowListItemQuery) GroupBy(field string, fields ...string) *AllowListItemGroupBy { + aliq.ctx.Fields = append([]string{field}, fields...) + grbuild := &AllowListItemGroupBy{build: aliq} + grbuild.flds = &aliq.ctx.Fields + grbuild.label = allowlistitem.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.AllowListItem.Query(). +// Select(allowlistitem.FieldCreatedAt). +// Scan(ctx, &v) +func (aliq *AllowListItemQuery) Select(fields ...string) *AllowListItemSelect { + aliq.ctx.Fields = append(aliq.ctx.Fields, fields...) + sbuild := &AllowListItemSelect{AllowListItemQuery: aliq} + sbuild.label = allowlistitem.Label + sbuild.flds, sbuild.scan = &aliq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a AllowListItemSelect configured with the given aggregations. +func (aliq *AllowListItemQuery) Aggregate(fns ...AggregateFunc) *AllowListItemSelect { + return aliq.Select().Aggregate(fns...) +} + +func (aliq *AllowListItemQuery) prepareQuery(ctx context.Context) error { + for _, inter := range aliq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, aliq); err != nil { + return err + } + } + } + for _, f := range aliq.ctx.Fields { + if !allowlistitem.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if aliq.path != nil { + prev, err := aliq.path(ctx) + if err != nil { + return err + } + aliq.sql = prev + } + return nil +} + +func (aliq *AllowListItemQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AllowListItem, error) { + var ( + nodes = []*AllowListItem{} + _spec = aliq.querySpec() + loadedTypes = [1]bool{ + aliq.withAllowlist != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*AllowListItem).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &AllowListItem{config: aliq.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, aliq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := aliq.withAllowlist; query != nil { + if err := aliq.loadAllowlist(ctx, query, nodes, + func(n *AllowListItem) { n.Edges.Allowlist = []*AllowList{} }, + func(n *AllowListItem, e *AllowList) { n.Edges.Allowlist = append(n.Edges.Allowlist, e) }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (aliq *AllowListItemQuery) loadAllowlist(ctx context.Context, query *AllowListQuery, nodes []*AllowListItem, init func(*AllowListItem), assign func(*AllowListItem, *AllowList)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*AllowListItem) + nids := make(map[int]map[*AllowListItem]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(allowlistitem.AllowlistTable) + s.Join(joinT).On(s.C(allowlist.FieldID), joinT.C(allowlistitem.AllowlistPrimaryKey[0])) + s.Where(sql.InValues(joinT.C(allowlistitem.AllowlistPrimaryKey[1]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(allowlistitem.AllowlistPrimaryKey[1])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + if err := query.prepareQuery(ctx); err != nil { + return err + } + qr := QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + return query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]any, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err + } + return append([]any{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []any) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*AllowListItem]struct{}{byID[outValue]: {}} + return assign(columns[1:], values[1:]) + } + nids[inValue][byID[outValue]] = struct{}{} + return nil + } + }) + }) + neighbors, err := withInterceptors[[]*AllowList](ctx, query, qr, query.inters) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "allowlist" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil +} + +func (aliq *AllowListItemQuery) sqlCount(ctx context.Context) (int, error) { + _spec := aliq.querySpec() + _spec.Node.Columns = aliq.ctx.Fields + if len(aliq.ctx.Fields) > 0 { + _spec.Unique = aliq.ctx.Unique != nil && *aliq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, aliq.driver, _spec) +} + +func (aliq *AllowListItemQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(allowlistitem.Table, allowlistitem.Columns, sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt)) + _spec.From = aliq.sql + if unique := aliq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if aliq.path != nil { + _spec.Unique = true + } + if fields := aliq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, allowlistitem.FieldID) + for i := range fields { + if fields[i] != allowlistitem.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := aliq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := aliq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := aliq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := aliq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (aliq *AllowListItemQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(aliq.driver.Dialect()) + t1 := builder.Table(allowlistitem.Table) + columns := aliq.ctx.Fields + if len(columns) == 0 { + columns = allowlistitem.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if aliq.sql != nil { + selector = aliq.sql + selector.Select(selector.Columns(columns...)...) + } + if aliq.ctx.Unique != nil && *aliq.ctx.Unique { + selector.Distinct() + } + for _, p := range aliq.predicates { + p(selector) + } + for _, p := range aliq.order { + p(selector) + } + if offset := aliq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := aliq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// AllowListItemGroupBy is the group-by builder for AllowListItem entities. +type AllowListItemGroupBy struct { + selector + build *AllowListItemQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (aligb *AllowListItemGroupBy) Aggregate(fns ...AggregateFunc) *AllowListItemGroupBy { + aligb.fns = append(aligb.fns, fns...) + return aligb +} + +// Scan applies the selector query and scans the result into the given value. +func (aligb *AllowListItemGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, aligb.build.ctx, "GroupBy") + if err := aligb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AllowListItemQuery, *AllowListItemGroupBy](ctx, aligb.build, aligb, aligb.build.inters, v) +} + +func (aligb *AllowListItemGroupBy) sqlScan(ctx context.Context, root *AllowListItemQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(aligb.fns)) + for _, fn := range aligb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*aligb.flds)+len(aligb.fns)) + for _, f := range *aligb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*aligb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := aligb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// AllowListItemSelect is the builder for selecting fields of AllowListItem entities. +type AllowListItemSelect struct { + *AllowListItemQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (alis *AllowListItemSelect) Aggregate(fns ...AggregateFunc) *AllowListItemSelect { + alis.fns = append(alis.fns, fns...) + return alis +} + +// Scan applies the selector query and scans the result into the given value. +func (alis *AllowListItemSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, alis.ctx, "Select") + if err := alis.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AllowListItemQuery, *AllowListItemSelect](ctx, alis.AllowListItemQuery, alis, alis.inters, v) +} + +func (alis *AllowListItemSelect) sqlScan(ctx context.Context, root *AllowListItemQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(alis.fns)) + for _, fn := range alis.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*alis.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := alis.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/pkg/database/ent/allowlistitem_update.go b/pkg/database/ent/allowlistitem_update.go new file mode 100644 index 00000000000..e6878955afe --- /dev/null +++ b/pkg/database/ent/allowlistitem_update.go @@ -0,0 +1,463 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" +) + +// AllowListItemUpdate is the builder for updating AllowListItem entities. +type AllowListItemUpdate struct { + config + hooks []Hook + mutation *AllowListItemMutation +} + +// Where appends a list predicates to the AllowListItemUpdate builder. +func (aliu *AllowListItemUpdate) Where(ps ...predicate.AllowListItem) *AllowListItemUpdate { + aliu.mutation.Where(ps...) + return aliu +} + +// SetUpdatedAt sets the "updated_at" field. +func (aliu *AllowListItemUpdate) SetUpdatedAt(t time.Time) *AllowListItemUpdate { + aliu.mutation.SetUpdatedAt(t) + return aliu +} + +// SetExpiresAt sets the "expires_at" field. +func (aliu *AllowListItemUpdate) SetExpiresAt(t time.Time) *AllowListItemUpdate { + aliu.mutation.SetExpiresAt(t) + return aliu +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (aliu *AllowListItemUpdate) SetNillableExpiresAt(t *time.Time) *AllowListItemUpdate { + if t != nil { + aliu.SetExpiresAt(*t) + } + return aliu +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (aliu *AllowListItemUpdate) ClearExpiresAt() *AllowListItemUpdate { + aliu.mutation.ClearExpiresAt() + return aliu +} + +// AddAllowlistIDs adds the "allowlist" edge to the AllowList entity by IDs. +func (aliu *AllowListItemUpdate) AddAllowlistIDs(ids ...int) *AllowListItemUpdate { + aliu.mutation.AddAllowlistIDs(ids...) + return aliu +} + +// AddAllowlist adds the "allowlist" edges to the AllowList entity. +func (aliu *AllowListItemUpdate) AddAllowlist(a ...*AllowList) *AllowListItemUpdate { + ids := make([]int, len(a)) + for i := range a { + ids[i] = a[i].ID + } + return aliu.AddAllowlistIDs(ids...) +} + +// Mutation returns the AllowListItemMutation object of the builder. +func (aliu *AllowListItemUpdate) Mutation() *AllowListItemMutation { + return aliu.mutation +} + +// ClearAllowlist clears all "allowlist" edges to the AllowList entity. +func (aliu *AllowListItemUpdate) ClearAllowlist() *AllowListItemUpdate { + aliu.mutation.ClearAllowlist() + return aliu +} + +// RemoveAllowlistIDs removes the "allowlist" edge to AllowList entities by IDs. +func (aliu *AllowListItemUpdate) RemoveAllowlistIDs(ids ...int) *AllowListItemUpdate { + aliu.mutation.RemoveAllowlistIDs(ids...) + return aliu +} + +// RemoveAllowlist removes "allowlist" edges to AllowList entities. +func (aliu *AllowListItemUpdate) RemoveAllowlist(a ...*AllowList) *AllowListItemUpdate { + ids := make([]int, len(a)) + for i := range a { + ids[i] = a[i].ID + } + return aliu.RemoveAllowlistIDs(ids...) +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (aliu *AllowListItemUpdate) Save(ctx context.Context) (int, error) { + aliu.defaults() + return withHooks(ctx, aliu.sqlSave, aliu.mutation, aliu.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (aliu *AllowListItemUpdate) SaveX(ctx context.Context) int { + affected, err := aliu.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (aliu *AllowListItemUpdate) Exec(ctx context.Context) error { + _, err := aliu.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (aliu *AllowListItemUpdate) ExecX(ctx context.Context) { + if err := aliu.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (aliu *AllowListItemUpdate) defaults() { + if _, ok := aliu.mutation.UpdatedAt(); !ok { + v := allowlistitem.UpdateDefaultUpdatedAt() + aliu.mutation.SetUpdatedAt(v) + } +} + +func (aliu *AllowListItemUpdate) sqlSave(ctx context.Context) (n int, err error) { + _spec := sqlgraph.NewUpdateSpec(allowlistitem.Table, allowlistitem.Columns, sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt)) + if ps := aliu.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := aliu.mutation.UpdatedAt(); ok { + _spec.SetField(allowlistitem.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := aliu.mutation.ExpiresAt(); ok { + _spec.SetField(allowlistitem.FieldExpiresAt, field.TypeTime, value) + } + if aliu.mutation.ExpiresAtCleared() { + _spec.ClearField(allowlistitem.FieldExpiresAt, field.TypeTime) + } + if aliu.mutation.CommentCleared() { + _spec.ClearField(allowlistitem.FieldComment, field.TypeString) + } + if aliu.mutation.StartIPCleared() { + _spec.ClearField(allowlistitem.FieldStartIP, field.TypeInt64) + } + if aliu.mutation.EndIPCleared() { + _spec.ClearField(allowlistitem.FieldEndIP, field.TypeInt64) + } + if aliu.mutation.StartSuffixCleared() { + _spec.ClearField(allowlistitem.FieldStartSuffix, field.TypeInt64) + } + if aliu.mutation.EndSuffixCleared() { + _spec.ClearField(allowlistitem.FieldEndSuffix, field.TypeInt64) + } + if aliu.mutation.IPSizeCleared() { + _spec.ClearField(allowlistitem.FieldIPSize, field.TypeInt64) + } + if aliu.mutation.AllowlistCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: true, + Table: allowlistitem.AllowlistTable, + Columns: allowlistitem.AllowlistPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := aliu.mutation.RemovedAllowlistIDs(); len(nodes) > 0 && !aliu.mutation.AllowlistCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: true, + Table: allowlistitem.AllowlistTable, + Columns: allowlistitem.AllowlistPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := aliu.mutation.AllowlistIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: true, + Table: allowlistitem.AllowlistTable, + Columns: allowlistitem.AllowlistPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if n, err = sqlgraph.UpdateNodes(ctx, aliu.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{allowlistitem.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + aliu.mutation.done = true + return n, nil +} + +// AllowListItemUpdateOne is the builder for updating a single AllowListItem entity. +type AllowListItemUpdateOne struct { + config + fields []string + hooks []Hook + mutation *AllowListItemMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (aliuo *AllowListItemUpdateOne) SetUpdatedAt(t time.Time) *AllowListItemUpdateOne { + aliuo.mutation.SetUpdatedAt(t) + return aliuo +} + +// SetExpiresAt sets the "expires_at" field. +func (aliuo *AllowListItemUpdateOne) SetExpiresAt(t time.Time) *AllowListItemUpdateOne { + aliuo.mutation.SetExpiresAt(t) + return aliuo +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (aliuo *AllowListItemUpdateOne) SetNillableExpiresAt(t *time.Time) *AllowListItemUpdateOne { + if t != nil { + aliuo.SetExpiresAt(*t) + } + return aliuo +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (aliuo *AllowListItemUpdateOne) ClearExpiresAt() *AllowListItemUpdateOne { + aliuo.mutation.ClearExpiresAt() + return aliuo +} + +// AddAllowlistIDs adds the "allowlist" edge to the AllowList entity by IDs. +func (aliuo *AllowListItemUpdateOne) AddAllowlistIDs(ids ...int) *AllowListItemUpdateOne { + aliuo.mutation.AddAllowlistIDs(ids...) + return aliuo +} + +// AddAllowlist adds the "allowlist" edges to the AllowList entity. +func (aliuo *AllowListItemUpdateOne) AddAllowlist(a ...*AllowList) *AllowListItemUpdateOne { + ids := make([]int, len(a)) + for i := range a { + ids[i] = a[i].ID + } + return aliuo.AddAllowlistIDs(ids...) +} + +// Mutation returns the AllowListItemMutation object of the builder. +func (aliuo *AllowListItemUpdateOne) Mutation() *AllowListItemMutation { + return aliuo.mutation +} + +// ClearAllowlist clears all "allowlist" edges to the AllowList entity. +func (aliuo *AllowListItemUpdateOne) ClearAllowlist() *AllowListItemUpdateOne { + aliuo.mutation.ClearAllowlist() + return aliuo +} + +// RemoveAllowlistIDs removes the "allowlist" edge to AllowList entities by IDs. +func (aliuo *AllowListItemUpdateOne) RemoveAllowlistIDs(ids ...int) *AllowListItemUpdateOne { + aliuo.mutation.RemoveAllowlistIDs(ids...) + return aliuo +} + +// RemoveAllowlist removes "allowlist" edges to AllowList entities. +func (aliuo *AllowListItemUpdateOne) RemoveAllowlist(a ...*AllowList) *AllowListItemUpdateOne { + ids := make([]int, len(a)) + for i := range a { + ids[i] = a[i].ID + } + return aliuo.RemoveAllowlistIDs(ids...) +} + +// Where appends a list predicates to the AllowListItemUpdate builder. +func (aliuo *AllowListItemUpdateOne) Where(ps ...predicate.AllowListItem) *AllowListItemUpdateOne { + aliuo.mutation.Where(ps...) + return aliuo +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (aliuo *AllowListItemUpdateOne) Select(field string, fields ...string) *AllowListItemUpdateOne { + aliuo.fields = append([]string{field}, fields...) + return aliuo +} + +// Save executes the query and returns the updated AllowListItem entity. +func (aliuo *AllowListItemUpdateOne) Save(ctx context.Context) (*AllowListItem, error) { + aliuo.defaults() + return withHooks(ctx, aliuo.sqlSave, aliuo.mutation, aliuo.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (aliuo *AllowListItemUpdateOne) SaveX(ctx context.Context) *AllowListItem { + node, err := aliuo.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (aliuo *AllowListItemUpdateOne) Exec(ctx context.Context) error { + _, err := aliuo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (aliuo *AllowListItemUpdateOne) ExecX(ctx context.Context) { + if err := aliuo.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (aliuo *AllowListItemUpdateOne) defaults() { + if _, ok := aliuo.mutation.UpdatedAt(); !ok { + v := allowlistitem.UpdateDefaultUpdatedAt() + aliuo.mutation.SetUpdatedAt(v) + } +} + +func (aliuo *AllowListItemUpdateOne) sqlSave(ctx context.Context) (_node *AllowListItem, err error) { + _spec := sqlgraph.NewUpdateSpec(allowlistitem.Table, allowlistitem.Columns, sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt)) + id, ok := aliuo.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AllowListItem.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := aliuo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, allowlistitem.FieldID) + for _, f := range fields { + if !allowlistitem.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != allowlistitem.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := aliuo.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := aliuo.mutation.UpdatedAt(); ok { + _spec.SetField(allowlistitem.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := aliuo.mutation.ExpiresAt(); ok { + _spec.SetField(allowlistitem.FieldExpiresAt, field.TypeTime, value) + } + if aliuo.mutation.ExpiresAtCleared() { + _spec.ClearField(allowlistitem.FieldExpiresAt, field.TypeTime) + } + if aliuo.mutation.CommentCleared() { + _spec.ClearField(allowlistitem.FieldComment, field.TypeString) + } + if aliuo.mutation.StartIPCleared() { + _spec.ClearField(allowlistitem.FieldStartIP, field.TypeInt64) + } + if aliuo.mutation.EndIPCleared() { + _spec.ClearField(allowlistitem.FieldEndIP, field.TypeInt64) + } + if aliuo.mutation.StartSuffixCleared() { + _spec.ClearField(allowlistitem.FieldStartSuffix, field.TypeInt64) + } + if aliuo.mutation.EndSuffixCleared() { + _spec.ClearField(allowlistitem.FieldEndSuffix, field.TypeInt64) + } + if aliuo.mutation.IPSizeCleared() { + _spec.ClearField(allowlistitem.FieldIPSize, field.TypeInt64) + } + if aliuo.mutation.AllowlistCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: true, + Table: allowlistitem.AllowlistTable, + Columns: allowlistitem.AllowlistPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := aliuo.mutation.RemovedAllowlistIDs(); len(nodes) > 0 && !aliuo.mutation.AllowlistCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: true, + Table: allowlistitem.AllowlistTable, + Columns: allowlistitem.AllowlistPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := aliuo.mutation.AllowlistIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: true, + Table: allowlistitem.AllowlistTable, + Columns: allowlistitem.AllowlistPrimaryKey, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &AllowListItem{config: aliuo.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, aliuo.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{allowlistitem.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + aliuo.mutation.done = true + return _node, nil +} diff --git a/pkg/database/ent/bouncer.go b/pkg/database/ent/bouncer.go index 3b4d619e384..197f61cde19 100644 --- a/pkg/database/ent/bouncer.go +++ b/pkg/database/ent/bouncer.go @@ -43,6 +43,8 @@ type Bouncer struct { Osversion string `json:"osversion,omitempty"` // Featureflags holds the value of the "featureflags" field. Featureflags string `json:"featureflags,omitempty"` + // AutoCreated holds the value of the "auto_created" field. + AutoCreated bool `json:"auto_created"` selectValues sql.SelectValues } @@ -51,7 +53,7 @@ func (*Bouncer) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case bouncer.FieldRevoked: + case bouncer.FieldRevoked, bouncer.FieldAutoCreated: values[i] = new(sql.NullBool) case bouncer.FieldID: values[i] = new(sql.NullInt64) @@ -159,6 +161,12 @@ func (b *Bouncer) assignValues(columns []string, values []any) error { } else if value.Valid { b.Featureflags = value.String } + case bouncer.FieldAutoCreated: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field auto_created", values[i]) + } else if value.Valid { + b.AutoCreated = value.Bool + } default: b.selectValues.Set(columns[i], values[i]) } @@ -234,6 +242,9 @@ func (b *Bouncer) String() string { builder.WriteString(", ") builder.WriteString("featureflags=") builder.WriteString(b.Featureflags) + builder.WriteString(", ") + builder.WriteString("auto_created=") + builder.WriteString(fmt.Sprintf("%v", b.AutoCreated)) builder.WriteByte(')') return builder.String() } diff --git a/pkg/database/ent/bouncer/bouncer.go b/pkg/database/ent/bouncer/bouncer.go index a6f62aeadd5..f25b5a5815a 100644 --- a/pkg/database/ent/bouncer/bouncer.go +++ b/pkg/database/ent/bouncer/bouncer.go @@ -39,6 +39,8 @@ const ( FieldOsversion = "osversion" // FieldFeatureflags holds the string denoting the featureflags field in the database. FieldFeatureflags = "featureflags" + // FieldAutoCreated holds the string denoting the auto_created field in the database. + FieldAutoCreated = "auto_created" // Table holds the table name of the bouncer in the database. Table = "bouncers" ) @@ -59,6 +61,7 @@ var Columns = []string{ FieldOsname, FieldOsversion, FieldFeatureflags, + FieldAutoCreated, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -82,6 +85,8 @@ var ( DefaultIPAddress string // DefaultAuthType holds the default value on creation for the "auth_type" field. DefaultAuthType string + // DefaultAutoCreated holds the default value on creation for the "auto_created" field. + DefaultAutoCreated bool ) // OrderOption defines the ordering options for the Bouncer queries. @@ -156,3 +161,8 @@ func ByOsversion(opts ...sql.OrderTermOption) OrderOption { func ByFeatureflags(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldFeatureflags, opts...).ToFunc() } + +// ByAutoCreated orders the results by the auto_created field. +func ByAutoCreated(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAutoCreated, opts...).ToFunc() +} diff --git a/pkg/database/ent/bouncer/where.go b/pkg/database/ent/bouncer/where.go index e02199bc0a9..79b8999354f 100644 --- a/pkg/database/ent/bouncer/where.go +++ b/pkg/database/ent/bouncer/where.go @@ -119,6 +119,11 @@ func Featureflags(v string) predicate.Bouncer { return predicate.Bouncer(sql.FieldEQ(FieldFeatureflags, v)) } +// AutoCreated applies equality check predicate on the "auto_created" field. It's identical to AutoCreatedEQ. +func AutoCreated(v bool) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEQ(FieldAutoCreated, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Bouncer { return predicate.Bouncer(sql.FieldEQ(FieldCreatedAt, v)) @@ -904,6 +909,16 @@ func FeatureflagsContainsFold(v string) predicate.Bouncer { return predicate.Bouncer(sql.FieldContainsFold(FieldFeatureflags, v)) } +// AutoCreatedEQ applies the EQ predicate on the "auto_created" field. +func AutoCreatedEQ(v bool) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEQ(FieldAutoCreated, v)) +} + +// AutoCreatedNEQ applies the NEQ predicate on the "auto_created" field. +func AutoCreatedNEQ(v bool) predicate.Bouncer { + return predicate.Bouncer(sql.FieldNEQ(FieldAutoCreated, v)) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.Bouncer) predicate.Bouncer { return predicate.Bouncer(sql.AndPredicates(predicates...)) diff --git a/pkg/database/ent/bouncer_create.go b/pkg/database/ent/bouncer_create.go index 29b23f87cf1..9ff4c0e0820 100644 --- a/pkg/database/ent/bouncer_create.go +++ b/pkg/database/ent/bouncer_create.go @@ -178,6 +178,20 @@ func (bc *BouncerCreate) SetNillableFeatureflags(s *string) *BouncerCreate { return bc } +// SetAutoCreated sets the "auto_created" field. +func (bc *BouncerCreate) SetAutoCreated(b bool) *BouncerCreate { + bc.mutation.SetAutoCreated(b) + return bc +} + +// SetNillableAutoCreated sets the "auto_created" field if the given value is not nil. +func (bc *BouncerCreate) SetNillableAutoCreated(b *bool) *BouncerCreate { + if b != nil { + bc.SetAutoCreated(*b) + } + return bc +} + // Mutation returns the BouncerMutation object of the builder. func (bc *BouncerCreate) Mutation() *BouncerMutation { return bc.mutation @@ -229,6 +243,10 @@ func (bc *BouncerCreate) defaults() { v := bouncer.DefaultAuthType bc.mutation.SetAuthType(v) } + if _, ok := bc.mutation.AutoCreated(); !ok { + v := bouncer.DefaultAutoCreated + bc.mutation.SetAutoCreated(v) + } } // check runs all checks and user-defined validators on the builder. @@ -251,6 +269,9 @@ func (bc *BouncerCreate) check() error { if _, ok := bc.mutation.AuthType(); !ok { return &ValidationError{Name: "auth_type", err: errors.New(`ent: missing required field "Bouncer.auth_type"`)} } + if _, ok := bc.mutation.AutoCreated(); !ok { + return &ValidationError{Name: "auto_created", err: errors.New(`ent: missing required field "Bouncer.auto_created"`)} + } return nil } @@ -329,6 +350,10 @@ func (bc *BouncerCreate) createSpec() (*Bouncer, *sqlgraph.CreateSpec) { _spec.SetField(bouncer.FieldFeatureflags, field.TypeString, value) _node.Featureflags = value } + if value, ok := bc.mutation.AutoCreated(); ok { + _spec.SetField(bouncer.FieldAutoCreated, field.TypeBool, value) + _node.AutoCreated = value + } return _node, _spec } diff --git a/pkg/database/ent/client.go b/pkg/database/ent/client.go index 59686102ebe..bc7c0330459 100644 --- a/pkg/database/ent/client.go +++ b/pkg/database/ent/client.go @@ -16,6 +16,8 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem" "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" @@ -33,6 +35,10 @@ type Client struct { Schema *migrate.Schema // Alert is the client for interacting with the Alert builders. Alert *AlertClient + // AllowList is the client for interacting with the AllowList builders. + AllowList *AllowListClient + // AllowListItem is the client for interacting with the AllowListItem builders. + AllowListItem *AllowListItemClient // Bouncer is the client for interacting with the Bouncer builders. Bouncer *BouncerClient // ConfigItem is the client for interacting with the ConfigItem builders. @@ -61,6 +67,8 @@ func NewClient(opts ...Option) *Client { func (c *Client) init() { c.Schema = migrate.NewSchema(c.driver) c.Alert = NewAlertClient(c.config) + c.AllowList = NewAllowListClient(c.config) + c.AllowListItem = NewAllowListItemClient(c.config) c.Bouncer = NewBouncerClient(c.config) c.ConfigItem = NewConfigItemClient(c.config) c.Decision = NewDecisionClient(c.config) @@ -159,17 +167,19 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { cfg := c.config cfg.driver = tx return &Tx{ - ctx: ctx, - config: cfg, - Alert: NewAlertClient(cfg), - Bouncer: NewBouncerClient(cfg), - ConfigItem: NewConfigItemClient(cfg), - Decision: NewDecisionClient(cfg), - Event: NewEventClient(cfg), - Lock: NewLockClient(cfg), - Machine: NewMachineClient(cfg), - Meta: NewMetaClient(cfg), - Metric: NewMetricClient(cfg), + ctx: ctx, + config: cfg, + Alert: NewAlertClient(cfg), + AllowList: NewAllowListClient(cfg), + AllowListItem: NewAllowListItemClient(cfg), + Bouncer: NewBouncerClient(cfg), + ConfigItem: NewConfigItemClient(cfg), + Decision: NewDecisionClient(cfg), + Event: NewEventClient(cfg), + Lock: NewLockClient(cfg), + Machine: NewMachineClient(cfg), + Meta: NewMetaClient(cfg), + Metric: NewMetricClient(cfg), }, nil } @@ -187,17 +197,19 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) cfg := c.config cfg.driver = &txDriver{tx: tx, drv: c.driver} return &Tx{ - ctx: ctx, - config: cfg, - Alert: NewAlertClient(cfg), - Bouncer: NewBouncerClient(cfg), - ConfigItem: NewConfigItemClient(cfg), - Decision: NewDecisionClient(cfg), - Event: NewEventClient(cfg), - Lock: NewLockClient(cfg), - Machine: NewMachineClient(cfg), - Meta: NewMetaClient(cfg), - Metric: NewMetricClient(cfg), + ctx: ctx, + config: cfg, + Alert: NewAlertClient(cfg), + AllowList: NewAllowListClient(cfg), + AllowListItem: NewAllowListItemClient(cfg), + Bouncer: NewBouncerClient(cfg), + ConfigItem: NewConfigItemClient(cfg), + Decision: NewDecisionClient(cfg), + Event: NewEventClient(cfg), + Lock: NewLockClient(cfg), + Machine: NewMachineClient(cfg), + Meta: NewMetaClient(cfg), + Metric: NewMetricClient(cfg), }, nil } @@ -227,8 +239,8 @@ func (c *Client) Close() error { // In order to add hooks to a specific client, call: `client.Node.Use(...)`. func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ - c.Alert, c.Bouncer, c.ConfigItem, c.Decision, c.Event, c.Lock, c.Machine, - c.Meta, c.Metric, + c.Alert, c.AllowList, c.AllowListItem, c.Bouncer, c.ConfigItem, c.Decision, + c.Event, c.Lock, c.Machine, c.Meta, c.Metric, } { n.Use(hooks...) } @@ -238,8 +250,8 @@ func (c *Client) Use(hooks ...Hook) { // In order to add interceptors to a specific client, call: `client.Node.Intercept(...)`. func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ - c.Alert, c.Bouncer, c.ConfigItem, c.Decision, c.Event, c.Lock, c.Machine, - c.Meta, c.Metric, + c.Alert, c.AllowList, c.AllowListItem, c.Bouncer, c.ConfigItem, c.Decision, + c.Event, c.Lock, c.Machine, c.Meta, c.Metric, } { n.Intercept(interceptors...) } @@ -250,6 +262,10 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { switch m := m.(type) { case *AlertMutation: return c.Alert.mutate(ctx, m) + case *AllowListMutation: + return c.AllowList.mutate(ctx, m) + case *AllowListItemMutation: + return c.AllowListItem.mutate(ctx, m) case *BouncerMutation: return c.Bouncer.mutate(ctx, m) case *ConfigItemMutation: @@ -468,6 +484,304 @@ func (c *AlertClient) mutate(ctx context.Context, m *AlertMutation) (Value, erro } } +// AllowListClient is a client for the AllowList schema. +type AllowListClient struct { + config +} + +// NewAllowListClient returns a client for the AllowList from the given config. +func NewAllowListClient(c config) *AllowListClient { + return &AllowListClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `allowlist.Hooks(f(g(h())))`. +func (c *AllowListClient) Use(hooks ...Hook) { + c.hooks.AllowList = append(c.hooks.AllowList, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `allowlist.Intercept(f(g(h())))`. +func (c *AllowListClient) Intercept(interceptors ...Interceptor) { + c.inters.AllowList = append(c.inters.AllowList, interceptors...) +} + +// Create returns a builder for creating a AllowList entity. +func (c *AllowListClient) Create() *AllowListCreate { + mutation := newAllowListMutation(c.config, OpCreate) + return &AllowListCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of AllowList entities. +func (c *AllowListClient) CreateBulk(builders ...*AllowListCreate) *AllowListCreateBulk { + return &AllowListCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *AllowListClient) MapCreateBulk(slice any, setFunc func(*AllowListCreate, int)) *AllowListCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &AllowListCreateBulk{err: fmt.Errorf("calling to AllowListClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*AllowListCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &AllowListCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for AllowList. +func (c *AllowListClient) Update() *AllowListUpdate { + mutation := newAllowListMutation(c.config, OpUpdate) + return &AllowListUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *AllowListClient) UpdateOne(al *AllowList) *AllowListUpdateOne { + mutation := newAllowListMutation(c.config, OpUpdateOne, withAllowList(al)) + return &AllowListUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *AllowListClient) UpdateOneID(id int) *AllowListUpdateOne { + mutation := newAllowListMutation(c.config, OpUpdateOne, withAllowListID(id)) + return &AllowListUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for AllowList. +func (c *AllowListClient) Delete() *AllowListDelete { + mutation := newAllowListMutation(c.config, OpDelete) + return &AllowListDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *AllowListClient) DeleteOne(al *AllowList) *AllowListDeleteOne { + return c.DeleteOneID(al.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *AllowListClient) DeleteOneID(id int) *AllowListDeleteOne { + builder := c.Delete().Where(allowlist.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &AllowListDeleteOne{builder} +} + +// Query returns a query builder for AllowList. +func (c *AllowListClient) Query() *AllowListQuery { + return &AllowListQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeAllowList}, + inters: c.Interceptors(), + } +} + +// Get returns a AllowList entity by its id. +func (c *AllowListClient) Get(ctx context.Context, id int) (*AllowList, error) { + return c.Query().Where(allowlist.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *AllowListClient) GetX(ctx context.Context, id int) *AllowList { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryAllowlistItems queries the allowlist_items edge of a AllowList. +func (c *AllowListClient) QueryAllowlistItems(al *AllowList) *AllowListItemQuery { + query := (&AllowListItemClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := al.ID + step := sqlgraph.NewStep( + sqlgraph.From(allowlist.Table, allowlist.FieldID, id), + sqlgraph.To(allowlistitem.Table, allowlistitem.FieldID), + sqlgraph.Edge(sqlgraph.M2M, false, allowlist.AllowlistItemsTable, allowlist.AllowlistItemsPrimaryKey...), + ) + fromV = sqlgraph.Neighbors(al.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *AllowListClient) Hooks() []Hook { + return c.hooks.AllowList +} + +// Interceptors returns the client interceptors. +func (c *AllowListClient) Interceptors() []Interceptor { + return c.inters.AllowList +} + +func (c *AllowListClient) mutate(ctx context.Context, m *AllowListMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&AllowListCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&AllowListUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&AllowListUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&AllowListDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown AllowList mutation op: %q", m.Op()) + } +} + +// AllowListItemClient is a client for the AllowListItem schema. +type AllowListItemClient struct { + config +} + +// NewAllowListItemClient returns a client for the AllowListItem from the given config. +func NewAllowListItemClient(c config) *AllowListItemClient { + return &AllowListItemClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `allowlistitem.Hooks(f(g(h())))`. +func (c *AllowListItemClient) Use(hooks ...Hook) { + c.hooks.AllowListItem = append(c.hooks.AllowListItem, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `allowlistitem.Intercept(f(g(h())))`. +func (c *AllowListItemClient) Intercept(interceptors ...Interceptor) { + c.inters.AllowListItem = append(c.inters.AllowListItem, interceptors...) +} + +// Create returns a builder for creating a AllowListItem entity. +func (c *AllowListItemClient) Create() *AllowListItemCreate { + mutation := newAllowListItemMutation(c.config, OpCreate) + return &AllowListItemCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of AllowListItem entities. +func (c *AllowListItemClient) CreateBulk(builders ...*AllowListItemCreate) *AllowListItemCreateBulk { + return &AllowListItemCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *AllowListItemClient) MapCreateBulk(slice any, setFunc func(*AllowListItemCreate, int)) *AllowListItemCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &AllowListItemCreateBulk{err: fmt.Errorf("calling to AllowListItemClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*AllowListItemCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &AllowListItemCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for AllowListItem. +func (c *AllowListItemClient) Update() *AllowListItemUpdate { + mutation := newAllowListItemMutation(c.config, OpUpdate) + return &AllowListItemUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *AllowListItemClient) UpdateOne(ali *AllowListItem) *AllowListItemUpdateOne { + mutation := newAllowListItemMutation(c.config, OpUpdateOne, withAllowListItem(ali)) + return &AllowListItemUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *AllowListItemClient) UpdateOneID(id int) *AllowListItemUpdateOne { + mutation := newAllowListItemMutation(c.config, OpUpdateOne, withAllowListItemID(id)) + return &AllowListItemUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for AllowListItem. +func (c *AllowListItemClient) Delete() *AllowListItemDelete { + mutation := newAllowListItemMutation(c.config, OpDelete) + return &AllowListItemDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *AllowListItemClient) DeleteOne(ali *AllowListItem) *AllowListItemDeleteOne { + return c.DeleteOneID(ali.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *AllowListItemClient) DeleteOneID(id int) *AllowListItemDeleteOne { + builder := c.Delete().Where(allowlistitem.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &AllowListItemDeleteOne{builder} +} + +// Query returns a query builder for AllowListItem. +func (c *AllowListItemClient) Query() *AllowListItemQuery { + return &AllowListItemQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeAllowListItem}, + inters: c.Interceptors(), + } +} + +// Get returns a AllowListItem entity by its id. +func (c *AllowListItemClient) Get(ctx context.Context, id int) (*AllowListItem, error) { + return c.Query().Where(allowlistitem.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *AllowListItemClient) GetX(ctx context.Context, id int) *AllowListItem { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryAllowlist queries the allowlist edge of a AllowListItem. +func (c *AllowListItemClient) QueryAllowlist(ali *AllowListItem) *AllowListQuery { + query := (&AllowListClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := ali.ID + step := sqlgraph.NewStep( + sqlgraph.From(allowlistitem.Table, allowlistitem.FieldID, id), + sqlgraph.To(allowlist.Table, allowlist.FieldID), + sqlgraph.Edge(sqlgraph.M2M, true, allowlistitem.AllowlistTable, allowlistitem.AllowlistPrimaryKey...), + ) + fromV = sqlgraph.Neighbors(ali.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *AllowListItemClient) Hooks() []Hook { + return c.hooks.AllowListItem +} + +// Interceptors returns the client interceptors. +func (c *AllowListItemClient) Interceptors() []Interceptor { + return c.inters.AllowListItem +} + +func (c *AllowListItemClient) mutate(ctx context.Context, m *AllowListItemMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&AllowListItemCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&AllowListItemUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&AllowListItemUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&AllowListItemDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown AllowListItem mutation op: %q", m.Op()) + } +} + // BouncerClient is a client for the Bouncer schema. type BouncerClient struct { config @@ -1599,11 +1913,11 @@ func (c *MetricClient) mutate(ctx context.Context, m *MetricMutation) (Value, er // hooks and interceptors per client, for fast access. type ( hooks struct { - Alert, Bouncer, ConfigItem, Decision, Event, Lock, Machine, Meta, - Metric []ent.Hook + Alert, AllowList, AllowListItem, Bouncer, ConfigItem, Decision, Event, Lock, + Machine, Meta, Metric []ent.Hook } inters struct { - Alert, Bouncer, ConfigItem, Decision, Event, Lock, Machine, Meta, - Metric []ent.Interceptor + Alert, AllowList, AllowListItem, Bouncer, ConfigItem, Decision, Event, Lock, + Machine, Meta, Metric []ent.Interceptor } ) diff --git a/pkg/database/ent/ent.go b/pkg/database/ent/ent.go index 2a5ad188197..e38db54aa59 100644 --- a/pkg/database/ent/ent.go +++ b/pkg/database/ent/ent.go @@ -13,6 +13,8 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem" "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" @@ -81,15 +83,17 @@ var ( func checkColumn(table, column string) error { initCheck.Do(func() { columnCheck = sql.NewColumnCheck(map[string]func(string) bool{ - alert.Table: alert.ValidColumn, - bouncer.Table: bouncer.ValidColumn, - configitem.Table: configitem.ValidColumn, - decision.Table: decision.ValidColumn, - event.Table: event.ValidColumn, - lock.Table: lock.ValidColumn, - machine.Table: machine.ValidColumn, - meta.Table: meta.ValidColumn, - metric.Table: metric.ValidColumn, + alert.Table: alert.ValidColumn, + allowlist.Table: allowlist.ValidColumn, + allowlistitem.Table: allowlistitem.ValidColumn, + bouncer.Table: bouncer.ValidColumn, + configitem.Table: configitem.ValidColumn, + decision.Table: decision.ValidColumn, + event.Table: event.ValidColumn, + lock.Table: lock.ValidColumn, + machine.Table: machine.ValidColumn, + meta.Table: meta.ValidColumn, + metric.Table: metric.ValidColumn, }) }) return columnCheck(table, column) diff --git a/pkg/database/ent/hook/hook.go b/pkg/database/ent/hook/hook.go index 62cc07820d0..b5ddfc81290 100644 --- a/pkg/database/ent/hook/hook.go +++ b/pkg/database/ent/hook/hook.go @@ -21,6 +21,30 @@ func (f AlertFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AlertMutation", m) } +// The AllowListFunc type is an adapter to allow the use of ordinary +// function as AllowList mutator. +type AllowListFunc func(context.Context, *ent.AllowListMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f AllowListFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.AllowListMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AllowListMutation", m) +} + +// The AllowListItemFunc type is an adapter to allow the use of ordinary +// function as AllowListItem mutator. +type AllowListItemFunc func(context.Context, *ent.AllowListItemMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f AllowListItemFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.AllowListItemMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AllowListItemMutation", m) +} + // The BouncerFunc type is an adapter to allow the use of ordinary // function as Bouncer mutator. type BouncerFunc func(context.Context, *ent.BouncerMutation) (ent.Value, error) diff --git a/pkg/database/ent/migrate/schema.go b/pkg/database/ent/migrate/schema.go index 986f5bc8c67..932c27dd7a6 100644 --- a/pkg/database/ent/migrate/schema.go +++ b/pkg/database/ent/migrate/schema.go @@ -58,6 +58,66 @@ var ( }, }, } + // AllowListsColumns holds the columns for the "allow_lists" table. + AllowListsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "created_at", Type: field.TypeTime}, + {Name: "updated_at", Type: field.TypeTime}, + {Name: "name", Type: field.TypeString}, + {Name: "from_console", Type: field.TypeBool}, + {Name: "description", Type: field.TypeString, Nullable: true}, + {Name: "allowlist_id", Type: field.TypeString, Nullable: true}, + } + // AllowListsTable holds the schema information for the "allow_lists" table. + AllowListsTable = &schema.Table{ + Name: "allow_lists", + Columns: AllowListsColumns, + PrimaryKey: []*schema.Column{AllowListsColumns[0]}, + Indexes: []*schema.Index{ + { + Name: "allowlist_id", + Unique: true, + Columns: []*schema.Column{AllowListsColumns[0]}, + }, + { + Name: "allowlist_name", + Unique: true, + Columns: []*schema.Column{AllowListsColumns[3]}, + }, + }, + } + // AllowListItemsColumns holds the columns for the "allow_list_items" table. + AllowListItemsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "created_at", Type: field.TypeTime}, + {Name: "updated_at", Type: field.TypeTime}, + {Name: "expires_at", Type: field.TypeTime, Nullable: true}, + {Name: "comment", Type: field.TypeString, Nullable: true}, + {Name: "value", Type: field.TypeString}, + {Name: "start_ip", Type: field.TypeInt64, Nullable: true}, + {Name: "end_ip", Type: field.TypeInt64, Nullable: true}, + {Name: "start_suffix", Type: field.TypeInt64, Nullable: true}, + {Name: "end_suffix", Type: field.TypeInt64, Nullable: true}, + {Name: "ip_size", Type: field.TypeInt64, Nullable: true}, + } + // AllowListItemsTable holds the schema information for the "allow_list_items" table. + AllowListItemsTable = &schema.Table{ + Name: "allow_list_items", + Columns: AllowListItemsColumns, + PrimaryKey: []*schema.Column{AllowListItemsColumns[0]}, + Indexes: []*schema.Index{ + { + Name: "allowlistitem_id", + Unique: false, + Columns: []*schema.Column{AllowListItemsColumns[0]}, + }, + { + Name: "allowlistitem_start_ip_end_ip", + Unique: false, + Columns: []*schema.Column{AllowListItemsColumns[6], AllowListItemsColumns[7]}, + }, + }, + } // BouncersColumns holds the columns for the "bouncers" table. BouncersColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt, Increment: true}, @@ -74,6 +134,7 @@ var ( {Name: "osname", Type: field.TypeString, Nullable: true}, {Name: "osversion", Type: field.TypeString, Nullable: true}, {Name: "featureflags", Type: field.TypeString, Nullable: true}, + {Name: "auto_created", Type: field.TypeBool, Default: false}, } // BouncersTable holds the schema information for the "bouncers" table. BouncersTable = &schema.Table{ @@ -264,9 +325,36 @@ var ( Columns: MetricsColumns, PrimaryKey: []*schema.Column{MetricsColumns[0]}, } + // AllowListAllowlistItemsColumns holds the columns for the "allow_list_allowlist_items" table. + AllowListAllowlistItemsColumns = []*schema.Column{ + {Name: "allow_list_id", Type: field.TypeInt}, + {Name: "allow_list_item_id", Type: field.TypeInt}, + } + // AllowListAllowlistItemsTable holds the schema information for the "allow_list_allowlist_items" table. + AllowListAllowlistItemsTable = &schema.Table{ + Name: "allow_list_allowlist_items", + Columns: AllowListAllowlistItemsColumns, + PrimaryKey: []*schema.Column{AllowListAllowlistItemsColumns[0], AllowListAllowlistItemsColumns[1]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "allow_list_allowlist_items_allow_list_id", + Columns: []*schema.Column{AllowListAllowlistItemsColumns[0]}, + RefColumns: []*schema.Column{AllowListsColumns[0]}, + OnDelete: schema.Cascade, + }, + { + Symbol: "allow_list_allowlist_items_allow_list_item_id", + Columns: []*schema.Column{AllowListAllowlistItemsColumns[1]}, + RefColumns: []*schema.Column{AllowListItemsColumns[0]}, + OnDelete: schema.Cascade, + }, + }, + } // Tables holds all the tables in the schema. Tables = []*schema.Table{ AlertsTable, + AllowListsTable, + AllowListItemsTable, BouncersTable, ConfigItemsTable, DecisionsTable, @@ -275,6 +363,7 @@ var ( MachinesTable, MetaTable, MetricsTable, + AllowListAllowlistItemsTable, } ) @@ -283,4 +372,6 @@ func init() { DecisionsTable.ForeignKeys[0].RefTable = AlertsTable EventsTable.ForeignKeys[0].RefTable = AlertsTable MetaTable.ForeignKeys[0].RefTable = AlertsTable + AllowListAllowlistItemsTable.ForeignKeys[0].RefTable = AllowListsTable + AllowListAllowlistItemsTable.ForeignKeys[1].RefTable = AllowListItemsTable } diff --git a/pkg/database/ent/mutation.go b/pkg/database/ent/mutation.go index 5c6596f3db4..f45bd47a5fb 100644 --- a/pkg/database/ent/mutation.go +++ b/pkg/database/ent/mutation.go @@ -12,6 +12,8 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem" "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" @@ -33,15 +35,17 @@ const ( OpUpdateOne = ent.OpUpdateOne // Node types. - TypeAlert = "Alert" - TypeBouncer = "Bouncer" - TypeConfigItem = "ConfigItem" - TypeDecision = "Decision" - TypeEvent = "Event" - TypeLock = "Lock" - TypeMachine = "Machine" - TypeMeta = "Meta" - TypeMetric = "Metric" + TypeAlert = "Alert" + TypeAllowList = "AllowList" + TypeAllowListItem = "AllowListItem" + TypeBouncer = "Bouncer" + TypeConfigItem = "ConfigItem" + TypeDecision = "Decision" + TypeEvent = "Event" + TypeLock = "Lock" + TypeMachine = "Machine" + TypeMeta = "Meta" + TypeMetric = "Metric" ) // AlertMutation represents an operation that mutates the Alert nodes in the graph. @@ -2452,6 +2456,1950 @@ func (m *AlertMutation) ResetEdge(name string) error { return fmt.Errorf("unknown Alert edge %s", name) } +// AllowListMutation represents an operation that mutates the AllowList nodes in the graph. +type AllowListMutation struct { + config + op Op + typ string + id *int + created_at *time.Time + updated_at *time.Time + name *string + from_console *bool + description *string + allowlist_id *string + clearedFields map[string]struct{} + allowlist_items map[int]struct{} + removedallowlist_items map[int]struct{} + clearedallowlist_items bool + done bool + oldValue func(context.Context) (*AllowList, error) + predicates []predicate.AllowList +} + +var _ ent.Mutation = (*AllowListMutation)(nil) + +// allowlistOption allows management of the mutation configuration using functional options. +type allowlistOption func(*AllowListMutation) + +// newAllowListMutation creates new mutation for the AllowList entity. +func newAllowListMutation(c config, op Op, opts ...allowlistOption) *AllowListMutation { + m := &AllowListMutation{ + config: c, + op: op, + typ: TypeAllowList, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withAllowListID sets the ID field of the mutation. +func withAllowListID(id int) allowlistOption { + return func(m *AllowListMutation) { + var ( + err error + once sync.Once + value *AllowList + ) + m.oldValue = func(ctx context.Context) (*AllowList, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().AllowList.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withAllowList sets the old AllowList of the mutation. +func withAllowList(node *AllowList) allowlistOption { + return func(m *AllowListMutation) { + m.oldValue = func(context.Context) (*AllowList, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m AllowListMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m AllowListMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *AllowListMutation) ID() (id int, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *AllowListMutation) IDs(ctx context.Context) ([]int, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().AllowList.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *AllowListMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *AllowListMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the AllowList entity. +// If the AllowList object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AllowListMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *AllowListMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *AllowListMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *AllowListMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the AllowList entity. +// If the AllowList object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AllowListMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *AllowListMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetName sets the "name" field. +func (m *AllowListMutation) SetName(s string) { + m.name = &s +} + +// Name returns the value of the "name" field in the mutation. +func (m *AllowListMutation) Name() (r string, exists bool) { + v := m.name + if v == nil { + return + } + return *v, true +} + +// OldName returns the old "name" field's value of the AllowList entity. +// If the AllowList object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AllowListMutation) OldName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldName: %w", err) + } + return oldValue.Name, nil +} + +// ResetName resets all changes to the "name" field. +func (m *AllowListMutation) ResetName() { + m.name = nil +} + +// SetFromConsole sets the "from_console" field. +func (m *AllowListMutation) SetFromConsole(b bool) { + m.from_console = &b +} + +// FromConsole returns the value of the "from_console" field in the mutation. +func (m *AllowListMutation) FromConsole() (r bool, exists bool) { + v := m.from_console + if v == nil { + return + } + return *v, true +} + +// OldFromConsole returns the old "from_console" field's value of the AllowList entity. +// If the AllowList object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AllowListMutation) OldFromConsole(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFromConsole is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFromConsole requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFromConsole: %w", err) + } + return oldValue.FromConsole, nil +} + +// ResetFromConsole resets all changes to the "from_console" field. +func (m *AllowListMutation) ResetFromConsole() { + m.from_console = nil +} + +// SetDescription sets the "description" field. +func (m *AllowListMutation) SetDescription(s string) { + m.description = &s +} + +// Description returns the value of the "description" field in the mutation. +func (m *AllowListMutation) Description() (r string, exists bool) { + v := m.description + if v == nil { + return + } + return *v, true +} + +// OldDescription returns the old "description" field's value of the AllowList entity. +// If the AllowList object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AllowListMutation) OldDescription(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDescription is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDescription requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDescription: %w", err) + } + return oldValue.Description, nil +} + +// ClearDescription clears the value of the "description" field. +func (m *AllowListMutation) ClearDescription() { + m.description = nil + m.clearedFields[allowlist.FieldDescription] = struct{}{} +} + +// DescriptionCleared returns if the "description" field was cleared in this mutation. +func (m *AllowListMutation) DescriptionCleared() bool { + _, ok := m.clearedFields[allowlist.FieldDescription] + return ok +} + +// ResetDescription resets all changes to the "description" field. +func (m *AllowListMutation) ResetDescription() { + m.description = nil + delete(m.clearedFields, allowlist.FieldDescription) +} + +// SetAllowlistID sets the "allowlist_id" field. +func (m *AllowListMutation) SetAllowlistID(s string) { + m.allowlist_id = &s +} + +// AllowlistID returns the value of the "allowlist_id" field in the mutation. +func (m *AllowListMutation) AllowlistID() (r string, exists bool) { + v := m.allowlist_id + if v == nil { + return + } + return *v, true +} + +// OldAllowlistID returns the old "allowlist_id" field's value of the AllowList entity. +// If the AllowList object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AllowListMutation) OldAllowlistID(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAllowlistID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAllowlistID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAllowlistID: %w", err) + } + return oldValue.AllowlistID, nil +} + +// ClearAllowlistID clears the value of the "allowlist_id" field. +func (m *AllowListMutation) ClearAllowlistID() { + m.allowlist_id = nil + m.clearedFields[allowlist.FieldAllowlistID] = struct{}{} +} + +// AllowlistIDCleared returns if the "allowlist_id" field was cleared in this mutation. +func (m *AllowListMutation) AllowlistIDCleared() bool { + _, ok := m.clearedFields[allowlist.FieldAllowlistID] + return ok +} + +// ResetAllowlistID resets all changes to the "allowlist_id" field. +func (m *AllowListMutation) ResetAllowlistID() { + m.allowlist_id = nil + delete(m.clearedFields, allowlist.FieldAllowlistID) +} + +// AddAllowlistItemIDs adds the "allowlist_items" edge to the AllowListItem entity by ids. +func (m *AllowListMutation) AddAllowlistItemIDs(ids ...int) { + if m.allowlist_items == nil { + m.allowlist_items = make(map[int]struct{}) + } + for i := range ids { + m.allowlist_items[ids[i]] = struct{}{} + } +} + +// ClearAllowlistItems clears the "allowlist_items" edge to the AllowListItem entity. +func (m *AllowListMutation) ClearAllowlistItems() { + m.clearedallowlist_items = true +} + +// AllowlistItemsCleared reports if the "allowlist_items" edge to the AllowListItem entity was cleared. +func (m *AllowListMutation) AllowlistItemsCleared() bool { + return m.clearedallowlist_items +} + +// RemoveAllowlistItemIDs removes the "allowlist_items" edge to the AllowListItem entity by IDs. +func (m *AllowListMutation) RemoveAllowlistItemIDs(ids ...int) { + if m.removedallowlist_items == nil { + m.removedallowlist_items = make(map[int]struct{}) + } + for i := range ids { + delete(m.allowlist_items, ids[i]) + m.removedallowlist_items[ids[i]] = struct{}{} + } +} + +// RemovedAllowlistItems returns the removed IDs of the "allowlist_items" edge to the AllowListItem entity. +func (m *AllowListMutation) RemovedAllowlistItemsIDs() (ids []int) { + for id := range m.removedallowlist_items { + ids = append(ids, id) + } + return +} + +// AllowlistItemsIDs returns the "allowlist_items" edge IDs in the mutation. +func (m *AllowListMutation) AllowlistItemsIDs() (ids []int) { + for id := range m.allowlist_items { + ids = append(ids, id) + } + return +} + +// ResetAllowlistItems resets all changes to the "allowlist_items" edge. +func (m *AllowListMutation) ResetAllowlistItems() { + m.allowlist_items = nil + m.clearedallowlist_items = false + m.removedallowlist_items = nil +} + +// Where appends a list predicates to the AllowListMutation builder. +func (m *AllowListMutation) Where(ps ...predicate.AllowList) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the AllowListMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *AllowListMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.AllowList, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *AllowListMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *AllowListMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (AllowList). +func (m *AllowListMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *AllowListMutation) Fields() []string { + fields := make([]string, 0, 6) + if m.created_at != nil { + fields = append(fields, allowlist.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, allowlist.FieldUpdatedAt) + } + if m.name != nil { + fields = append(fields, allowlist.FieldName) + } + if m.from_console != nil { + fields = append(fields, allowlist.FieldFromConsole) + } + if m.description != nil { + fields = append(fields, allowlist.FieldDescription) + } + if m.allowlist_id != nil { + fields = append(fields, allowlist.FieldAllowlistID) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *AllowListMutation) Field(name string) (ent.Value, bool) { + switch name { + case allowlist.FieldCreatedAt: + return m.CreatedAt() + case allowlist.FieldUpdatedAt: + return m.UpdatedAt() + case allowlist.FieldName: + return m.Name() + case allowlist.FieldFromConsole: + return m.FromConsole() + case allowlist.FieldDescription: + return m.Description() + case allowlist.FieldAllowlistID: + return m.AllowlistID() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *AllowListMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case allowlist.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case allowlist.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case allowlist.FieldName: + return m.OldName(ctx) + case allowlist.FieldFromConsole: + return m.OldFromConsole(ctx) + case allowlist.FieldDescription: + return m.OldDescription(ctx) + case allowlist.FieldAllowlistID: + return m.OldAllowlistID(ctx) + } + return nil, fmt.Errorf("unknown AllowList field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *AllowListMutation) SetField(name string, value ent.Value) error { + switch name { + case allowlist.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case allowlist.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case allowlist.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) + return nil + case allowlist.FieldFromConsole: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFromConsole(v) + return nil + case allowlist.FieldDescription: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDescription(v) + return nil + case allowlist.FieldAllowlistID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAllowlistID(v) + return nil + } + return fmt.Errorf("unknown AllowList field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *AllowListMutation) AddedFields() []string { + return nil +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *AllowListMutation) AddedField(name string) (ent.Value, bool) { + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *AllowListMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown AllowList numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *AllowListMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(allowlist.FieldDescription) { + fields = append(fields, allowlist.FieldDescription) + } + if m.FieldCleared(allowlist.FieldAllowlistID) { + fields = append(fields, allowlist.FieldAllowlistID) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *AllowListMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *AllowListMutation) ClearField(name string) error { + switch name { + case allowlist.FieldDescription: + m.ClearDescription() + return nil + case allowlist.FieldAllowlistID: + m.ClearAllowlistID() + return nil + } + return fmt.Errorf("unknown AllowList nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *AllowListMutation) ResetField(name string) error { + switch name { + case allowlist.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case allowlist.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case allowlist.FieldName: + m.ResetName() + return nil + case allowlist.FieldFromConsole: + m.ResetFromConsole() + return nil + case allowlist.FieldDescription: + m.ResetDescription() + return nil + case allowlist.FieldAllowlistID: + m.ResetAllowlistID() + return nil + } + return fmt.Errorf("unknown AllowList field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *AllowListMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.allowlist_items != nil { + edges = append(edges, allowlist.EdgeAllowlistItems) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *AllowListMutation) AddedIDs(name string) []ent.Value { + switch name { + case allowlist.EdgeAllowlistItems: + ids := make([]ent.Value, 0, len(m.allowlist_items)) + for id := range m.allowlist_items { + ids = append(ids, id) + } + return ids + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *AllowListMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) + if m.removedallowlist_items != nil { + edges = append(edges, allowlist.EdgeAllowlistItems) + } + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *AllowListMutation) RemovedIDs(name string) []ent.Value { + switch name { + case allowlist.EdgeAllowlistItems: + ids := make([]ent.Value, 0, len(m.removedallowlist_items)) + for id := range m.removedallowlist_items { + ids = append(ids, id) + } + return ids + } + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *AllowListMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.clearedallowlist_items { + edges = append(edges, allowlist.EdgeAllowlistItems) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *AllowListMutation) EdgeCleared(name string) bool { + switch name { + case allowlist.EdgeAllowlistItems: + return m.clearedallowlist_items + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *AllowListMutation) ClearEdge(name string) error { + switch name { + } + return fmt.Errorf("unknown AllowList unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *AllowListMutation) ResetEdge(name string) error { + switch name { + case allowlist.EdgeAllowlistItems: + m.ResetAllowlistItems() + return nil + } + return fmt.Errorf("unknown AllowList edge %s", name) +} + +// AllowListItemMutation represents an operation that mutates the AllowListItem nodes in the graph. +type AllowListItemMutation struct { + config + op Op + typ string + id *int + created_at *time.Time + updated_at *time.Time + expires_at *time.Time + comment *string + value *string + start_ip *int64 + addstart_ip *int64 + end_ip *int64 + addend_ip *int64 + start_suffix *int64 + addstart_suffix *int64 + end_suffix *int64 + addend_suffix *int64 + ip_size *int64 + addip_size *int64 + clearedFields map[string]struct{} + allowlist map[int]struct{} + removedallowlist map[int]struct{} + clearedallowlist bool + done bool + oldValue func(context.Context) (*AllowListItem, error) + predicates []predicate.AllowListItem +} + +var _ ent.Mutation = (*AllowListItemMutation)(nil) + +// allowlistitemOption allows management of the mutation configuration using functional options. +type allowlistitemOption func(*AllowListItemMutation) + +// newAllowListItemMutation creates new mutation for the AllowListItem entity. +func newAllowListItemMutation(c config, op Op, opts ...allowlistitemOption) *AllowListItemMutation { + m := &AllowListItemMutation{ + config: c, + op: op, + typ: TypeAllowListItem, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withAllowListItemID sets the ID field of the mutation. +func withAllowListItemID(id int) allowlistitemOption { + return func(m *AllowListItemMutation) { + var ( + err error + once sync.Once + value *AllowListItem + ) + m.oldValue = func(ctx context.Context) (*AllowListItem, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().AllowListItem.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withAllowListItem sets the old AllowListItem of the mutation. +func withAllowListItem(node *AllowListItem) allowlistitemOption { + return func(m *AllowListItemMutation) { + m.oldValue = func(context.Context) (*AllowListItem, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m AllowListItemMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m AllowListItemMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *AllowListItemMutation) ID() (id int, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *AllowListItemMutation) IDs(ctx context.Context) ([]int, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().AllowListItem.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *AllowListItemMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *AllowListItemMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the AllowListItem entity. +// If the AllowListItem object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AllowListItemMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *AllowListItemMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *AllowListItemMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *AllowListItemMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the AllowListItem entity. +// If the AllowListItem object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AllowListItemMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *AllowListItemMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetExpiresAt sets the "expires_at" field. +func (m *AllowListItemMutation) SetExpiresAt(t time.Time) { + m.expires_at = &t +} + +// ExpiresAt returns the value of the "expires_at" field in the mutation. +func (m *AllowListItemMutation) ExpiresAt() (r time.Time, exists bool) { + v := m.expires_at + if v == nil { + return + } + return *v, true +} + +// OldExpiresAt returns the old "expires_at" field's value of the AllowListItem entity. +// If the AllowListItem object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AllowListItemMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldExpiresAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) + } + return oldValue.ExpiresAt, nil +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (m *AllowListItemMutation) ClearExpiresAt() { + m.expires_at = nil + m.clearedFields[allowlistitem.FieldExpiresAt] = struct{}{} +} + +// ExpiresAtCleared returns if the "expires_at" field was cleared in this mutation. +func (m *AllowListItemMutation) ExpiresAtCleared() bool { + _, ok := m.clearedFields[allowlistitem.FieldExpiresAt] + return ok +} + +// ResetExpiresAt resets all changes to the "expires_at" field. +func (m *AllowListItemMutation) ResetExpiresAt() { + m.expires_at = nil + delete(m.clearedFields, allowlistitem.FieldExpiresAt) +} + +// SetComment sets the "comment" field. +func (m *AllowListItemMutation) SetComment(s string) { + m.comment = &s +} + +// Comment returns the value of the "comment" field in the mutation. +func (m *AllowListItemMutation) Comment() (r string, exists bool) { + v := m.comment + if v == nil { + return + } + return *v, true +} + +// OldComment returns the old "comment" field's value of the AllowListItem entity. +// If the AllowListItem object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AllowListItemMutation) OldComment(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldComment is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldComment requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldComment: %w", err) + } + return oldValue.Comment, nil +} + +// ClearComment clears the value of the "comment" field. +func (m *AllowListItemMutation) ClearComment() { + m.comment = nil + m.clearedFields[allowlistitem.FieldComment] = struct{}{} +} + +// CommentCleared returns if the "comment" field was cleared in this mutation. +func (m *AllowListItemMutation) CommentCleared() bool { + _, ok := m.clearedFields[allowlistitem.FieldComment] + return ok +} + +// ResetComment resets all changes to the "comment" field. +func (m *AllowListItemMutation) ResetComment() { + m.comment = nil + delete(m.clearedFields, allowlistitem.FieldComment) +} + +// SetValue sets the "value" field. +func (m *AllowListItemMutation) SetValue(s string) { + m.value = &s +} + +// Value returns the value of the "value" field in the mutation. +func (m *AllowListItemMutation) Value() (r string, exists bool) { + v := m.value + if v == nil { + return + } + return *v, true +} + +// OldValue returns the old "value" field's value of the AllowListItem entity. +// If the AllowListItem object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AllowListItemMutation) OldValue(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldValue is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldValue requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldValue: %w", err) + } + return oldValue.Value, nil +} + +// ResetValue resets all changes to the "value" field. +func (m *AllowListItemMutation) ResetValue() { + m.value = nil +} + +// SetStartIP sets the "start_ip" field. +func (m *AllowListItemMutation) SetStartIP(i int64) { + m.start_ip = &i + m.addstart_ip = nil +} + +// StartIP returns the value of the "start_ip" field in the mutation. +func (m *AllowListItemMutation) StartIP() (r int64, exists bool) { + v := m.start_ip + if v == nil { + return + } + return *v, true +} + +// OldStartIP returns the old "start_ip" field's value of the AllowListItem entity. +// If the AllowListItem object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AllowListItemMutation) OldStartIP(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStartIP is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStartIP requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStartIP: %w", err) + } + return oldValue.StartIP, nil +} + +// AddStartIP adds i to the "start_ip" field. +func (m *AllowListItemMutation) AddStartIP(i int64) { + if m.addstart_ip != nil { + *m.addstart_ip += i + } else { + m.addstart_ip = &i + } +} + +// AddedStartIP returns the value that was added to the "start_ip" field in this mutation. +func (m *AllowListItemMutation) AddedStartIP() (r int64, exists bool) { + v := m.addstart_ip + if v == nil { + return + } + return *v, true +} + +// ClearStartIP clears the value of the "start_ip" field. +func (m *AllowListItemMutation) ClearStartIP() { + m.start_ip = nil + m.addstart_ip = nil + m.clearedFields[allowlistitem.FieldStartIP] = struct{}{} +} + +// StartIPCleared returns if the "start_ip" field was cleared in this mutation. +func (m *AllowListItemMutation) StartIPCleared() bool { + _, ok := m.clearedFields[allowlistitem.FieldStartIP] + return ok +} + +// ResetStartIP resets all changes to the "start_ip" field. +func (m *AllowListItemMutation) ResetStartIP() { + m.start_ip = nil + m.addstart_ip = nil + delete(m.clearedFields, allowlistitem.FieldStartIP) +} + +// SetEndIP sets the "end_ip" field. +func (m *AllowListItemMutation) SetEndIP(i int64) { + m.end_ip = &i + m.addend_ip = nil +} + +// EndIP returns the value of the "end_ip" field in the mutation. +func (m *AllowListItemMutation) EndIP() (r int64, exists bool) { + v := m.end_ip + if v == nil { + return + } + return *v, true +} + +// OldEndIP returns the old "end_ip" field's value of the AllowListItem entity. +// If the AllowListItem object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AllowListItemMutation) OldEndIP(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldEndIP is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldEndIP requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldEndIP: %w", err) + } + return oldValue.EndIP, nil +} + +// AddEndIP adds i to the "end_ip" field. +func (m *AllowListItemMutation) AddEndIP(i int64) { + if m.addend_ip != nil { + *m.addend_ip += i + } else { + m.addend_ip = &i + } +} + +// AddedEndIP returns the value that was added to the "end_ip" field in this mutation. +func (m *AllowListItemMutation) AddedEndIP() (r int64, exists bool) { + v := m.addend_ip + if v == nil { + return + } + return *v, true +} + +// ClearEndIP clears the value of the "end_ip" field. +func (m *AllowListItemMutation) ClearEndIP() { + m.end_ip = nil + m.addend_ip = nil + m.clearedFields[allowlistitem.FieldEndIP] = struct{}{} +} + +// EndIPCleared returns if the "end_ip" field was cleared in this mutation. +func (m *AllowListItemMutation) EndIPCleared() bool { + _, ok := m.clearedFields[allowlistitem.FieldEndIP] + return ok +} + +// ResetEndIP resets all changes to the "end_ip" field. +func (m *AllowListItemMutation) ResetEndIP() { + m.end_ip = nil + m.addend_ip = nil + delete(m.clearedFields, allowlistitem.FieldEndIP) +} + +// SetStartSuffix sets the "start_suffix" field. +func (m *AllowListItemMutation) SetStartSuffix(i int64) { + m.start_suffix = &i + m.addstart_suffix = nil +} + +// StartSuffix returns the value of the "start_suffix" field in the mutation. +func (m *AllowListItemMutation) StartSuffix() (r int64, exists bool) { + v := m.start_suffix + if v == nil { + return + } + return *v, true +} + +// OldStartSuffix returns the old "start_suffix" field's value of the AllowListItem entity. +// If the AllowListItem object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AllowListItemMutation) OldStartSuffix(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStartSuffix is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStartSuffix requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStartSuffix: %w", err) + } + return oldValue.StartSuffix, nil +} + +// AddStartSuffix adds i to the "start_suffix" field. +func (m *AllowListItemMutation) AddStartSuffix(i int64) { + if m.addstart_suffix != nil { + *m.addstart_suffix += i + } else { + m.addstart_suffix = &i + } +} + +// AddedStartSuffix returns the value that was added to the "start_suffix" field in this mutation. +func (m *AllowListItemMutation) AddedStartSuffix() (r int64, exists bool) { + v := m.addstart_suffix + if v == nil { + return + } + return *v, true +} + +// ClearStartSuffix clears the value of the "start_suffix" field. +func (m *AllowListItemMutation) ClearStartSuffix() { + m.start_suffix = nil + m.addstart_suffix = nil + m.clearedFields[allowlistitem.FieldStartSuffix] = struct{}{} +} + +// StartSuffixCleared returns if the "start_suffix" field was cleared in this mutation. +func (m *AllowListItemMutation) StartSuffixCleared() bool { + _, ok := m.clearedFields[allowlistitem.FieldStartSuffix] + return ok +} + +// ResetStartSuffix resets all changes to the "start_suffix" field. +func (m *AllowListItemMutation) ResetStartSuffix() { + m.start_suffix = nil + m.addstart_suffix = nil + delete(m.clearedFields, allowlistitem.FieldStartSuffix) +} + +// SetEndSuffix sets the "end_suffix" field. +func (m *AllowListItemMutation) SetEndSuffix(i int64) { + m.end_suffix = &i + m.addend_suffix = nil +} + +// EndSuffix returns the value of the "end_suffix" field in the mutation. +func (m *AllowListItemMutation) EndSuffix() (r int64, exists bool) { + v := m.end_suffix + if v == nil { + return + } + return *v, true +} + +// OldEndSuffix returns the old "end_suffix" field's value of the AllowListItem entity. +// If the AllowListItem object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AllowListItemMutation) OldEndSuffix(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldEndSuffix is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldEndSuffix requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldEndSuffix: %w", err) + } + return oldValue.EndSuffix, nil +} + +// AddEndSuffix adds i to the "end_suffix" field. +func (m *AllowListItemMutation) AddEndSuffix(i int64) { + if m.addend_suffix != nil { + *m.addend_suffix += i + } else { + m.addend_suffix = &i + } +} + +// AddedEndSuffix returns the value that was added to the "end_suffix" field in this mutation. +func (m *AllowListItemMutation) AddedEndSuffix() (r int64, exists bool) { + v := m.addend_suffix + if v == nil { + return + } + return *v, true +} + +// ClearEndSuffix clears the value of the "end_suffix" field. +func (m *AllowListItemMutation) ClearEndSuffix() { + m.end_suffix = nil + m.addend_suffix = nil + m.clearedFields[allowlistitem.FieldEndSuffix] = struct{}{} +} + +// EndSuffixCleared returns if the "end_suffix" field was cleared in this mutation. +func (m *AllowListItemMutation) EndSuffixCleared() bool { + _, ok := m.clearedFields[allowlistitem.FieldEndSuffix] + return ok +} + +// ResetEndSuffix resets all changes to the "end_suffix" field. +func (m *AllowListItemMutation) ResetEndSuffix() { + m.end_suffix = nil + m.addend_suffix = nil + delete(m.clearedFields, allowlistitem.FieldEndSuffix) +} + +// SetIPSize sets the "ip_size" field. +func (m *AllowListItemMutation) SetIPSize(i int64) { + m.ip_size = &i + m.addip_size = nil +} + +// IPSize returns the value of the "ip_size" field in the mutation. +func (m *AllowListItemMutation) IPSize() (r int64, exists bool) { + v := m.ip_size + if v == nil { + return + } + return *v, true +} + +// OldIPSize returns the old "ip_size" field's value of the AllowListItem entity. +// If the AllowListItem object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AllowListItemMutation) OldIPSize(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIPSize is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIPSize requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIPSize: %w", err) + } + return oldValue.IPSize, nil +} + +// AddIPSize adds i to the "ip_size" field. +func (m *AllowListItemMutation) AddIPSize(i int64) { + if m.addip_size != nil { + *m.addip_size += i + } else { + m.addip_size = &i + } +} + +// AddedIPSize returns the value that was added to the "ip_size" field in this mutation. +func (m *AllowListItemMutation) AddedIPSize() (r int64, exists bool) { + v := m.addip_size + if v == nil { + return + } + return *v, true +} + +// ClearIPSize clears the value of the "ip_size" field. +func (m *AllowListItemMutation) ClearIPSize() { + m.ip_size = nil + m.addip_size = nil + m.clearedFields[allowlistitem.FieldIPSize] = struct{}{} +} + +// IPSizeCleared returns if the "ip_size" field was cleared in this mutation. +func (m *AllowListItemMutation) IPSizeCleared() bool { + _, ok := m.clearedFields[allowlistitem.FieldIPSize] + return ok +} + +// ResetIPSize resets all changes to the "ip_size" field. +func (m *AllowListItemMutation) ResetIPSize() { + m.ip_size = nil + m.addip_size = nil + delete(m.clearedFields, allowlistitem.FieldIPSize) +} + +// AddAllowlistIDs adds the "allowlist" edge to the AllowList entity by ids. +func (m *AllowListItemMutation) AddAllowlistIDs(ids ...int) { + if m.allowlist == nil { + m.allowlist = make(map[int]struct{}) + } + for i := range ids { + m.allowlist[ids[i]] = struct{}{} + } +} + +// ClearAllowlist clears the "allowlist" edge to the AllowList entity. +func (m *AllowListItemMutation) ClearAllowlist() { + m.clearedallowlist = true +} + +// AllowlistCleared reports if the "allowlist" edge to the AllowList entity was cleared. +func (m *AllowListItemMutation) AllowlistCleared() bool { + return m.clearedallowlist +} + +// RemoveAllowlistIDs removes the "allowlist" edge to the AllowList entity by IDs. +func (m *AllowListItemMutation) RemoveAllowlistIDs(ids ...int) { + if m.removedallowlist == nil { + m.removedallowlist = make(map[int]struct{}) + } + for i := range ids { + delete(m.allowlist, ids[i]) + m.removedallowlist[ids[i]] = struct{}{} + } +} + +// RemovedAllowlist returns the removed IDs of the "allowlist" edge to the AllowList entity. +func (m *AllowListItemMutation) RemovedAllowlistIDs() (ids []int) { + for id := range m.removedallowlist { + ids = append(ids, id) + } + return +} + +// AllowlistIDs returns the "allowlist" edge IDs in the mutation. +func (m *AllowListItemMutation) AllowlistIDs() (ids []int) { + for id := range m.allowlist { + ids = append(ids, id) + } + return +} + +// ResetAllowlist resets all changes to the "allowlist" edge. +func (m *AllowListItemMutation) ResetAllowlist() { + m.allowlist = nil + m.clearedallowlist = false + m.removedallowlist = nil +} + +// Where appends a list predicates to the AllowListItemMutation builder. +func (m *AllowListItemMutation) Where(ps ...predicate.AllowListItem) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the AllowListItemMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *AllowListItemMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.AllowListItem, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *AllowListItemMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *AllowListItemMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (AllowListItem). +func (m *AllowListItemMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *AllowListItemMutation) Fields() []string { + fields := make([]string, 0, 10) + if m.created_at != nil { + fields = append(fields, allowlistitem.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, allowlistitem.FieldUpdatedAt) + } + if m.expires_at != nil { + fields = append(fields, allowlistitem.FieldExpiresAt) + } + if m.comment != nil { + fields = append(fields, allowlistitem.FieldComment) + } + if m.value != nil { + fields = append(fields, allowlistitem.FieldValue) + } + if m.start_ip != nil { + fields = append(fields, allowlistitem.FieldStartIP) + } + if m.end_ip != nil { + fields = append(fields, allowlistitem.FieldEndIP) + } + if m.start_suffix != nil { + fields = append(fields, allowlistitem.FieldStartSuffix) + } + if m.end_suffix != nil { + fields = append(fields, allowlistitem.FieldEndSuffix) + } + if m.ip_size != nil { + fields = append(fields, allowlistitem.FieldIPSize) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *AllowListItemMutation) Field(name string) (ent.Value, bool) { + switch name { + case allowlistitem.FieldCreatedAt: + return m.CreatedAt() + case allowlistitem.FieldUpdatedAt: + return m.UpdatedAt() + case allowlistitem.FieldExpiresAt: + return m.ExpiresAt() + case allowlistitem.FieldComment: + return m.Comment() + case allowlistitem.FieldValue: + return m.Value() + case allowlistitem.FieldStartIP: + return m.StartIP() + case allowlistitem.FieldEndIP: + return m.EndIP() + case allowlistitem.FieldStartSuffix: + return m.StartSuffix() + case allowlistitem.FieldEndSuffix: + return m.EndSuffix() + case allowlistitem.FieldIPSize: + return m.IPSize() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *AllowListItemMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case allowlistitem.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case allowlistitem.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case allowlistitem.FieldExpiresAt: + return m.OldExpiresAt(ctx) + case allowlistitem.FieldComment: + return m.OldComment(ctx) + case allowlistitem.FieldValue: + return m.OldValue(ctx) + case allowlistitem.FieldStartIP: + return m.OldStartIP(ctx) + case allowlistitem.FieldEndIP: + return m.OldEndIP(ctx) + case allowlistitem.FieldStartSuffix: + return m.OldStartSuffix(ctx) + case allowlistitem.FieldEndSuffix: + return m.OldEndSuffix(ctx) + case allowlistitem.FieldIPSize: + return m.OldIPSize(ctx) + } + return nil, fmt.Errorf("unknown AllowListItem field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *AllowListItemMutation) SetField(name string, value ent.Value) error { + switch name { + case allowlistitem.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case allowlistitem.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case allowlistitem.FieldExpiresAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExpiresAt(v) + return nil + case allowlistitem.FieldComment: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetComment(v) + return nil + case allowlistitem.FieldValue: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetValue(v) + return nil + case allowlistitem.FieldStartIP: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStartIP(v) + return nil + case allowlistitem.FieldEndIP: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetEndIP(v) + return nil + case allowlistitem.FieldStartSuffix: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStartSuffix(v) + return nil + case allowlistitem.FieldEndSuffix: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetEndSuffix(v) + return nil + case allowlistitem.FieldIPSize: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIPSize(v) + return nil + } + return fmt.Errorf("unknown AllowListItem field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *AllowListItemMutation) AddedFields() []string { + var fields []string + if m.addstart_ip != nil { + fields = append(fields, allowlistitem.FieldStartIP) + } + if m.addend_ip != nil { + fields = append(fields, allowlistitem.FieldEndIP) + } + if m.addstart_suffix != nil { + fields = append(fields, allowlistitem.FieldStartSuffix) + } + if m.addend_suffix != nil { + fields = append(fields, allowlistitem.FieldEndSuffix) + } + if m.addip_size != nil { + fields = append(fields, allowlistitem.FieldIPSize) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *AllowListItemMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case allowlistitem.FieldStartIP: + return m.AddedStartIP() + case allowlistitem.FieldEndIP: + return m.AddedEndIP() + case allowlistitem.FieldStartSuffix: + return m.AddedStartSuffix() + case allowlistitem.FieldEndSuffix: + return m.AddedEndSuffix() + case allowlistitem.FieldIPSize: + return m.AddedIPSize() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *AllowListItemMutation) AddField(name string, value ent.Value) error { + switch name { + case allowlistitem.FieldStartIP: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddStartIP(v) + return nil + case allowlistitem.FieldEndIP: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddEndIP(v) + return nil + case allowlistitem.FieldStartSuffix: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddStartSuffix(v) + return nil + case allowlistitem.FieldEndSuffix: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddEndSuffix(v) + return nil + case allowlistitem.FieldIPSize: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddIPSize(v) + return nil + } + return fmt.Errorf("unknown AllowListItem numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *AllowListItemMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(allowlistitem.FieldExpiresAt) { + fields = append(fields, allowlistitem.FieldExpiresAt) + } + if m.FieldCleared(allowlistitem.FieldComment) { + fields = append(fields, allowlistitem.FieldComment) + } + if m.FieldCleared(allowlistitem.FieldStartIP) { + fields = append(fields, allowlistitem.FieldStartIP) + } + if m.FieldCleared(allowlistitem.FieldEndIP) { + fields = append(fields, allowlistitem.FieldEndIP) + } + if m.FieldCleared(allowlistitem.FieldStartSuffix) { + fields = append(fields, allowlistitem.FieldStartSuffix) + } + if m.FieldCleared(allowlistitem.FieldEndSuffix) { + fields = append(fields, allowlistitem.FieldEndSuffix) + } + if m.FieldCleared(allowlistitem.FieldIPSize) { + fields = append(fields, allowlistitem.FieldIPSize) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *AllowListItemMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *AllowListItemMutation) ClearField(name string) error { + switch name { + case allowlistitem.FieldExpiresAt: + m.ClearExpiresAt() + return nil + case allowlistitem.FieldComment: + m.ClearComment() + return nil + case allowlistitem.FieldStartIP: + m.ClearStartIP() + return nil + case allowlistitem.FieldEndIP: + m.ClearEndIP() + return nil + case allowlistitem.FieldStartSuffix: + m.ClearStartSuffix() + return nil + case allowlistitem.FieldEndSuffix: + m.ClearEndSuffix() + return nil + case allowlistitem.FieldIPSize: + m.ClearIPSize() + return nil + } + return fmt.Errorf("unknown AllowListItem nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *AllowListItemMutation) ResetField(name string) error { + switch name { + case allowlistitem.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case allowlistitem.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case allowlistitem.FieldExpiresAt: + m.ResetExpiresAt() + return nil + case allowlistitem.FieldComment: + m.ResetComment() + return nil + case allowlistitem.FieldValue: + m.ResetValue() + return nil + case allowlistitem.FieldStartIP: + m.ResetStartIP() + return nil + case allowlistitem.FieldEndIP: + m.ResetEndIP() + return nil + case allowlistitem.FieldStartSuffix: + m.ResetStartSuffix() + return nil + case allowlistitem.FieldEndSuffix: + m.ResetEndSuffix() + return nil + case allowlistitem.FieldIPSize: + m.ResetIPSize() + return nil + } + return fmt.Errorf("unknown AllowListItem field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *AllowListItemMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.allowlist != nil { + edges = append(edges, allowlistitem.EdgeAllowlist) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *AllowListItemMutation) AddedIDs(name string) []ent.Value { + switch name { + case allowlistitem.EdgeAllowlist: + ids := make([]ent.Value, 0, len(m.allowlist)) + for id := range m.allowlist { + ids = append(ids, id) + } + return ids + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *AllowListItemMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) + if m.removedallowlist != nil { + edges = append(edges, allowlistitem.EdgeAllowlist) + } + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *AllowListItemMutation) RemovedIDs(name string) []ent.Value { + switch name { + case allowlistitem.EdgeAllowlist: + ids := make([]ent.Value, 0, len(m.removedallowlist)) + for id := range m.removedallowlist { + ids = append(ids, id) + } + return ids + } + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *AllowListItemMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.clearedallowlist { + edges = append(edges, allowlistitem.EdgeAllowlist) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *AllowListItemMutation) EdgeCleared(name string) bool { + switch name { + case allowlistitem.EdgeAllowlist: + return m.clearedallowlist + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *AllowListItemMutation) ClearEdge(name string) error { + switch name { + } + return fmt.Errorf("unknown AllowListItem unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *AllowListItemMutation) ResetEdge(name string) error { + switch name { + case allowlistitem.EdgeAllowlist: + m.ResetAllowlist() + return nil + } + return fmt.Errorf("unknown AllowListItem edge %s", name) +} + // BouncerMutation represents an operation that mutates the Bouncer nodes in the graph. type BouncerMutation struct { config @@ -2471,6 +4419,7 @@ type BouncerMutation struct { osname *string osversion *string featureflags *string + auto_created *bool clearedFields map[string]struct{} done bool oldValue func(context.Context) (*Bouncer, error) @@ -3134,6 +5083,42 @@ func (m *BouncerMutation) ResetFeatureflags() { delete(m.clearedFields, bouncer.FieldFeatureflags) } +// SetAutoCreated sets the "auto_created" field. +func (m *BouncerMutation) SetAutoCreated(b bool) { + m.auto_created = &b +} + +// AutoCreated returns the value of the "auto_created" field in the mutation. +func (m *BouncerMutation) AutoCreated() (r bool, exists bool) { + v := m.auto_created + if v == nil { + return + } + return *v, true +} + +// OldAutoCreated returns the old "auto_created" field's value of the Bouncer entity. +// If the Bouncer object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *BouncerMutation) OldAutoCreated(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAutoCreated is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAutoCreated requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAutoCreated: %w", err) + } + return oldValue.AutoCreated, nil +} + +// ResetAutoCreated resets all changes to the "auto_created" field. +func (m *BouncerMutation) ResetAutoCreated() { + m.auto_created = nil +} + // Where appends a list predicates to the BouncerMutation builder. func (m *BouncerMutation) Where(ps ...predicate.Bouncer) { m.predicates = append(m.predicates, ps...) @@ -3168,7 +5153,7 @@ func (m *BouncerMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *BouncerMutation) Fields() []string { - fields := make([]string, 0, 13) + fields := make([]string, 0, 14) if m.created_at != nil { fields = append(fields, bouncer.FieldCreatedAt) } @@ -3208,6 +5193,9 @@ func (m *BouncerMutation) Fields() []string { if m.featureflags != nil { fields = append(fields, bouncer.FieldFeatureflags) } + if m.auto_created != nil { + fields = append(fields, bouncer.FieldAutoCreated) + } return fields } @@ -3242,6 +5230,8 @@ func (m *BouncerMutation) Field(name string) (ent.Value, bool) { return m.Osversion() case bouncer.FieldFeatureflags: return m.Featureflags() + case bouncer.FieldAutoCreated: + return m.AutoCreated() } return nil, false } @@ -3277,6 +5267,8 @@ func (m *BouncerMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldOsversion(ctx) case bouncer.FieldFeatureflags: return m.OldFeatureflags(ctx) + case bouncer.FieldAutoCreated: + return m.OldAutoCreated(ctx) } return nil, fmt.Errorf("unknown Bouncer field %s", name) } @@ -3377,6 +5369,13 @@ func (m *BouncerMutation) SetField(name string, value ent.Value) error { } m.SetFeatureflags(v) return nil + case bouncer.FieldAutoCreated: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAutoCreated(v) + return nil } return fmt.Errorf("unknown Bouncer field %s", name) } @@ -3510,6 +5509,9 @@ func (m *BouncerMutation) ResetField(name string) error { case bouncer.FieldFeatureflags: m.ResetFeatureflags() return nil + case bouncer.FieldAutoCreated: + m.ResetAutoCreated() + return nil } return fmt.Errorf("unknown Bouncer field %s", name) } diff --git a/pkg/database/ent/predicate/predicate.go b/pkg/database/ent/predicate/predicate.go index 8ad03e2fc48..97e574aa167 100644 --- a/pkg/database/ent/predicate/predicate.go +++ b/pkg/database/ent/predicate/predicate.go @@ -9,6 +9,12 @@ import ( // Alert is the predicate function for alert builders. type Alert func(*sql.Selector) +// AllowList is the predicate function for allowlist builders. +type AllowList func(*sql.Selector) + +// AllowListItem is the predicate function for allowlistitem builders. +type AllowListItem func(*sql.Selector) + // Bouncer is the predicate function for bouncer builders. type Bouncer func(*sql.Selector) diff --git a/pkg/database/ent/runtime.go b/pkg/database/ent/runtime.go index 15413490633..989e67fda7d 100644 --- a/pkg/database/ent/runtime.go +++ b/pkg/database/ent/runtime.go @@ -6,6 +6,8 @@ import ( "time" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem" "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" @@ -56,6 +58,30 @@ func init() { alertDescSimulated := alertFields[21].Descriptor() // alert.DefaultSimulated holds the default value on creation for the simulated field. alert.DefaultSimulated = alertDescSimulated.Default.(bool) + allowlistFields := schema.AllowList{}.Fields() + _ = allowlistFields + // allowlistDescCreatedAt is the schema descriptor for created_at field. + allowlistDescCreatedAt := allowlistFields[0].Descriptor() + // allowlist.DefaultCreatedAt holds the default value on creation for the created_at field. + allowlist.DefaultCreatedAt = allowlistDescCreatedAt.Default.(func() time.Time) + // allowlistDescUpdatedAt is the schema descriptor for updated_at field. + allowlistDescUpdatedAt := allowlistFields[1].Descriptor() + // allowlist.DefaultUpdatedAt holds the default value on creation for the updated_at field. + allowlist.DefaultUpdatedAt = allowlistDescUpdatedAt.Default.(func() time.Time) + // allowlist.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + allowlist.UpdateDefaultUpdatedAt = allowlistDescUpdatedAt.UpdateDefault.(func() time.Time) + allowlistitemFields := schema.AllowListItem{}.Fields() + _ = allowlistitemFields + // allowlistitemDescCreatedAt is the schema descriptor for created_at field. + allowlistitemDescCreatedAt := allowlistitemFields[0].Descriptor() + // allowlistitem.DefaultCreatedAt holds the default value on creation for the created_at field. + allowlistitem.DefaultCreatedAt = allowlistitemDescCreatedAt.Default.(func() time.Time) + // allowlistitemDescUpdatedAt is the schema descriptor for updated_at field. + allowlistitemDescUpdatedAt := allowlistitemFields[1].Descriptor() + // allowlistitem.DefaultUpdatedAt holds the default value on creation for the updated_at field. + allowlistitem.DefaultUpdatedAt = allowlistitemDescUpdatedAt.Default.(func() time.Time) + // allowlistitem.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + allowlistitem.UpdateDefaultUpdatedAt = allowlistitemDescUpdatedAt.UpdateDefault.(func() time.Time) bouncerFields := schema.Bouncer{}.Fields() _ = bouncerFields // bouncerDescCreatedAt is the schema descriptor for created_at field. @@ -76,6 +102,10 @@ func init() { bouncerDescAuthType := bouncerFields[9].Descriptor() // bouncer.DefaultAuthType holds the default value on creation for the auth_type field. bouncer.DefaultAuthType = bouncerDescAuthType.Default.(string) + // bouncerDescAutoCreated is the schema descriptor for auto_created field. + bouncerDescAutoCreated := bouncerFields[13].Descriptor() + // bouncer.DefaultAutoCreated holds the default value on creation for the auto_created field. + bouncer.DefaultAutoCreated = bouncerDescAutoCreated.Default.(bool) configitemFields := schema.ConfigItem{}.Fields() _ = configitemFields // configitemDescCreatedAt is the schema descriptor for created_at field. diff --git a/pkg/database/ent/schema/allowlist.go b/pkg/database/ent/schema/allowlist.go new file mode 100644 index 00000000000..1df878ac87a --- /dev/null +++ b/pkg/database/ent/schema/allowlist.go @@ -0,0 +1,44 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" + + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +// Alert holds the schema definition for the Alert entity. +type AllowList struct { + ent.Schema +} + +// Fields of the Alert. +func (AllowList) Fields() []ent.Field { + return []ent.Field{ + field.Time("created_at"). + Default(types.UtcNow). + Immutable(), + field.Time("updated_at"). + Default(types.UtcNow). + UpdateDefault(types.UtcNow), + field.String("name").Immutable(), + field.Bool("from_console"), + field.String("description").Optional().Immutable(), + field.String("allowlist_id").Optional().Immutable(), + } +} + +func (AllowList) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("id").Unique(), + index.Fields("name").Unique(), + } +} + +func (AllowList) Edges() []ent.Edge { + return []ent.Edge{ + edge.To("allowlist_items", AllowListItem.Type), + } +} diff --git a/pkg/database/ent/schema/allowlist_item.go b/pkg/database/ent/schema/allowlist_item.go new file mode 100644 index 00000000000..907b8a44729 --- /dev/null +++ b/pkg/database/ent/schema/allowlist_item.go @@ -0,0 +1,51 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" + + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +// AllowListItem holds the schema definition for the AllowListItem entity. +type AllowListItem struct { + ent.Schema +} + +// Fields of the AllowListItem. +func (AllowListItem) Fields() []ent.Field { + return []ent.Field{ + field.Time("created_at"). + Default(types.UtcNow). + Immutable(), + field.Time("updated_at"). + Default(types.UtcNow). + UpdateDefault(types.UtcNow), + field.Time("expires_at"). + Optional(), + field.String("comment").Optional().Immutable(), + field.String("value").Immutable(), //For textual representation of the IP/range + //Use the same fields as the decision table + field.Int64("start_ip").Optional().Immutable(), + field.Int64("end_ip").Optional().Immutable(), + field.Int64("start_suffix").Optional().Immutable(), + field.Int64("end_suffix").Optional().Immutable(), + field.Int64("ip_size").Optional().Immutable(), + } +} + +func (AllowListItem) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("id"), + index.Fields("start_ip", "end_ip"), + } +} + +func (AllowListItem) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("allowlist", AllowList.Type). + Ref("allowlist_items"), + } +} diff --git a/pkg/database/ent/schema/bouncer.go b/pkg/database/ent/schema/bouncer.go index 599c4c404fc..c176bf0f766 100644 --- a/pkg/database/ent/schema/bouncer.go +++ b/pkg/database/ent/schema/bouncer.go @@ -33,6 +33,8 @@ func (Bouncer) Fields() []ent.Field { field.String("osname").Optional(), field.String("osversion").Optional(), field.String("featureflags").Optional(), + // Old auto-created TLS bouncers will have a wrong value for this field + field.Bool("auto_created").StructTag(`json:"auto_created"`).Default(false).Immutable(), } } diff --git a/pkg/database/ent/tx.go b/pkg/database/ent/tx.go index bf8221ce4a5..69983beebc5 100644 --- a/pkg/database/ent/tx.go +++ b/pkg/database/ent/tx.go @@ -14,6 +14,10 @@ type Tx struct { config // Alert is the client for interacting with the Alert builders. Alert *AlertClient + // AllowList is the client for interacting with the AllowList builders. + AllowList *AllowListClient + // AllowListItem is the client for interacting with the AllowListItem builders. + AllowListItem *AllowListItemClient // Bouncer is the client for interacting with the Bouncer builders. Bouncer *BouncerClient // ConfigItem is the client for interacting with the ConfigItem builders. @@ -162,6 +166,8 @@ func (tx *Tx) Client() *Client { func (tx *Tx) init() { tx.Alert = NewAlertClient(tx.config) + tx.AllowList = NewAllowListClient(tx.config) + tx.AllowListItem = NewAllowListItemClient(tx.config) tx.Bouncer = NewBouncerClient(tx.config) tx.ConfigItem = NewConfigItemClient(tx.config) tx.Decision = NewDecisionClient(tx.config) diff --git a/pkg/database/flush.go b/pkg/database/flush.go index 8f646ddc961..e66e47a7635 100644 --- a/pkg/database/flush.go +++ b/pkg/database/flush.go @@ -13,6 +13,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem" "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/database/ent/event" @@ -115,6 +116,13 @@ func (c *Client) StartFlushScheduler(ctx context.Context, config *csconfig.Flush metricsJob.SingletonMode() + allowlistsJob, err := scheduler.Every(flushInterval).Do(c.flushAllowlists, ctx) + if err != nil { + return nil, fmt.Errorf("while starting FlushAllowlists scheduler: %w", err) + } + + allowlistsJob.SingletonMode() + scheduler.StartAsync() return scheduler, nil @@ -309,3 +317,17 @@ func (c *Client) FlushAlerts(ctx context.Context, MaxAge string, MaxItems int) e return nil } + +func (c *Client) flushAllowlists(ctx context.Context) { + deleted, err := c.Ent.AllowListItem.Delete().Where( + allowlistitem.ExpiresAtLTE(time.Now().UTC()), + ).Exec(ctx) + if err != nil { + c.Log.Errorf("while flushing allowlists: %s", err) + return + } + + if deleted > 0 { + c.Log.Debugf("flushed %d allowlists", deleted) + } +} diff --git a/pkg/exprhelpers/crowdsec_cti.go b/pkg/exprhelpers/crowdsec_cti.go index ccd67b27a49..9b9eac4b95c 100644 --- a/pkg/exprhelpers/crowdsec_cti.go +++ b/pkg/exprhelpers/crowdsec_cti.go @@ -12,16 +12,20 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -var CTIUrl = "https://cti.api.crowdsec.net" -var CTIUrlSuffix = "/v2/smoke/" -var CTIApiKey = "" +var ( + CTIUrl = "https://cti.api.crowdsec.net" + CTIUrlSuffix = "/v2/smoke/" + CTIApiKey = "" +) // this is set for non-recoverable errors, such as 403 when querying API or empty API key var CTIApiEnabled = false // when hitting quotas or auth errors, we temporarily disable the API -var CTIBackOffUntil time.Time -var CTIBackOffDuration = 5 * time.Minute +var ( + CTIBackOffUntil time.Time + CTIBackOffDuration = 5 * time.Minute +) var ctiClient *cticlient.CrowdsecCTIClient @@ -62,8 +66,10 @@ func ShutdownCrowdsecCTI() { } // Cache for responses -var CTICache gcache.Cache -var CacheExpiration time.Duration +var ( + CTICache gcache.Cache + CacheExpiration time.Duration +) func CrowdsecCTIInitCache(size int, ttl time.Duration) { CTICache = gcache.New(size).LRU().Build() diff --git a/pkg/exprhelpers/geoip.go b/pkg/exprhelpers/geoip.go index fb0c344d884..6d8813dc0ad 100644 --- a/pkg/exprhelpers/geoip.go +++ b/pkg/exprhelpers/geoip.go @@ -14,7 +14,6 @@ func GeoIPEnrich(params ...any) (any, error) { parsedIP := net.ParseIP(ip) city, err := geoIPCityReader.City(parsedIP) - if err != nil { return nil, err } @@ -31,7 +30,6 @@ func GeoIPASNEnrich(params ...any) (any, error) { parsedIP := net.ParseIP(ip) asn, err := geoIPASNReader.ASN(parsedIP) - if err != nil { return nil, err } @@ -50,7 +48,6 @@ func GeoIPRangeEnrich(params ...any) (any, error) { parsedIP := net.ParseIP(ip) rangeIP, ok, err := geoIPRangeReader.LookupNetwork(parsedIP, &dummy) - if err != nil { return nil, err } diff --git a/pkg/fflag/crowdsec.go b/pkg/fflag/crowdsec.go index d42d6a05ef6..ea397bfe5bc 100644 --- a/pkg/fflag/crowdsec.go +++ b/pkg/fflag/crowdsec.go @@ -2,12 +2,14 @@ package fflag var Crowdsec = FeatureRegister{EnvPrefix: "CROWDSEC_FEATURE_"} -var CscliSetup = &Feature{Name: "cscli_setup", Description: "Enable cscli setup command (service detection)"} -var DisableHttpRetryBackoff = &Feature{Name: "disable_http_retry_backoff", Description: "Disable http retry backoff"} -var ChunkedDecisionsStream = &Feature{Name: "chunked_decisions_stream", Description: "Enable chunked decisions stream"} -var PapiClient = &Feature{Name: "papi_client", Description: "Enable Polling API client", State: DeprecatedState} -var Re2GrokSupport = &Feature{Name: "re2_grok_support", Description: "Enable RE2 support for GROK patterns"} -var Re2RegexpInfileSupport = &Feature{Name: "re2_regexp_in_file_support", Description: "Enable RE2 support for RegexpInFile expr helper"} +var ( + CscliSetup = &Feature{Name: "cscli_setup", Description: "Enable cscli setup command (service detection)"} + DisableHttpRetryBackoff = &Feature{Name: "disable_http_retry_backoff", Description: "Disable http retry backoff"} + ChunkedDecisionsStream = &Feature{Name: "chunked_decisions_stream", Description: "Enable chunked decisions stream"} + PapiClient = &Feature{Name: "papi_client", Description: "Enable Polling API client", State: DeprecatedState} + Re2GrokSupport = &Feature{Name: "re2_grok_support", Description: "Enable RE2 support for GROK patterns"} + Re2RegexpInfileSupport = &Feature{Name: "re2_regexp_in_file_support", Description: "Enable RE2 support for RegexpInFile expr helper"} +) func RegisterAllFeatures() error { err := Crowdsec.RegisterFeature(CscliSetup) diff --git a/pkg/leakybucket/blackhole.go b/pkg/leakybucket/blackhole.go index b12f169acd9..bda2e7c9ed1 100644 --- a/pkg/leakybucket/blackhole.go +++ b/pkg/leakybucket/blackhole.go @@ -49,7 +49,6 @@ func (bl *Blackhole) OnBucketOverflow(bucketFactory *BucketFactory) func(*Leaky, tmp = append(tmp, element) } else { leaky.logger.Debugf("%s left blackhole %s ago", element.key, leaky.Ovflw_ts.Sub(element.expiration)) - } } bl.hiddenKeys = tmp @@ -64,5 +63,4 @@ func (bl *Blackhole) OnBucketOverflow(bucketFactory *BucketFactory) func(*Leaky, leaky.logger.Debugf("Adding overflow to blackhole (%s)", leaky.First_ts) return alert, queue } - } diff --git a/pkg/leakybucket/bucket.go b/pkg/leakybucket/bucket.go index e981551af8f..bc81a505925 100644 --- a/pkg/leakybucket/bucket.go +++ b/pkg/leakybucket/bucket.go @@ -204,7 +204,6 @@ func FromFactory(bucketFactory BucketFactory) *Leaky { /* for now mimic a leak routine */ //LeakRoutine us the life of a bucket. It dies when the bucket underflows or overflows func LeakRoutine(leaky *Leaky) error { - var ( durationTickerChan = make(<-chan time.Time) durationTicker *time.Ticker diff --git a/pkg/leakybucket/buckets.go b/pkg/leakybucket/buckets.go index cfe8d7c302e..72948da1ad7 100644 --- a/pkg/leakybucket/buckets.go +++ b/pkg/leakybucket/buckets.go @@ -25,5 +25,4 @@ func NewBuckets() *Buckets { func GetKey(bucketCfg BucketFactory, stackkey string) string { return fmt.Sprintf("%x", sha1.Sum([]byte(bucketCfg.Filter+stackkey+bucketCfg.Name))) - } diff --git a/pkg/leakybucket/conditional.go b/pkg/leakybucket/conditional.go index a203a639743..b3a84b07c21 100644 --- a/pkg/leakybucket/conditional.go +++ b/pkg/leakybucket/conditional.go @@ -11,8 +11,10 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -var conditionalExprCache map[string]vm.Program -var conditionalExprCacheLock sync.Mutex +var ( + conditionalExprCache map[string]vm.Program + conditionalExprCacheLock sync.Mutex +) type ConditionalOverflow struct { ConditionalFilter string diff --git a/pkg/leakybucket/manager_load_test.go b/pkg/leakybucket/manager_load_test.go index 513f11ff373..9d207da164e 100644 --- a/pkg/leakybucket/manager_load_test.go +++ b/pkg/leakybucket/manager_load_test.go @@ -51,93 +51,86 @@ func TestBadBucketsConfig(t *testing.T) { } func TestLeakyBucketsConfig(t *testing.T) { - var CfgTests = []cfgTest{ - //leaky with bad capacity + CfgTests := []cfgTest{ + // leaky with bad capacity {BucketFactory{Name: "test", Description: "test1", Type: "leaky", Capacity: 0}, false, false}, - //leaky with empty leakspeed + // leaky with empty leakspeed {BucketFactory{Name: "test", Description: "test1", Type: "leaky", Capacity: 1}, false, false}, - //leaky with missing filter + // leaky with missing filter {BucketFactory{Name: "test", Description: "test1", Type: "leaky", Capacity: 1, LeakSpeed: "1s"}, false, true}, - //leaky with invalid leakspeed + // leaky with invalid leakspeed {BucketFactory{Name: "test", Description: "test1", Type: "leaky", Capacity: 1, LeakSpeed: "abs", Filter: "true"}, false, false}, - //leaky with valid filter + // leaky with valid filter {BucketFactory{Name: "test", Description: "test1", Type: "leaky", Capacity: 1, LeakSpeed: "1s", Filter: "true"}, true, true}, - //leaky with invalid filter + // leaky with invalid filter {BucketFactory{Name: "test", Description: "test1", Type: "leaky", Capacity: 1, LeakSpeed: "1s", Filter: "xu"}, false, true}, - //leaky with valid filter + // leaky with valid filter {BucketFactory{Name: "test", Description: "test1", Type: "leaky", Capacity: 1, LeakSpeed: "1s", Filter: "true"}, true, true}, - //leaky with bad overflow filter + // leaky with bad overflow filter {BucketFactory{Name: "test", Description: "test1", Type: "leaky", Capacity: 1, LeakSpeed: "1s", Filter: "true", OverflowFilter: "xu"}, false, true}, } if err := runTest(CfgTests); err != nil { t.Fatalf("%s", err) } - } func TestBlackholeConfig(t *testing.T) { - var CfgTests = []cfgTest{ - //basic bh + CfgTests := []cfgTest{ + // basic bh {BucketFactory{Name: "test", Description: "test1", Type: "trigger", Filter: "true", Blackhole: "15s"}, true, true}, - //bad bh + // bad bh {BucketFactory{Name: "test", Description: "test1", Type: "trigger", Filter: "true", Blackhole: "abc"}, false, true}, } if err := runTest(CfgTests); err != nil { t.Fatalf("%s", err) } - } func TestTriggerBucketsConfig(t *testing.T) { - var CfgTests = []cfgTest{ - //basic valid counter + CfgTests := []cfgTest{ + // basic valid counter {BucketFactory{Name: "test", Description: "test1", Type: "trigger", Filter: "true"}, true, true}, } if err := runTest(CfgTests); err != nil { t.Fatalf("%s", err) } - } func TestCounterBucketsConfig(t *testing.T) { - var CfgTests = []cfgTest{ - - //basic valid counter + CfgTests := []cfgTest{ + // basic valid counter {BucketFactory{Name: "test", Description: "test1", Type: "counter", Capacity: -1, Duration: "5s", Filter: "true"}, true, true}, - //missing duration + // missing duration {BucketFactory{Name: "test", Description: "test1", Type: "counter", Capacity: -1, Filter: "true"}, false, false}, - //bad duration + // bad duration {BucketFactory{Name: "test", Description: "test1", Type: "counter", Capacity: -1, Duration: "abc", Filter: "true"}, false, false}, - //capacity must be -1 + // capacity must be -1 {BucketFactory{Name: "test", Description: "test1", Type: "counter", Capacity: 0, Duration: "5s", Filter: "true"}, false, false}, } if err := runTest(CfgTests); err != nil { t.Fatalf("%s", err) } - } func TestBayesianBucketsConfig(t *testing.T) { - var CfgTests = []cfgTest{ - - //basic valid counter + CfgTests := []cfgTest{ + // basic valid counter {BucketFactory{Name: "test", Description: "test1", Type: "bayesian", Capacity: -1, Filter: "true", BayesianPrior: 0.5, BayesianThreshold: 0.5, BayesianConditions: []RawBayesianCondition{{ConditionalFilterName: "true", ProbGivenEvil: 0.5, ProbGivenBenign: 0.5}}}, true, true}, - //bad capacity + // bad capacity {BucketFactory{Name: "test", Description: "test1", Type: "bayesian", Capacity: 1, Filter: "true", BayesianPrior: 0.5, BayesianThreshold: 0.5, BayesianConditions: []RawBayesianCondition{{ConditionalFilterName: "true", ProbGivenEvil: 0.5, ProbGivenBenign: 0.5}}}, false, false}, - //missing prior + // missing prior {BucketFactory{Name: "test", Description: "test1", Type: "bayesian", Capacity: -1, Filter: "true", BayesianThreshold: 0.5, BayesianConditions: []RawBayesianCondition{{ConditionalFilterName: "true", ProbGivenEvil: 0.5, ProbGivenBenign: 0.5}}}, false, false}, - //missing threshold + // missing threshold {BucketFactory{Name: "test", Description: "test1", Type: "bayesian", Capacity: -1, Filter: "true", BayesianPrior: 0.5, BayesianConditions: []RawBayesianCondition{{ConditionalFilterName: "true", ProbGivenEvil: 0.5, ProbGivenBenign: 0.5}}}, false, false}, - //bad prior + // bad prior {BucketFactory{Name: "test", Description: "test1", Type: "bayesian", Capacity: -1, Filter: "true", BayesianPrior: 1.5, BayesianThreshold: 0.5, BayesianConditions: []RawBayesianCondition{{ConditionalFilterName: "true", ProbGivenEvil: 0.5, ProbGivenBenign: 0.5}}}, false, false}, - //bad threshold + // bad threshold {BucketFactory{Name: "test", Description: "test1", Type: "bayesian", Capacity: -1, Filter: "true", BayesianPrior: 0.5, BayesianThreshold: 1.5, BayesianConditions: []RawBayesianCondition{{ConditionalFilterName: "true", ProbGivenEvil: 0.5, ProbGivenBenign: 0.5}}}, false, false}, } if err := runTest(CfgTests); err != nil { t.Fatalf("%s", err) } - } diff --git a/pkg/leakybucket/manager_run.go b/pkg/leakybucket/manager_run.go index 2858d8b5635..e6712e6e47e 100644 --- a/pkg/leakybucket/manager_run.go +++ b/pkg/leakybucket/manager_run.go @@ -17,9 +17,11 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -var serialized map[string]Leaky -var BucketPourCache map[string][]types.Event -var BucketPourTrack bool +var ( + serialized map[string]Leaky + BucketPourCache map[string][]types.Event + BucketPourTrack bool +) /* The leaky routines lifecycle are based on "real" time. @@ -243,7 +245,6 @@ func PourItemToBucket(bucket *Leaky, holder BucketFactory, buckets *Buckets, par } func LoadOrStoreBucketFromHolder(partitionKey string, buckets *Buckets, holder BucketFactory, expectMode int) (*Leaky, error) { - biface, ok := buckets.Bucket_map.Load(partitionKey) /* the bucket doesn't exist, create it !*/ @@ -283,9 +284,7 @@ func LoadOrStoreBucketFromHolder(partitionKey string, buckets *Buckets, holder B var orderEvent map[string]*sync.WaitGroup func PourItemToHolders(parsed types.Event, holders []BucketFactory, buckets *Buckets) (bool, error) { - var ( - ok, condition, poured bool - ) + var ok, condition, poured bool if BucketPourTrack { if BucketPourCache == nil { diff --git a/pkg/leakybucket/overflows.go b/pkg/leakybucket/overflows.go index 39b0e6a0ec4..126bcd05685 100644 --- a/pkg/leakybucket/overflows.go +++ b/pkg/leakybucket/overflows.go @@ -198,22 +198,24 @@ func eventSources(evt types.Event, leaky *Leaky) (map[string]models.Source, erro func EventsFromQueue(queue *types.Queue) []*models.Event { events := []*models.Event{} - for _, evt := range queue.Queue { - if evt.Meta == nil { + qEvents := queue.GetQueue() + + for idx := range qEvents { + if qEvents[idx].Meta == nil { continue } meta := models.Meta{} // we want consistence - skeys := make([]string, 0, len(evt.Meta)) - for k := range evt.Meta { + skeys := make([]string, 0, len(qEvents[idx].Meta)) + for k := range qEvents[idx].Meta { skeys = append(skeys, k) } sort.Strings(skeys) for _, k := range skeys { - v := evt.Meta[k] + v := qEvents[idx].Meta[k] subMeta := models.MetaItems0{Key: k, Value: v} meta = append(meta, &subMeta) } @@ -223,15 +225,15 @@ func EventsFromQueue(queue *types.Queue) []*models.Event { Meta: meta, } // either MarshaledTime is present and is extracted from log - if evt.MarshaledTime != "" { - tmpTimeStamp := evt.MarshaledTime + if qEvents[idx].MarshaledTime != "" { + tmpTimeStamp := qEvents[idx].MarshaledTime ovflwEvent.Timestamp = &tmpTimeStamp - } else if !evt.Time.IsZero() { // or .Time has been set during parse as time.Now().UTC() + } else if !qEvents[idx].Time.IsZero() { // or .Time has been set during parse as time.Now().UTC() ovflwEvent.Timestamp = new(string) - raw, err := evt.Time.MarshalText() + raw, err := qEvents[idx].Time.MarshalText() if err != nil { - log.Warningf("while serializing time '%s' : %s", evt.Time.String(), err) + log.Warningf("while serializing time '%s' : %s", qEvents[idx].Time.String(), err) } else { *ovflwEvent.Timestamp = string(raw) } @@ -253,8 +255,9 @@ func alertFormatSource(leaky *Leaky, queue *types.Queue) (map[string]models.Sour log.Debugf("Formatting (%s) - scope Info : scope_type:%s / scope_filter:%s", leaky.Name, leaky.scopeType.Scope, leaky.scopeType.Filter) - for _, evt := range queue.Queue { - srcs, err := SourceFromEvent(evt, leaky) + qEvents := queue.GetQueue() + for idx := range qEvents { + srcs, err := SourceFromEvent(qEvents[idx], leaky) if err != nil { return nil, "", fmt.Errorf("while extracting scope from bucket %s: %w", leaky.Name, err) } diff --git a/pkg/leakybucket/processor.go b/pkg/leakybucket/processor.go index 81af3000c1c..dc5330a612e 100644 --- a/pkg/leakybucket/processor.go +++ b/pkg/leakybucket/processor.go @@ -10,8 +10,7 @@ type Processor interface { AfterBucketPour(Bucket *BucketFactory) func(types.Event, *Leaky) *types.Event } -type DumbProcessor struct { -} +type DumbProcessor struct{} func (d *DumbProcessor) OnBucketInit(bucketFactory *BucketFactory) error { return nil diff --git a/pkg/leakybucket/reset_filter.go b/pkg/leakybucket/reset_filter.go index 452ccc085b1..3b9b876aff4 100644 --- a/pkg/leakybucket/reset_filter.go +++ b/pkg/leakybucket/reset_filter.go @@ -23,10 +23,12 @@ type CancelOnFilter struct { Debug bool } -var cancelExprCacheLock sync.Mutex -var cancelExprCache map[string]struct { - CancelOnFilter *vm.Program -} +var ( + cancelExprCacheLock sync.Mutex + cancelExprCache map[string]struct { + CancelOnFilter *vm.Program + } +) func (u *CancelOnFilter) OnBucketPour(bucketFactory *BucketFactory) func(types.Event, *Leaky) *types.Event { return func(msg types.Event, leaky *Leaky) *types.Event { diff --git a/pkg/leakybucket/uniq.go b/pkg/leakybucket/uniq.go index 0cc0583390b..3a4683ae309 100644 --- a/pkg/leakybucket/uniq.go +++ b/pkg/leakybucket/uniq.go @@ -16,8 +16,10 @@ import ( // on overflow // on leak -var uniqExprCache map[string]vm.Program -var uniqExprCacheLock sync.Mutex +var ( + uniqExprCache map[string]vm.Program + uniqExprCacheLock sync.Mutex +) type Uniq struct { DistinctCompiled *vm.Program diff --git a/pkg/longpollclient/client.go b/pkg/longpollclient/client.go index 5a7af0bfa63..5c395185b20 100644 --- a/pkg/longpollclient/client.go +++ b/pkg/longpollclient/client.go @@ -1,6 +1,7 @@ package longpollclient import ( + "context" "encoding/json" "errors" "fmt" @@ -50,7 +51,7 @@ var errUnauthorized = errors.New("user is not authorized to use PAPI") const timeoutMessage = "no events before timeout" -func (c *LongPollClient) doQuery() (*http.Response, error) { +func (c *LongPollClient) doQuery(ctx context.Context) (*http.Response, error) { logger := c.logger.WithField("method", "doQuery") query := c.url.Query() query.Set("since_time", fmt.Sprintf("%d", c.since)) @@ -59,7 +60,7 @@ func (c *LongPollClient) doQuery() (*http.Response, error) { logger.Debugf("Query parameters: %s", c.url.RawQuery) - req, err := http.NewRequest(http.MethodGet, c.url.String(), nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url.String(), nil) if err != nil { logger.Errorf("failed to create request: %s", err) return nil, err @@ -73,10 +74,10 @@ func (c *LongPollClient) doQuery() (*http.Response, error) { return resp, nil } -func (c *LongPollClient) poll() error { +func (c *LongPollClient) poll(ctx context.Context) error { logger := c.logger.WithField("method", "poll") - resp, err := c.doQuery() + resp, err := c.doQuery(ctx) if err != nil { return err } @@ -146,7 +147,7 @@ func (c *LongPollClient) poll() error { } } -func (c *LongPollClient) pollEvents() error { +func (c *LongPollClient) pollEvents(ctx context.Context) error { for { select { case <-c.t.Dying(): @@ -154,7 +155,7 @@ func (c *LongPollClient) pollEvents() error { return nil default: c.logger.Debug("Polling PAPI") - err := c.poll() + err := c.poll(ctx) if err != nil { c.logger.Errorf("failed to poll: %s", err) if errors.Is(err, errUnauthorized) { @@ -168,12 +169,12 @@ func (c *LongPollClient) pollEvents() error { } } -func (c *LongPollClient) Start(since time.Time) chan Event { +func (c *LongPollClient) Start(ctx context.Context, since time.Time) chan Event { c.logger.Infof("starting polling client") c.c = make(chan Event) c.since = since.Unix() * 1000 c.timeout = "45" - c.t.Go(c.pollEvents) + c.t.Go(func() error {return c.pollEvents(ctx)}) return c.c } @@ -182,11 +183,11 @@ func (c *LongPollClient) Stop() error { return nil } -func (c *LongPollClient) PullOnce(since time.Time) ([]Event, error) { +func (c *LongPollClient) PullOnce(ctx context.Context, since time.Time) ([]Event, error) { c.logger.Debug("Pulling PAPI once") c.since = since.Unix() * 1000 c.timeout = "1" - resp, err := c.doQuery() + resp, err := c.doQuery(ctx) if err != nil { return nil, err } diff --git a/pkg/models/allowlist_item.go b/pkg/models/allowlist_item.go new file mode 100644 index 00000000000..3d688d52e5d --- /dev/null +++ b/pkg/models/allowlist_item.go @@ -0,0 +1,100 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// AllowlistItem AllowlistItem +// +// swagger:model AllowlistItem +type AllowlistItem struct { + + // creation date of the allowlist item + // Format: date-time + CreatedAt strfmt.DateTime `json:"created_at,omitempty"` + + // description of the allowlist item + Description string `json:"description,omitempty"` + + // expiration date of the allowlist item + // Format: date-time + Expiration strfmt.DateTime `json:"expiration,omitempty"` + + // value of the allowlist item + Value string `json:"value,omitempty"` +} + +// Validate validates this allowlist item +func (m *AllowlistItem) Validate(formats strfmt.Registry) error { + var res []error + + if err := m.validateCreatedAt(formats); err != nil { + res = append(res, err) + } + + if err := m.validateExpiration(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *AllowlistItem) validateCreatedAt(formats strfmt.Registry) error { + if swag.IsZero(m.CreatedAt) { // not required + return nil + } + + if err := validate.FormatOf("created_at", "body", "date-time", m.CreatedAt.String(), formats); err != nil { + return err + } + + return nil +} + +func (m *AllowlistItem) validateExpiration(formats strfmt.Registry) error { + if swag.IsZero(m.Expiration) { // not required + return nil + } + + if err := validate.FormatOf("expiration", "body", "date-time", m.Expiration.String(), formats); err != nil { + return err + } + + return nil +} + +// ContextValidate validates this allowlist item based on context it is used +func (m *AllowlistItem) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + return nil +} + +// MarshalBinary interface implementation +func (m *AllowlistItem) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *AllowlistItem) UnmarshalBinary(b []byte) error { + var res AllowlistItem + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/check_allowlist_response.go b/pkg/models/check_allowlist_response.go new file mode 100644 index 00000000000..ab57d4fea71 --- /dev/null +++ b/pkg/models/check_allowlist_response.go @@ -0,0 +1,50 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" +) + +// CheckAllowlistResponse CheckAllowlistResponse +// +// swagger:model CheckAllowlistResponse +type CheckAllowlistResponse struct { + + // true if the IP or range is in the allowlist + Allowlisted bool `json:"allowlisted,omitempty"` +} + +// Validate validates this check allowlist response +func (m *CheckAllowlistResponse) Validate(formats strfmt.Registry) error { + return nil +} + +// ContextValidate validates this check allowlist response based on context it is used +func (m *CheckAllowlistResponse) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + return nil +} + +// MarshalBinary interface implementation +func (m *CheckAllowlistResponse) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *CheckAllowlistResponse) UnmarshalBinary(b []byte) error { + var res CheckAllowlistResponse + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/get_allowlist_response.go b/pkg/models/get_allowlist_response.go new file mode 100644 index 00000000000..4459457ecb3 --- /dev/null +++ b/pkg/models/get_allowlist_response.go @@ -0,0 +1,174 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + "strconv" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// GetAllowlistResponse GetAllowlistResponse +// +// swagger:model GetAllowlistResponse +type GetAllowlistResponse struct { + + // id of the allowlist + AllowlistID string `json:"allowlist_id,omitempty"` + + // true if the allowlist is managed by the console + ConsoleManaged bool `json:"console_managed,omitempty"` + + // creation date of the allowlist + // Format: date-time + CreatedAt strfmt.DateTime `json:"created_at,omitempty"` + + // description of the allowlist + Description string `json:"description,omitempty"` + + // items in the allowlist + Items []*AllowlistItem `json:"items"` + + // name of the allowlist + Name string `json:"name,omitempty"` + + // last update date of the allowlist + // Format: date-time + UpdatedAt strfmt.DateTime `json:"updated_at,omitempty"` +} + +// Validate validates this get allowlist response +func (m *GetAllowlistResponse) Validate(formats strfmt.Registry) error { + var res []error + + if err := m.validateCreatedAt(formats); err != nil { + res = append(res, err) + } + + if err := m.validateItems(formats); err != nil { + res = append(res, err) + } + + if err := m.validateUpdatedAt(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *GetAllowlistResponse) validateCreatedAt(formats strfmt.Registry) error { + if swag.IsZero(m.CreatedAt) { // not required + return nil + } + + if err := validate.FormatOf("created_at", "body", "date-time", m.CreatedAt.String(), formats); err != nil { + return err + } + + return nil +} + +func (m *GetAllowlistResponse) validateItems(formats strfmt.Registry) error { + if swag.IsZero(m.Items) { // not required + return nil + } + + for i := 0; i < len(m.Items); i++ { + if swag.IsZero(m.Items[i]) { // not required + continue + } + + if m.Items[i] != nil { + if err := m.Items[i].Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("items" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("items" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + +func (m *GetAllowlistResponse) validateUpdatedAt(formats strfmt.Registry) error { + if swag.IsZero(m.UpdatedAt) { // not required + return nil + } + + if err := validate.FormatOf("updated_at", "body", "date-time", m.UpdatedAt.String(), formats); err != nil { + return err + } + + return nil +} + +// ContextValidate validate this get allowlist response based on the context it is used +func (m *GetAllowlistResponse) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + var res []error + + if err := m.contextValidateItems(ctx, formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *GetAllowlistResponse) contextValidateItems(ctx context.Context, formats strfmt.Registry) error { + + for i := 0; i < len(m.Items); i++ { + + if m.Items[i] != nil { + + if swag.IsZero(m.Items[i]) { // not required + return nil + } + + if err := m.Items[i].ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("items" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("items" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + +// MarshalBinary interface implementation +func (m *GetAllowlistResponse) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *GetAllowlistResponse) UnmarshalBinary(b []byte) error { + var res GetAllowlistResponse + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/models/get_allowlists_response.go b/pkg/models/get_allowlists_response.go new file mode 100644 index 00000000000..dd6a80918c6 --- /dev/null +++ b/pkg/models/get_allowlists_response.go @@ -0,0 +1,78 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + "strconv" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" +) + +// GetAllowlistsResponse GetAllowlistsResponse +// +// swagger:model GetAllowlistsResponse +type GetAllowlistsResponse []*GetAllowlistResponse + +// Validate validates this get allowlists response +func (m GetAllowlistsResponse) Validate(formats strfmt.Registry) error { + var res []error + + for i := 0; i < len(m); i++ { + if swag.IsZero(m[i]) { // not required + continue + } + + if m[i] != nil { + if err := m[i].Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName(strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName(strconv.Itoa(i)) + } + return err + } + } + + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +// ContextValidate validate this get allowlists response based on the context it is used +func (m GetAllowlistsResponse) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + var res []error + + for i := 0; i < len(m); i++ { + + if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + + if err := m[i].ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName(strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName(strconv.Itoa(i)) + } + return err + } + } + + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} diff --git a/pkg/models/localapi_swagger.yaml b/pkg/models/localapi_swagger.yaml index 01bbe6f8bde..5d156e5791f 100644 --- a/pkg/models/localapi_swagger.yaml +++ b/pkg/models/localapi_swagger.yaml @@ -719,6 +719,120 @@ paths: security: - APIKeyAuthorizer: [] - JWTAuthorizer: [] + /allowlists: + get: + description: Get a list of all allowlists + summary: getAllowlists + tags: + - watchers + operationId: getAllowlists + produces: + - application/json + responses: + '200': + description: successful operation + schema: + $ref: '#/definitions/GetAllowlistsResponse' + headers: {} + /allowlists/{allowlist_name}: + get: + description: Get a specific allowlist + summary: getAllowlist + tags: + - watchers + operationId: getAllowlist + produces: + - application/json + parameters: + - name: allowlist_name + in: path + required: true + type: string + description: '' + responses: + '200': + description: successful operation + schema: + $ref: '#/definitions/GetAllowlistResponse' + headers: {} + '404': + description: "404 response" + schema: + $ref: "#/definitions/ErrorResponse" + head: + description: Get a specific allowlist + summary: getAllowlist + tags: + - watchers + operationId: headAllowlist + produces: + - application/json + parameters: + - name: allowlist_name + in: path + required: true + type: string + description: '' + - name: with_content + in: query + required: false + type: boolean + description: 'if true, the content of the allowlist will be returned as well' + responses: + '200': + description: successful operation + headers: {} + '404': + description: "404 response" + /allowlists/check/{ip_or_range}: + get: + description: Check if an IP or range is in an allowlist + summary: checkAllowlist + tags: + - watchers + operationId: checkAllowlist + produces: + - application/json + parameters: + - name: ip_or_range + in: path + required: true + type: string + description: '' + responses: + '200': + description: successful operation + schema: + $ref: '#/definitions/CheckAllowlistResponse' + headers: {} + '400': + description: "missing ip_or_range" + schema: + $ref: "#/definitions/ErrorResponse" + head: + description: Check if an IP or range is in an allowlist + summary: checkAllowlist + tags: + - watchers + operationId: headCheckAllowlist + produces: + - application/json + parameters: + - name: ip_or_range + in: path + required: true + type: string + description: '' + responses: + '200': + description: IP or range is in an allowlist + headers: {} + '204': + description: "IP or range is not in an allowlist" + '400': + description: "missing ip_or_range" + schema: + $ref: "#/definitions/ErrorResponse" definitions: WatcherRegistrationRequest: title: WatcherRegistrationRequest @@ -1220,6 +1334,65 @@ definitions: status: type: string description: status of the hub item (official, custom, tainted, etc.) + GetAllowlistsResponse: + title: GetAllowlistsResponse + type: array + items: + $ref: '#/definitions/GetAllowlistResponse' + GetAllowlistResponse: + title: GetAllowlistResponse + type: object + properties: + name: + type: string + description: name of the allowlist + allowlist_id: + type: string + description: id of the allowlist + description: + type: string + description: description of the allowlist + items: + type: array + items: + $ref: '#/definitions/AllowlistItem' + description: items in the allowlist + created_at: + type: string + format: date-time + description: creation date of the allowlist + updated_at: + type: string + format: date-time + description: last update date of the allowlist + console_managed: + type: boolean + description: true if the allowlist is managed by the console + AllowlistItem: + title: AllowlistItem + type: object + properties: + value: + type: string + description: value of the allowlist item + description: + type: string + description: description of the allowlist item + created_at: + type: string + format: date-time + description: creation date of the allowlist item + expiration: + type: string + format: date-time + description: expiration date of the allowlist item + CheckAllowlistResponse: + title: CheckAllowlistResponse + type: object + properties: + allowlisted: + type: boolean + description: 'true if the IP or range is in the allowlist' ErrorResponse: type: "object" required: diff --git a/pkg/modelscapi/allowlist_link.go b/pkg/modelscapi/allowlist_link.go new file mode 100644 index 00000000000..ce9fce17357 --- /dev/null +++ b/pkg/modelscapi/allowlist_link.go @@ -0,0 +1,166 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package modelscapi + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// AllowlistLink allowlist link +// +// swagger:model AllowlistLink +type AllowlistLink struct { + + // the creation date of the allowlist + // Required: true + // Format: date-time + CreatedAt *strfmt.DateTime `json:"created_at"` + + // the description of the allowlist + // Required: true + Description *string `json:"description"` + + // the id of the allowlist + // Required: true + ID *string `json:"id"` + + // the name of the allowlist + // Required: true + Name *string `json:"name"` + + // the last update date of the allowlist + // Required: true + // Format: date-time + UpdatedAt *strfmt.DateTime `json:"updated_at"` + + // the url from which the allowlist content can be downloaded + // Required: true + URL *string `json:"url"` +} + +// Validate validates this allowlist link +func (m *AllowlistLink) Validate(formats strfmt.Registry) error { + var res []error + + if err := m.validateCreatedAt(formats); err != nil { + res = append(res, err) + } + + if err := m.validateDescription(formats); err != nil { + res = append(res, err) + } + + if err := m.validateID(formats); err != nil { + res = append(res, err) + } + + if err := m.validateName(formats); err != nil { + res = append(res, err) + } + + if err := m.validateUpdatedAt(formats); err != nil { + res = append(res, err) + } + + if err := m.validateURL(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *AllowlistLink) validateCreatedAt(formats strfmt.Registry) error { + + if err := validate.Required("created_at", "body", m.CreatedAt); err != nil { + return err + } + + if err := validate.FormatOf("created_at", "body", "date-time", m.CreatedAt.String(), formats); err != nil { + return err + } + + return nil +} + +func (m *AllowlistLink) validateDescription(formats strfmt.Registry) error { + + if err := validate.Required("description", "body", m.Description); err != nil { + return err + } + + return nil +} + +func (m *AllowlistLink) validateID(formats strfmt.Registry) error { + + if err := validate.Required("id", "body", m.ID); err != nil { + return err + } + + return nil +} + +func (m *AllowlistLink) validateName(formats strfmt.Registry) error { + + if err := validate.Required("name", "body", m.Name); err != nil { + return err + } + + return nil +} + +func (m *AllowlistLink) validateUpdatedAt(formats strfmt.Registry) error { + + if err := validate.Required("updated_at", "body", m.UpdatedAt); err != nil { + return err + } + + if err := validate.FormatOf("updated_at", "body", "date-time", m.UpdatedAt.String(), formats); err != nil { + return err + } + + return nil +} + +func (m *AllowlistLink) validateURL(formats strfmt.Registry) error { + + if err := validate.Required("url", "body", m.URL); err != nil { + return err + } + + return nil +} + +// ContextValidate validates this allowlist link based on context it is used +func (m *AllowlistLink) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + return nil +} + +// MarshalBinary interface implementation +func (m *AllowlistLink) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *AllowlistLink) UnmarshalBinary(b []byte) error { + var res AllowlistLink + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/pkg/modelscapi/centralapi_swagger.yaml b/pkg/modelscapi/centralapi_swagger.yaml index bd695894f2b..6a830ee3820 100644 --- a/pkg/modelscapi/centralapi_swagger.yaml +++ b/pkg/modelscapi/centralapi_swagger.yaml @@ -55,6 +55,19 @@ paths: description: "returns list of top decisions to add or delete" produces: - "application/json" + parameters: + - in: query + name: "community_pull" + type: "boolean" + default: true + required: false + description: "Fetch the community blocklist content" + - in: query + name: "additional_pull" + type: "boolean" + default: true + required: false + description: "Fetch additional blocklists content" responses: "200": description: "200 response" @@ -535,6 +548,37 @@ definitions: description: "the scope of decisions in the blocklist" duration: type: string + AllowlistLink: + type: object + required: + - name + - description + - url + - id + - created_at + - updated_at + properties: + name: + type: string + description: "the name of the allowlist" + description: + type: string + description: "the description of the allowlist" + url: + type: string + description: "the url from which the allowlist content can be downloaded" + id: + type: string + description: "the id of the allowlist" + created_at: + type: string + format: date-time + description: "the creation date of the allowlist" + updated_at: + type: string + format: date-time + description: "the last update date of the allowlist" + AddSignalsRequestItemDecisionsItem: type: "object" required: @@ -872,4 +916,8 @@ definitions: type: array items: $ref: "#/definitions/BlocklistLink" + allowlists: + type: array + items: + $ref: "#/definitions/AllowlistLink" diff --git a/pkg/modelscapi/get_decisions_stream_response_links.go b/pkg/modelscapi/get_decisions_stream_response_links.go index 6b9054574f1..f9e320aee38 100644 --- a/pkg/modelscapi/get_decisions_stream_response_links.go +++ b/pkg/modelscapi/get_decisions_stream_response_links.go @@ -19,6 +19,9 @@ import ( // swagger:model GetDecisionsStreamResponseLinks type GetDecisionsStreamResponseLinks struct { + // allowlists + Allowlists []*AllowlistLink `json:"allowlists"` + // blocklists Blocklists []*BlocklistLink `json:"blocklists"` } @@ -27,6 +30,10 @@ type GetDecisionsStreamResponseLinks struct { func (m *GetDecisionsStreamResponseLinks) Validate(formats strfmt.Registry) error { var res []error + if err := m.validateAllowlists(formats); err != nil { + res = append(res, err) + } + if err := m.validateBlocklists(formats); err != nil { res = append(res, err) } @@ -37,6 +44,32 @@ func (m *GetDecisionsStreamResponseLinks) Validate(formats strfmt.Registry) erro return nil } +func (m *GetDecisionsStreamResponseLinks) validateAllowlists(formats strfmt.Registry) error { + if swag.IsZero(m.Allowlists) { // not required + return nil + } + + for i := 0; i < len(m.Allowlists); i++ { + if swag.IsZero(m.Allowlists[i]) { // not required + continue + } + + if m.Allowlists[i] != nil { + if err := m.Allowlists[i].Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("allowlists" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("allowlists" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + func (m *GetDecisionsStreamResponseLinks) validateBlocklists(formats strfmt.Registry) error { if swag.IsZero(m.Blocklists) { // not required return nil @@ -67,6 +100,10 @@ func (m *GetDecisionsStreamResponseLinks) validateBlocklists(formats strfmt.Regi func (m *GetDecisionsStreamResponseLinks) ContextValidate(ctx context.Context, formats strfmt.Registry) error { var res []error + if err := m.contextValidateAllowlists(ctx, formats); err != nil { + res = append(res, err) + } + if err := m.contextValidateBlocklists(ctx, formats); err != nil { res = append(res, err) } @@ -77,6 +114,31 @@ func (m *GetDecisionsStreamResponseLinks) ContextValidate(ctx context.Context, f return nil } +func (m *GetDecisionsStreamResponseLinks) contextValidateAllowlists(ctx context.Context, formats strfmt.Registry) error { + + for i := 0; i < len(m.Allowlists); i++ { + + if m.Allowlists[i] != nil { + + if swag.IsZero(m.Allowlists[i]) { // not required + return nil + } + + if err := m.Allowlists[i].ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("allowlists" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("allowlists" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + func (m *GetDecisionsStreamResponseLinks) contextValidateBlocklists(ctx context.Context, formats strfmt.Registry) error { for i := 0; i < len(m.Blocklists); i++ { diff --git a/pkg/parser/enrich.go b/pkg/parser/enrich.go index 661410d20d3..a69cd963813 100644 --- a/pkg/parser/enrich.go +++ b/pkg/parser/enrich.go @@ -7,8 +7,10 @@ import ( ) /* should be part of a package shared with enrich/geoip.go */ -type EnrichFunc func(string, *types.Event, *log.Entry) (map[string]string, error) -type InitFunc func(map[string]string) (interface{}, error) +type ( + EnrichFunc func(string, *types.Event, *log.Entry) (map[string]string, error) + InitFunc func(map[string]string) (interface{}, error) +) type EnricherCtx struct { Registered map[string]*Enricher diff --git a/pkg/parser/enrich_geoip.go b/pkg/parser/enrich_geoip.go index 1756927bc4b..79a70077283 100644 --- a/pkg/parser/enrich_geoip.go +++ b/pkg/parser/enrich_geoip.go @@ -18,7 +18,6 @@ func IpToRange(field string, p *types.Event, plog *log.Entry) (map[string]string } r, err := exprhelpers.GeoIPRangeEnrich(field) - if err != nil { plog.Errorf("Unable to enrich ip '%s'", field) return nil, nil //nolint:nilerr @@ -47,7 +46,6 @@ func GeoIpASN(field string, p *types.Event, plog *log.Entry) (map[string]string, } r, err := exprhelpers.GeoIPASNEnrich(field) - if err != nil { plog.Debugf("Unable to enrich ip '%s'", field) return nil, nil //nolint:nilerr @@ -81,7 +79,6 @@ func GeoIpCity(field string, p *types.Event, plog *log.Entry) (map[string]string } r, err := exprhelpers.GeoIPEnrich(field) - if err != nil { plog.Debugf("Unable to enrich ip '%s'", field) return nil, nil //nolint:nilerr diff --git a/pkg/parser/node.go b/pkg/parser/node.go index 26046ae4fd6..62a1ff6c4e2 100644 --- a/pkg/parser/node.go +++ b/pkg/parser/node.go @@ -3,6 +3,7 @@ package parser import ( "errors" "fmt" + "strconv" "strings" "time" @@ -236,7 +237,7 @@ func (n *Node) processGrok(p *types.Event, cachedExprEnv map[string]any) (bool, case string: gstr = out case int: - gstr = fmt.Sprintf("%d", out) + gstr = strconv.Itoa(out) case float64, float32: gstr = fmt.Sprintf("%f", out) default: @@ -357,16 +358,17 @@ func (n *Node) process(p *types.Event, ctx UnixParserCtx, expressionEnv map[stri } // Iterate on leafs - for _, leaf := range n.LeavesNodes { - ret, err := leaf.process(p, ctx, cachedExprEnv) + leaves := n.LeavesNodes + for idx := range leaves { + ret, err := leaves[idx].process(p, ctx, cachedExprEnv) if err != nil { - clog.Tracef("\tNode (%s) failed : %v", leaf.rn, err) + clog.Tracef("\tNode (%s) failed : %v", leaves[idx].rn, err) clog.Debugf("Event leaving node : ko") return false, err } - clog.Tracef("\tsub-node (%s) ret : %v (strategy:%s)", leaf.rn, ret, n.OnSuccess) + clog.Tracef("\tsub-node (%s) ret : %v (strategy:%s)", leaves[idx].rn, ret, n.OnSuccess) if ret { NodeState = true @@ -593,7 +595,7 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { /* compile leafs if present */ for idx := range n.LeavesNodes { if n.LeavesNodes[idx].Name == "" { - n.LeavesNodes[idx].Name = fmt.Sprintf("child-%s", n.Name) + n.LeavesNodes[idx].Name = "child-" + n.Name } /*propagate debug/stats to child nodes*/ if !n.LeavesNodes[idx].Debug && n.Debug { diff --git a/pkg/parser/parsing_test.go b/pkg/parser/parsing_test.go index 269d51a1ba2..5f6f924e7df 100644 --- a/pkg/parser/parsing_test.go +++ b/pkg/parser/parsing_test.go @@ -151,7 +151,7 @@ func testOneParser(pctx *UnixParserCtx, ectx EnricherCtx, dir string, b *testing b.ResetTimer() } - for range(count) { + for range count { if !testFile(tests, *pctx, pnodes) { return errors.New("test failed") } @@ -285,7 +285,7 @@ func matchEvent(expected types.Event, out types.Event, debug bool) ([]string, bo valid = true - for mapIdx := range(len(expectMaps)) { + for mapIdx := range len(expectMaps) { for expKey, expVal := range expectMaps[mapIdx] { outVal, ok := outMaps[mapIdx][expKey] if !ok { diff --git a/pkg/parser/runtime.go b/pkg/parser/runtime.go index 8068690b68f..7af82a71535 100644 --- a/pkg/parser/runtime.go +++ b/pkg/parser/runtime.go @@ -29,10 +29,11 @@ func SetTargetByName(target string, value string, evt *types.Event) bool { return false } - //it's a hack, we do it for the user + // it's a hack, we do it for the user target = strings.TrimPrefix(target, "evt.") log.Debugf("setting target %s to %s", target, value) + defer func() { if r := recover(); r != nil { log.Errorf("Runtime error while trying to set '%s': %+v", target, r) @@ -46,6 +47,7 @@ func SetTargetByName(target string, value string, evt *types.Event) bool { //event is nil return false } + for _, f := range strings.Split(target, ".") { /* ** According to current Event layout we only have to handle struct and map @@ -57,7 +59,9 @@ func SetTargetByName(target string, value string, evt *types.Event) bool { if (tmp == reflect.Value{}) || tmp.IsZero() { log.Debugf("map entry is zero in '%s'", target) } + iter.SetMapIndex(reflect.ValueOf(f), reflect.ValueOf(value)) + return true case reflect.Struct: tmp := iter.FieldByName(f) @@ -65,9 +69,11 @@ func SetTargetByName(target string, value string, evt *types.Event) bool { log.Debugf("'%s' is not a valid target because '%s' is not valid", target, f) return false } + if tmp.Kind() == reflect.Ptr { tmp = reflect.Indirect(tmp) } + iter = tmp case reflect.Ptr: tmp := iter.Elem() @@ -82,11 +88,14 @@ func SetTargetByName(target string, value string, evt *types.Event) bool { log.Errorf("'%s' can't be set", target) return false } + if iter.Kind() != reflect.String { log.Errorf("Expected string, got %v when handling '%s'", iter.Kind(), target) return false } + iter.Set(reflect.ValueOf(value)) + return true } @@ -248,14 +257,18 @@ func stageidx(stage string, stages []string) int { return -1 } -var ParseDump bool -var DumpFolder string +var ( + ParseDump bool + DumpFolder string +) -var StageParseCache dumps.ParserResults -var StageParseMutex sync.Mutex +var ( + StageParseCache dumps.ParserResults + StageParseMutex sync.Mutex +) func Parse(ctx UnixParserCtx, xp types.Event, nodes []Node) (types.Event, error) { - var event = xp + event := xp /* the stage is undefined, probably line is freshly acquired, set to first stage !*/ if event.Stage == "" && len(ctx.Stages) > 0 { @@ -317,46 +330,46 @@ func Parse(ctx UnixParserCtx, xp types.Event, nodes []Node) (types.Event, error) } isStageOK := false - for idx, node := range nodes { + for idx := range nodes { //Only process current stage's nodes - if event.Stage != node.Stage { + if event.Stage != nodes[idx].Stage { continue } clog := log.WithFields(log.Fields{ - "node-name": node.rn, + "node-name": nodes[idx].rn, "stage": event.Stage, }) - clog.Tracef("Processing node %d/%d -> %s", idx, len(nodes), node.rn) + clog.Tracef("Processing node %d/%d -> %s", idx, len(nodes), nodes[idx].rn) if ctx.Profiling { - node.Profiling = true + nodes[idx].Profiling = true } - ret, err := node.process(&event, ctx, map[string]interface{}{"evt": &event}) + ret, err := nodes[idx].process(&event, ctx, map[string]interface{}{"evt": &event}) if err != nil { clog.Errorf("Error while processing node : %v", err) return event, err } - clog.Tracef("node (%s) ret : %v", node.rn, ret) + clog.Tracef("node (%s) ret : %v", nodes[idx].rn, ret) if ParseDump { var parserIdxInStage int StageParseMutex.Lock() - if len(StageParseCache[stage][node.Name]) == 0 { - StageParseCache[stage][node.Name] = make([]dumps.ParserResult, 0) + if len(StageParseCache[stage][nodes[idx].Name]) == 0 { + StageParseCache[stage][nodes[idx].Name] = make([]dumps.ParserResult, 0) parserIdxInStage = len(StageParseCache[stage]) } else { - parserIdxInStage = StageParseCache[stage][node.Name][0].Idx + parserIdxInStage = StageParseCache[stage][nodes[idx].Name][0].Idx } StageParseMutex.Unlock() evtcopy := deepcopy.Copy(event) parserInfo := dumps.ParserResult{Evt: evtcopy.(types.Event), Success: ret, Idx: parserIdxInStage} StageParseMutex.Lock() - StageParseCache[stage][node.Name] = append(StageParseCache[stage][node.Name], parserInfo) + StageParseCache[stage][nodes[idx].Name] = append(StageParseCache[stage][nodes[idx].Name], parserInfo) StageParseMutex.Unlock() } if ret { isStageOK = true } - if ret && node.OnSuccess == "next_stage" { + if ret && nodes[idx].OnSuccess == "next_stage" { clog.Debugf("node successful, stop end stage %s", stage) break } diff --git a/pkg/parser/unix_parser.go b/pkg/parser/unix_parser.go index 351de8ade56..f0f26a06645 100644 --- a/pkg/parser/unix_parser.go +++ b/pkg/parser/unix_parser.go @@ -43,7 +43,7 @@ func Init(c map[string]interface{}) (*UnixParserCtx, error) { } r.DataFolder = c["data"].(string) for _, f := range files { - if strings.Contains(f.Name(), ".") { + if strings.Contains(f.Name(), ".") || f.IsDir() { continue } if err := r.Grok.AddFromFile(filepath.Join(c["patterns"].(string), f.Name())); err != nil { diff --git a/pkg/types/appsec_event.go b/pkg/types/appsec_event.go index dc81c63b344..54163f53fef 100644 --- a/pkg/types/appsec_event.go +++ b/pkg/types/appsec_event.go @@ -18,7 +18,9 @@ len(evt.Waf.ByTagRx("*CVE*").ByConfidence("high").ByAction("block")) > 1 */ -type MatchedRules []map[string]interface{} +type MatchedRules []MatchedRule + +type MatchedRule map[string]interface{} type AppsecEvent struct { HasInBandMatches, HasOutBandMatches bool @@ -45,6 +47,10 @@ const ( Kind Field = "kind" ) +func NewMatchedRule() *MatchedRule { + return &MatchedRule{} +} + func (w AppsecEvent) GetVar(varName string) string { if w.Vars == nil { return "" @@ -54,7 +60,6 @@ func (w AppsecEvent) GetVar(varName string) string { } log.Infof("var %s not found. Available variables: %+v", varName, w.Vars) return "" - } // getters diff --git a/pkg/types/constants.go b/pkg/types/constants.go index acb5b5bfacf..2421b076b97 100644 --- a/pkg/types/constants.go +++ b/pkg/types/constants.go @@ -1,23 +1,29 @@ package types -const ApiKeyAuthType = "api-key" -const TlsAuthType = "tls" -const PasswordAuthType = "password" +const ( + ApiKeyAuthType = "api-key" + TlsAuthType = "tls" + PasswordAuthType = "password" +) -const PAPIBaseURL = "https://papi.api.crowdsec.net/" -const PAPIVersion = "v1" -const PAPIPollUrl = "/decisions/stream/poll" -const PAPIPermissionsUrl = "/permissions" +const ( + PAPIBaseURL = "https://papi.api.crowdsec.net/" + PAPIVersion = "v1" + PAPIPollUrl = "/decisions/stream/poll" + PAPIPermissionsUrl = "/permissions" +) const CAPIBaseURL = "https://api.crowdsec.net/" -const CscliOrigin = "cscli" -const CrowdSecOrigin = "crowdsec" -const ConsoleOrigin = "console" -const CscliImportOrigin = "cscli-import" -const ListOrigin = "lists" -const CAPIOrigin = "CAPI" -const CommunityBlocklistPullSourceScope = "crowdsecurity/community-blocklist" +const ( + CscliOrigin = "cscli" + CrowdSecOrigin = "crowdsec" + ConsoleOrigin = "console" + CscliImportOrigin = "cscli-import" + ListOrigin = "lists" + CAPIOrigin = "CAPI" + CommunityBlocklistPullSourceScope = "crowdsecurity/community-blocklist" +) const DecisionTypeBan = "ban" diff --git a/pkg/types/event.go b/pkg/types/event.go index e016d0294c4..0b09bf7cbdf 100644 --- a/pkg/types/event.go +++ b/pkg/types/event.go @@ -47,6 +47,23 @@ type Event struct { Meta map[string]string `yaml:"Meta,omitempty" json:"Meta,omitempty"` } +func MakeEvent(timeMachine bool, evtType int, process bool) Event { + evt := Event{ + Parsed: make(map[string]string), + Meta: make(map[string]string), + Unmarshaled: make(map[string]interface{}), + Enriched: make(map[string]string), + ExpectMode: LIVE, + Process: process, + Type: evtType, + } + if timeMachine { + evt.ExpectMode = TIMEMACHINE + } + + return evt +} + func (e *Event) SetMeta(key string, value string) bool { if e.Meta == nil { e.Meta = make(map[string]string) @@ -81,8 +98,9 @@ func (e *Event) GetType() string { func (e *Event) GetMeta(key string) string { if e.Type == OVFLW { - for _, alert := range e.Overflow.APIAlerts { - for _, event := range alert.Events { + alerts := e.Overflow.APIAlerts + for idx := range alerts { + for _, event := range alerts[idx].Events { if event.GetMeta(key) != "" { return event.GetMeta(key) } diff --git a/pkg/types/event_test.go b/pkg/types/event_test.go index 97b13f96d9a..638e42fe757 100644 --- a/pkg/types/event_test.go +++ b/pkg/types/event_test.go @@ -46,7 +46,6 @@ func TestSetParsed(t *testing.T) { assert.Equal(t, tt.value, tt.evt.Parsed[tt.key]) }) } - } func TestSetMeta(t *testing.T) { @@ -86,7 +85,6 @@ func TestSetMeta(t *testing.T) { assert.Equal(t, tt.value, tt.evt.GetMeta(tt.key)) }) } - } func TestParseIPSources(t *testing.T) { diff --git a/pkg/types/getfstype.go b/pkg/types/getfstype.go index 728e986bed0..c16fe86ec9c 100644 --- a/pkg/types/getfstype.go +++ b/pkg/types/getfstype.go @@ -100,7 +100,6 @@ func GetFSType(path string) (string, error) { var buf unix.Statfs_t err := unix.Statfs(path, &buf) - if err != nil { return "", err } diff --git a/pkg/types/ip.go b/pkg/types/ip.go index 9d08afd8809..47fb3fc83a5 100644 --- a/pkg/types/ip.go +++ b/pkg/types/ip.go @@ -23,7 +23,8 @@ func LastAddress(n net.IPNet) net.IP { ip[6] | ^n.Mask[6], ip[7] | ^n.Mask[7], ip[8] | ^n.Mask[8], ip[9] | ^n.Mask[9], ip[10] | ^n.Mask[10], ip[11] | ^n.Mask[11], ip[12] | ^n.Mask[12], ip[13] | ^n.Mask[13], ip[14] | ^n.Mask[14], - ip[15] | ^n.Mask[15]} + ip[15] | ^n.Mask[15], + } } return net.IPv4( diff --git a/pkg/types/ip_test.go b/pkg/types/ip_test.go index f8c14b12e3c..b9298ba487f 100644 --- a/pkg/types/ip_test.go +++ b/pkg/types/ip_test.go @@ -8,21 +8,20 @@ import ( ) func TestIP2Int(t *testing.T) { - tEmpty := net.IP{} _, _, _, err := IP2Ints(tEmpty) if !strings.Contains(err.Error(), "unexpected len 0 for ") { t.Fatalf("unexpected: %s", err) } } + func TestRange2Int(t *testing.T) { tEmpty := net.IPNet{} - //empty item + // empty item _, _, _, _, _, err := Range2Ints(tEmpty) if !strings.Contains(err.Error(), "converting first ip in range") { t.Fatalf("unexpected: %s", err) } - } func TestAdd2Int(t *testing.T) { diff --git a/pkg/types/utils.go b/pkg/types/utils.go index 712d44ba12d..3e1ae4f7547 100644 --- a/pkg/types/utils.go +++ b/pkg/types/utils.go @@ -10,9 +10,11 @@ import ( "gopkg.in/natefinch/lumberjack.v2" ) -var logFormatter log.Formatter -var LogOutput *lumberjack.Logger //io.Writer -var logLevel log.Level +var ( + logFormatter log.Formatter + LogOutput *lumberjack.Logger // io.Writer + logLevel log.Level +) func SetDefaultLoggerConfig(cfgMode string, cfgFolder string, cfgLevel log.Level, maxSize int, maxFiles int, maxAge int, compress *bool, forceColors bool) error { /*Configure logs*/ diff --git a/rpm/SPECS/crowdsec.spec b/rpm/SPECS/crowdsec.spec index ab71b650d11..ac438ad0c14 100644 --- a/rpm/SPECS/crowdsec.spec +++ b/rpm/SPECS/crowdsec.spec @@ -12,7 +12,7 @@ Patch0: user.patch BuildRoot: %{_tmppath}/%{name}-%{version}-%{release}-root-%(%{__id_u} -n) BuildRequires: systemd -Requires: crontabs +Requires: (crontabs or cron) %{?fc33:BuildRequires: systemd-rpm-macros} %{?fc34:BuildRequires: systemd-rpm-macros} %{?fc35:BuildRequires: systemd-rpm-macros} diff --git a/test/ansible/vagrant/fedora-40/Vagrantfile b/test/ansible/vagrant/fedora-40/Vagrantfile index ec03661fe39..5541d453acf 100644 --- a/test/ansible/vagrant/fedora-40/Vagrantfile +++ b/test/ansible/vagrant/fedora-40/Vagrantfile @@ -1,7 +1,7 @@ # frozen_string_literal: true Vagrant.configure('2') do |config| - config.vm.box = "fedora/39-cloud-base" + config.vm.box = "fedora/40-cloud-base" config.vm.provision "shell", inline: <<-SHELL SHELL end diff --git a/test/ansible/vagrant/fedora-41/Vagrantfile b/test/ansible/vagrant/fedora-41/Vagrantfile new file mode 100644 index 00000000000..3f905f51671 --- /dev/null +++ b/test/ansible/vagrant/fedora-41/Vagrantfile @@ -0,0 +1,13 @@ +# frozen_string_literal: true + +Vagrant.configure('2') do |config| + config.vm.box = "fedora/40-cloud-base" + config.vm.provision "shell", inline: <<-SHELL + SHELL + config.vm.provision "shell" do |s| + s.inline = "sudo dnf upgrade --refresh -y && sudo dnf install dnf-plugin-system-upgrade -y && sudo dnf system-upgrade download --releasever=41 -y && sudo dnf system-upgrade reboot -y" + end +end + +common = '../common' +load common if File.exist?(common) diff --git a/test/ansible/vagrant/fedora-41/skip b/test/ansible/vagrant/fedora-41/skip new file mode 100644 index 00000000000..4f1a9063d2b --- /dev/null +++ b/test/ansible/vagrant/fedora-41/skip @@ -0,0 +1,9 @@ +#!/bin/sh + +die() { + echo "$@" >&2 + exit 1 +} + +[ "${DB_BACKEND}" = "mysql" ] && die "mysql role does not support this distribution" +exit 0 diff --git a/test/ansible/vagrant/opensuse-leap-15/Vagrantfile b/test/ansible/vagrant/opensuse-leap-15/Vagrantfile new file mode 100644 index 00000000000..d10e68a50a7 --- /dev/null +++ b/test/ansible/vagrant/opensuse-leap-15/Vagrantfile @@ -0,0 +1,10 @@ +# frozen_string_literal: true + +Vagrant.configure('2') do |config| + config.vm.box = "opensuse/Leap-15.6.x86_64" + config.vm.provision "shell", inline: <<-SHELL + SHELL +end + +common = '../common' +load common if File.exist?(common) diff --git a/test/ansible/vagrant/opensuse-leap-15/skip b/test/ansible/vagrant/opensuse-leap-15/skip new file mode 100644 index 00000000000..4f1a9063d2b --- /dev/null +++ b/test/ansible/vagrant/opensuse-leap-15/skip @@ -0,0 +1,9 @@ +#!/bin/sh + +die() { + echo "$@" >&2 + exit 1 +} + +[ "${DB_BACKEND}" = "mysql" ] && die "mysql role does not support this distribution" +exit 0 diff --git a/test/bats/10_bouncers.bats b/test/bats/10_bouncers.bats index f99913dcee5..b1c90116dd2 100644 --- a/test/bats/10_bouncers.bats +++ b/test/bats/10_bouncers.bats @@ -63,7 +63,7 @@ teardown() { @test "delete non-existent bouncer" { # this is a fatal error, which is not consistent with "machines delete" rune -1 cscli bouncers delete something - assert_stderr --partial "unable to delete bouncer: 'something' does not exist" + assert_stderr --partial "unable to delete bouncer something: ent: bouncer not found" rune -0 cscli bouncers delete something --ignore-missing refute_stderr } @@ -144,3 +144,56 @@ teardown() { rune -0 cscli bouncers prune assert_output 'No bouncers to prune.' } + +curl_localhost() { + [[ -z "$API_KEY" ]] && { fail "${FUNCNAME[0]}: missing API_KEY"; } + local path=$1 + shift + curl "localhost:8080$path" -sS --fail-with-body -H "X-Api-Key: $API_KEY" "$@" +} + +# We can't use curl-with-key here, as we want to query localhost, not 127.0.0.1 +@test "multiple bouncers sharing api key" { + export API_KEY=bouncerkey + + # crowdsec needs to listen on all interfaces + rune -0 ./instance-crowdsec stop + rune -0 config_set 'del(.api.server.listen_socket) | del(.api.server.listen_uri)' + echo "{'api':{'server':{'listen_uri':0.0.0.0:8080}}}" >"${CONFIG_YAML}.local" + + rune -0 ./instance-crowdsec start + + # add a decision for our bouncers + rune -0 cscli decisions add -i '1.2.3.5' + + rune -0 cscli bouncers add test-auto -k "$API_KEY" + + # query with 127.0.0.1 as source ip + rune -0 curl_localhost "/v1/decisions/stream" -4 + rune -0 jq -r '.new' <(output) + assert_output --partial '1.2.3.5' + + # now with ::1, we should get the same IP, even though we are using the same key + rune -0 curl_localhost "/v1/decisions/stream" -6 + rune -0 jq -r '.new' <(output) + assert_output --partial '1.2.3.5' + + rune -0 cscli bouncers list -o json + rune -0 jq -c '[.[] | [.name,.revoked,.ip_address,.auto_created]]' <(output) + assert_json '[["test-auto",false,"127.0.0.1",false],["test-auto@::1",false,"::1",true]]' + + # check the 2nd bouncer was created automatically + rune -0 cscli bouncers inspect "test-auto@::1" -o json + rune -0 jq -r '.ip_address' <(output) + assert_output --partial '::1' + + # attempt to delete the auto-created bouncer, it should fail + rune -0 cscli bouncers delete 'test-auto@::1' + assert_stderr --partial 'cannot be deleted' + + # delete the "real" bouncer, it should delete both + rune -0 cscli bouncers delete 'test-auto' + + rune -0 cscli bouncers list -o json + assert_json [] +} diff --git a/test/lib/init/crowdsec-daemon b/test/lib/init/crowdsec-daemon index a232f344b6a..ba8e98992db 100755 --- a/test/lib/init/crowdsec-daemon +++ b/test/lib/init/crowdsec-daemon @@ -51,7 +51,11 @@ stop() { PGID="$(ps -o pgid= -p "$(cat "${DAEMON_PID}")" | tr -d ' ')" # ps above should work on linux, freebsd, busybox.. if [[ -n "${PGID}" ]]; then - kill -- "-${PGID}" + kill -- "-${PGID}" + + while pgrep -g "${PGID}" >/dev/null; do + sleep .05 + done fi rm -f -- "${DAEMON_PID}"