From dbb9adcd225c79b1f84a9ac0ee3fdd363d4214d2 Mon Sep 17 00:00:00 2001
From: Sebastien Blot
Date: Thu, 7 Nov 2024 11:03:26 +0100
Subject: [PATCH] Centralized allowlists support
---
.github/workflows/go-tests-windows.yml | 2 +-
.github/workflows/go-tests.yml | 2 +-
.golangci.yml | 38 +-
Makefile | 61 +-
README.md | 2 +-
azure-pipelines.yml | 2 +-
cmd/crowdsec-cli/cliallowlists/allowlists.go | 519 +++++
cmd/crowdsec-cli/clibouncer/add.go | 2 +-
cmd/crowdsec-cli/clibouncer/bouncers.go | 2 +
cmd/crowdsec-cli/clibouncer/delete.go | 62 +-
cmd/crowdsec-cli/clibouncer/inspect.go | 1 +
cmd/crowdsec-cli/clicapi/capi.go | 21 +
.../clinotifications/notifications.go | 2 +-
cmd/crowdsec-cli/clipapi/papi.go | 2 +-
cmd/crowdsec-cli/dashboard.go | 4 +-
cmd/crowdsec-cli/main.go | 5 +
cmd/crowdsec-cli/setup.go | 1 +
cmd/crowdsec/appsec.go | 2 +-
cmd/crowdsec/crowdsec.go | 12 +-
cmd/crowdsec/lapiclient.go | 4 +-
cmd/crowdsec/main.go | 4 +-
cmd/notification-email/main.go | 2 +-
go.mod | 2 +-
pkg/acquisition/acquisition.go | 24 +-
.../configuration/configuration.go | 10 +-
pkg/acquisition/http.go | 12 +
pkg/acquisition/modules/appsec/appsec.go | 147 +-
.../modules/appsec/appsec_hooks_test.go | 4 +-
.../modules/appsec/appsec_rules_test.go | 119 +-
.../modules/appsec/appsec_runner.go | 50 +-
.../modules/appsec/appsec_runner_test.go | 153 ++
pkg/acquisition/modules/appsec/appsec_test.go | 24 +-
.../modules/appsec/bodyprocessors/raw.go | 7 +-
pkg/acquisition/modules/appsec/utils.go | 168 +-
.../modules/cloudwatch/cloudwatch.go | 5 +-
pkg/acquisition/modules/docker/docker.go | 13 +-
pkg/acquisition/modules/file/file.go | 10 +-
pkg/acquisition/modules/http/http.go | 414 ++++
pkg/acquisition/modules/http/http_test.go | 784 +++++++
pkg/acquisition/modules/http/testdata/ca.crt | 23 +
.../modules/http/testdata/client.crt | 24 +
.../modules/http/testdata/client.key | 27 +
.../modules/http/testdata/server.crt | 23 +
.../modules/http/testdata/server.key | 27 +
.../modules/journalctl/journalctl.go | 9 +-
pkg/acquisition/modules/kafka/kafka.go | 9 +-
pkg/acquisition/modules/kinesis/kinesis.go | 8 +-
.../modules/kubernetesaudit/k8s_audit.go | 32 +-
.../loki/internal/lokiclient/loki_client.go | 9 +-
pkg/acquisition/modules/loki/loki.go | 61 +-
pkg/acquisition/modules/loki/loki_test.go | 23 +-
pkg/acquisition/modules/s3/s3.go | 8 +-
.../syslog/internal/parser/rfc3164/parse.go | 1 -
.../syslog/internal/parser/rfc5424/parse.go | 3 -
.../internal/parser/rfc5424/parse_test.go | 58 +-
.../syslog/internal/server/syslogserver.go | 1 -
pkg/acquisition/modules/syslog/syslog.go | 8 +-
.../wineventlog/wineventlog_windows.go | 14 +-
pkg/alertcontext/alertcontext.go | 152 +-
pkg/alertcontext/alertcontext_test.go | 161 ++
pkg/apiclient/allowlists_service.go | 91 +
pkg/apiclient/auth_jwt.go | 3 -
pkg/apiclient/client.go | 69 +
pkg/apiclient/decisions_service.go | 13 +
pkg/apiclient/decisions_service_test.go | 27 +-
pkg/apiserver/alerts_test.go | 41 +-
pkg/apiserver/allowlists_test.go | 119 +
pkg/apiserver/api_key_test.go | 58 +-
pkg/apiserver/apic.go | 164 +-
pkg/apiserver/apic_metrics.go | 8 +-
pkg/apiserver/apic_test.go | 32 +-
pkg/apiserver/apiserver.go | 28 +-
pkg/apiserver/apiserver_test.go | 20 +-
pkg/apiserver/controllers/controller.go | 5 +
pkg/apiserver/controllers/v1/alerts.go | 21 +-
pkg/apiserver/controllers/v1/allowlist.go | 126 +
pkg/apiserver/decisions_test.go | 16 +-
pkg/apiserver/jwt_test.go | 10 +-
pkg/apiserver/machines_test.go | 6 +-
pkg/apiserver/middlewares/v1/api_key.go | 79 +-
pkg/apiserver/papi.go | 6 +-
pkg/apiserver/papi_cmd.go | 97 +-
pkg/appsec/appsec.go | 73 +-
pkg/appsec/appsec_rule/appsec_rule.go | 1 -
pkg/appsec/appsec_rule/modsec_rule_test.go | 2 -
pkg/appsec/coraza_logger.go | 2 +-
pkg/appsec/request_test.go | 3 -
pkg/cache/cache_test.go | 13 +-
pkg/csconfig/api.go | 24 +-
pkg/csconfig/api_test.go | 5 +
pkg/csplugin/broker.go | 4 +-
pkg/cticlient/types.go | 2 -
pkg/cwversion/component/component.go | 29 +-
pkg/database/allowlists.go | 255 +++
pkg/database/allowlists_test.go | 99 +
pkg/database/bouncers.go | 19 +-
pkg/database/ent/allowlist.go | 189 ++
pkg/database/ent/allowlist/allowlist.go | 133 ++
pkg/database/ent/allowlist/where.go | 429 ++++
pkg/database/ent/allowlist_create.go | 321 +++
pkg/database/ent/allowlist_delete.go | 88 +
pkg/database/ent/allowlist_query.go | 636 ++++++
pkg/database/ent/allowlist_update.go | 421 ++++
pkg/database/ent/allowlistitem.go | 231 ++
.../ent/allowlistitem/allowlistitem.go | 165 ++
pkg/database/ent/allowlistitem/where.go | 664 ++++++
pkg/database/ent/allowlistitem_create.go | 398 ++++
pkg/database/ent/allowlistitem_delete.go | 88 +
pkg/database/ent/allowlistitem_query.go | 636 ++++++
pkg/database/ent/allowlistitem_update.go | 463 ++++
pkg/database/ent/bouncer.go | 13 +-
pkg/database/ent/bouncer/bouncer.go | 10 +
pkg/database/ent/bouncer/where.go | 15 +
pkg/database/ent/bouncer_create.go | 25 +
pkg/database/ent/client.go | 374 ++-
pkg/database/ent/ent.go | 22 +-
pkg/database/ent/hook/hook.go | 24 +
pkg/database/ent/migrate/schema.go | 91 +
pkg/database/ent/mutation.go | 2022 ++++++++++++++++-
pkg/database/ent/predicate/predicate.go | 6 +
pkg/database/ent/runtime.go | 30 +
pkg/database/ent/schema/allowlist.go | 44 +
pkg/database/ent/schema/allowlist_item.go | 51 +
pkg/database/ent/schema/bouncer.go | 2 +
pkg/database/ent/tx.go | 6 +
pkg/database/flush.go | 22 +
pkg/exprhelpers/crowdsec_cti.go | 20 +-
pkg/exprhelpers/geoip.go | 3 -
pkg/fflag/crowdsec.go | 14 +-
pkg/leakybucket/blackhole.go | 2 -
pkg/leakybucket/bucket.go | 1 -
pkg/leakybucket/buckets.go | 1 -
pkg/leakybucket/conditional.go | 6 +-
pkg/leakybucket/manager_load_test.go | 59 +-
pkg/leakybucket/manager_run.go | 13 +-
pkg/leakybucket/overflows.go | 27 +-
pkg/leakybucket/processor.go | 3 +-
pkg/leakybucket/reset_filter.go | 10 +-
pkg/leakybucket/uniq.go | 6 +-
pkg/longpollclient/client.go | 21 +-
pkg/models/allowlist_item.go | 100 +
pkg/models/check_allowlist_response.go | 50 +
pkg/models/get_allowlist_response.go | 174 ++
pkg/models/get_allowlists_response.go | 78 +
pkg/models/localapi_swagger.yaml | 173 ++
pkg/modelscapi/allowlist_link.go | 166 ++
pkg/modelscapi/centralapi_swagger.yaml | 48 +
.../get_decisions_stream_response_links.go | 62 +
pkg/parser/enrich.go | 6 +-
pkg/parser/enrich_geoip.go | 3 -
pkg/parser/node.go | 14 +-
pkg/parser/parsing_test.go | 4 +-
pkg/parser/runtime.go | 49 +-
pkg/parser/unix_parser.go | 2 +-
pkg/types/appsec_event.go | 9 +-
pkg/types/constants.go | 34 +-
pkg/types/event.go | 22 +-
pkg/types/event_test.go | 2 -
pkg/types/getfstype.go | 1 -
pkg/types/ip.go | 3 +-
pkg/types/ip_test.go | 5 +-
pkg/types/utils.go | 8 +-
rpm/SPECS/crowdsec.spec | 2 +-
test/ansible/vagrant/fedora-40/Vagrantfile | 2 +-
test/ansible/vagrant/fedora-41/Vagrantfile | 13 +
test/ansible/vagrant/fedora-41/skip | 9 +
.../vagrant/opensuse-leap-15/Vagrantfile | 10 +
test/ansible/vagrant/opensuse-leap-15/skip | 9 +
test/bats/10_bouncers.bats | 55 +-
test/lib/init/crowdsec-daemon | 6 +-
170 files changed, 13064 insertions(+), 809 deletions(-)
create mode 100644 cmd/crowdsec-cli/cliallowlists/allowlists.go
create mode 100644 pkg/acquisition/http.go
create mode 100644 pkg/acquisition/modules/appsec/appsec_runner_test.go
create mode 100644 pkg/acquisition/modules/http/http.go
create mode 100644 pkg/acquisition/modules/http/http_test.go
create mode 100644 pkg/acquisition/modules/http/testdata/ca.crt
create mode 100644 pkg/acquisition/modules/http/testdata/client.crt
create mode 100644 pkg/acquisition/modules/http/testdata/client.key
create mode 100644 pkg/acquisition/modules/http/testdata/server.crt
create mode 100644 pkg/acquisition/modules/http/testdata/server.key
create mode 100644 pkg/apiclient/allowlists_service.go
create mode 100644 pkg/apiserver/allowlists_test.go
create mode 100644 pkg/apiserver/controllers/v1/allowlist.go
create mode 100644 pkg/database/allowlists.go
create mode 100644 pkg/database/allowlists_test.go
create mode 100644 pkg/database/ent/allowlist.go
create mode 100644 pkg/database/ent/allowlist/allowlist.go
create mode 100644 pkg/database/ent/allowlist/where.go
create mode 100644 pkg/database/ent/allowlist_create.go
create mode 100644 pkg/database/ent/allowlist_delete.go
create mode 100644 pkg/database/ent/allowlist_query.go
create mode 100644 pkg/database/ent/allowlist_update.go
create mode 100644 pkg/database/ent/allowlistitem.go
create mode 100644 pkg/database/ent/allowlistitem/allowlistitem.go
create mode 100644 pkg/database/ent/allowlistitem/where.go
create mode 100644 pkg/database/ent/allowlistitem_create.go
create mode 100644 pkg/database/ent/allowlistitem_delete.go
create mode 100644 pkg/database/ent/allowlistitem_query.go
create mode 100644 pkg/database/ent/allowlistitem_update.go
create mode 100644 pkg/database/ent/schema/allowlist.go
create mode 100644 pkg/database/ent/schema/allowlist_item.go
create mode 100644 pkg/models/allowlist_item.go
create mode 100644 pkg/models/check_allowlist_response.go
create mode 100644 pkg/models/get_allowlist_response.go
create mode 100644 pkg/models/get_allowlists_response.go
create mode 100644 pkg/modelscapi/allowlist_link.go
create mode 100644 test/ansible/vagrant/fedora-41/Vagrantfile
create mode 100644 test/ansible/vagrant/fedora-41/skip
create mode 100644 test/ansible/vagrant/opensuse-leap-15/Vagrantfile
create mode 100644 test/ansible/vagrant/opensuse-leap-15/skip
diff --git a/.github/workflows/go-tests-windows.yml b/.github/workflows/go-tests-windows.yml
index 2966b999a4a..3276dbb1bfd 100644
--- a/.github/workflows/go-tests-windows.yml
+++ b/.github/workflows/go-tests-windows.yml
@@ -61,6 +61,6 @@ jobs:
- name: golangci-lint
uses: golangci/golangci-lint-action@v6
with:
- version: v1.61
+ version: v1.62
args: --issues-exit-code=1 --timeout 10m
only-new-issues: false
diff --git a/.github/workflows/go-tests.yml b/.github/workflows/go-tests.yml
index 3f4aa67e139..3638696b4f6 100644
--- a/.github/workflows/go-tests.yml
+++ b/.github/workflows/go-tests.yml
@@ -190,6 +190,6 @@ jobs:
- name: golangci-lint
uses: golangci/golangci-lint-action@v6
with:
- version: v1.61
+ version: v1.62
args: --issues-exit-code=1 --timeout 10m
only-new-issues: false
diff --git a/.golangci.yml b/.golangci.yml
index 271e3a57d34..7217c6da2b1 100644
--- a/.golangci.yml
+++ b/.golangci.yml
@@ -183,7 +183,6 @@ linters-settings:
- ifElseChain
- importShadow
- hugeParam
- - rangeValCopy
- commentedOutCode
- commentedOutImport
- unnamedResult
@@ -211,9 +210,7 @@ linters:
#
# DEPRECATED by golangi-lint
#
- - execinquery
- exportloopref
- - gomnd
#
# Redundant
@@ -348,10 +345,6 @@ issues:
- errorlint
text: "type switch on error will fail on wrapped errors. Use errors.As to check for specific errors"
- - linters:
- - errorlint
- text: "comparing with .* will fail on wrapped errors. Use errors.Is to check for a specific error"
-
- linters:
- nosprintfhostport
text: "host:port in url should be constructed with net.JoinHostPort and not directly with fmt.Sprintf"
@@ -460,3 +453,34 @@ issues:
- revive
path: "cmd/crowdsec/win_service.go"
text: "deep-exit: .*"
+
+ - linters:
+ - recvcheck
+ path: "pkg/csplugin/hclog_adapter.go"
+ text: 'the methods of "HCLogAdapter" use pointer receiver and non-pointer receiver.'
+
+ # encoding to json/yaml requires value receivers
+ - linters:
+ - recvcheck
+ path: "pkg/cwhub/item.go"
+ text: 'the methods of "Item" use pointer receiver and non-pointer receiver.'
+
+ - linters:
+ - gocritic
+ path: "cmd/crowdsec-cli"
+ text: "rangeValCopy: .*"
+
+ - linters:
+ - gocritic
+ path: "pkg/(cticlient|hubtest)"
+ text: "rangeValCopy: .*"
+
+ - linters:
+ - gocritic
+ path: "(.+)_test.go"
+ text: "rangeValCopy: .*"
+
+ - linters:
+ - gocritic
+ path: "pkg/(appsec|acquisition|dumps|alertcontext|leakybucket|exprhelpers)"
+ text: "rangeValCopy: .*"
diff --git a/Makefile b/Makefile
index 29a84d5b066..f8ae66e1cb6 100644
--- a/Makefile
+++ b/Makefile
@@ -22,7 +22,7 @@ BUILD_RE2_WASM ?= 0
# for your distribution (look for libre2.a). See the Dockerfile for an example of how to build it.
BUILD_STATIC ?= 0
-# List of plugins to build
+# List of notification plugins to build
PLUGINS ?= $(patsubst ./cmd/notification-%,%,$(wildcard ./cmd/notification-*))
#--------------------------------------
@@ -80,9 +80,17 @@ endif
#expr_debug tag is required to enable the debug mode in expr
GO_TAGS := netgo,osusergo,sqlite_omit_load_extension,expr_debug
+# Allow building on ubuntu 24.10, see https://github.com/golang/go/issues/70023
+export CGO_LDFLAGS_ALLOW=-Wl,--(push|pop)-state.*
+
# this will be used by Go in the make target, some distributions require it
export PKG_CONFIG_PATH:=/usr/local/lib/pkgconfig:$(PKG_CONFIG_PATH)
+#--------------------------------------
+#
+# Choose the re2 backend.
+#
+
ifeq ($(call bool,$(BUILD_RE2_WASM)),0)
ifeq ($(PKG_CONFIG),)
$(error "pkg-config is not available. Please install pkg-config.")
@@ -90,35 +98,28 @@ endif
ifeq ($(RE2_CHECK),)
RE2_FAIL := "libre2-dev is not installed, please install it or set BUILD_RE2_WASM=1 to use the WebAssembly version"
+# if you prefer to build WASM instead of a critical error, comment out RE2_FAIL and uncomment RE2_MSG.
+# RE2_MSG := Fallback to WebAssembly regexp library. To use the C++ version, make sure you have installed libre2-dev and pkg-config.
else
# += adds a space that we don't want
GO_TAGS := $(GO_TAGS),re2_cgo
LD_OPTS_VARS += -X '$(GO_MODULE_NAME)/pkg/cwversion.Libre2=C++'
+RE2_MSG := Using C++ regexp library
endif
-endif
-
-# Build static to avoid the runtime dependency on libre2.so
-ifeq ($(call bool,$(BUILD_STATIC)),1)
-BUILD_TYPE = static
-EXTLDFLAGS := -extldflags '-static'
else
-BUILD_TYPE = dynamic
-EXTLDFLAGS :=
+RE2_MSG := Using WebAssembly regexp library
endif
-# Build with debug symbols, and disable optimizations + inlining, to use Delve
-ifeq ($(call bool,$(DEBUG)),1)
-STRIP_SYMBOLS :=
-DISABLE_OPTIMIZATION := -gcflags "-N -l"
+ifeq ($(call bool,$(BUILD_RE2_WASM)),1)
else
-STRIP_SYMBOLS := -s
-DISABLE_OPTIMIZATION :=
+ifneq (,$(RE2_CHECK))
+endif
endif
#--------------------------------------
-
+#
# Handle optional components and build profiles, to save space on the final binaries.
-
+#
# Keep it safe for now until we decide how to expand on the idea. Either choose a profile or exclude components manually.
# For example if we want to disable some component by default, or have opt-in components (INCLUDE?).
@@ -131,6 +132,7 @@ COMPONENTS := \
datasource_cloudwatch \
datasource_docker \
datasource_file \
+ datasource_http \
datasource_k8saudit \
datasource_kafka \
datasource_journalctl \
@@ -178,6 +180,23 @@ endif
#--------------------------------------
+ifeq ($(call bool,$(BUILD_STATIC)),1)
+BUILD_TYPE = static
+EXTLDFLAGS := -extldflags '-static'
+else
+BUILD_TYPE = dynamic
+EXTLDFLAGS :=
+endif
+
+# Build with debug symbols, and disable optimizations + inlining, to use Delve
+ifeq ($(call bool,$(DEBUG)),1)
+STRIP_SYMBOLS :=
+DISABLE_OPTIMIZATION := -gcflags "-N -l"
+else
+STRIP_SYMBOLS := -s
+DISABLE_OPTIMIZATION :=
+endif
+
export LD_OPTS=-ldflags "$(STRIP_SYMBOLS) $(EXTLDFLAGS) $(LD_OPTS_VARS)" \
-trimpath -tags $(GO_TAGS) $(DISABLE_OPTIMIZATION)
@@ -193,17 +212,13 @@ build: build-info crowdsec cscli plugins ## Build crowdsec, cscli and plugins
.PHONY: build-info
build-info: ## Print build information
$(info Building $(BUILD_VERSION) ($(BUILD_TAG)) $(BUILD_TYPE) for $(GOOS)/$(GOARCH))
- $(info Excluded components: $(EXCLUDE_LIST))
+ $(info Excluded components: $(if $(EXCLUDE_LIST),$(EXCLUDE_LIST),none))
ifneq (,$(RE2_FAIL))
$(error $(RE2_FAIL))
endif
-ifneq (,$(RE2_CHECK))
- $(info Using C++ regexp library)
-else
- $(info Fallback to WebAssembly regexp library. To use the C++ version, make sure you have installed libre2-dev and pkg-config.)
-endif
+ $(info $(RE2_MSG))
ifeq ($(call bool,$(DEBUG)),1)
$(info Building with debug symbols and disabled optimizations)
diff --git a/README.md b/README.md
index a900f0ee514..1e57d4e91c4 100644
--- a/README.md
+++ b/README.md
@@ -84,7 +84,7 @@ The architecture is as follows :
-Once an unwanted behavior is detected, deal with it through a [bouncer](https://hub.crowdsec.net/browse/#bouncers). The aggressive IP, scenario triggered and timestamp are sent for curation, to avoid poisoning & false positives. (This can be disabled). If verified, this IP is then redistributed to all CrowdSec users running the same scenario.
+Once an unwanted behavior is detected, deal with it through a [bouncer](https://app.crowdsec.net/hub/remediation-components). The aggressive IP, scenario triggered and timestamp are sent for curation, to avoid poisoning & false positives. (This can be disabled). If verified, this IP is then redistributed to all CrowdSec users running the same scenario.
## Outnumbering hackers all together
diff --git a/azure-pipelines.yml b/azure-pipelines.yml
index acbcabc20c5..bcf327bdf38 100644
--- a/azure-pipelines.yml
+++ b/azure-pipelines.yml
@@ -21,7 +21,7 @@ stages:
- task: GoTool@0
displayName: "Install Go"
inputs:
- version: '1.23'
+ version: '1.23.3'
- pwsh: |
choco install -y make
diff --git a/cmd/crowdsec-cli/cliallowlists/allowlists.go b/cmd/crowdsec-cli/cliallowlists/allowlists.go
new file mode 100644
index 00000000000..9660a6ec817
--- /dev/null
+++ b/cmd/crowdsec-cli/cliallowlists/allowlists.go
@@ -0,0 +1,519 @@
+package cliallowlists
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/url"
+ "slices"
+ "strings"
+ "time"
+
+ "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable"
+ "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require"
+ "github.com/crowdsecurity/crowdsec/pkg/apiclient"
+ "github.com/crowdsecurity/crowdsec/pkg/csconfig"
+ "github.com/crowdsecurity/crowdsec/pkg/database"
+ "github.com/crowdsecurity/crowdsec/pkg/models"
+ "github.com/fatih/color"
+ "github.com/go-openapi/strfmt"
+ "github.com/jedib0t/go-pretty/v6/table"
+ log "github.com/sirupsen/logrus"
+ "github.com/spf13/cobra"
+)
+
+type configGetter = func() *csconfig.Config
+
+type cliAllowLists struct {
+ db *database.Client
+ client *apiclient.ApiClient
+ cfg configGetter
+}
+
+func New(cfg configGetter) *cliAllowLists {
+ return &cliAllowLists{
+ cfg: cfg,
+ }
+}
+
+// validAllowlists returns a list of valid allowlists name for command completion
+func (cli *cliAllowLists) validAllowlists(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
+ var err error
+
+ cfg := cli.cfg()
+ ctx := cmd.Context()
+
+ // need to load config and db because PersistentPreRunE is not called for completions
+
+ if err = require.LAPI(cfg); err != nil {
+ cobra.CompError("unable to load LAPI " + err.Error())
+ return nil, cobra.ShellCompDirectiveNoFileComp
+ }
+
+ cli.db, err = require.DBClient(ctx, cfg.DbConfig)
+ if err != nil {
+ cobra.CompError("unable to load dbclient " + err.Error())
+ return nil, cobra.ShellCompDirectiveNoFileComp
+ }
+
+ allowlists, err := cli.db.ListAllowLists(ctx, false)
+ if err != nil {
+ cobra.CompError("unable to list allowlists " + err.Error())
+ return nil, cobra.ShellCompDirectiveNoFileComp
+ }
+
+ ret := []string{}
+
+ for _, allowlist := range allowlists {
+ if strings.Contains(allowlist.Name, toComplete) && !slices.Contains(args, allowlist.Name) {
+ ret = append(ret, allowlist.Name)
+ }
+ }
+
+ return ret, cobra.ShellCompDirectiveNoFileComp
+}
+
+func (cli *cliAllowLists) listHuman(out io.Writer, allowlists *models.GetAllowlistsResponse) {
+ t := cstable.NewLight(out, cli.cfg().Cscli.Color).Writer
+ t.AppendHeader(table.Row{"Name", "Description", "Creation Date", "Updated at", "Managed by Console", "Size"})
+
+ for _, allowlist := range *allowlists {
+ t.AppendRow(table.Row{allowlist.Name, allowlist.Description, allowlist.CreatedAt, allowlist.UpdatedAt, allowlist.ConsoleManaged, len(allowlist.Items)})
+ }
+
+ io.WriteString(out, t.Render()+"\n")
+}
+
+func (cli *cliAllowLists) listContentHuman(out io.Writer, allowlist *models.GetAllowlistResponse) {
+ t := cstable.NewLight(out, cli.cfg().Cscli.Color).Writer
+ t.AppendHeader(table.Row{"Value", "Comment", "Expiration", "Created at"})
+
+ for _, content := range allowlist.Items {
+ expiration := "never"
+ if !time.Time(content.Expiration).IsZero() {
+ expiration = content.Expiration.String()
+ }
+ t.AppendRow(table.Row{content.Value, content.Description, expiration, allowlist.CreatedAt})
+ }
+
+ io.WriteString(out, t.Render()+"\n")
+}
+
+func (cli *cliAllowLists) NewCommand() *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "allowlists [action]",
+ Short: "Manage centralized allowlists",
+ Args: cobra.MinimumNArgs(1),
+ DisableAutoGenTag: true,
+ }
+
+ cmd.AddCommand(cli.newCreateCmd())
+ cmd.AddCommand(cli.newListCmd())
+ cmd.AddCommand(cli.newDeleteCmd())
+ cmd.AddCommand(cli.newAddCmd())
+ cmd.AddCommand(cli.newRemoveCmd())
+ cmd.AddCommand(cli.newInspectCmd())
+ return cmd
+}
+
+func (cli *cliAllowLists) newCreateCmd() *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "create [allowlist_name]",
+ Example: "cscli allowlists create my_allowlist -d 'my allowlist description'",
+ Short: "Create a new allowlist",
+ Args: cobra.ExactArgs(1),
+ PersistentPreRunE: func(cmd *cobra.Command, _ []string) error {
+ var err error
+ cfg := cli.cfg()
+ if err = require.LAPI(cfg); err != nil {
+ return err
+ }
+ cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig)
+ if err != nil {
+ return err
+ }
+ return nil
+ },
+ RunE: cli.create,
+ }
+
+ flags := cmd.Flags()
+
+ flags.StringP("description", "d", "", "description of the allowlist")
+
+ cmd.MarkFlagRequired("description")
+
+ return cmd
+}
+
+func (cli *cliAllowLists) create(cmd *cobra.Command, args []string) error {
+ name := args[0]
+ description := cmd.Flag("description").Value.String()
+
+ _, err := cli.db.CreateAllowList(cmd.Context(), name, description, false)
+
+ if err != nil {
+ return err
+ }
+
+ log.Infof("allowlist '%s' created successfully", name)
+
+ return nil
+}
+
+func (cli *cliAllowLists) newListCmd() *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "list",
+ Example: `cscli allowlists list`,
+ Short: "List all allowlists",
+ Args: cobra.NoArgs,
+ PersistentPreRunE: func(_ *cobra.Command, _ []string) error {
+ cfg := cli.cfg()
+ if err := cfg.LoadAPIClient(); err != nil {
+ return fmt.Errorf("loading api client: %w", err)
+ }
+ apiURL, err := url.Parse(cfg.API.Client.Credentials.URL)
+ if err != nil {
+ return fmt.Errorf("parsing api url: %w", err)
+ }
+
+ cli.client, err = apiclient.NewClient(&apiclient.Config{
+ MachineID: cfg.API.Client.Credentials.Login,
+ Password: strfmt.Password(cfg.API.Client.Credentials.Password),
+ URL: apiURL,
+ VersionPrefix: "v1",
+ })
+ if err != nil {
+ return fmt.Errorf("creating api client: %w", err)
+ }
+
+ return nil
+ },
+ RunE: func(cmd *cobra.Command, _ []string) error {
+ return cli.list(cmd, color.Output)
+ },
+ }
+
+ return cmd
+}
+
+func (cli *cliAllowLists) list(cmd *cobra.Command, out io.Writer) error {
+ allowlists, _, err := cli.client.Allowlists.List(cmd.Context(), apiclient.AllowlistListOpts{WithContent: true})
+ if err != nil {
+ return err
+ }
+
+ switch cli.cfg().Cscli.Output {
+ case "human":
+ cli.listHuman(out, allowlists)
+ case "json":
+ enc := json.NewEncoder(out)
+ enc.SetIndent("", " ")
+
+ if err := enc.Encode(allowlists); err != nil {
+ return errors.New("failed to serialize")
+ }
+
+ return nil
+ case "raw":
+ //return cli.listCSV(out, allowlists)
+ }
+
+ return nil
+}
+
+func (cli *cliAllowLists) newDeleteCmd() *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "delete [allowlist_name]",
+ Short: "Delete an allowlist",
+ Example: `cscli allowlists delete my_allowlist`,
+ Args: cobra.ExactArgs(1),
+ PersistentPreRunE: func(cmd *cobra.Command, _ []string) error {
+ var err error
+ cfg := cli.cfg()
+
+ if err = require.LAPI(cfg); err != nil {
+ return err
+ }
+ cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig)
+ if err != nil {
+ return err
+ }
+ return nil
+ },
+ RunE: cli.delete,
+ }
+
+ return cmd
+}
+
+func (cli *cliAllowLists) delete(cmd *cobra.Command, args []string) error {
+ name := args[0]
+ list, err := cli.db.GetAllowList(cmd.Context(), name, false)
+
+ if err != nil {
+ return err
+ }
+
+ if list == nil {
+ return fmt.Errorf("allowlist %s not found", name)
+ }
+
+ if list.FromConsole {
+ return fmt.Errorf("allowlist %s is managed by console, cannot delete with cscli", name)
+ }
+
+ err = cli.db.DeleteAllowList(cmd.Context(), name, false)
+ if err != nil {
+ return err
+ }
+
+ log.Infof("allowlist '%s' deleted successfully", name)
+
+ return nil
+}
+
+func (cli *cliAllowLists) newAddCmd() *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "add [allowlist_name] --value [value] [-e expiration] [-d comment]",
+ Short: "Add content an allowlist",
+ Example: `cscli allowlists add my_allowlist --value 1.2.3.4 --value 2.3.4.5 -e 1h -d "my comment"`,
+ Args: cobra.ExactArgs(1),
+ PersistentPreRunE: func(cmd *cobra.Command, _ []string) error {
+ var err error
+ cfg := cli.cfg()
+
+ if err = require.LAPI(cfg); err != nil {
+ return err
+ }
+
+ cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig)
+ if err != nil {
+ return err
+ }
+ return nil
+ },
+ RunE: cli.add,
+ }
+
+ flags := cmd.Flags()
+
+ flags.StringSliceP("value", "v", nil, "value to add to the allowlist")
+ flags.StringP("expiration", "e", "", "expiration duration")
+ flags.StringP("comment", "d", "", "comment for the value")
+
+ cmd.MarkFlagRequired("value")
+
+ return cmd
+}
+
+func (cli *cliAllowLists) add(cmd *cobra.Command, args []string) error {
+
+ var expiration time.Duration
+
+ name := args[0]
+ values, err := cmd.Flags().GetStringSlice("value")
+ comment := cmd.Flag("comment").Value.String()
+
+ if err != nil {
+ return err
+ }
+
+ expirationStr := cmd.Flag("expiration").Value.String()
+
+ if expirationStr != "" {
+ //FIXME: handle days (and maybe more ?)
+ expiration, err = time.ParseDuration(expirationStr)
+
+ if err != nil {
+ return err
+ }
+ }
+
+ allowlist, err := cli.db.GetAllowList(cmd.Context(), name, true)
+
+ if err != nil {
+ return fmt.Errorf("unable to get allowlist: %w", err)
+ }
+
+ if allowlist.FromConsole {
+ return fmt.Errorf("allowlist %s is managed by console, cannot update with cscli", name)
+ }
+
+ toAdd := make([]*models.AllowlistItem, 0)
+
+ for _, v := range values {
+ found := false
+ for _, item := range allowlist.Edges.AllowlistItems {
+ if item.Value == v {
+ found = true
+ log.Warnf("value %s already in allowlist", v)
+ break
+ }
+ }
+ if !found {
+ toAdd = append(toAdd, &models.AllowlistItem{Value: v, Description: comment, Expiration: strfmt.DateTime(time.Now().UTC().Add(expiration))})
+ }
+ }
+
+ if len(toAdd) == 0 {
+ log.Warn("no value to add to allowlist")
+ return nil
+ }
+
+ log.Debugf("adding %d values to allowlist %s", len(toAdd), name)
+
+ err = cli.db.AddToAllowlist(cmd.Context(), allowlist, toAdd)
+
+ if err != nil {
+ return fmt.Errorf("unable to add values to allowlist: %w", err)
+ }
+
+ return nil
+}
+
+func (cli *cliAllowLists) newInspectCmd() *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "inspect [allowlist_name]",
+ Example: `cscli allowlists inspect my_allowlist`,
+ Short: "Inspect an allowlist",
+ Args: cobra.ExactArgs(1),
+ PersistentPreRunE: func(_ *cobra.Command, _ []string) error {
+ cfg := cli.cfg()
+ if err := cfg.LoadAPIClient(); err != nil {
+ return fmt.Errorf("loading api client: %w", err)
+ }
+ apiURL, err := url.Parse(cfg.API.Client.Credentials.URL)
+ if err != nil {
+ return fmt.Errorf("parsing api url: %w", err)
+ }
+
+ cli.client, err = apiclient.NewClient(&apiclient.Config{
+ MachineID: cfg.API.Client.Credentials.Login,
+ Password: strfmt.Password(cfg.API.Client.Credentials.Password),
+ URL: apiURL,
+ VersionPrefix: "v1",
+ })
+ if err != nil {
+ return fmt.Errorf("creating api client: %w", err)
+ }
+
+ return nil
+ },
+ RunE: func(cmd *cobra.Command, args []string) error {
+ return cli.inspect(cmd, args, color.Output)
+ },
+ }
+
+ return cmd
+}
+
+func (cli *cliAllowLists) inspect(cmd *cobra.Command, args []string, out io.Writer) error {
+ name := args[0]
+ allowlist, _, err := cli.client.Allowlists.Get(cmd.Context(), name, apiclient.AllowlistGetOpts{WithContent: true})
+
+ if err != nil {
+ return fmt.Errorf("unable to get allowlist: %w", err)
+ }
+
+ switch cli.cfg().Cscli.Output {
+ case "human":
+ cli.listContentHuman(out, allowlist)
+ case "json":
+ enc := json.NewEncoder(out)
+ enc.SetIndent("", " ")
+
+ if err := enc.Encode(allowlist); err != nil {
+ return errors.New("failed to serialize")
+ }
+
+ return nil
+ case "raw":
+ //return cli.listCSV(out, allowlists)
+ }
+
+ return nil
+}
+
+func (cli *cliAllowLists) newRemoveCmd() *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "remove [allowlist_name] --value [value]",
+ Short: "remove content from an allowlist",
+ Example: `cscli allowlists remove my_allowlist --value 1.2.3.4 --value 2.3.4.5"`,
+ Args: cobra.ExactArgs(1),
+ PersistentPreRunE: func(cmd *cobra.Command, _ []string) error {
+ var err error
+ cfg := cli.cfg()
+
+ if err = require.LAPI(cfg); err != nil {
+ return err
+ }
+
+ cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig)
+ if err != nil {
+ return err
+ }
+ return nil
+ },
+ RunE: cli.remove,
+ }
+
+ flags := cmd.Flags()
+
+ flags.StringSliceP("value", "v", nil, "value to remove from the allowlist")
+
+ cmd.MarkFlagRequired("value")
+
+ return cmd
+}
+
+func (cli *cliAllowLists) remove(cmd *cobra.Command, args []string) error {
+ name := args[0]
+ values, err := cmd.Flags().GetStringSlice("value")
+
+ if err != nil {
+ return err
+ }
+
+ allowlist, err := cli.db.GetAllowList(cmd.Context(), name, true)
+
+ if err != nil {
+ return fmt.Errorf("unable to get allowlist: %w", err)
+ }
+
+ if allowlist.FromConsole {
+ return fmt.Errorf("allowlist %s is managed by console, cannot update with cscli", name)
+ }
+
+ toRemove := make([]string, 0)
+
+ for _, v := range values {
+ found := false
+ for _, item := range allowlist.Edges.AllowlistItems {
+ if item.Value == v {
+ found = true
+ break
+ }
+ }
+ if found {
+ toRemove = append(toRemove, v)
+ }
+ }
+
+ if len(toRemove) == 0 {
+ log.Warn("no value to remove from allowlist")
+ }
+
+ log.Debugf("removing %d values from allowlist %s", len(toRemove), name)
+
+ nbDeleted, err := cli.db.RemoveFromAllowlist(cmd.Context(), allowlist, toRemove...)
+
+ if err != nil {
+ return fmt.Errorf("unable to remove values from allowlist: %w", err)
+ }
+
+ log.Infof("removed %d values from allowlist %s", nbDeleted, name)
+
+ return nil
+}
diff --git a/cmd/crowdsec-cli/clibouncer/add.go b/cmd/crowdsec-cli/clibouncer/add.go
index 8c40507a996..7cc74e45fba 100644
--- a/cmd/crowdsec-cli/clibouncer/add.go
+++ b/cmd/crowdsec-cli/clibouncer/add.go
@@ -24,7 +24,7 @@ func (cli *cliBouncers) add(ctx context.Context, bouncerName string, key string)
}
}
- _, err = cli.db.CreateBouncer(ctx, bouncerName, "", middlewares.HashSHA512(key), types.ApiKeyAuthType)
+ _, err = cli.db.CreateBouncer(ctx, bouncerName, "", middlewares.HashSHA512(key), types.ApiKeyAuthType, false)
if err != nil {
return fmt.Errorf("unable to create bouncer: %w", err)
}
diff --git a/cmd/crowdsec-cli/clibouncer/bouncers.go b/cmd/crowdsec-cli/clibouncer/bouncers.go
index 876b613be53..2b0a3556873 100644
--- a/cmd/crowdsec-cli/clibouncer/bouncers.go
+++ b/cmd/crowdsec-cli/clibouncer/bouncers.go
@@ -77,6 +77,7 @@ type bouncerInfo struct {
AuthType string `json:"auth_type"`
OS string `json:"os,omitempty"`
Featureflags []string `json:"featureflags,omitempty"`
+ AutoCreated bool `json:"auto_created"`
}
func newBouncerInfo(b *ent.Bouncer) bouncerInfo {
@@ -92,6 +93,7 @@ func newBouncerInfo(b *ent.Bouncer) bouncerInfo {
AuthType: b.AuthType,
OS: clientinfo.GetOSNameAndVersion(b),
Featureflags: clientinfo.GetFeatureFlagList(b),
+ AutoCreated: b.AutoCreated,
}
}
diff --git a/cmd/crowdsec-cli/clibouncer/delete.go b/cmd/crowdsec-cli/clibouncer/delete.go
index 6e2f312d4af..33419f483b6 100644
--- a/cmd/crowdsec-cli/clibouncer/delete.go
+++ b/cmd/crowdsec-cli/clibouncer/delete.go
@@ -4,25 +4,73 @@ import (
"context"
"errors"
"fmt"
+ "strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
- "github.com/crowdsecurity/crowdsec/pkg/database"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent"
+ "github.com/crowdsecurity/crowdsec/pkg/types"
)
+func (cli *cliBouncers) findParentBouncer(bouncerName string, bouncers []*ent.Bouncer) (string, error) {
+ bouncerPrefix := strings.Split(bouncerName, "@")[0]
+ for _, bouncer := range bouncers {
+ if strings.HasPrefix(bouncer.Name, bouncerPrefix) && !bouncer.AutoCreated {
+ return bouncer.Name, nil
+ }
+ }
+
+ return "", errors.New("no parent bouncer found")
+}
+
func (cli *cliBouncers) delete(ctx context.Context, bouncers []string, ignoreMissing bool) error {
- for _, bouncerID := range bouncers {
- if err := cli.db.DeleteBouncer(ctx, bouncerID); err != nil {
- var notFoundErr *database.BouncerNotFoundError
+ allBouncers, err := cli.db.ListBouncers(ctx)
+ if err != nil {
+ return fmt.Errorf("unable to list bouncers: %w", err)
+ }
+ for _, bouncerName := range bouncers {
+ bouncer, err := cli.db.SelectBouncerByName(ctx, bouncerName)
+ if err != nil {
+ var notFoundErr *ent.NotFoundError
if ignoreMissing && errors.As(err, ¬FoundErr) {
- return nil
+ continue
}
+ return fmt.Errorf("unable to delete bouncer %s: %w", bouncerName, err)
+ }
+
+ // For TLS bouncers, always delete them, they have no parents
+ if bouncer.AuthType == types.TlsAuthType {
+ if err := cli.db.DeleteBouncer(ctx, bouncerName); err != nil {
+ return fmt.Errorf("unable to delete bouncer %s: %w", bouncerName, err)
+ }
+ continue
+ }
+
+ if bouncer.AutoCreated {
+ parentBouncer, err := cli.findParentBouncer(bouncerName, allBouncers)
+ if err != nil {
+ log.Errorf("bouncer '%s' is auto-created, but couldn't find a parent bouncer", err)
+ continue
+ }
+ log.Warnf("bouncer '%s' is auto-created and cannot be deleted, delete parent bouncer %s instead", bouncerName, parentBouncer)
+ continue
+ }
+ //Try to find all child bouncers and delete them
+ for _, childBouncer := range allBouncers {
+ if strings.HasPrefix(childBouncer.Name, bouncerName+"@") && childBouncer.AutoCreated {
+ if err := cli.db.DeleteBouncer(ctx, childBouncer.Name); err != nil {
+ return fmt.Errorf("unable to delete bouncer %s: %w", childBouncer.Name, err)
+ }
+ log.Infof("bouncer '%s' deleted successfully", childBouncer.Name)
+ }
+ }
- return fmt.Errorf("unable to delete bouncer: %w", err)
+ if err := cli.db.DeleteBouncer(ctx, bouncerName); err != nil {
+ return fmt.Errorf("unable to delete bouncer %s: %w", bouncerName, err)
}
- log.Infof("bouncer '%s' deleted successfully", bouncerID)
+ log.Infof("bouncer '%s' deleted successfully", bouncerName)
}
return nil
diff --git a/cmd/crowdsec-cli/clibouncer/inspect.go b/cmd/crowdsec-cli/clibouncer/inspect.go
index 6dac386b888..b62344baa9b 100644
--- a/cmd/crowdsec-cli/clibouncer/inspect.go
+++ b/cmd/crowdsec-cli/clibouncer/inspect.go
@@ -40,6 +40,7 @@ func (cli *cliBouncers) inspectHuman(out io.Writer, bouncer *ent.Bouncer) {
{"Last Pull", lastPull},
{"Auth type", bouncer.AuthType},
{"OS", clientinfo.GetOSNameAndVersion(bouncer)},
+ {"Auto Created", bouncer.AutoCreated},
})
for _, ff := range clientinfo.GetFeatureFlagList(bouncer) {
diff --git a/cmd/crowdsec-cli/clicapi/capi.go b/cmd/crowdsec-cli/clicapi/capi.go
index cba66f11104..61d59836fdd 100644
--- a/cmd/crowdsec-cli/clicapi/capi.go
+++ b/cmd/crowdsec-cli/clicapi/capi.go
@@ -225,6 +225,27 @@ func (cli *cliCapi) Status(ctx context.Context, out io.Writer, hub *cwhub.Hub) e
fmt.Fprint(out, "Your instance is enrolled in the console\n")
}
+ switch *cfg.API.Server.OnlineClient.Sharing {
+ case true:
+ fmt.Fprint(out, "Sharing signals is enabled\n")
+ case false:
+ fmt.Fprint(out, "Sharing signals is disabled\n")
+ }
+
+ switch *cfg.API.Server.OnlineClient.PullConfig.Community {
+ case true:
+ fmt.Fprint(out, "Pulling community blocklist is enabled\n")
+ case false:
+ fmt.Fprint(out, "Pulling community blocklist is disabled\n")
+ }
+
+ switch *cfg.API.Server.OnlineClient.PullConfig.Blocklists {
+ case true:
+ fmt.Fprint(out, "Pulling blocklists from the console is enabled\n")
+ case false:
+ fmt.Fprint(out, "Pulling blocklists from the console is disabled\n")
+ }
+
return nil
}
diff --git a/cmd/crowdsec-cli/clinotifications/notifications.go b/cmd/crowdsec-cli/clinotifications/notifications.go
index baf899c10cf..80ffebeaa23 100644
--- a/cmd/crowdsec-cli/clinotifications/notifications.go
+++ b/cmd/crowdsec-cli/clinotifications/notifications.go
@@ -260,7 +260,7 @@ func (cli *cliNotifications) notificationConfigFilter(cmd *cobra.Command, args [
return ret, cobra.ShellCompDirectiveNoFileComp
}
-func (cli cliNotifications) newTestCmd() *cobra.Command {
+func (cli *cliNotifications) newTestCmd() *cobra.Command {
var (
pluginBroker csplugin.PluginBroker
pluginTomb tomb.Tomb
diff --git a/cmd/crowdsec-cli/clipapi/papi.go b/cmd/crowdsec-cli/clipapi/papi.go
index 461215c3a39..7ac2455d28f 100644
--- a/cmd/crowdsec-cli/clipapi/papi.go
+++ b/cmd/crowdsec-cli/clipapi/papi.go
@@ -136,7 +136,7 @@ func (cli *cliPapi) sync(ctx context.Context, out io.Writer, db *database.Client
t.Go(papi.SyncDecisions)
- err = papi.PullOnce(time.Time{}, true)
+ err = papi.PullOnce(ctx, time.Time{}, true)
if err != nil {
return fmt.Errorf("unable to sync decisions: %w", err)
}
diff --git a/cmd/crowdsec-cli/dashboard.go b/cmd/crowdsec-cli/dashboard.go
index 53a7dff85a0..7ddac093dcd 100644
--- a/cmd/crowdsec-cli/dashboard.go
+++ b/cmd/crowdsec-cli/dashboard.go
@@ -243,7 +243,8 @@ func (cli *cliDashboard) newStopCmd() *cobra.Command {
}
func (cli *cliDashboard) newShowPasswordCmd() *cobra.Command {
- cmd := &cobra.Command{Use: "show-password",
+ cmd := &cobra.Command{
+ Use: "show-password",
Short: "displays password of metabase.",
Args: cobra.NoArgs,
DisableAutoGenTag: true,
@@ -457,7 +458,6 @@ func checkGroups(forceYes *bool) (*user.Group, error) {
func (cli *cliDashboard) chownDatabase(gid string) error {
cfg := cli.cfg()
intID, err := strconv.Atoi(gid)
-
if err != nil {
return fmt.Errorf("unable to convert group ID to int: %s", err)
}
diff --git a/cmd/crowdsec-cli/main.go b/cmd/crowdsec-cli/main.go
index 1cca03b1d3d..bd12901d580 100644
--- a/cmd/crowdsec-cli/main.go
+++ b/cmd/crowdsec-cli/main.go
@@ -12,9 +12,11 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
+ "github.com/crowdsecurity/go-cs-lib/ptr"
"github.com/crowdsecurity/go-cs-lib/trace"
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clialert"
+ "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cliallowlists"
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clibouncer"
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clicapi"
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cliconsole"
@@ -163,6 +165,8 @@ func (cli *cliRoot) initialize() error {
}
}
+ csConfig.DbConfig.LogLevel = ptr.Of(cli.wantedLogLevel())
+
return nil
}
@@ -279,6 +283,7 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall
cmd.AddCommand(cliitem.NewContext(cli.cfg).NewCommand())
cmd.AddCommand(cliitem.NewAppsecConfig(cli.cfg).NewCommand())
cmd.AddCommand(cliitem.NewAppsecRule(cli.cfg).NewCommand())
+ cmd.AddCommand(cliallowlists.New(cli.cfg).NewCommand())
cli.addSetup(cmd)
diff --git a/cmd/crowdsec-cli/setup.go b/cmd/crowdsec-cli/setup.go
index 66c0d71e777..3581d69f052 100644
--- a/cmd/crowdsec-cli/setup.go
+++ b/cmd/crowdsec-cli/setup.go
@@ -1,4 +1,5 @@
//go:build !no_cscli_setup
+
package main
import (
diff --git a/cmd/crowdsec/appsec.go b/cmd/crowdsec/appsec.go
index cb02b137dcd..4320133b063 100644
--- a/cmd/crowdsec/appsec.go
+++ b/cmd/crowdsec/appsec.go
@@ -1,4 +1,4 @@
-// +build !no_datasource_appsec
+//go:build !no_datasource_appsec
package main
diff --git a/cmd/crowdsec/crowdsec.go b/cmd/crowdsec/crowdsec.go
index c44d71d2093..4df66d5a773 100644
--- a/cmd/crowdsec/crowdsec.go
+++ b/cmd/crowdsec/crowdsec.go
@@ -14,6 +14,7 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/acquisition"
"github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration"
"github.com/crowdsecurity/crowdsec/pkg/alertcontext"
+ "github.com/crowdsecurity/crowdsec/pkg/apiclient"
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/cwhub"
"github.com/crowdsecurity/crowdsec/pkg/exprhelpers"
@@ -51,6 +52,15 @@ func initCrowdsec(cConfig *csconfig.Config, hub *cwhub.Hub) (*parser.Parsers, []
return nil, nil, err
}
+ err = apiclient.InitLAPIClient(
+ context.TODO(), cConfig.API.Client.Credentials.URL, cConfig.API.Client.Credentials.PapiURL,
+ cConfig.API.Client.Credentials.Login, cConfig.API.Client.Credentials.Password,
+ hub.GetInstalledListForAPI())
+
+ if err != nil {
+ return nil, nil, fmt.Errorf("while initializing LAPIClient: %w", err)
+ }
+
datasources, err := LoadAcquisition(cConfig)
if err != nil {
return nil, nil, fmt.Errorf("while loading acquisition config: %w", err)
@@ -116,7 +126,7 @@ func runCrowdsec(cConfig *csconfig.Config, parsers *parser.Parsers, hub *cwhub.H
})
bucketWg.Wait()
- apiClient, err := AuthenticatedLAPIClient(*cConfig.API.Client.Credentials, hub)
+ apiClient, err := apiclient.GetLAPIClient()
if err != nil {
return err
}
diff --git a/cmd/crowdsec/lapiclient.go b/cmd/crowdsec/lapiclient.go
index eed517f9df9..6656ba6b4c2 100644
--- a/cmd/crowdsec/lapiclient.go
+++ b/cmd/crowdsec/lapiclient.go
@@ -14,7 +14,7 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/models"
)
-func AuthenticatedLAPIClient(credentials csconfig.ApiCredentialsCfg, hub *cwhub.Hub) (*apiclient.ApiClient, error) {
+func AuthenticatedLAPIClient(ctx context.Context, credentials csconfig.ApiCredentialsCfg, hub *cwhub.Hub) (*apiclient.ApiClient, error) {
apiURL, err := url.Parse(credentials.URL)
if err != nil {
return nil, fmt.Errorf("parsing api url ('%s'): %w", credentials.URL, err)
@@ -44,7 +44,7 @@ func AuthenticatedLAPIClient(credentials csconfig.ApiCredentialsCfg, hub *cwhub.
return nil, fmt.Errorf("new client api: %w", err)
}
- authResp, _, err := client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
+ authResp, _, err := client.Auth.AuthenticateWatcher(ctx, models.WatcherAuthRequest{
MachineID: &credentials.Login,
Password: &password,
Scenarios: itemsForAPI,
diff --git a/cmd/crowdsec/main.go b/cmd/crowdsec/main.go
index 6d8ca24c335..e414f59f3e2 100644
--- a/cmd/crowdsec/main.go
+++ b/cmd/crowdsec/main.go
@@ -148,14 +148,14 @@ func (l *labelsMap) String() string {
return "labels"
}
-func (l labelsMap) Set(label string) error {
+func (l *labelsMap) Set(label string) error {
for _, pair := range strings.Split(label, ",") {
split := strings.Split(pair, ":")
if len(split) != 2 {
return fmt.Errorf("invalid format for label '%s', must be key:value", pair)
}
- l[split[0]] = split[1]
+ (*l)[split[0]] = split[1]
}
return nil
diff --git a/cmd/notification-email/main.go b/cmd/notification-email/main.go
index 5fc02cdd1d7..b61644611b4 100644
--- a/cmd/notification-email/main.go
+++ b/cmd/notification-email/main.go
@@ -68,7 +68,7 @@ func (n *EmailPlugin) Configure(ctx context.Context, config *protobufs.Config) (
EncryptionType: "ssltls",
AuthType: "login",
SenderEmail: "crowdsec@crowdsec.local",
- HeloHost: "localhost",
+ HeloHost: "localhost",
}
if err := yaml.Unmarshal(config.Config, &d); err != nil {
diff --git a/go.mod b/go.mod
index c889b62cb8c..f4bd9379a2d 100644
--- a/go.mod
+++ b/go.mod
@@ -1,6 +1,6 @@
module github.com/crowdsecurity/crowdsec
-go 1.23
+go 1.23.3
// Don't use the toolchain directive to avoid uncontrolled downloads during
// a build, especially in sandboxed environments (freebsd, gentoo...).
diff --git a/pkg/acquisition/acquisition.go b/pkg/acquisition/acquisition.go
index 1ad385105d3..ef5a413b91f 100644
--- a/pkg/acquisition/acquisition.go
+++ b/pkg/acquisition/acquisition.go
@@ -337,6 +337,20 @@ func GetMetrics(sources []DataSource, aggregated bool) error {
return nil
}
+// There's no need for an actual deep copy
+// The event is almost empty, we are mostly interested in allocating new maps for Parsed/Meta/...
+func copyEvent(evt types.Event, line string) types.Event {
+ evtCopy := types.MakeEvent(evt.ExpectMode == types.TIMEMACHINE, evt.Type, evt.Process)
+ evtCopy.Line = evt.Line
+ evtCopy.Line.Raw = line
+ evtCopy.Line.Labels = make(map[string]string)
+ for k, v := range evt.Line.Labels {
+ evtCopy.Line.Labels[k] = v
+ }
+
+ return evtCopy
+}
+
func transform(transformChan chan types.Event, output chan types.Event, AcquisTomb *tomb.Tomb, transformRuntime *vm.Program, logger *log.Entry) {
defer trace.CatchPanic("crowdsec/acquis")
logger.Infof("transformer started")
@@ -363,8 +377,7 @@ func transform(transformChan chan types.Event, output chan types.Event, AcquisTo
switch v := out.(type) {
case string:
logger.Tracef("transform expression returned %s", v)
- evt.Line.Raw = v
- output <- evt
+ output <- copyEvent(evt, v)
case []interface{}:
logger.Tracef("transform expression returned %v", v) //nolint:asasalint // We actually want to log the slice content
@@ -373,19 +386,16 @@ func transform(transformChan chan types.Event, output chan types.Event, AcquisTo
if !ok {
logger.Errorf("transform expression returned []interface{}, but cannot assert an element to string")
output <- evt
-
continue
}
- evt.Line.Raw = l
- output <- evt
+ output <- copyEvent(evt, l)
}
case []string:
logger.Tracef("transform expression returned %v", v)
for _, line := range v {
- evt.Line.Raw = line
- output <- evt
+ output <- copyEvent(evt, line)
}
default:
logger.Errorf("transform expression returned an invalid type %T, sending event as-is", out)
diff --git a/pkg/acquisition/configuration/configuration.go b/pkg/acquisition/configuration/configuration.go
index 3e27da1b9e6..a9d570d2788 100644
--- a/pkg/acquisition/configuration/configuration.go
+++ b/pkg/acquisition/configuration/configuration.go
@@ -13,12 +13,14 @@ type DataSourceCommonCfg struct {
UseTimeMachine bool `yaml:"use_time_machine,omitempty"`
UniqueId string `yaml:"unique_id,omitempty"`
TransformExpr string `yaml:"transform,omitempty"`
- Config map[string]interface{} `yaml:",inline"` //to keep the datasource-specific configuration directives
+ Config map[string]interface{} `yaml:",inline"` // to keep the datasource-specific configuration directives
}
-var TAIL_MODE = "tail"
-var CAT_MODE = "cat"
-var SERVER_MODE = "server" // No difference with tail, just a bit more verbose
+var (
+ TAIL_MODE = "tail"
+ CAT_MODE = "cat"
+ SERVER_MODE = "server" // No difference with tail, just a bit more verbose
+)
const (
METRICS_NONE = iota
diff --git a/pkg/acquisition/http.go b/pkg/acquisition/http.go
new file mode 100644
index 00000000000..59745772b62
--- /dev/null
+++ b/pkg/acquisition/http.go
@@ -0,0 +1,12 @@
+//go:build !no_datasource_http
+
+package acquisition
+
+import (
+ httpacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/http"
+)
+
+//nolint:gochecknoinits
+func init() {
+ registerDataSource("http", func() DataSource { return &httpacquisition.HTTPSource{} })
+}
diff --git a/pkg/acquisition/modules/appsec/appsec.go b/pkg/acquisition/modules/appsec/appsec.go
index a6dcffe89a2..ed6c85ea95e 100644
--- a/pkg/acquisition/modules/appsec/appsec.go
+++ b/pkg/acquisition/modules/appsec/appsec.go
@@ -20,6 +20,7 @@ import (
"github.com/crowdsecurity/go-cs-lib/trace"
"github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration"
+ "github.com/crowdsecurity/crowdsec/pkg/apiclient"
"github.com/crowdsecurity/crowdsec/pkg/appsec"
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/types"
@@ -31,6 +32,8 @@ const (
)
var DefaultAuthCacheDuration = (1 * time.Minute)
+var negativeAllowlistCacheDuration = (5 * time.Minute)
+var positiveAllowlistCacheDuration = (5 * time.Minute)
// configuration structure of the acquis for the application security engine
type AppsecSourceConfig struct {
@@ -41,6 +44,7 @@ type AppsecSourceConfig struct {
Path string `yaml:"path"`
Routines int `yaml:"routines"`
AppsecConfig string `yaml:"appsec_config"`
+ AppsecConfigs []string `yaml:"appsec_configs"`
AppsecConfigPath string `yaml:"appsec_config_path"`
AuthCacheDuration *time.Duration `yaml:"auth_cache_duration"`
configuration.DataSourceCommonCfg `yaml:",inline"`
@@ -48,18 +52,20 @@ type AppsecSourceConfig struct {
// runtime structure of AppsecSourceConfig
type AppsecSource struct {
- metricsLevel int
- config AppsecSourceConfig
- logger *log.Entry
- mux *http.ServeMux
- server *http.Server
- outChan chan types.Event
- InChan chan appsec.ParsedRequest
- AppsecRuntime *appsec.AppsecRuntimeConfig
- AppsecConfigs map[string]appsec.AppsecConfig
- lapiURL string
- AuthCache AuthCache
- AppsecRunners []AppsecRunner // one for each go-routine
+ metricsLevel int
+ config AppsecSourceConfig
+ logger *log.Entry
+ mux *http.ServeMux
+ server *http.Server
+ outChan chan types.Event
+ InChan chan appsec.ParsedRequest
+ AppsecRuntime *appsec.AppsecRuntimeConfig
+ AppsecConfigs map[string]appsec.AppsecConfig
+ lapiURL string
+ AuthCache AuthCache
+ AppsecRunners []AppsecRunner // one for each go-routine
+ allowlistCache allowlistCache
+ apiClient *apiclient.ApiClient
}
// Struct to handle cache of authentication
@@ -68,6 +74,17 @@ type AuthCache struct {
mu sync.RWMutex
}
+// FIXME: auth and allowlist should probably be merged to a common structure
+type allowlistCache struct {
+ mu sync.RWMutex
+ allowlist map[string]allowlistCacheEntry
+}
+
+type allowlistCacheEntry struct {
+ allowlisted bool
+ expiration time.Time
+}
+
func NewAuthCache() AuthCache {
return AuthCache{
APIKeys: make(map[string]time.Time, 0),
@@ -85,9 +102,34 @@ func (ac *AuthCache) Get(apiKey string) (time.Time, bool) {
ac.mu.RLock()
expiration, exists := ac.APIKeys[apiKey]
ac.mu.RUnlock()
+
return expiration, exists
}
+func NewAllowlistCache() allowlistCache {
+ return allowlistCache{
+ allowlist: make(map[string]allowlistCacheEntry, 0),
+ mu: sync.RWMutex{},
+ }
+}
+
+func (ac *allowlistCache) Set(value string, allowlisted bool, expiration time.Time) {
+ ac.mu.Lock()
+ ac.allowlist[value] = allowlistCacheEntry{
+ allowlisted: allowlisted,
+ expiration: expiration,
+ }
+ ac.mu.Unlock()
+}
+
+func (ac *allowlistCache) Get(value string) (bool, time.Time, bool) {
+ ac.mu.RLock()
+ entry, exists := ac.allowlist[value]
+ ac.mu.RUnlock()
+
+ return entry.allowlisted, entry.expiration, exists
+}
+
// @tko + @sbl : we might want to get rid of that or improve it
type BodyResponse struct {
Action string `json:"action"`
@@ -120,14 +162,19 @@ func (w *AppsecSource) UnmarshalConfig(yamlConfig []byte) error {
w.config.Routines = 1
}
- if w.config.AppsecConfig == "" && w.config.AppsecConfigPath == "" {
+ if w.config.AppsecConfig == "" && w.config.AppsecConfigPath == "" && len(w.config.AppsecConfigs) == 0 {
return errors.New("appsec_config or appsec_config_path must be set")
}
+ if (w.config.AppsecConfig != "" || w.config.AppsecConfigPath != "") && len(w.config.AppsecConfigs) != 0 {
+ return errors.New("appsec_config and appsec_config_path are mutually exclusive with appsec_configs")
+ }
+
if w.config.Name == "" {
if w.config.ListenSocket != "" && w.config.ListenAddr == "" {
w.config.Name = w.config.ListenSocket
}
+
if w.config.ListenSocket == "" {
w.config.Name = fmt.Sprintf("%s%s", w.config.ListenAddr, w.config.Path)
}
@@ -153,6 +200,7 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLe
if err != nil {
return fmt.Errorf("unable to parse appsec configuration: %w", err)
}
+
w.logger = logger
w.metricsLevel = MetricsLevel
w.logger.Tracef("Appsec configuration: %+v", w.config)
@@ -172,6 +220,9 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLe
w.InChan = make(chan appsec.ParsedRequest)
appsecCfg := appsec.AppsecConfig{Logger: w.logger.WithField("component", "appsec_config")}
+ //we keep the datasource name
+ appsecCfg.Name = w.config.Name
+
// let's load the associated appsec_config:
if w.config.AppsecConfigPath != "" {
err := appsecCfg.LoadByPath(w.config.AppsecConfigPath)
@@ -183,10 +234,20 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLe
if err != nil {
return fmt.Errorf("unable to load appsec_config: %w", err)
}
+ } else if len(w.config.AppsecConfigs) > 0 {
+ for _, appsecConfig := range w.config.AppsecConfigs {
+ err := appsecCfg.Load(appsecConfig)
+ if err != nil {
+ return fmt.Errorf("unable to load appsec_config: %w", err)
+ }
+ }
} else {
return errors.New("no appsec_config provided")
}
+ // Now we can set up the logger
+ appsecCfg.SetUpLogger()
+
w.AppsecRuntime, err = appsecCfg.Build()
if err != nil {
return fmt.Errorf("unable to build appsec_config: %w", err)
@@ -211,10 +272,12 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLe
AppsecRuntime: &wrt,
Labels: w.config.Labels,
}
+
err := runner.Init(appsecCfg.GetDataDir())
if err != nil {
return fmt.Errorf("unable to initialize runner: %w", err)
}
+
w.AppsecRunners[nbRoutine] = runner
}
@@ -222,6 +285,13 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLe
// We donĀ“t use the wrapper provided by coraza because we want to fully control what happens when a rule match to send the information in crowdsec
w.mux.HandleFunc(w.config.Path, w.appsecHandler)
+
+ w.apiClient, err = apiclient.GetLAPIClient()
+ if err != nil {
+ return fmt.Errorf("unable to get authenticated LAPI client: %w", err)
+ }
+ w.allowlistCache = NewAllowlistCache()
+
return nil
}
@@ -243,10 +313,12 @@ func (w *AppsecSource) OneShotAcquisition(_ context.Context, _ chan types.Event,
func (w *AppsecSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error {
w.outChan = out
+
t.Go(func() error {
defer trace.CatchPanic("crowdsec/acquis/appsec/live")
w.logger.Infof("%d appsec runner to start", len(w.AppsecRunners))
+
for _, runner := range w.AppsecRunners {
runner.outChan = out
t.Go(func() error {
@@ -254,6 +326,7 @@ func (w *AppsecSource) StreamingAcquisition(ctx context.Context, out chan types.
return runner.Run(t)
})
}
+
t.Go(func() error {
if w.config.ListenSocket != "" {
w.logger.Infof("creating unix socket %s", w.config.ListenSocket)
@@ -268,10 +341,11 @@ func (w *AppsecSource) StreamingAcquisition(ctx context.Context, out chan types.
} else {
err = w.server.Serve(listener)
}
- if err != nil && err != http.ErrServerClosed {
+ if err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("appsec server failed: %w", err)
}
}
+
return nil
})
t.Go(func() error {
@@ -288,6 +362,7 @@ func (w *AppsecSource) StreamingAcquisition(ctx context.Context, out chan types.
return fmt.Errorf("appsec server failed: %w", err)
}
}
+
return nil
})
<-t.Dying()
@@ -297,6 +372,7 @@ func (w *AppsecSource) StreamingAcquisition(ctx context.Context, out chan types.
w.server.Shutdown(ctx)
return nil
})
+
return nil
}
@@ -334,6 +410,29 @@ func (w *AppsecSource) IsAuth(apiKey string) bool {
return resp.StatusCode == http.StatusOK
}
+func (w *AppsecSource) isAllowlisted(ctx context.Context, value string) bool {
+ var err error
+
+ allowlisted, expiration, exists := w.allowlistCache.Get(value)
+ if exists && !time.Now().After(expiration) {
+ return allowlisted
+ }
+
+ allowlisted, _, err = w.apiClient.Allowlists.CheckIfAllowlisted(ctx, value)
+ if err != nil {
+ w.logger.Errorf("unable to check if %s is allowlisted: %s", value, err)
+ return false
+ }
+
+ if allowlisted {
+ w.allowlistCache.Set(value, allowlisted, time.Now().Add(positiveAllowlistCacheDuration))
+ } else {
+ w.allowlistCache.Set(value, allowlisted, time.Now().Add(negativeAllowlistCacheDuration))
+ }
+
+ return allowlisted
+}
+
// should this be in the runner ?
func (w *AppsecSource) appsecHandler(rw http.ResponseWriter, r *http.Request) {
w.logger.Debugf("Received request from '%s' on %s", r.RemoteAddr, r.URL.Path)
@@ -359,6 +458,25 @@ func (w *AppsecSource) appsecHandler(rw http.ResponseWriter, r *http.Request) {
w.AuthCache.Set(apiKey, time.Now().Add(*w.config.AuthCacheDuration))
}
+ // check if the client IP is allowlisted
+ if w.isAllowlisted(r.Context(), clientIP) {
+ w.logger.Infof("%s is allowlisted by LAPI, not processing", clientIP)
+ statusCode, appsecResponse := w.AppsecRuntime.GenerateResponse(appsec.AppsecTempResponse{
+ InBandInterrupt: false,
+ OutOfBandInterrupt: false,
+ Action: appsec.AllowRemediation,
+ }, w.logger)
+ body, err := json.Marshal(appsecResponse)
+ if err != nil {
+ w.logger.Errorf("unable to serialize response: %s", err)
+ rw.WriteHeader(http.StatusInternalServerError)
+ return
+ }
+ rw.WriteHeader(statusCode)
+ rw.Write(body)
+ return
+ }
+
// parse the request only once
parsedRequest, err := appsec.NewParsedRequestFromRequest(r, w.logger)
if err != nil {
@@ -391,6 +509,7 @@ func (w *AppsecSource) appsecHandler(rw http.ResponseWriter, r *http.Request) {
logger.Debugf("Response: %+v", appsecResponse)
rw.WriteHeader(statusCode)
+
body, err := json.Marshal(appsecResponse)
if err != nil {
logger.Errorf("unable to serialize response: %s", err)
diff --git a/pkg/acquisition/modules/appsec/appsec_hooks_test.go b/pkg/acquisition/modules/appsec/appsec_hooks_test.go
index c549d2ef1d1..d87384a0189 100644
--- a/pkg/acquisition/modules/appsec/appsec_hooks_test.go
+++ b/pkg/acquisition/modules/appsec/appsec_hooks_test.go
@@ -341,7 +341,6 @@ func TestAppsecOnMatchHooks(t *testing.T) {
}
func TestAppsecPreEvalHooks(t *testing.T) {
-
tests := []appsecRuleTest{
{
name: "Basic pre_eval hook to disable inband rule",
@@ -403,7 +402,6 @@ func TestAppsecPreEvalHooks(t *testing.T) {
require.Len(t, responses, 1)
require.True(t, responses[0].InBandInterrupt)
-
},
},
{
@@ -670,7 +668,6 @@ func TestAppsecPreEvalHooks(t *testing.T) {
}
func TestAppsecRemediationConfigHooks(t *testing.T) {
-
tests := []appsecRuleTest{
{
name: "Basic matching rule",
@@ -759,6 +756,7 @@ func TestAppsecRemediationConfigHooks(t *testing.T) {
})
}
}
+
func TestOnMatchRemediationHooks(t *testing.T) {
tests := []appsecRuleTest{
{
diff --git a/pkg/acquisition/modules/appsec/appsec_rules_test.go b/pkg/acquisition/modules/appsec/appsec_rules_test.go
index 909f16357ed..00093c5a5ad 100644
--- a/pkg/acquisition/modules/appsec/appsec_rules_test.go
+++ b/pkg/acquisition/modules/appsec/appsec_rules_test.go
@@ -28,7 +28,8 @@ func TestAppsecRuleMatches(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
@@ -59,7 +60,8 @@ func TestAppsecRuleMatches(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"tutu"}},
@@ -84,7 +86,8 @@ func TestAppsecRuleMatches(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
@@ -110,7 +113,8 @@ func TestAppsecRuleMatches(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
@@ -136,7 +140,8 @@ func TestAppsecRuleMatches(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
@@ -165,7 +170,8 @@ func TestAppsecRuleMatches(t *testing.T) {
{Filter: "IsInBand == true", Apply: []string{"SetRemediation('captcha')"}},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"bla"}},
@@ -192,7 +198,8 @@ func TestAppsecRuleMatches(t *testing.T) {
{Filter: "IsInBand == true", Apply: []string{"SetReturnCode(418)"}},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"bla"}},
@@ -219,7 +226,8 @@ func TestAppsecRuleMatches(t *testing.T) {
{Filter: "IsInBand == true", Apply: []string{"SetRemediationByName('rule42', 'captcha')"}},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"bla"}},
@@ -243,7 +251,8 @@ func TestAppsecRuleMatches(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/urllll",
Headers: http.Header{"Cookie": []string{"foo=toto"}},
@@ -273,7 +282,8 @@ func TestAppsecRuleMatches(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/urllll",
Headers: http.Header{"Cookie": []string{"foo=toto; bar=tutu"}},
@@ -303,7 +313,8 @@ func TestAppsecRuleMatches(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/urllll",
Headers: http.Header{"Cookie": []string{"bar=tutu; tututata=toto"}},
@@ -333,7 +344,8 @@ func TestAppsecRuleMatches(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/urllll",
Headers: http.Header{"Content-Type": []string{"multipart/form-data; boundary=boundary"}},
@@ -354,6 +366,32 @@ toto
require.Len(t, events[1].Appsec.MatchedRules, 1)
require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"])
+ require.Len(t, responses, 1)
+ require.True(t, responses[0].InBandInterrupt)
+ },
+ },
+ {
+ name: "Basic matching IP address",
+ expected_load_ok: true,
+ inband_native_rules: []string{
+ "SecRule REMOTE_ADDR \"@ipMatch 1.2.3.4\" \"id:1,phase:1,log,deny,msg: 'block ip'\"",
+ },
+ input_request: appsec.ParsedRequest{
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
+ Method: "GET",
+ URI: "/urllll",
+ Headers: http.Header{"Content-Type": []string{"multipart/form-data; boundary=boundary"}},
+ },
+ output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
+ require.Len(t, events, 2)
+ require.Equal(t, types.APPSEC, events[0].Type)
+
+ require.Equal(t, types.LOG, events[1].Type)
+ require.True(t, events[1].Appsec.HasInBandMatches)
+ require.Len(t, events[1].Appsec.MatchedRules, 1)
+ require.Equal(t, "block ip", events[1].Appsec.MatchedRules[0]["msg"])
+
require.Len(t, responses, 1)
require.True(t, responses[0].InBandInterrupt)
},
@@ -381,7 +419,8 @@ func TestAppsecRuleTransforms(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/toto",
},
@@ -404,7 +443,8 @@ func TestAppsecRuleTransforms(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/TOTO",
},
@@ -427,7 +467,8 @@ func TestAppsecRuleTransforms(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/toto",
},
@@ -451,7 +492,8 @@ func TestAppsecRuleTransforms(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/?foo=dG90bw",
},
@@ -475,7 +517,8 @@ func TestAppsecRuleTransforms(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/?foo=dG90bw===",
},
@@ -499,7 +542,8 @@ func TestAppsecRuleTransforms(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/?foo=toto",
},
@@ -523,7 +567,8 @@ func TestAppsecRuleTransforms(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/?foo=%42%42%2F%41",
},
@@ -547,7 +592,8 @@ func TestAppsecRuleTransforms(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/?foo=%20%20%42%42%2F%41%20%20",
},
@@ -585,7 +631,8 @@ func TestAppsecRuleZones(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/foobar?something=toto&foobar=smth",
},
@@ -612,7 +659,8 @@ func TestAppsecRuleZones(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/foobar?something=toto&foobar=smth",
},
@@ -639,7 +687,8 @@ func TestAppsecRuleZones(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/",
Body: []byte("smth=toto&foobar=other"),
@@ -668,7 +717,8 @@ func TestAppsecRuleZones(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/",
Body: []byte("smth=toto&foobar=other"),
@@ -697,7 +747,8 @@ func TestAppsecRuleZones(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/",
Headers: http.Header{"foobar": []string{"toto"}},
@@ -725,7 +776,8 @@ func TestAppsecRuleZones(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/",
Headers: http.Header{"foobar": []string{"toto"}},
@@ -748,7 +800,8 @@ func TestAppsecRuleZones(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/",
},
@@ -770,7 +823,8 @@ func TestAppsecRuleZones(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/",
Proto: "HTTP/3.1",
@@ -793,7 +847,8 @@ func TestAppsecRuleZones(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/foobar",
},
@@ -815,7 +870,8 @@ func TestAppsecRuleZones(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/foobar?a=b",
},
@@ -837,7 +893,8 @@ func TestAppsecRuleZones(t *testing.T) {
},
},
input_request: appsec.ParsedRequest{
- RemoteAddr: "1.2.3.4",
+ ClientIP: "1.2.3.4",
+ RemoteAddr: "127.0.0.1",
Method: "GET",
URI: "/",
Body: []byte("foobar=42421"),
diff --git a/pkg/acquisition/modules/appsec/appsec_runner.go b/pkg/acquisition/modules/appsec/appsec_runner.go
index de34b62d704..d4535d3f9a2 100644
--- a/pkg/acquisition/modules/appsec/appsec_runner.go
+++ b/pkg/acquisition/modules/appsec/appsec_runner.go
@@ -4,6 +4,7 @@ import (
"fmt"
"os"
"slices"
+ "strings"
"time"
"github.com/prometheus/client_golang/prometheus"
@@ -31,23 +32,38 @@ type AppsecRunner struct {
logger *log.Entry
}
+func (r *AppsecRunner) MergeDedupRules(collections []appsec.AppsecCollection, logger *log.Entry) string {
+ var rulesArr []string
+ dedupRules := make(map[string]struct{})
+
+ for _, collection := range collections {
+ for _, rule := range collection.Rules {
+ if _, ok := dedupRules[rule]; !ok {
+ rulesArr = append(rulesArr, rule)
+ dedupRules[rule] = struct{}{}
+ } else {
+ logger.Debugf("Discarding duplicate rule : %s", rule)
+ }
+ }
+ }
+ if len(rulesArr) != len(dedupRules) {
+ logger.Warningf("%d rules were discarded as they were duplicates", len(rulesArr)-len(dedupRules))
+ }
+
+ return strings.Join(rulesArr, "\n")
+}
+
func (r *AppsecRunner) Init(datadir string) error {
var err error
fs := os.DirFS(datadir)
- inBandRules := ""
- outOfBandRules := ""
-
- for _, collection := range r.AppsecRuntime.InBandRules {
- inBandRules += collection.String()
- }
-
- for _, collection := range r.AppsecRuntime.OutOfBandRules {
- outOfBandRules += collection.String()
- }
inBandLogger := r.logger.Dup().WithField("band", "inband")
outBandLogger := r.logger.Dup().WithField("band", "outband")
+ //While loading rules, we dedup rules based on their content, while keeping the order
+ inBandRules := r.MergeDedupRules(r.AppsecRuntime.InBandRules, inBandLogger)
+ outOfBandRules := r.MergeDedupRules(r.AppsecRuntime.OutOfBandRules, outBandLogger)
+
//setting up inband engine
inbandCfg := coraza.NewWAFConfig().WithDirectives(inBandRules).WithRootFS(fs).WithDebugLogger(appsec.NewCrzLogger(inBandLogger))
if !r.AppsecRuntime.Config.InbandOptions.DisableBodyInspection {
@@ -74,6 +90,9 @@ func (r *AppsecRunner) Init(datadir string) error {
outbandCfg = outbandCfg.WithRequestBodyInMemoryLimit(*r.AppsecRuntime.Config.OutOfBandOptions.RequestBodyInMemoryLimit)
}
r.AppsecOutbandEngine, err = coraza.NewWAF(outbandCfg)
+ if err != nil {
+ return fmt.Errorf("unable to initialize outband engine : %w", err)
+ }
if r.AppsecRuntime.DisabledInBandRulesTags != nil {
for _, tag := range r.AppsecRuntime.DisabledInBandRulesTags {
@@ -102,10 +121,6 @@ func (r *AppsecRunner) Init(datadir string) error {
r.logger.Tracef("Loaded inband rules: %+v", r.AppsecInbandEngine.GetRuleGroup().GetRules())
r.logger.Tracef("Loaded outband rules: %+v", r.AppsecOutbandEngine.GetRuleGroup().GetRules())
- if err != nil {
- return fmt.Errorf("unable to initialize outband engine : %w", err)
- }
-
return nil
}
@@ -135,7 +150,7 @@ func (r *AppsecRunner) processRequest(tx appsec.ExtendedTransaction, request *ap
//FIXME: should we abort here ?
}
- request.Tx.ProcessConnection(request.RemoteAddr, 0, "", 0)
+ request.Tx.ProcessConnection(request.ClientIP, 0, "", 0)
for k, v := range request.Args {
for _, vv := range v {
@@ -249,7 +264,7 @@ func (r *AppsecRunner) handleInBandInterrupt(request *appsec.ParsedRequest) {
// Should the in band match trigger an overflow ?
if r.AppsecRuntime.Response.SendAlert {
- appsecOvlfw, err := AppsecEventGeneration(evt)
+ appsecOvlfw, err := AppsecEventGeneration(evt, request.HTTPRequest)
if err != nil {
r.logger.Errorf("unable to generate appsec event : %s", err)
return
@@ -293,7 +308,7 @@ func (r *AppsecRunner) handleOutBandInterrupt(request *appsec.ParsedRequest) {
// Should the match trigger an overflow ?
if r.AppsecRuntime.Response.SendAlert {
- appsecOvlfw, err := AppsecEventGeneration(evt)
+ appsecOvlfw, err := AppsecEventGeneration(evt, request.HTTPRequest)
if err != nil {
r.logger.Errorf("unable to generate appsec event : %s", err)
return
@@ -363,7 +378,6 @@ func (r *AppsecRunner) handleRequest(request *appsec.ParsedRequest) {
// time spent to process inband AND out of band rules
globalParsingElapsed := time.Since(startGlobalParsing)
AppsecGlobalParsingHistogram.With(prometheus.Labels{"source": request.RemoteAddrNormalized, "appsec_engine": request.AppsecEngine}).Observe(globalParsingElapsed.Seconds())
-
}
func (r *AppsecRunner) Run(t *tomb.Tomb) error {
diff --git a/pkg/acquisition/modules/appsec/appsec_runner_test.go b/pkg/acquisition/modules/appsec/appsec_runner_test.go
new file mode 100644
index 00000000000..d07fb153186
--- /dev/null
+++ b/pkg/acquisition/modules/appsec/appsec_runner_test.go
@@ -0,0 +1,153 @@
+package appsecacquisition
+
+import (
+ "testing"
+
+ "github.com/crowdsecurity/crowdsec/pkg/appsec/appsec_rule"
+ log "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAppsecRuleLoad(t *testing.T) {
+ log.SetLevel(log.TraceLevel)
+ tests := []appsecRuleTest{
+ {
+ name: "simple rule load",
+ expected_load_ok: true,
+ inband_rules: []appsec_rule.CustomRule{
+ {
+ Name: "rule1",
+ Zones: []string{"ARGS"},
+ Match: appsec_rule.Match{Type: "equals", Value: "toto"},
+ },
+ },
+ afterload_asserts: func(runner AppsecRunner) {
+ require.Len(t, runner.AppsecInbandEngine.GetRuleGroup().GetRules(), 1)
+ },
+ },
+ {
+ name: "simple native rule load",
+ expected_load_ok: true,
+ inband_native_rules: []string{
+ `Secrule REQUEST_HEADERS:Content-Type "@rx ^application/x-www-form-urlencoded" "id:100,phase:1,pass,nolog,noauditlog,ctl:requestBodyProcessor=URLENCODED"`,
+ },
+ afterload_asserts: func(runner AppsecRunner) {
+ require.Len(t, runner.AppsecInbandEngine.GetRuleGroup().GetRules(), 1)
+ },
+ },
+ {
+ name: "simple native rule load (2)",
+ expected_load_ok: true,
+ inband_native_rules: []string{
+ `Secrule REQUEST_HEADERS:Content-Type "@rx ^application/x-www-form-urlencoded" "id:100,phase:1,pass,nolog,noauditlog,ctl:requestBodyProcessor=URLENCODED"`,
+ `Secrule REQUEST_HEADERS:Content-Type "@rx ^multipart/form-data" "id:101,phase:1,pass,nolog,noauditlog,ctl:requestBodyProcessor=MULTIPART"`,
+ },
+ afterload_asserts: func(runner AppsecRunner) {
+ require.Len(t, runner.AppsecInbandEngine.GetRuleGroup().GetRules(), 2)
+ },
+ },
+ {
+ name: "simple native rule load + dedup",
+ expected_load_ok: true,
+ inband_native_rules: []string{
+ `Secrule REQUEST_HEADERS:Content-Type "@rx ^application/x-www-form-urlencoded" "id:100,phase:1,pass,nolog,noauditlog,ctl:requestBodyProcessor=URLENCODED"`,
+ `Secrule REQUEST_HEADERS:Content-Type "@rx ^multipart/form-data" "id:101,phase:1,pass,nolog,noauditlog,ctl:requestBodyProcessor=MULTIPART"`,
+ `Secrule REQUEST_HEADERS:Content-Type "@rx ^application/x-www-form-urlencoded" "id:100,phase:1,pass,nolog,noauditlog,ctl:requestBodyProcessor=URLENCODED"`,
+ },
+ afterload_asserts: func(runner AppsecRunner) {
+ require.Len(t, runner.AppsecInbandEngine.GetRuleGroup().GetRules(), 2)
+ },
+ },
+ {
+ name: "multi simple rule load",
+ expected_load_ok: true,
+ inband_rules: []appsec_rule.CustomRule{
+ {
+ Name: "rule1",
+ Zones: []string{"ARGS"},
+ Match: appsec_rule.Match{Type: "equals", Value: "toto"},
+ },
+ {
+ Name: "rule2",
+ Zones: []string{"ARGS"},
+ Match: appsec_rule.Match{Type: "equals", Value: "toto"},
+ },
+ },
+ afterload_asserts: func(runner AppsecRunner) {
+ require.Len(t, runner.AppsecInbandEngine.GetRuleGroup().GetRules(), 2)
+ },
+ },
+ {
+ name: "multi simple rule load",
+ expected_load_ok: true,
+ inband_rules: []appsec_rule.CustomRule{
+ {
+ Name: "rule1",
+ Zones: []string{"ARGS"},
+ Match: appsec_rule.Match{Type: "equals", Value: "toto"},
+ },
+ {
+ Name: "rule2",
+ Zones: []string{"ARGS"},
+ Match: appsec_rule.Match{Type: "equals", Value: "toto"},
+ },
+ },
+ afterload_asserts: func(runner AppsecRunner) {
+ require.Len(t, runner.AppsecInbandEngine.GetRuleGroup().GetRules(), 2)
+ },
+ },
+ {
+ name: "imbricated rule load",
+ expected_load_ok: true,
+ inband_rules: []appsec_rule.CustomRule{
+ {
+ Name: "rule1",
+
+ Or: []appsec_rule.CustomRule{
+ {
+ //Name: "rule1",
+ Zones: []string{"ARGS"},
+ Match: appsec_rule.Match{Type: "equals", Value: "toto"},
+ },
+ {
+ //Name: "rule1",
+ Zones: []string{"ARGS"},
+ Match: appsec_rule.Match{Type: "equals", Value: "tutu"},
+ },
+ {
+ //Name: "rule1",
+ Zones: []string{"ARGS"},
+ Match: appsec_rule.Match{Type: "equals", Value: "tata"},
+ }, {
+ //Name: "rule1",
+ Zones: []string{"ARGS"},
+ Match: appsec_rule.Match{Type: "equals", Value: "titi"},
+ },
+ },
+ },
+ },
+ afterload_asserts: func(runner AppsecRunner) {
+ require.Len(t, runner.AppsecInbandEngine.GetRuleGroup().GetRules(), 4)
+ },
+ },
+ {
+ name: "invalid inband rule",
+ expected_load_ok: false,
+ inband_native_rules: []string{
+ "this_is_not_a_rule",
+ },
+ },
+ {
+ name: "invalid outofband rule",
+ expected_load_ok: false,
+ outofband_native_rules: []string{
+ "this_is_not_a_rule",
+ },
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ loadAppSecEngine(test, t)
+ })
+ }
+}
diff --git a/pkg/acquisition/modules/appsec/appsec_test.go b/pkg/acquisition/modules/appsec/appsec_test.go
index d2079b43726..c0af1002f49 100644
--- a/pkg/acquisition/modules/appsec/appsec_test.go
+++ b/pkg/acquisition/modules/appsec/appsec_test.go
@@ -18,6 +18,8 @@ type appsecRuleTest struct {
expected_load_ok bool
inband_rules []appsec_rule.CustomRule
outofband_rules []appsec_rule.CustomRule
+ inband_native_rules []string
+ outofband_native_rules []string
on_load []appsec.Hook
pre_eval []appsec.Hook
post_eval []appsec.Hook
@@ -28,6 +30,7 @@ type appsecRuleTest struct {
DefaultRemediation string
DefaultPassAction string
input_request appsec.ParsedRequest
+ afterload_asserts func(runner AppsecRunner)
output_asserts func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int)
}
@@ -53,6 +56,8 @@ func loadAppSecEngine(test appsecRuleTest, t *testing.T) {
inbandRules = append(inbandRules, strRule)
}
+ inbandRules = append(inbandRules, test.inband_native_rules...)
+ outofbandRules = append(outofbandRules, test.outofband_native_rules...)
for ridx, rule := range test.outofband_rules {
strRule, _, err := rule.Convert(appsec_rule.ModsecurityRuleType, rule.Name)
if err != nil {
@@ -61,7 +66,8 @@ func loadAppSecEngine(test appsecRuleTest, t *testing.T) {
outofbandRules = append(outofbandRules, strRule)
}
- appsecCfg := appsec.AppsecConfig{Logger: logger,
+ appsecCfg := appsec.AppsecConfig{
+ Logger: logger,
OnLoad: test.on_load,
PreEval: test.pre_eval,
PostEval: test.post_eval,
@@ -70,7 +76,8 @@ func loadAppSecEngine(test appsecRuleTest, t *testing.T) {
UserBlockedHTTPCode: test.UserBlockedHTTPCode,
UserPassedHTTPCode: test.UserPassedHTTPCode,
DefaultRemediation: test.DefaultRemediation,
- DefaultPassAction: test.DefaultPassAction}
+ DefaultPassAction: test.DefaultPassAction,
+ }
AppsecRuntime, err := appsecCfg.Build()
if err != nil {
t.Fatalf("unable to build appsec runtime : %s", err)
@@ -91,8 +98,21 @@ func loadAppSecEngine(test appsecRuleTest, t *testing.T) {
}
err = runner.Init("/tmp/")
if err != nil {
+ if !test.expected_load_ok {
+ return
+ }
t.Fatalf("unable to initialize runner : %s", err)
}
+ if !test.expected_load_ok {
+ t.Fatalf("expected load to fail but it didn't")
+ }
+
+ if test.afterload_asserts != nil {
+ //afterload asserts are just to evaluate the state of the runner after the rules have been loaded
+ //if it's present, don't try to process requests
+ test.afterload_asserts(runner)
+ return
+ }
input := test.input_request
input.ResponseChannel = make(chan appsec.AppsecTempResponse)
diff --git a/pkg/acquisition/modules/appsec/bodyprocessors/raw.go b/pkg/acquisition/modules/appsec/bodyprocessors/raw.go
index e2e23eb57ae..aa467ecf048 100644
--- a/pkg/acquisition/modules/appsec/bodyprocessors/raw.go
+++ b/pkg/acquisition/modules/appsec/bodyprocessors/raw.go
@@ -9,8 +9,7 @@ import (
"github.com/crowdsecurity/coraza/v3/experimental/plugins/plugintypes"
)
-type rawBodyProcessor struct {
-}
+type rawBodyProcessor struct{}
type setterInterface interface {
Set(string)
@@ -33,9 +32,7 @@ func (*rawBodyProcessor) ProcessResponse(reader io.Reader, v plugintypes.Transac
return nil
}
-var (
- _ plugintypes.BodyProcessor = &rawBodyProcessor{}
-)
+var _ plugintypes.BodyProcessor = &rawBodyProcessor{}
//nolint:gochecknoinits //Coraza recommends to use init() for registering plugins
func init() {
diff --git a/pkg/acquisition/modules/appsec/utils.go b/pkg/acquisition/modules/appsec/utils.go
index 4fb1a979d14..8995b305680 100644
--- a/pkg/acquisition/modules/appsec/utils.go
+++ b/pkg/acquisition/modules/appsec/utils.go
@@ -1,10 +1,10 @@
package appsecacquisition
import (
+ "errors"
"fmt"
"net"
- "slices"
- "strconv"
+ "net/http"
"time"
"github.com/oschwald/geoip2-golang"
@@ -22,29 +22,44 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/types"
)
-var appsecMetaKeys = []string{
- "id",
- "name",
- "method",
- "uri",
- "matched_zones",
- "msg",
-}
+func AppsecEventGenerationGeoIPEnrich(src *models.Source) error {
-func appendMeta(meta models.Meta, key string, value string) models.Meta {
- if value == "" {
- return meta
+ if src == nil || src.Scope == nil || *src.Scope != types.Ip {
+ return errors.New("source is nil or not an IP")
}
- meta = append(meta, &models.MetaItems0{
- Key: key,
- Value: value,
- })
+ //GeoIP enrich
+ asndata, err := exprhelpers.GeoIPASNEnrich(src.IP)
+
+ if err != nil {
+ return err
+ } else if asndata != nil {
+ record := asndata.(*geoip2.ASN)
+ src.AsName = record.AutonomousSystemOrganization
+ src.AsNumber = fmt.Sprintf("%d", record.AutonomousSystemNumber)
+ }
- return meta
+ cityData, err := exprhelpers.GeoIPEnrich(src.IP)
+ if err != nil {
+ return err
+ } else if cityData != nil {
+ record := cityData.(*geoip2.City)
+ src.Cn = record.Country.IsoCode
+ src.Latitude = float32(record.Location.Latitude)
+ src.Longitude = float32(record.Location.Longitude)
+ }
+
+ rangeData, err := exprhelpers.GeoIPRangeEnrich(src.IP)
+ if err != nil {
+ return err
+ } else if rangeData != nil {
+ record := rangeData.(*net.IPNet)
+ src.Range = record.String()
+ }
+ return nil
}
-func AppsecEventGeneration(inEvt types.Event) (*types.Event, error) {
+func AppsecEventGeneration(inEvt types.Event, request *http.Request) (*types.Event, error) {
// if the request didnd't trigger inband rules, we don't want to generate an event to LAPI/CAPI
if !inEvt.Appsec.HasInBandMatches {
return nil, nil
@@ -60,34 +75,12 @@ func AppsecEventGeneration(inEvt types.Event) (*types.Event, error) {
Scope: ptr.Of(types.Ip),
}
- asndata, err := exprhelpers.GeoIPASNEnrich(sourceIP)
-
- if err != nil {
- log.Errorf("Unable to enrich ip '%s' for ASN: %s", sourceIP, err)
- } else if asndata != nil {
- record := asndata.(*geoip2.ASN)
- source.AsName = record.AutonomousSystemOrganization
- source.AsNumber = fmt.Sprintf("%d", record.AutonomousSystemNumber)
- }
-
- cityData, err := exprhelpers.GeoIPEnrich(sourceIP)
- if err != nil {
- log.Errorf("Unable to enrich ip '%s' for geo data: %s", sourceIP, err)
- } else if cityData != nil {
- record := cityData.(*geoip2.City)
- source.Cn = record.Country.IsoCode
- source.Latitude = float32(record.Location.Latitude)
- source.Longitude = float32(record.Location.Longitude)
- }
-
- rangeData, err := exprhelpers.GeoIPRangeEnrich(sourceIP)
- if err != nil {
- log.Errorf("Unable to enrich ip '%s' for range: %s", sourceIP, err)
- } else if rangeData != nil {
- record := rangeData.(*net.IPNet)
- source.Range = record.String()
+ // Enrich source with GeoIP data
+ if err := AppsecEventGenerationGeoIPEnrich(&source); err != nil {
+ log.Errorf("unable to enrich source with GeoIP data : %s", err)
}
+ // Build overflow
evt.Overflow.Sources = make(map[string]models.Source)
evt.Overflow.Sources[sourceIP] = source
@@ -95,83 +88,11 @@ func AppsecEventGeneration(inEvt types.Event) (*types.Event, error) {
alert.Capacity = ptr.Of(int32(1))
alert.Events = make([]*models.Event, len(evt.Appsec.GetRuleIDs()))
- now := ptr.Of(time.Now().UTC().Format(time.RFC3339))
-
- tmpAppsecContext := make(map[string][]string)
-
- for _, matched_rule := range inEvt.Appsec.MatchedRules {
- evtRule := models.Event{}
-
- evtRule.Timestamp = now
-
- evtRule.Meta = make(models.Meta, 0)
-
- for _, key := range appsecMetaKeys {
- if tmpAppsecContext[key] == nil {
- tmpAppsecContext[key] = make([]string, 0)
- }
-
- switch value := matched_rule[key].(type) {
- case string:
- evtRule.Meta = appendMeta(evtRule.Meta, key, value)
-
- if value != "" && !slices.Contains(tmpAppsecContext[key], value) {
- tmpAppsecContext[key] = append(tmpAppsecContext[key], value)
- }
- case int:
- val := strconv.Itoa(value)
- evtRule.Meta = appendMeta(evtRule.Meta, key, val)
-
- if val != "" && !slices.Contains(tmpAppsecContext[key], val) {
- tmpAppsecContext[key] = append(tmpAppsecContext[key], val)
- }
- case []string:
- for _, v := range value {
- evtRule.Meta = appendMeta(evtRule.Meta, key, v)
-
- if v != "" && !slices.Contains(tmpAppsecContext[key], v) {
- tmpAppsecContext[key] = append(tmpAppsecContext[key], v)
- }
- }
- case []int:
- for _, v := range value {
- val := strconv.Itoa(v)
- evtRule.Meta = appendMeta(evtRule.Meta, key, val)
-
- if val != "" && !slices.Contains(tmpAppsecContext[key], val) {
- tmpAppsecContext[key] = append(tmpAppsecContext[key], val)
- }
- }
- default:
- val := fmt.Sprintf("%v", value)
- evtRule.Meta = appendMeta(evtRule.Meta, key, val)
-
- if val != "" && !slices.Contains(tmpAppsecContext[key], val) {
- tmpAppsecContext[key] = append(tmpAppsecContext[key], val)
- }
- }
- }
-
- alert.Events = append(alert.Events, &evtRule)
- }
-
- metas := make([]*models.MetaItems0, 0)
-
- for key, values := range tmpAppsecContext {
- if len(values) == 0 {
- continue
- }
-
- valueStr, err := alertcontext.TruncateContext(values, alertcontext.MaxContextValueLen)
- if err != nil {
- log.Warning(err.Error())
+ metas, errors := alertcontext.AppsecEventToContext(inEvt.Appsec, request)
+ if len(errors) > 0 {
+ for _, err := range errors {
+ log.Errorf("failed to generate appsec context: %s", err)
}
-
- meta := models.MetaItems0{
- Key: key,
- Value: valueStr,
- }
- metas = append(metas, &meta)
}
alert.Meta = metas
@@ -195,10 +116,7 @@ func AppsecEventGeneration(inEvt types.Event) (*types.Event, error) {
}
func EventFromRequest(r *appsec.ParsedRequest, labels map[string]string) (types.Event, error) {
- evt := types.Event{}
- // we might want to change this based on in-band vs out-of-band ?
- evt.Type = types.LOG
- evt.ExpectMode = types.LIVE
+ evt := types.MakeEvent(false, types.LOG, true)
// def needs fixing
evt.Stage = "s00-raw"
evt.Parsed = map[string]string{
diff --git a/pkg/acquisition/modules/cloudwatch/cloudwatch.go b/pkg/acquisition/modules/cloudwatch/cloudwatch.go
index 2df70b3312b..ba267c9050b 100644
--- a/pkg/acquisition/modules/cloudwatch/cloudwatch.go
+++ b/pkg/acquisition/modules/cloudwatch/cloudwatch.go
@@ -710,7 +710,7 @@ func (cw *CloudwatchSource) CatLogStream(ctx context.Context, cfg *LogStreamTail
func cwLogToEvent(log *cloudwatchlogs.OutputLogEvent, cfg *LogStreamTailConfig) (types.Event, error) {
l := types.Line{}
- evt := types.Event{}
+ evt := types.MakeEvent(cfg.ExpectMode == types.TIMEMACHINE, types.LOG, true)
if log.Message == nil {
return evt, errors.New("nil message")
}
@@ -726,9 +726,6 @@ func cwLogToEvent(log *cloudwatchlogs.OutputLogEvent, cfg *LogStreamTailConfig)
l.Process = true
l.Module = "cloudwatch"
evt.Line = l
- evt.Process = true
- evt.Type = types.LOG
- evt.ExpectMode = cfg.ExpectMode
cfg.logger.Debugf("returned event labels : %+v", evt.Line.Labels)
return evt, nil
}
diff --git a/pkg/acquisition/modules/docker/docker.go b/pkg/acquisition/modules/docker/docker.go
index 2f79d4dcee6..b27255ec13f 100644
--- a/pkg/acquisition/modules/docker/docker.go
+++ b/pkg/acquisition/modules/docker/docker.go
@@ -334,7 +334,10 @@ func (d *DockerSource) OneShotAcquisition(ctx context.Context, out chan types.Ev
if d.metricsLevel != configuration.METRICS_NONE {
linesRead.With(prometheus.Labels{"source": containerConfig.Name}).Inc()
}
- evt := types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE}
+ evt := types.MakeEvent(true, types.LOG, true)
+ evt.Line = l
+ evt.Process = true
+ evt.Type = types.LOG
out <- evt
d.logger.Debugf("Sent line to parsing: %+v", evt.Line.Raw)
}
@@ -579,12 +582,8 @@ func (d *DockerSource) TailDocker(ctx context.Context, container *ContainerConfi
l.Src = container.Name
l.Process = true
l.Module = d.GetName()
- var evt types.Event
- if !d.Config.UseTimeMachine {
- evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE}
- } else {
- evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE}
- }
+ evt := types.MakeEvent(d.Config.UseTimeMachine, types.LOG, true)
+ evt.Line = l
linesRead.With(prometheus.Labels{"source": container.Name}).Inc()
outChan <- evt
d.logger.Debugf("Sent line to parsing: %+v", evt.Line.Raw)
diff --git a/pkg/acquisition/modules/file/file.go b/pkg/acquisition/modules/file/file.go
index f752d04aada..9f439b0c82e 100644
--- a/pkg/acquisition/modules/file/file.go
+++ b/pkg/acquisition/modules/file/file.go
@@ -621,11 +621,9 @@ func (f *FileSource) tailFile(out chan types.Event, t *tomb.Tomb, tail *tail.Tai
// we're tailing, it must be real time logs
logger.Debugf("pushing %+v", l)
- expectMode := types.LIVE
- if f.config.UseTimeMachine {
- expectMode = types.TIMEMACHINE
- }
- out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: expectMode}
+ evt := types.MakeEvent(f.config.UseTimeMachine, types.LOG, true)
+ evt.Line = l
+ out <- evt
}
}
}
@@ -684,7 +682,7 @@ func (f *FileSource) readFile(filename string, out chan types.Event, t *tomb.Tom
linesRead.With(prometheus.Labels{"source": filename}).Inc()
// we're reading logs at once, it must be time-machine buckets
- out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE}
+ out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE, Unmarshaled: make(map[string]interface{})}
}
}
diff --git a/pkg/acquisition/modules/http/http.go b/pkg/acquisition/modules/http/http.go
new file mode 100644
index 00000000000..6bb8228f32c
--- /dev/null
+++ b/pkg/acquisition/modules/http/http.go
@@ -0,0 +1,414 @@
+package httpacquisition
+
+import (
+ "compress/gzip"
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "os"
+ "time"
+
+ "github.com/prometheus/client_golang/prometheus"
+ log "github.com/sirupsen/logrus"
+
+ "gopkg.in/tomb.v2"
+ "gopkg.in/yaml.v3"
+
+ "github.com/crowdsecurity/go-cs-lib/trace"
+
+ "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration"
+ "github.com/crowdsecurity/crowdsec/pkg/types"
+)
+
+var dataSourceName = "http"
+
+var linesRead = prometheus.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "cs_httpsource_hits_total",
+ Help: "Total lines that were read from http source",
+ },
+ []string{"path", "src"})
+
+type HttpConfiguration struct {
+ //IPFilter []string `yaml:"ip_filter"`
+ //ChunkSize *int64 `yaml:"chunk_size"`
+ ListenAddr string `yaml:"listen_addr"`
+ Path string `yaml:"path"`
+ AuthType string `yaml:"auth_type"`
+ BasicAuth *BasicAuthConfig `yaml:"basic_auth"`
+ Headers *map[string]string `yaml:"headers"`
+ TLS *TLSConfig `yaml:"tls"`
+ CustomStatusCode *int `yaml:"custom_status_code"`
+ CustomHeaders *map[string]string `yaml:"custom_headers"`
+ MaxBodySize *int64 `yaml:"max_body_size"`
+ Timeout *time.Duration `yaml:"timeout"`
+ configuration.DataSourceCommonCfg `yaml:",inline"`
+}
+
+type BasicAuthConfig struct {
+ Username string `yaml:"username"`
+ Password string `yaml:"password"`
+}
+
+type TLSConfig struct {
+ InsecureSkipVerify bool `yaml:"insecure_skip_verify"`
+ ServerCert string `yaml:"server_cert"`
+ ServerKey string `yaml:"server_key"`
+ CaCert string `yaml:"ca_cert"`
+}
+
+type HTTPSource struct {
+ metricsLevel int
+ Config HttpConfiguration
+ logger *log.Entry
+ Server *http.Server
+}
+
+func (h *HTTPSource) GetUuid() string {
+ return h.Config.UniqueId
+}
+
+func (h *HTTPSource) UnmarshalConfig(yamlConfig []byte) error {
+ h.Config = HttpConfiguration{}
+ err := yaml.Unmarshal(yamlConfig, &h.Config)
+ if err != nil {
+ return fmt.Errorf("cannot parse %s datasource configuration: %w", dataSourceName, err)
+ }
+
+ if h.Config.Mode == "" {
+ h.Config.Mode = configuration.TAIL_MODE
+ }
+
+ return nil
+}
+
+func (hc *HttpConfiguration) Validate() error {
+ if hc.ListenAddr == "" {
+ return errors.New("listen_addr is required")
+ }
+
+ if hc.Path == "" {
+ hc.Path = "/"
+ }
+ if hc.Path[0] != '/' {
+ return errors.New("path must start with /")
+ }
+
+ switch hc.AuthType {
+ case "basic_auth":
+ baseErr := "basic_auth is selected, but"
+ if hc.BasicAuth == nil {
+ return errors.New(baseErr + " basic_auth is not provided")
+ }
+ if hc.BasicAuth.Username == "" {
+ return errors.New(baseErr + " username is not provided")
+ }
+ if hc.BasicAuth.Password == "" {
+ return errors.New(baseErr + " password is not provided")
+ }
+ case "headers":
+ if hc.Headers == nil {
+ return errors.New("headers is selected, but headers is not provided")
+ }
+ case "mtls":
+ if hc.TLS == nil || hc.TLS.CaCert == "" {
+ return errors.New("mtls is selected, but ca_cert is not provided")
+ }
+ default:
+ return errors.New("invalid auth_type: must be one of basic_auth, headers, mtls")
+ }
+
+ if hc.TLS != nil {
+ if hc.TLS.ServerCert == "" {
+ return errors.New("server_cert is required")
+ }
+ if hc.TLS.ServerKey == "" {
+ return errors.New("server_key is required")
+ }
+ }
+
+ if hc.MaxBodySize != nil && *hc.MaxBodySize <= 0 {
+ return errors.New("max_body_size must be positive")
+ }
+
+ /*
+ if hc.ChunkSize != nil && *hc.ChunkSize <= 0 {
+ return errors.New("chunk_size must be positive")
+ }
+ */
+
+ if hc.CustomStatusCode != nil {
+ statusText := http.StatusText(*hc.CustomStatusCode)
+ if statusText == "" {
+ return errors.New("invalid HTTP status code")
+ }
+ }
+
+ return nil
+}
+
+func (h *HTTPSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLevel int) error {
+ h.logger = logger
+ h.metricsLevel = MetricsLevel
+ err := h.UnmarshalConfig(yamlConfig)
+ if err != nil {
+ return err
+ }
+
+ if err := h.Config.Validate(); err != nil {
+ return fmt.Errorf("invalid configuration: %w", err)
+ }
+
+ return nil
+}
+
+func (h *HTTPSource) ConfigureByDSN(string, map[string]string, *log.Entry, string) error {
+ return fmt.Errorf("%s datasource does not support command-line acquisition", dataSourceName)
+}
+
+func (h *HTTPSource) GetMode() string {
+ return h.Config.Mode
+}
+
+func (h *HTTPSource) GetName() string {
+ return dataSourceName
+}
+
+func (h *HTTPSource) OneShotAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error {
+ return fmt.Errorf("%s datasource does not support one-shot acquisition", dataSourceName)
+}
+
+func (h *HTTPSource) CanRun() error {
+ return nil
+}
+
+func (h *HTTPSource) GetMetrics() []prometheus.Collector {
+ return []prometheus.Collector{linesRead}
+}
+
+func (h *HTTPSource) GetAggregMetrics() []prometheus.Collector {
+ return []prometheus.Collector{linesRead}
+}
+
+func (h *HTTPSource) Dump() interface{} {
+ return h
+}
+
+func (hc *HttpConfiguration) NewTLSConfig() (*tls.Config, error) {
+ tlsConfig := tls.Config{
+ InsecureSkipVerify: hc.TLS.InsecureSkipVerify,
+ }
+
+ if hc.TLS.ServerCert != "" && hc.TLS.ServerKey != "" {
+ cert, err := tls.LoadX509KeyPair(hc.TLS.ServerCert, hc.TLS.ServerKey)
+ if err != nil {
+ return nil, fmt.Errorf("failed to load server cert/key: %w", err)
+ }
+ tlsConfig.Certificates = []tls.Certificate{cert}
+ }
+
+ if hc.AuthType == "mtls" && hc.TLS.CaCert != "" {
+ caCert, err := os.ReadFile(hc.TLS.CaCert)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read ca cert: %w", err)
+ }
+
+ caCertPool, err := x509.SystemCertPool()
+ if err != nil {
+ return nil, fmt.Errorf("failed to load system cert pool: %w", err)
+ }
+
+ if caCertPool == nil {
+ caCertPool = x509.NewCertPool()
+ }
+ caCertPool.AppendCertsFromPEM(caCert)
+ tlsConfig.ClientCAs = caCertPool
+ tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
+ }
+
+ return &tlsConfig, nil
+}
+
+func authorizeRequest(r *http.Request, hc *HttpConfiguration) error {
+ if hc.AuthType == "basic_auth" {
+ username, password, ok := r.BasicAuth()
+ if !ok {
+ return errors.New("missing basic auth")
+ }
+ if username != hc.BasicAuth.Username || password != hc.BasicAuth.Password {
+ return errors.New("invalid basic auth")
+ }
+ }
+ if hc.AuthType == "headers" {
+ for key, value := range *hc.Headers {
+ if r.Header.Get(key) != value {
+ return errors.New("invalid headers")
+ }
+ }
+ }
+ return nil
+}
+
+func (h *HTTPSource) processRequest(w http.ResponseWriter, r *http.Request, hc *HttpConfiguration, out chan types.Event) error {
+ if hc.MaxBodySize != nil && r.ContentLength > *hc.MaxBodySize {
+ w.WriteHeader(http.StatusRequestEntityTooLarge)
+ return fmt.Errorf("body size exceeds max body size: %d > %d", r.ContentLength, *hc.MaxBodySize)
+ }
+
+ srcHost, _, err := net.SplitHostPort(r.RemoteAddr)
+ if err != nil {
+ return err
+ }
+
+ defer r.Body.Close()
+
+ reader := r.Body
+
+ if r.Header.Get("Content-Encoding") == "gzip" {
+ reader, err = gzip.NewReader(r.Body)
+ if err != nil {
+ w.WriteHeader(http.StatusBadRequest)
+ return fmt.Errorf("failed to create gzip reader: %w", err)
+ }
+ defer reader.Close()
+ }
+
+ decoder := json.NewDecoder(reader)
+ for {
+ var message json.RawMessage
+
+ if err := decoder.Decode(&message); err != nil {
+ if err == io.EOF {
+ break
+ }
+ w.WriteHeader(http.StatusBadRequest)
+ return fmt.Errorf("failed to decode: %w", err)
+ }
+
+ line := types.Line{
+ Raw: string(message),
+ Src: srcHost,
+ Time: time.Now().UTC(),
+ Labels: hc.Labels,
+ Process: true,
+ Module: h.GetName(),
+ }
+
+ if h.metricsLevel == configuration.METRICS_AGGREGATE {
+ line.Src = hc.Path
+ }
+
+ evt := types.MakeEvent(h.Config.UseTimeMachine, types.LOG, true)
+ evt.Line = line
+
+ if h.metricsLevel == configuration.METRICS_AGGREGATE {
+ linesRead.With(prometheus.Labels{"path": hc.Path, "src": ""}).Inc()
+ } else if h.metricsLevel == configuration.METRICS_FULL {
+ linesRead.With(prometheus.Labels{"path": hc.Path, "src": srcHost}).Inc()
+ }
+
+ h.logger.Tracef("line to send: %+v", line)
+ out <- evt
+ }
+
+ return nil
+}
+
+func (h *HTTPSource) RunServer(out chan types.Event, t *tomb.Tomb) error {
+ mux := http.NewServeMux()
+ mux.HandleFunc(h.Config.Path, func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ h.logger.Errorf("method not allowed: %s", r.Method)
+ http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
+ return
+ }
+ if err := authorizeRequest(r, &h.Config); err != nil {
+ h.logger.Errorf("failed to authorize request from '%s': %s", r.RemoteAddr, err)
+ http.Error(w, "Unauthorized", http.StatusUnauthorized)
+ return
+ }
+ err := h.processRequest(w, r, &h.Config, out)
+ if err != nil {
+ h.logger.Errorf("failed to process request from '%s': %s", r.RemoteAddr, err)
+ return
+ }
+
+ if h.Config.CustomHeaders != nil {
+ for key, value := range *h.Config.CustomHeaders {
+ w.Header().Set(key, value)
+ }
+ }
+ if h.Config.CustomStatusCode != nil {
+ w.WriteHeader(*h.Config.CustomStatusCode)
+ } else {
+ w.WriteHeader(http.StatusOK)
+ }
+
+ w.Write([]byte("OK"))
+ })
+
+ h.Server = &http.Server{
+ Addr: h.Config.ListenAddr,
+ Handler: mux,
+ }
+
+ if h.Config.Timeout != nil {
+ h.Server.ReadTimeout = *h.Config.Timeout
+ }
+
+ if h.Config.TLS != nil {
+ tlsConfig, err := h.Config.NewTLSConfig()
+ if err != nil {
+ return fmt.Errorf("failed to create tls config: %w", err)
+ }
+ h.logger.Tracef("tls config: %+v", tlsConfig)
+ h.Server.TLSConfig = tlsConfig
+ }
+
+ t.Go(func() error {
+ defer trace.CatchPanic("crowdsec/acquis/http/server")
+ if h.Config.TLS != nil {
+ h.logger.Infof("start https server on %s", h.Config.ListenAddr)
+ err := h.Server.ListenAndServeTLS(h.Config.TLS.ServerCert, h.Config.TLS.ServerKey)
+ if err != nil && err != http.ErrServerClosed {
+ return fmt.Errorf("https server failed: %w", err)
+ }
+ } else {
+ h.logger.Infof("start http server on %s", h.Config.ListenAddr)
+ err := h.Server.ListenAndServe()
+ if err != nil && err != http.ErrServerClosed {
+ return fmt.Errorf("http server failed: %w", err)
+ }
+ }
+ return nil
+ })
+
+ //nolint //fp
+ for {
+ select {
+ case <-t.Dying():
+ h.logger.Infof("%s datasource stopping", dataSourceName)
+ if err := h.Server.Close(); err != nil {
+ return fmt.Errorf("while closing %s server: %w", dataSourceName, err)
+ }
+ return nil
+ }
+ }
+}
+
+func (h *HTTPSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error {
+ h.logger.Debugf("start http server on %s", h.Config.ListenAddr)
+
+ t.Go(func() error {
+ defer trace.CatchPanic("crowdsec/acquis/http/live")
+ return h.RunServer(out, t)
+ })
+
+ return nil
+}
diff --git a/pkg/acquisition/modules/http/http_test.go b/pkg/acquisition/modules/http/http_test.go
new file mode 100644
index 00000000000..4d99134419f
--- /dev/null
+++ b/pkg/acquisition/modules/http/http_test.go
@@ -0,0 +1,784 @@
+package httpacquisition
+
+import (
+ "compress/gzip"
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/crowdsecurity/crowdsec/pkg/types"
+ "github.com/crowdsecurity/go-cs-lib/cstest"
+ "github.com/prometheus/client_golang/prometheus"
+ log "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "gopkg.in/tomb.v2"
+)
+
+const (
+ testHTTPServerAddr = "http://127.0.0.1:8080"
+ testHTTPServerAddrTLS = "https://127.0.0.1:8080"
+)
+
+func TestConfigure(t *testing.T) {
+ tests := []struct {
+ config string
+ expectedErr string
+ }{
+ {
+ config: `
+foobar: bla`,
+ expectedErr: "invalid configuration: listen_addr is required",
+ },
+ {
+ config: `
+source: http
+listen_addr: 127.0.0.1:8080
+path: wrongpath`,
+ expectedErr: "invalid configuration: path must start with /",
+ },
+ {
+ config: `
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: basic_auth`,
+ expectedErr: "invalid configuration: basic_auth is selected, but basic_auth is not provided",
+ },
+ {
+ config: `
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: headers`,
+ expectedErr: "invalid configuration: headers is selected, but headers is not provided",
+ },
+ {
+ config: `
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: basic_auth
+basic_auth:
+ username: 132`,
+ expectedErr: "invalid configuration: basic_auth is selected, but password is not provided",
+ },
+ {
+ config: `
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: basic_auth
+basic_auth:
+ password: 132`,
+ expectedErr: "invalid configuration: basic_auth is selected, but username is not provided",
+ },
+ {
+ config: `
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: headers
+headers:`,
+ expectedErr: "invalid configuration: headers is selected, but headers is not provided",
+ },
+ {
+ config: `
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: toto`,
+ expectedErr: "invalid configuration: invalid auth_type: must be one of basic_auth, headers, mtls",
+ },
+ {
+ config: `
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: headers
+headers:
+ key: value
+tls:
+ server_key: key`,
+ expectedErr: "invalid configuration: server_cert is required",
+ },
+ {
+ config: `
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: headers
+headers:
+ key: value
+tls:
+ server_cert: cert`,
+ expectedErr: "invalid configuration: server_key is required",
+ },
+ {
+ config: `
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: mtls
+tls:
+ server_cert: cert
+ server_key: key`,
+ expectedErr: "invalid configuration: mtls is selected, but ca_cert is not provided",
+ },
+ {
+ config: `
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: headers
+headers:
+ key: value
+max_body_size: 0`,
+ expectedErr: "invalid configuration: max_body_size must be positive",
+ },
+ {
+ config: `
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: headers
+headers:
+ key: value
+timeout: toto`,
+ expectedErr: "cannot parse http datasource configuration: yaml: unmarshal errors:\n line 8: cannot unmarshal !!str `toto` into time.Duration",
+ },
+ {
+ config: `
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: headers
+headers:
+ key: value
+custom_status_code: 999`,
+ expectedErr: "invalid configuration: invalid HTTP status code",
+ },
+ }
+
+ subLogger := log.WithFields(log.Fields{
+ "type": "http",
+ })
+
+ for _, test := range tests {
+ h := HTTPSource{}
+ err := h.Configure([]byte(test.config), subLogger, 0)
+ cstest.AssertErrorContains(t, err, test.expectedErr)
+ }
+}
+
+func TestGetUuid(t *testing.T) {
+ h := HTTPSource{}
+ h.Config.UniqueId = "test"
+ assert.Equal(t, "test", h.GetUuid())
+}
+
+func TestUnmarshalConfig(t *testing.T) {
+ h := HTTPSource{}
+ err := h.UnmarshalConfig([]byte(`
+source: http
+listen_addr: 127.0.0.1:8080
+path: 15
+ auth_type: headers`))
+ cstest.AssertErrorMessage(t, err, "cannot parse http datasource configuration: yaml: line 4: found a tab character that violates indentation")
+}
+
+func TestConfigureByDSN(t *testing.T) {
+ h := HTTPSource{}
+ err := h.ConfigureByDSN("http://localhost:8080/test", map[string]string{}, log.WithFields(log.Fields{
+ "type": "http",
+ }), "test")
+ cstest.AssertErrorMessage(
+ t,
+ err,
+ "http datasource does not support command-line acquisition",
+ )
+}
+
+func TestGetMode(t *testing.T) {
+ h := HTTPSource{}
+ h.Config.Mode = "test"
+ assert.Equal(t, "test", h.GetMode())
+}
+
+func TestGetName(t *testing.T) {
+ h := HTTPSource{}
+ assert.Equal(t, "http", h.GetName())
+}
+
+func SetupAndRunHTTPSource(t *testing.T, h *HTTPSource, config []byte, metricLevel int) (chan types.Event, *tomb.Tomb) {
+ ctx := context.Background()
+ subLogger := log.WithFields(log.Fields{
+ "type": "http",
+ })
+ err := h.Configure(config, subLogger, metricLevel)
+ require.NoError(t, err)
+ tomb := tomb.Tomb{}
+ out := make(chan types.Event)
+ err = h.StreamingAcquisition(ctx, out, &tomb)
+ require.NoError(t, err)
+
+ for _, metric := range h.GetMetrics() {
+ prometheus.Register(metric)
+ }
+
+ return out, &tomb
+}
+
+func TestStreamingAcquisitionWrongHTTPMethod(t *testing.T) {
+ h := &HTTPSource{}
+ _, tomb := SetupAndRunHTTPSource(t, h, []byte(`
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: basic_auth
+basic_auth:
+ username: test
+ password: test`), 0)
+
+ time.Sleep(1 * time.Second)
+
+ res, err := http.Get(fmt.Sprintf("%s/test", testHTTPServerAddr))
+ require.NoError(t, err)
+ assert.Equal(t, http.StatusMethodNotAllowed, res.StatusCode)
+
+ h.Server.Close()
+ tomb.Kill(nil)
+ tomb.Wait()
+}
+
+func TestStreamingAcquisitionUnknownPath(t *testing.T) {
+ h := &HTTPSource{}
+ _, tomb := SetupAndRunHTTPSource(t, h, []byte(`
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: basic_auth
+basic_auth:
+ username: test
+ password: test`), 0)
+
+ time.Sleep(1 * time.Second)
+
+ res, err := http.Get(fmt.Sprintf("%s/unknown", testHTTPServerAddr))
+ require.NoError(t, err)
+ assert.Equal(t, http.StatusNotFound, res.StatusCode)
+
+ h.Server.Close()
+ tomb.Kill(nil)
+ tomb.Wait()
+}
+
+func TestStreamingAcquisitionBasicAuth(t *testing.T) {
+ h := &HTTPSource{}
+ _, tomb := SetupAndRunHTTPSource(t, h, []byte(`
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: basic_auth
+basic_auth:
+ username: test
+ password: test`), 0)
+
+ time.Sleep(1 * time.Second)
+
+ client := &http.Client{}
+
+ resp, err := http.Post(fmt.Sprintf("%s/test", testHTTPServerAddr), "application/json", strings.NewReader("test"))
+ require.NoError(t, err)
+ assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
+
+ req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("test"))
+ require.NoError(t, err)
+ req.SetBasicAuth("test", "WrongPassword")
+
+ resp, err = client.Do(req)
+ require.NoError(t, err)
+ assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
+
+ h.Server.Close()
+ tomb.Kill(nil)
+ tomb.Wait()
+}
+
+func TestStreamingAcquisitionBadHeaders(t *testing.T) {
+ h := &HTTPSource{}
+ _, tomb := SetupAndRunHTTPSource(t, h, []byte(`
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: headers
+headers:
+ key: test`), 0)
+
+ time.Sleep(1 * time.Second)
+
+ client := &http.Client{}
+
+ req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("test"))
+ require.NoError(t, err)
+
+ req.Header.Add("Key", "wrong")
+ resp, err := client.Do(req)
+ require.NoError(t, err)
+ assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
+
+ h.Server.Close()
+ tomb.Kill(nil)
+ tomb.Wait()
+}
+
+func TestStreamingAcquisitionMaxBodySize(t *testing.T) {
+ h := &HTTPSource{}
+ _, tomb := SetupAndRunHTTPSource(t, h, []byte(`
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: headers
+headers:
+ key: test
+max_body_size: 5`), 0)
+
+ time.Sleep(1 * time.Second)
+
+ client := &http.Client{}
+ req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("testtest"))
+ require.NoError(t, err)
+
+ req.Header.Add("Key", "test")
+ resp, err := client.Do(req)
+ require.NoError(t, err)
+
+ assert.Equal(t, http.StatusRequestEntityTooLarge, resp.StatusCode)
+
+ h.Server.Close()
+ tomb.Kill(nil)
+ tomb.Wait()
+}
+
+func TestStreamingAcquisitionSuccess(t *testing.T) {
+ h := &HTTPSource{}
+ out, tomb := SetupAndRunHTTPSource(t, h, []byte(`
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: headers
+headers:
+ key: test`), 2)
+
+ time.Sleep(1 * time.Second)
+ rawEvt := `{"test": "test"}`
+
+ errChan := make(chan error)
+ go assertEvents(out, []string{rawEvt}, errChan)
+
+ client := &http.Client{}
+ req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(rawEvt))
+ require.NoError(t, err)
+
+ req.Header.Add("Key", "test")
+ resp, err := client.Do(req)
+ require.NoError(t, err)
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+
+ err = <-errChan
+ require.NoError(t, err)
+
+ assertMetrics(t, h.GetMetrics(), 1)
+
+ h.Server.Close()
+ tomb.Kill(nil)
+ tomb.Wait()
+}
+
+func TestStreamingAcquisitionCustomStatusCodeAndCustomHeaders(t *testing.T) {
+ h := &HTTPSource{}
+ out, tomb := SetupAndRunHTTPSource(t, h, []byte(`
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: headers
+headers:
+ key: test
+custom_status_code: 201
+custom_headers:
+ success: true`), 2)
+
+ time.Sleep(1 * time.Second)
+
+ rawEvt := `{"test": "test"}`
+ errChan := make(chan error)
+ go assertEvents(out, []string{rawEvt}, errChan)
+
+ req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(rawEvt))
+ require.NoError(t, err)
+
+ req.Header.Add("Key", "test")
+ resp, err := http.DefaultClient.Do(req)
+ require.NoError(t, err)
+
+ assert.Equal(t, http.StatusCreated, resp.StatusCode)
+ assert.Equal(t, "true", resp.Header.Get("Success"))
+
+ err = <-errChan
+ require.NoError(t, err)
+
+ assertMetrics(t, h.GetMetrics(), 1)
+
+ h.Server.Close()
+ tomb.Kill(nil)
+ tomb.Wait()
+}
+
+type slowReader struct {
+ delay time.Duration
+ body []byte
+ index int
+}
+
+func (sr *slowReader) Read(p []byte) (int, error) {
+ if sr.index >= len(sr.body) {
+ return 0, io.EOF
+ }
+ time.Sleep(sr.delay) // Simulate a delay in reading
+ n := copy(p, sr.body[sr.index:])
+ sr.index += n
+ return n, nil
+}
+
+func assertEvents(out chan types.Event, expected []string, errChan chan error) {
+ readLines := []types.Event{}
+
+ for i := 0; i < len(expected); i++ {
+ select {
+ case event := <-out:
+ readLines = append(readLines, event)
+ case <-time.After(2 * time.Second):
+ errChan <- errors.New("timeout waiting for event")
+ return
+ }
+ }
+
+ if len(readLines) != len(expected) {
+ errChan <- fmt.Errorf("expected %d lines, got %d", len(expected), len(readLines))
+ return
+ }
+
+ for i, evt := range readLines {
+ if evt.Line.Raw != expected[i] {
+ errChan <- fmt.Errorf(`expected %s, got '%+v'`, expected, evt.Line.Raw)
+ return
+ }
+ if evt.Line.Src != "127.0.0.1" {
+ errChan <- fmt.Errorf("expected '127.0.0.1', got '%s'", evt.Line.Src)
+ return
+ }
+ if evt.Line.Module != "http" {
+ errChan <- fmt.Errorf("expected 'http', got '%s'", evt.Line.Module)
+ return
+ }
+ }
+ errChan <- nil
+}
+
+func TestStreamingAcquisitionTimeout(t *testing.T) {
+ h := &HTTPSource{}
+ _, tomb := SetupAndRunHTTPSource(t, h, []byte(`
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: headers
+headers:
+ key: test
+timeout: 1s`), 0)
+
+ time.Sleep(1 * time.Second)
+
+ slow := &slowReader{
+ delay: 2 * time.Second,
+ body: []byte(`{"test": "delayed_payload"}`),
+ }
+
+ req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), slow)
+ require.NoError(t, err)
+
+ req.Header.Add("Key", "test")
+ req.Header.Set("Content-Type", "application/json")
+
+ client := &http.Client{}
+ resp, err := client.Do(req)
+ require.NoError(t, err)
+
+ assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
+
+ h.Server.Close()
+ tomb.Kill(nil)
+ tomb.Wait()
+}
+
+func TestStreamingAcquisitionTLSHTTPRequest(t *testing.T) {
+ h := &HTTPSource{}
+ _, tomb := SetupAndRunHTTPSource(t, h, []byte(`
+source: http
+listen_addr: 127.0.0.1:8080
+auth_type: mtls
+path: /test
+tls:
+ server_cert: testdata/server.crt
+ server_key: testdata/server.key
+ ca_cert: testdata/ca.crt`), 0)
+
+ time.Sleep(1 * time.Second)
+
+ resp, err := http.Post(fmt.Sprintf("%s/test", testHTTPServerAddr), "application/json", strings.NewReader("test"))
+ require.NoError(t, err)
+
+ assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
+
+ h.Server.Close()
+ tomb.Kill(nil)
+ tomb.Wait()
+}
+
+func TestStreamingAcquisitionTLSWithHeadersAuthSuccess(t *testing.T) {
+ h := &HTTPSource{}
+ out, tomb := SetupAndRunHTTPSource(t, h, []byte(`
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: headers
+headers:
+ key: test
+tls:
+ server_cert: testdata/server.crt
+ server_key: testdata/server.key
+`), 0)
+
+ time.Sleep(1 * time.Second)
+
+ caCert, err := os.ReadFile("testdata/server.crt")
+ require.NoError(t, err)
+
+ caCertPool := x509.NewCertPool()
+ caCertPool.AppendCertsFromPEM(caCert)
+
+ tlsConfig := &tls.Config{
+ RootCAs: caCertPool,
+ }
+
+ client := &http.Client{
+ Transport: &http.Transport{
+ TLSClientConfig: tlsConfig,
+ },
+ }
+
+ rawEvt := `{"test": "test"}`
+ errChan := make(chan error)
+ go assertEvents(out, []string{rawEvt}, errChan)
+
+ req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddrTLS), strings.NewReader(rawEvt))
+ require.NoError(t, err)
+
+ req.Header.Add("Key", "test")
+ resp, err := client.Do(req)
+ require.NoError(t, err)
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+
+ err = <-errChan
+ require.NoError(t, err)
+
+ assertMetrics(t, h.GetMetrics(), 0)
+
+ h.Server.Close()
+ tomb.Kill(nil)
+ tomb.Wait()
+}
+
+func TestStreamingAcquisitionMTLS(t *testing.T) {
+ h := &HTTPSource{}
+ out, tomb := SetupAndRunHTTPSource(t, h, []byte(`
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: mtls
+tls:
+ server_cert: testdata/server.crt
+ server_key: testdata/server.key
+ ca_cert: testdata/ca.crt`), 0)
+
+ time.Sleep(1 * time.Second)
+
+ // init client cert
+ cert, err := tls.LoadX509KeyPair("testdata/client.crt", "testdata/client.key")
+ require.NoError(t, err)
+
+ caCert, err := os.ReadFile("testdata/ca.crt")
+ require.NoError(t, err)
+
+ caCertPool := x509.NewCertPool()
+ caCertPool.AppendCertsFromPEM(caCert)
+
+ tlsConfig := &tls.Config{
+ Certificates: []tls.Certificate{cert},
+ RootCAs: caCertPool,
+ }
+
+ client := &http.Client{
+ Transport: &http.Transport{
+ TLSClientConfig: tlsConfig,
+ },
+ }
+
+ rawEvt := `{"test": "test"}`
+ errChan := make(chan error)
+ go assertEvents(out, []string{rawEvt}, errChan)
+
+ req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddrTLS), strings.NewReader(rawEvt))
+ require.NoError(t, err)
+
+ resp, err := client.Do(req)
+ require.NoError(t, err)
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+
+ err = <-errChan
+ require.NoError(t, err)
+
+ assertMetrics(t, h.GetMetrics(), 0)
+
+ h.Server.Close()
+ tomb.Kill(nil)
+ tomb.Wait()
+}
+
+func TestStreamingAcquisitionGzipData(t *testing.T) {
+ h := &HTTPSource{}
+ out, tomb := SetupAndRunHTTPSource(t, h, []byte(`
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: headers
+headers:
+ key: test`), 2)
+
+ time.Sleep(1 * time.Second)
+
+ rawEvt := `{"test": "test"}`
+ errChan := make(chan error)
+ go assertEvents(out, []string{rawEvt, rawEvt}, errChan)
+
+ var b strings.Builder
+ gz := gzip.NewWriter(&b)
+
+ _, err := gz.Write([]byte(rawEvt))
+ require.NoError(t, err)
+
+ _, err = gz.Write([]byte(rawEvt))
+ require.NoError(t, err)
+
+ err = gz.Close()
+ require.NoError(t, err)
+
+ // send gzipped compressed data
+ client := &http.Client{}
+ req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(b.String()))
+ require.NoError(t, err)
+
+ req.Header.Add("Key", "test")
+ req.Header.Add("Content-Encoding", "gzip")
+ req.Header.Add("Content-Type", "application/json")
+
+ resp, err := client.Do(req)
+ require.NoError(t, err)
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+
+ err = <-errChan
+ require.NoError(t, err)
+
+ assertMetrics(t, h.GetMetrics(), 2)
+
+ h.Server.Close()
+ tomb.Kill(nil)
+ tomb.Wait()
+}
+
+func TestStreamingAcquisitionNDJson(t *testing.T) {
+ h := &HTTPSource{}
+ out, tomb := SetupAndRunHTTPSource(t, h, []byte(`
+source: http
+listen_addr: 127.0.0.1:8080
+path: /test
+auth_type: headers
+headers:
+ key: test`), 2)
+
+ time.Sleep(1 * time.Second)
+ rawEvt := `{"test": "test"}`
+
+ errChan := make(chan error)
+ go assertEvents(out, []string{rawEvt, rawEvt}, errChan)
+
+ client := &http.Client{}
+ req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(fmt.Sprintf("%s\n%s\n", rawEvt, rawEvt)))
+
+ require.NoError(t, err)
+
+ req.Header.Add("Key", "test")
+ req.Header.Add("Content-Type", "application/x-ndjson")
+
+ resp, err := client.Do(req)
+ require.NoError(t, err)
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+
+ err = <-errChan
+ require.NoError(t, err)
+
+ assertMetrics(t, h.GetMetrics(), 2)
+
+ h.Server.Close()
+ tomb.Kill(nil)
+ tomb.Wait()
+}
+
+func assertMetrics(t *testing.T, metrics []prometheus.Collector, expected int) {
+ promMetrics, err := prometheus.DefaultGatherer.Gather()
+ require.NoError(t, err)
+
+ isExist := false
+ for _, metricFamily := range promMetrics {
+ if metricFamily.GetName() == "cs_httpsource_hits_total" {
+ isExist = true
+ assert.Len(t, metricFamily.GetMetric(), 1)
+ for _, metric := range metricFamily.GetMetric() {
+ assert.InDelta(t, float64(expected), metric.GetCounter().GetValue(), 0.000001)
+ labels := metric.GetLabel()
+ assert.Len(t, labels, 2)
+ assert.Equal(t, "path", labels[0].GetName())
+ assert.Equal(t, "/test", labels[0].GetValue())
+ assert.Equal(t, "src", labels[1].GetName())
+ assert.Equal(t, "127.0.0.1", labels[1].GetValue())
+ }
+ }
+ }
+ if !isExist && expected > 0 {
+ t.Fatalf("expected metric cs_httpsource_hits_total not found")
+ }
+
+ for _, metric := range metrics {
+ metric.(*prometheus.CounterVec).Reset()
+ }
+}
diff --git a/pkg/acquisition/modules/http/testdata/ca.crt b/pkg/acquisition/modules/http/testdata/ca.crt
new file mode 100644
index 00000000000..ac81b9db8a6
--- /dev/null
+++ b/pkg/acquisition/modules/http/testdata/ca.crt
@@ -0,0 +1,23 @@
+-----BEGIN CERTIFICATE-----
+MIIDvzCCAqegAwIBAgIUHQfsFpWkCy7gAmDa3A6O+y5CvAswDQYJKoZIhvcNAQEL
+BQAwbzELMAkGA1UEBhMCRlIxFjAUBgNVBAgTDUlsZS1kZS1GcmFuY2UxDjAMBgNV
+BAcTBVBhcmlzMREwDwYDVQQKEwhDcm93ZHNlYzERMA8GA1UECxMIQ3Jvd2RzZWMx
+EjAQBgNVBAMTCWxvY2FsaG9zdDAeFw0yNDEwMjMxMDAxMDBaFw0yOTEwMjIxMDAx
+MDBaMG8xCzAJBgNVBAYTAkZSMRYwFAYDVQQIEw1JbGUtZGUtRnJhbmNlMQ4wDAYD
+VQQHEwVQYXJpczERMA8GA1UEChMIQ3Jvd2RzZWMxETAPBgNVBAsTCENyb3dkc2Vj
+MRIwEAYDVQQDEwlsb2NhbGhvc3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK
+AoIBAQCZSR2/A24bpVHSiEeSlelfdA32uhk9wHkauwy2qxos/G/UmKG/dgWrHzRh
+LawlFVHtVn4u7Hjqz2y2EsH3bX42jC5NMVARgXIOBr1dE6F5/bPqA6SoVgkDm9wh
+ZBigyAMxYsR4+3ahuf0pQflBShKrLZ1UYoe6tQXob7l3x5vThEhNkBawBkLfWpj7
+7Imm1tGyEZdxCMkT400KRtSmJRrnpiOCUosnacwgp7MCbKWOIOng07Eh16cVUiuI
+BthWU/LycIuac2xaD9PFpeK/MpwASRRPXZgPUhiZuaa7vttD0phCdDaS46Oln5/7
+tFRZH0alBZmkpVZJCWAP4ujIA3vLAgMBAAGjUzBRMA4GA1UdDwEB/wQEAwIBBjAP
+BgNVHRMBAf8EBTADAQH/MB0GA1UdDgQWBBTwpg+WN1nZJs4gj5hfoa+fMSZjGTAP
+BgNVHREECDAGhwR/AAABMA0GCSqGSIb3DQEBCwUAA4IBAQAZuOWT8zHcwbWvC6Jm
+/ccgB/U7SbeIYFJrCZd9mTyqsgnkFNH8yJ5F4dXXtPXr+SO/uWWa3G5hams3qVFf
+zWzzPDQdyhUhfh5fjUHR2RsSGBmCxcapYHpVvAP5aY1/ujYrXMvAJV0hfDO2tGHb
+rveuJxhe8ymQ1Yb2u9NcmI1HG9IVt3Airz4gAIUJWbFvRigky0bukfddOkfiUiaF
+DMPJQO6HAj8d8ctSHHVZWzhAInZ1pDg6HIHYF44m1tT27pSQoi0ZFocskDi/fC2f
+EIF0nu5fRLUS6BZEfpnDi9U0lbJ/kUrgT5IFHMFqXdRpDqcnXpJZhYtp5l6GoqjM
+gT33
+-----END CERTIFICATE-----
diff --git a/pkg/acquisition/modules/http/testdata/client.crt b/pkg/acquisition/modules/http/testdata/client.crt
new file mode 100644
index 00000000000..55efdddad09
--- /dev/null
+++ b/pkg/acquisition/modules/http/testdata/client.crt
@@ -0,0 +1,24 @@
+-----BEGIN CERTIFICATE-----
+MIID7jCCAtagAwIBAgIUJMTPh3oPJLPgsnb9T85ieb4EuOQwDQYJKoZIhvcNAQEL
+BQAwbzELMAkGA1UEBhMCRlIxFjAUBgNVBAgTDUlsZS1kZS1GcmFuY2UxDjAMBgNV
+BAcTBVBhcmlzMREwDwYDVQQKEwhDcm93ZHNlYzERMA8GA1UECxMIQ3Jvd2RzZWMx
+EjAQBgNVBAMTCWxvY2FsaG9zdDAeFw0yNDEwMjMxMDQ2MDBaFw0yNTEwMjMxMDQ2
+MDBaMHIxCzAJBgNVBAYTAkZSMRYwFAYDVQQIEw1JbGUtZGUtRnJhbmNlMQ4wDAYD
+VQQHEwVQYXJpczERMA8GA1UEChMIQ3Jvd2RzZWMxFzAVBgNVBAsTDlRlc3Rpbmcg
+Y2xpZW50MQ8wDQYDVQQDEwZjbGllbnQwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw
+ggEKAoIBAQDAUOdpRieRrrH6krUjgcjLgJg6TzoWAb/iv6rfcioX1L9bj9fZSkwu
+GqKzXX/PceIXElzQgiGJZErbJtnTzhGS80QgtAB8BwWQIT2zgoGcYJf7pPFvmcMM
+qMGFwK0dMC+LHPk+ePtFz8dskI2XJ8jgBdtuZcnDblMuVGtjYT6n0rszvRdo118+
+mlGCLPzOfsO1JdOqLWAR88yZfqCFt1TrwmzpRT1crJQeM6i7muw4aO0L7uSek9QM
+6APHz0QexSq7/zHOtRjA4jnJbDzZJHRlwOdlsNU9cmTz6uWIQXlg+2ovD55YurNy
++jYfmfDYpimhoeGf54zaETp1fTuTJYpxAgMBAAGjfzB9MA4GA1UdDwEB/wQEAwIF
+oDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIwADAd
+BgNVHQ4EFgQUmH0/7RuKnoW7sEK4Cr8eVNGbb8swHwYDVR0jBBgwFoAU8KYPljdZ
+2SbOII+YX6GvnzEmYxkwDQYJKoZIhvcNAQELBQADggEBAHVn9Zuoyxu9iTFoyJ50
+e/XKcmt2uK2M1x+ap2Av7Wb/Omikx/R2YPq7994BfiUCAezY2YtreZzkE6Io1wNM
+qApijEJnlqEmOXiYJqlF89QrCcsAsz6lfaqitYBZSL3o4KT+7/uUDVxgNEjEksRz
+9qy6DFBLvyhxbOM2zDEV+MVfemBWSvNiojHqXzDBkZnBHHclJLuIKsXDZDGhKbNd
+hsoGU00RLevvcUpUJ3a68ekgwiYFJifm0uyfmao9lmiB3i+8ZW3Q4rbwHtD+U7U2
+3n+U5PkhiUAveuMfrvUMzsTolZiop9ZLtcALDUFaqyr4tjfVOf5+CGjiluio7oE1
+UYg=
+-----END CERTIFICATE-----
diff --git a/pkg/acquisition/modules/http/testdata/client.key b/pkg/acquisition/modules/http/testdata/client.key
new file mode 100644
index 00000000000..f8ef2efbd58
--- /dev/null
+++ b/pkg/acquisition/modules/http/testdata/client.key
@@ -0,0 +1,27 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIIEowIBAAKCAQEAwFDnaUYnka6x+pK1I4HIy4CYOk86FgG/4r+q33IqF9S/W4/X
+2UpMLhqis11/z3HiFxJc0IIhiWRK2ybZ084RkvNEILQAfAcFkCE9s4KBnGCX+6Tx
+b5nDDKjBhcCtHTAvixz5Pnj7Rc/HbJCNlyfI4AXbbmXJw25TLlRrY2E+p9K7M70X
+aNdfPppRgiz8zn7DtSXTqi1gEfPMmX6ghbdU68Js6UU9XKyUHjOou5rsOGjtC+7k
+npPUDOgDx89EHsUqu/8xzrUYwOI5yWw82SR0ZcDnZbDVPXJk8+rliEF5YPtqLw+e
+WLqzcvo2H5nw2KYpoaHhn+eM2hE6dX07kyWKcQIDAQABAoIBAQChriKuza0MfBri
+9x3UCRN/is/wDZVe1P+2KL8F9ZvPxytNVeP4qM7c38WzF8MQ6sRR8z0WiqCZOjj4
+f3QX7iG2MlAvUkUqAFk778ZIuUov5sE/bU8RLOrfJKz1vqOLa2w8/xHH5LwS1/jn
+m6t9zZPCSwpMiMSUSZci1xQlS6b6POZMjeqLPqv9cP8PJNv9UNrHFcQnQi1iwKJH
+MJ7CQI3R8FSeGad3P7tB9YDaBm7hHmd/TevuFkymcKNT44XBSgddPDfgKui6sHTY
+QQWgWI9VGVO350ZBLRLkrk8wboY4vc15qbBzYFG66WiR/tNdLt3rDYxpcXaDvcQy
+e47mYNVxAoGBAMFsUmPDssqzmOkmZxHDM+VmgHYPXjDqQdE299FtuClobUW4iU4g
+By7o84aCIBQz2sp9f1KM+10lr+Bqw3s7QBbR5M67PA8Zm45DL9t70NR/NZHGzFRD
+BR/NMbwzCqNtY2UGDhYQLGhW8heAwsYwir8ZqmOfKTd9aY1pu/S8m9AlAoGBAP6I
+483EIN8R5y+beGcGynYeIrH5Gc+W2FxWIW9jh/G7vRbhMlW4z0GxV3uEAYmOlBH2
+AqUkV6+uzU0P4B/m3vCYqLycBVDwifJazDj9nskVL5kGMxia62iwDMXs5nqNS4WJ
+ZM5Gl2xIiwmgWnYnujM3eKF2wbm439wj4na80SldAoGANdIqatA9o+GtntKsw2iJ
+vD91Z2SHVR0aC1k8Q+4/3GXOYiQjMLYAybDQcpEq0/RJ4SZik1nfZ9/gvJV4p4Wp
+I7Br9opq/9ikTEWtv2kIhtiO02151ciAWIUEXdXmE+uQSMASk1kUwkPPQXL2v6cq
+NFqz6tyS33nqMQtG3abNxHECgYA4AEA2nmcpDRRTSh50dG8JC9pQU+EU5jhWIHEc
+w8Y+LjMNHKDpcU7QQkdgGowICsGTLhAo61ULhycORGboPfBg+QVu8djNlQ6Urttt
+0ocj8LBXN6D4UeVnVAyLY3LWFc4+5Bq0s51PKqrEhG5Cvrzd1d+JjspSpVVDZvXF
+cAeI1QKBgC/cMN3+2Sc+2biu46DnkdYpdF/N0VGMOgzz+unSVD4RA2mEJ9UdwGga
+feshtrtcroHtEmc+WDYgTTnAq1MbsVFQYIwZ5fL/GJ1R8ccaWiPuX2HrKALKG4Y3
+CMFpDUWhRgtaBsmuOpUq3FeS5cyPNMHk6axL1KyFoJk9AgfhqhTp
+-----END RSA PRIVATE KEY-----
diff --git a/pkg/acquisition/modules/http/testdata/server.crt b/pkg/acquisition/modules/http/testdata/server.crt
new file mode 100644
index 00000000000..7a02c606c9d
--- /dev/null
+++ b/pkg/acquisition/modules/http/testdata/server.crt
@@ -0,0 +1,23 @@
+-----BEGIN CERTIFICATE-----
+MIID5jCCAs6gAwIBAgIUU3F6URi0oTe9ontkf7JqXOo89QYwDQYJKoZIhvcNAQEL
+BQAwbzELMAkGA1UEBhMCRlIxFjAUBgNVBAgTDUlsZS1kZS1GcmFuY2UxDjAMBgNV
+BAcTBVBhcmlzMREwDwYDVQQKEwhDcm93ZHNlYzERMA8GA1UECxMIQ3Jvd2RzZWMx
+EjAQBgNVBAMTCWxvY2FsaG9zdDAeFw0yNDEwMjMxMDAzMDBaFw0yNTEwMjMxMDAz
+MDBaMG8xCzAJBgNVBAYTAkZSMRYwFAYDVQQIEw1JbGUtZGUtRnJhbmNlMQ4wDAYD
+VQQHEwVQYXJpczERMA8GA1UEChMIQ3Jvd2RzZWMxETAPBgNVBAsTCENyb3dkc2Vj
+MRIwEAYDVQQDEwlsb2NhbGhvc3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK
+AoIBAQC/lnUubjBGe5x0LgIE5GeG52LRzj99iLWuvey4qbSwFZ07ECgv+JttVwDm
+AjEeakj2ZR46WHvHAR9eBNkRCORyWX0iKVIzm09PXYi80KtwGLaA8YMEio9/08Cc
++LS0TuP0yiOcw+btrhmvvauDzcQhA6u55q8anCZiF2BlHfX9Sh6QKewA3NhOkzbU
+VTxqrOqfcRsGNub7dheqfP5bfrPkF6Y6l/0Fhyx0NMsu1zaQ0hCls2hkTf0Y3XGt
+IQNWoN22seexR3qRmPf0j3jBa0qOmGgd6kAd+YpsjDblgCNUIJZiVj51fVb0sGRx
+ShkfKGU6t0eznTWPCqswujO/sn+pAgMBAAGjejB4MA4GA1UdDwEB/wQEAwIFoDAd
+BgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIwADAdBgNV
+HQ4EFgQUOiIF+7Wzx1J8Ki3DiBfx+E6zlSUwGgYDVR0RBBMwEYIJbG9jYWxob3N0
+hwR/AAABMA0GCSqGSIb3DQEBCwUAA4IBAQA0dzlhBr/0wXPyj/iWxMOXxZ1FNJ9f
+lxBMhLAgX0WrT2ys+284J7Hcn0lJeqelluYpmeKn9vmCAEj3MmUmHzZyf//lhuUJ
+0DlYWIHUsGaJHJ7A+1hQqrcXHhkcRy5WGIM9VoddKbBbg2b6qzTSvxn8EnuD7H4h
+28wLyGLCzsSXoVcAB8u+svYt29TPuy6xmMAokyIShV8FsE77fjVTgtCuxmx1PKv3
+zd6+uEae7bbZ+GJH1zKF0vokejQvmByt+YuIXlNbMseaMUeDdpy+6qlRvbbN1dyp
+rkQXfWvidMfSue5nH/akAn83v/CdKxG6tfW83d9Rud3naabUkywALDng
+-----END CERTIFICATE-----
diff --git a/pkg/acquisition/modules/http/testdata/server.key b/pkg/acquisition/modules/http/testdata/server.key
new file mode 100644
index 00000000000..4d0ee53b4c2
--- /dev/null
+++ b/pkg/acquisition/modules/http/testdata/server.key
@@ -0,0 +1,27 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIIEpQIBAAKCAQEAv5Z1Lm4wRnucdC4CBORnhudi0c4/fYi1rr3suKm0sBWdOxAo
+L/ibbVcA5gIxHmpI9mUeOlh7xwEfXgTZEQjkcll9IilSM5tPT12IvNCrcBi2gPGD
+BIqPf9PAnPi0tE7j9MojnMPm7a4Zr72rg83EIQOrueavGpwmYhdgZR31/UoekCns
+ANzYTpM21FU8aqzqn3EbBjbm+3YXqnz+W36z5BemOpf9BYcsdDTLLtc2kNIQpbNo
+ZE39GN1xrSEDVqDdtrHnsUd6kZj39I94wWtKjphoHepAHfmKbIw25YAjVCCWYlY+
+dX1W9LBkcUoZHyhlOrdHs501jwqrMLozv7J/qQIDAQABAoIBAF1Vd/rJlV0Q5RQ4
+QaWOe9zdpmedeZK3YgMh5UvE6RCLRxC5+0n7bASlSPvEf5dYofjfJA26g3pcUqKj
+6/d/hIMsk2hsBu67L7TzVSTe51XxxB8nCPPSaLwWNZSDGM1qTWU4gIbjbQHHOh5C
+YWcRfAW1WxhyiEWHYq+QwdYg9XCRrSg1UzvVvW1Yt2wDGcSZP5whbXipfw3BITDs
+XU7ODYNkU1sjIzQZwzVGxOf9qKdhZFZ26Vhoz8OJNMLyJxY7EspuwR7HbDGt11Pb
+CxOt/BV44LwdVYeqh57oIKtckQW33W/6EeaWr7GfMzyH5WSrsOJoK5IJVrZaPTcS
+QiMYLA0CgYEA9vMVsGshBl3TeRGaU3XLHqooXD4kszbdnjfPrwGlfCO/iybhDqo5
+WFypM/bYcIWzbTez/ihufHEHPSCUbFEcN4B+oczGcuxTcZjFyvJYvq2ycxPUiDIi
+JnVUcVxgh1Yn39+CsQ/b6meP7MumTD2P3I87CeQGlWTO5Ys9mdw0BjcCgYEAxpv1
+64l5UoFJGr4yElNKDIKnhEFbJZsLGKiiuVXcS1QVHW5Az5ar9fPxuepyHpz416l3
+ppncuhJiUIP+jbu5e0s0LsN46mLS3wkHLgYJj06CNT3uOSLSg1iFl7DusdbyiaA7
+wEJ/aotS1NZ4XaeryAWHwYJ6Kag3nz6NV3ZYuR8CgYEAxAFCuMj+6F+2RsTa+d1n
+v8oMyNImLPyiQD9KHzyuTW7OTDMqtIoVg/Xf8re9KOpl9I0e1t7eevT3auQeCi8C
+t2bMm7290V+UB3jbnO5n08hn+ADIUuV/x4ie4m8QyrpuYbm0sLbGtTFHwgoNzzuZ
+oNUqZfpP42mk8fpnhWSLAlcCgYEAgpY7XRI4HkJ5ocbav2faMV2a7X/XgWNvKViA
+HeJRhYoUlBRRMuz7xi0OjFKVlIFbsNlxna5fDk1WLWCMd/6tl168Qd8u2tX9lr6l
+5OH9WSeiv4Un5JN73PbQaAvi9jXBpTIg92oBwzk2TlFyNQoxDcRtHZQ/5LIBWIhV
+gOOEtLsCgYEA1wbGc4XlH+/nXVsvx7gmfK8pZG8XA4/ToeIEURwPYrxtQZLB4iZs
+aqWGgIwiB4F4UkuKZIjMrgInU9y0fG6EL96Qty7Yjh7dGy1vJTZl6C+QU6o4sEwl
+r5Id5BNLEaqISWQ0LvzfwdfABYlvFfBdaGbzUzLEitD79eyhxuNEOBw=
+-----END RSA PRIVATE KEY-----
diff --git a/pkg/acquisition/modules/journalctl/journalctl.go b/pkg/acquisition/modules/journalctl/journalctl.go
index e7a35d5a3ba..27f20b9f446 100644
--- a/pkg/acquisition/modules/journalctl/journalctl.go
+++ b/pkg/acquisition/modules/journalctl/journalctl.go
@@ -136,12 +136,9 @@ func (j *JournalCtlSource) runJournalCtl(ctx context.Context, out chan types.Eve
if j.metricsLevel != configuration.METRICS_NONE {
linesRead.With(prometheus.Labels{"source": j.src}).Inc()
}
- var evt types.Event
- if !j.config.UseTimeMachine {
- evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE}
- } else {
- evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE}
- }
+
+ evt := types.MakeEvent(j.config.UseTimeMachine, types.LOG, true)
+ evt.Line = l
out <- evt
case stderrLine := <-stderrChan:
logger.Warnf("Got stderr message : %s", stderrLine)
diff --git a/pkg/acquisition/modules/kafka/kafka.go b/pkg/acquisition/modules/kafka/kafka.go
index a9a5e13e958..77fc44e310d 100644
--- a/pkg/acquisition/modules/kafka/kafka.go
+++ b/pkg/acquisition/modules/kafka/kafka.go
@@ -173,13 +173,8 @@ func (k *KafkaSource) ReadMessage(ctx context.Context, out chan types.Event) err
if k.metricsLevel != configuration.METRICS_NONE {
linesRead.With(prometheus.Labels{"topic": k.Config.Topic}).Inc()
}
- var evt types.Event
-
- if !k.Config.UseTimeMachine {
- evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE}
- } else {
- evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE}
- }
+ evt := types.MakeEvent(k.Config.UseTimeMachine, types.LOG, true)
+ evt.Line = l
out <- evt
}
}
diff --git a/pkg/acquisition/modules/kinesis/kinesis.go b/pkg/acquisition/modules/kinesis/kinesis.go
index 3cfc224aa25..3744e43f38d 100644
--- a/pkg/acquisition/modules/kinesis/kinesis.go
+++ b/pkg/acquisition/modules/kinesis/kinesis.go
@@ -322,12 +322,8 @@ func (k *KinesisSource) ParseAndPushRecords(records []*kinesis.Record, out chan
} else {
l.Src = k.Config.StreamName
}
- var evt types.Event
- if !k.Config.UseTimeMachine {
- evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE}
- } else {
- evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE}
- }
+ evt := types.MakeEvent(k.Config.UseTimeMachine, types.LOG, true)
+ evt.Line = l
out <- evt
}
}
diff --git a/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go b/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go
index 30fc5c467ea..aaa83a3bbb2 100644
--- a/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go
+++ b/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go
@@ -66,6 +66,7 @@ func (ka *KubernetesAuditSource) GetAggregMetrics() []prometheus.Collector {
func (ka *KubernetesAuditSource) UnmarshalConfig(yamlConfig []byte) error {
k8sConfig := KubernetesAuditConfiguration{}
+
err := yaml.UnmarshalStrict(yamlConfig, &k8sConfig)
if err != nil {
return fmt.Errorf("cannot parse k8s-audit configuration: %w", err)
@@ -92,6 +93,7 @@ func (ka *KubernetesAuditSource) UnmarshalConfig(yamlConfig []byte) error {
if ka.config.Mode == "" {
ka.config.Mode = configuration.TAIL_MODE
}
+
return nil
}
@@ -116,6 +118,7 @@ func (ka *KubernetesAuditSource) Configure(config []byte, logger *log.Entry, Met
}
ka.mux.HandleFunc(ka.config.WebhookPath, ka.webhookHandler)
+
return nil
}
@@ -137,6 +140,7 @@ func (ka *KubernetesAuditSource) OneShotAcquisition(_ context.Context, _ chan ty
func (ka *KubernetesAuditSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error {
ka.outChan = out
+
t.Go(func() error {
defer trace.CatchPanic("crowdsec/acquis/k8s-audit/live")
ka.logger.Infof("Starting k8s-audit server on %s:%d%s", ka.config.ListenAddr, ka.config.ListenPort, ka.config.WebhookPath)
@@ -145,13 +149,16 @@ func (ka *KubernetesAuditSource) StreamingAcquisition(ctx context.Context, out c
if err != nil && err != http.ErrServerClosed {
return fmt.Errorf("k8s-audit server failed: %w", err)
}
+
return nil
})
<-t.Dying()
ka.logger.Infof("Stopping k8s-audit server on %s:%d%s", ka.config.ListenAddr, ka.config.ListenPort, ka.config.WebhookPath)
ka.server.Shutdown(ctx)
+
return nil
})
+
return nil
}
@@ -167,51 +174,58 @@ func (ka *KubernetesAuditSource) webhookHandler(w http.ResponseWriter, r *http.R
if ka.metricsLevel != configuration.METRICS_NONE {
requestCount.WithLabelValues(ka.addr).Inc()
}
+
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
+
ka.logger.Tracef("webhookHandler called")
+
var auditEvents audit.EventList
jsonBody, err := io.ReadAll(r.Body)
if err != nil {
ka.logger.Errorf("Error reading request body: %v", err)
w.WriteHeader(http.StatusInternalServerError)
+
return
}
+
ka.logger.Tracef("webhookHandler receveid: %s", string(jsonBody))
+
err = json.Unmarshal(jsonBody, &auditEvents)
if err != nil {
ka.logger.Errorf("Error decoding audit events: %s", err)
w.WriteHeader(http.StatusInternalServerError)
+
return
}
remoteIP := strings.Split(r.RemoteAddr, ":")[0]
- for _, auditEvent := range auditEvents.Items {
+
+ for idx := range auditEvents.Items {
if ka.metricsLevel != configuration.METRICS_NONE {
eventCount.WithLabelValues(ka.addr).Inc()
}
- bytesEvent, err := json.Marshal(auditEvent)
+
+ bytesEvent, err := json.Marshal(auditEvents.Items[idx])
if err != nil {
ka.logger.Errorf("Error serializing audit event: %s", err)
continue
}
+
ka.logger.Tracef("Got audit event: %s", string(bytesEvent))
l := types.Line{
Raw: string(bytesEvent),
Labels: ka.config.Labels,
- Time: auditEvent.StageTimestamp.Time,
+ Time: auditEvents.Items[idx].StageTimestamp.Time,
Src: remoteIP,
Process: true,
Module: ka.GetName(),
}
- ka.outChan <- types.Event{
- Line: l,
- Process: true,
- Type: types.LOG,
- ExpectMode: types.LIVE,
- }
+ evt := types.MakeEvent(ka.config.UseTimeMachine, types.LOG, true)
+ evt.Line = l
+ ka.outChan <- evt
}
}
diff --git a/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go b/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go
index 846e833abea..5996518e191 100644
--- a/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go
+++ b/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go
@@ -119,7 +119,7 @@ func (lc *LokiClient) queryRange(ctx context.Context, uri string, c chan *LokiQu
case <-lc.t.Dying():
return lc.t.Err()
case <-ticker.C:
- resp, err := lc.Get(uri)
+ resp, err := lc.Get(ctx, uri)
if err != nil {
if ok := lc.shouldRetry(); !ok {
return fmt.Errorf("error querying range: %w", err)
@@ -205,6 +205,7 @@ func (lc *LokiClient) getURLFor(endpoint string, params map[string]string) strin
func (lc *LokiClient) Ready(ctx context.Context) error {
tick := time.NewTicker(500 * time.Millisecond)
url := lc.getURLFor("ready", nil)
+ lc.Logger.Debugf("Using url: %s for ready check", url)
for {
select {
case <-ctx.Done():
@@ -215,7 +216,7 @@ func (lc *LokiClient) Ready(ctx context.Context) error {
return lc.t.Err()
case <-tick.C:
lc.Logger.Debug("Checking if Loki is ready")
- resp, err := lc.Get(url)
+ resp, err := lc.Get(ctx, url)
if err != nil {
lc.Logger.Warnf("Error checking if Loki is ready: %s", err)
continue
@@ -300,8 +301,8 @@ func (lc *LokiClient) QueryRange(ctx context.Context, infinite bool) chan *LokiQ
}
// Create a wrapper for http.Get to be able to set headers and auth
-func (lc *LokiClient) Get(url string) (*http.Response, error) {
- request, err := http.NewRequest(http.MethodGet, url, nil)
+func (lc *LokiClient) Get(ctx context.Context, url string) (*http.Response, error) {
+ request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
diff --git a/pkg/acquisition/modules/loki/loki.go b/pkg/acquisition/modules/loki/loki.go
index e39c76af22c..c57e6a67c94 100644
--- a/pkg/acquisition/modules/loki/loki.go
+++ b/pkg/acquisition/modules/loki/loki.go
@@ -53,6 +53,7 @@ type LokiConfiguration struct {
WaitForReady time.Duration `yaml:"wait_for_ready"` // Retry interval, default is 10 seconds
Auth LokiAuthConfiguration `yaml:"auth"`
MaxFailureDuration time.Duration `yaml:"max_failure_duration"` // Max duration of failure before stopping the source
+ NoReadyCheck bool `yaml:"no_ready_check"` // Bypass /ready check before starting
configuration.DataSourceCommonCfg `yaml:",inline"`
}
@@ -229,6 +230,14 @@ func (l *LokiSource) ConfigureByDSN(dsn string, labels map[string]string, logger
l.logger.Logger.SetLevel(level)
}
+ if noReadyCheck := params.Get("no_ready_check"); noReadyCheck != "" {
+ noReadyCheck, err := strconv.ParseBool(noReadyCheck)
+ if err != nil {
+ return fmt.Errorf("invalid no_ready_check in dsn: %w", err)
+ }
+ l.Config.NoReadyCheck = noReadyCheck
+ }
+
l.Config.URL = fmt.Sprintf("%s://%s", scheme, u.Host)
if u.User != nil {
l.Config.Auth.Username = u.User.Username()
@@ -264,26 +273,28 @@ func (l *LokiSource) GetName() string {
func (l *LokiSource) OneShotAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error {
l.logger.Debug("Loki one shot acquisition")
l.Client.SetTomb(t)
- readyCtx, cancel := context.WithTimeout(ctx, l.Config.WaitForReady)
- defer cancel()
- err := l.Client.Ready(readyCtx)
- if err != nil {
- return fmt.Errorf("loki is not ready: %w", err)
+
+ if !l.Config.NoReadyCheck {
+ readyCtx, readyCancel := context.WithTimeout(ctx, l.Config.WaitForReady)
+ defer readyCancel()
+ err := l.Client.Ready(readyCtx)
+ if err != nil {
+ return fmt.Errorf("loki is not ready: %w", err)
+ }
}
- ctx, cancel = context.WithCancel(ctx)
- c := l.Client.QueryRange(ctx, false)
+ lokiCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ c := l.Client.QueryRange(lokiCtx, false)
for {
select {
case <-t.Dying():
l.logger.Debug("Loki one shot acquisition stopped")
- cancel()
return nil
case resp, ok := <-c:
if !ok {
l.logger.Info("Loki acquisition done, chan closed")
- cancel()
return nil
}
for _, stream := range resp.Data.Result {
@@ -307,41 +318,33 @@ func (l *LokiSource) readOneEntry(entry lokiclient.Entry, labels map[string]stri
if l.metricsLevel != configuration.METRICS_NONE {
linesRead.With(prometheus.Labels{"source": l.Config.URL}).Inc()
}
- expectMode := types.LIVE
- if l.Config.UseTimeMachine {
- expectMode = types.TIMEMACHINE
- }
- out <- types.Event{
- Line: ll,
- Process: true,
- Type: types.LOG,
- ExpectMode: expectMode,
- }
+ evt := types.MakeEvent(l.Config.UseTimeMachine, types.LOG, true)
+ evt.Line = ll
+ out <- evt
}
func (l *LokiSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error {
l.Client.SetTomb(t)
- readyCtx, cancel := context.WithTimeout(ctx, l.Config.WaitForReady)
- defer cancel()
- err := l.Client.Ready(readyCtx)
- if err != nil {
- return fmt.Errorf("loki is not ready: %w", err)
+
+ if !l.Config.NoReadyCheck {
+ readyCtx, readyCancel := context.WithTimeout(ctx, l.Config.WaitForReady)
+ defer readyCancel()
+ err := l.Client.Ready(readyCtx)
+ if err != nil {
+ return fmt.Errorf("loki is not ready: %w", err)
+ }
}
ll := l.logger.WithField("websocket_url", l.lokiWebsocket)
t.Go(func() error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
respChan := l.Client.QueryRange(ctx, true)
- if err != nil {
- ll.Errorf("could not start loki tail: %s", err)
- return fmt.Errorf("while starting loki tail: %w", err)
- }
for {
select {
case resp, ok := <-respChan:
if !ok {
ll.Warnf("loki channel closed")
- return err
+ return errors.New("loki channel closed")
}
for _, stream := range resp.Data.Result {
for _, entry := range stream.Entries {
diff --git a/pkg/acquisition/modules/loki/loki_test.go b/pkg/acquisition/modules/loki/loki_test.go
index cacdda32d80..643aefad715 100644
--- a/pkg/acquisition/modules/loki/loki_test.go
+++ b/pkg/acquisition/modules/loki/loki_test.go
@@ -34,6 +34,7 @@ func TestConfiguration(t *testing.T) {
password string
waitForReady time.Duration
delayFor time.Duration
+ noReadyCheck bool
testName string
}{
{
@@ -99,6 +100,19 @@ query: >
mode: tail
source: loki
url: http://localhost:3100/
+no_ready_check: true
+query: >
+ {server="demo"}
+`,
+ expectedErr: "",
+ testName: "Correct config with no_ready_check",
+ noReadyCheck: true,
+ },
+ {
+ config: `
+mode: tail
+source: loki
+url: http://localhost:3100/
auth:
username: foo
password: bar
@@ -148,6 +162,8 @@ query: >
t.Fatalf("Wrong DelayFor %v != %v", lokiSource.Config.DelayFor, test.delayFor)
}
}
+
+ assert.Equal(t, test.noReadyCheck, lokiSource.Config.NoReadyCheck)
})
}
}
@@ -164,6 +180,7 @@ func TestConfigureDSN(t *testing.T) {
scheme string
waitForReady time.Duration
delayFor time.Duration
+ noReadyCheck bool
}{
{
name: "Wrong scheme",
@@ -202,10 +219,11 @@ func TestConfigureDSN(t *testing.T) {
},
{
name: "Correct DSN",
- dsn: `loki://localhost:3100/?query={server="demo"}&wait_for_ready=5s&delay_for=1s`,
+ dsn: `loki://localhost:3100/?query={server="demo"}&wait_for_ready=5s&delay_for=1s&no_ready_check=true`,
expectedErr: "",
waitForReady: 5 * time.Second,
delayFor: 1 * time.Second,
+ noReadyCheck: true,
},
{
name: "SSL DSN",
@@ -256,6 +274,9 @@ func TestConfigureDSN(t *testing.T) {
t.Fatalf("Wrong DelayFor %v != %v", lokiSource.Config.DelayFor, test.delayFor)
}
}
+
+ assert.Equal(t, test.noReadyCheck, lokiSource.Config.NoReadyCheck)
+
}
}
diff --git a/pkg/acquisition/modules/s3/s3.go b/pkg/acquisition/modules/s3/s3.go
index acd78ceba8f..cdc84a8a3ca 100644
--- a/pkg/acquisition/modules/s3/s3.go
+++ b/pkg/acquisition/modules/s3/s3.go
@@ -443,12 +443,8 @@ func (s *S3Source) readFile(bucket string, key string) error {
} else if s.MetricsLevel == configuration.METRICS_AGGREGATE {
l.Src = bucket
}
- var evt types.Event
- if !s.Config.UseTimeMachine {
- evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE}
- } else {
- evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE}
- }
+ evt := types.MakeEvent(s.Config.UseTimeMachine, types.LOG, true)
+ evt.Line = l
s.out <- evt
}
}
diff --git a/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse.go b/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse.go
index 66d842ed519..04c7053ef27 100644
--- a/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse.go
+++ b/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse.go
@@ -48,7 +48,6 @@ func WithStrictHostname() RFC3164Option {
}
func (r *RFC3164) parsePRI() error {
-
pri := 0
if r.buf[r.position] != '<' {
diff --git a/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse.go b/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse.go
index 639e91e1224..c9aa89f7256 100644
--- a/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse.go
+++ b/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse.go
@@ -48,7 +48,6 @@ func WithStrictHostname() RFC5424Option {
}
func (r *RFC5424) parsePRI() error {
-
pri := 0
if r.buf[r.position] != '<' {
@@ -94,7 +93,6 @@ func (r *RFC5424) parseVersion() error {
}
func (r *RFC5424) parseTimestamp() error {
-
timestamp := []byte{}
if r.buf[r.position] == NIL_VALUE {
@@ -121,7 +119,6 @@ func (r *RFC5424) parseTimestamp() error {
}
date, err := time.Parse(VALID_TIMESTAMP, string(timestamp))
-
if err != nil {
return errors.New("timestamp is not valid")
}
diff --git a/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse_test.go b/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse_test.go
index 0938e947fe7..d3a68c196db 100644
--- a/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse_test.go
+++ b/pkg/acquisition/modules/syslog/internal/parser/rfc5424/parse_test.go
@@ -94,7 +94,8 @@ func TestParse(t *testing.T) {
}{
{
"valid msg",
- `<13>1 2021-05-18T11:58:40.828081+02:42 mantis sshd 49340 - [timeQuality isSynced="0" tzKnown="1"] blabla`, expected{
+ `<13>1 2021-05-18T11:58:40.828081+02:42 mantis sshd 49340 - [timeQuality isSynced="0" tzKnown="1"] blabla`,
+ expected{
Timestamp: time.Date(2021, 5, 18, 11, 58, 40, 828081000, time.FixedZone("+0242", 9720)),
Hostname: "mantis",
Tag: "sshd",
@@ -102,11 +103,14 @@ func TestParse(t *testing.T) {
MsgID: "",
Message: "blabla",
PRI: 13,
- }, "", []RFC5424Option{},
+ },
+ "",
+ []RFC5424Option{},
},
{
"valid msg with msgid",
- `<13>1 2021-05-18T11:58:40.828081+02:42 mantis foobar 49340 123123 [timeQuality isSynced="0" tzKnown="1"] blabla`, expected{
+ `<13>1 2021-05-18T11:58:40.828081+02:42 mantis foobar 49340 123123 [timeQuality isSynced="0" tzKnown="1"] blabla`,
+ expected{
Timestamp: time.Date(2021, 5, 18, 11, 58, 40, 828081000, time.FixedZone("+0242", 9720)),
Hostname: "mantis",
Tag: "foobar",
@@ -114,11 +118,14 @@ func TestParse(t *testing.T) {
MsgID: "123123",
Message: "blabla",
PRI: 13,
- }, "", []RFC5424Option{},
+ },
+ "",
+ []RFC5424Option{},
},
{
"valid msg with repeating SD",
- `<13>1 2021-05-18T11:58:40.828081+02:42 mantis foobar 49340 123123 [timeQuality isSynced="0" tzKnown="1"][foo="bar][a] blabla`, expected{
+ `<13>1 2021-05-18T11:58:40.828081+02:42 mantis foobar 49340 123123 [timeQuality isSynced="0" tzKnown="1"][foo="bar][a] blabla`,
+ expected{
Timestamp: time.Date(2021, 5, 18, 11, 58, 40, 828081000, time.FixedZone("+0242", 9720)),
Hostname: "mantis",
Tag: "foobar",
@@ -126,36 +133,53 @@ func TestParse(t *testing.T) {
MsgID: "123123",
Message: "blabla",
PRI: 13,
- }, "", []RFC5424Option{},
+ },
+ "",
+ []RFC5424Option{},
},
{
"invalid SD",
- `<13>1 2021-05-18T11:58:40.828081+02:00 mantis foobar 49340 123123 [timeQuality asd`, expected{}, "structured data must end with ']'", []RFC5424Option{},
+ `<13>1 2021-05-18T11:58:40.828081+02:00 mantis foobar 49340 123123 [timeQuality asd`,
+ expected{},
+ "structured data must end with ']'",
+ []RFC5424Option{},
},
{
"invalid version",
- `<13>42 2021-05-18T11:58:40.828081+02:00 mantis foobar 49340 123123 [timeQuality isSynced="0" tzKnown="1"] blabla`, expected{}, "version must be 1", []RFC5424Option{},
+ `<13>42 2021-05-18T11:58:40.828081+02:00 mantis foobar 49340 123123 [timeQuality isSynced="0" tzKnown="1"] blabla`,
+ expected{},
+ "version must be 1",
+ []RFC5424Option{},
},
{
"invalid message",
- `<13>1`, expected{}, "version must be followed by a space", []RFC5424Option{},
+ `<13>1`,
+ expected{},
+ "version must be followed by a space",
+ []RFC5424Option{},
},
{
"valid msg with empty fields",
- `<13>1 - foo - - - - blabla`, expected{
+ `<13>1 - foo - - - - blabla`,
+ expected{
Timestamp: time.Now().UTC(),
Hostname: "foo",
PRI: 13,
Message: "blabla",
- }, "", []RFC5424Option{},
+ },
+ "",
+ []RFC5424Option{},
},
{
"valid msg with empty fields",
- `<13>1 - - - - - - blabla`, expected{
+ `<13>1 - - - - - - blabla`,
+ expected{
Timestamp: time.Now().UTC(),
PRI: 13,
Message: "blabla",
- }, "", []RFC5424Option{},
+ },
+ "",
+ []RFC5424Option{},
},
{
"valid msg with escaped SD",
@@ -167,7 +191,9 @@ func TestParse(t *testing.T) {
Hostname: "testhostname",
MsgID: `sn="msgid"`,
Message: `testmessage`,
- }, "", []RFC5424Option{},
+ },
+ "",
+ []RFC5424Option{},
},
{
"valid complex msg",
@@ -179,7 +205,9 @@ func TestParse(t *testing.T) {
PRI: 13,
MsgID: `sn="msgid"`,
Message: `source: sn="www.foobar.com" | message: 1.1.1.1 - - [24/May/2022:10:57:37 +0200] "GET /dist/precache-manifest.58b57debe6bc4f96698da0dc314461e9.js HTTP/2.0" 304 0 "https://www.foobar.com/sw.js" "Mozilla/5.0 (Linux; Android 9; ANE-LX1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/101.0.4951.61 Mobile Safari/537.36" "-" "www.foobar.com" sn="www.foobar.com" rt=0.000 ua="-" us="-" ut="-" ul="-" cs=HIT { request: /dist/precache-manifest.58b57debe6bc4f96698da0dc314461e9.js | src_ip_geo_country: DE | MONTH: May | COMMONAPACHELOG: 1.1.1.1 - - [24/May/2022:10:57:37 +0200] "GET /dist/precache-manifest.58b57debe6bc4f96698da0dc314461e9.js HTTP/2.0" 304 0 | auth: - | HOUR: 10 | gl2_remote_ip: 172.31.32.142 | ident: - | gl2_remote_port: 43375 | BASE10NUM: [2.0, 304, 0] | pid: -1 | program: nginx | gl2_source_input: 623ed3440183476d61cff974 | INT: +0200 | is_private_ip: false | YEAR: 2022 | src_ip_geo_city: Achern | clientip: 1.1.1.1 | USERNAME:`,
- }, "", []RFC5424Option{},
+ },
+ "",
+ []RFC5424Option{},
},
{
"partial message",
diff --git a/pkg/acquisition/modules/syslog/internal/server/syslogserver.go b/pkg/acquisition/modules/syslog/internal/server/syslogserver.go
index 7118c295b54..83f5e5a57e5 100644
--- a/pkg/acquisition/modules/syslog/internal/server/syslogserver.go
+++ b/pkg/acquisition/modules/syslog/internal/server/syslogserver.go
@@ -25,7 +25,6 @@ type SyslogMessage struct {
}
func (s *SyslogServer) Listen(listenAddr string, port int) error {
-
s.listenAddr = listenAddr
s.port = port
udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", s.listenAddr, s.port))
diff --git a/pkg/acquisition/modules/syslog/syslog.go b/pkg/acquisition/modules/syslog/syslog.go
index 33a2f1542db..fb6a04600c1 100644
--- a/pkg/acquisition/modules/syslog/syslog.go
+++ b/pkg/acquisition/modules/syslog/syslog.go
@@ -235,11 +235,9 @@ func (s *SyslogSource) handleSyslogMsg(out chan types.Event, t *tomb.Tomb, c cha
l.Time = ts
l.Src = syslogLine.Client
l.Process = true
- if !s.config.UseTimeMachine {
- out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE}
- } else {
- out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE}
- }
+ evt := types.MakeEvent(s.config.UseTimeMachine, types.LOG, true)
+ evt.Line = l
+ out <- evt
}
}
}
diff --git a/pkg/acquisition/modules/wineventlog/wineventlog_windows.go b/pkg/acquisition/modules/wineventlog/wineventlog_windows.go
index 887be8b7dd3..8283bcc21a2 100644
--- a/pkg/acquisition/modules/wineventlog/wineventlog_windows.go
+++ b/pkg/acquisition/modules/wineventlog/wineventlog_windows.go
@@ -94,7 +94,7 @@ func (w *WinEventLogSource) getXMLEvents(config *winlog.SubscribeConfig, publish
2000, // Timeout in milliseconds to wait.
0, // Reserved. Must be zero.
&returned) // The number of handles in the array that are set by the API.
- if err == windows.ERROR_NO_MORE_ITEMS {
+ if errors.Is(err, windows.ERROR_NO_MORE_ITEMS) {
return nil, err
} else if err != nil {
return nil, fmt.Errorf("wevtapi.EvtNext failed: %v", err)
@@ -188,7 +188,7 @@ func (w *WinEventLogSource) getEvents(out chan types.Event, t *tomb.Tomb) error
}
if status == syscall.WAIT_OBJECT_0 {
renderedEvents, err := w.getXMLEvents(w.evtConfig, publisherCache, subscription, 500)
- if err == windows.ERROR_NO_MORE_ITEMS {
+ if errors.Is(err, windows.ERROR_NO_MORE_ITEMS) {
windows.ResetEvent(w.evtConfig.SignalEvent)
} else if err != nil {
w.logger.Errorf("getXMLEvents failed: %v", err)
@@ -206,9 +206,9 @@ func (w *WinEventLogSource) getEvents(out chan types.Event, t *tomb.Tomb) error
l.Src = w.name
l.Process = true
if !w.config.UseTimeMachine {
- out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE}
+ out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE, Unmarshaled: make(map[string]interface{})}
} else {
- out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE}
+ out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE, Unmarshaled: make(map[string]interface{})}
}
}
}
@@ -411,7 +411,7 @@ OUTER_LOOP:
return nil
default:
evts, err := w.getXMLEvents(w.evtConfig, publisherCache, handle, 500)
- if err == windows.ERROR_NO_MORE_ITEMS {
+ if errors.Is(err, windows.ERROR_NO_MORE_ITEMS) {
log.Info("No more items")
break OUTER_LOOP
} else if err != nil {
@@ -430,7 +430,9 @@ OUTER_LOOP:
l.Time = time.Now()
l.Src = w.name
l.Process = true
- out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE}
+ csevt := types.MakeEvent(w.config.UseTimeMachine, types.LOG, true)
+ csevt.Line = l
+ out <- csevt
}
}
}
diff --git a/pkg/alertcontext/alertcontext.go b/pkg/alertcontext/alertcontext.go
index 16ebc6d0ac2..1b7d1e20018 100644
--- a/pkg/alertcontext/alertcontext.go
+++ b/pkg/alertcontext/alertcontext.go
@@ -3,6 +3,7 @@ package alertcontext
import (
"encoding/json"
"fmt"
+ "net/http"
"slices"
"strconv"
@@ -30,7 +31,10 @@ type Context struct {
func ValidateContextExpr(key string, expressions []string) error {
for _, expression := range expressions {
- _, err := expr.Compile(expression, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...)
+ _, err := expr.Compile(expression, exprhelpers.GetExprOptions(map[string]interface{}{
+ "evt": &types.Event{},
+ "match": &types.MatchedRule{},
+ "req": &http.Request{}})...)
if err != nil {
return fmt.Errorf("compilation of '%s' failed: %w", expression, err)
}
@@ -72,7 +76,10 @@ func NewAlertContext(contextToSend map[string][]string, valueLength int) error {
}
for _, value := range values {
- valueCompiled, err := expr.Compile(value, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...)
+ valueCompiled, err := expr.Compile(value, exprhelpers.GetExprOptions(map[string]interface{}{
+ "evt": &types.Event{},
+ "match": &types.MatchedRule{},
+ "req": &http.Request{}})...)
if err != nil {
return fmt.Errorf("compilation of '%s' context value failed: %w", value, err)
}
@@ -85,6 +92,32 @@ func NewAlertContext(contextToSend map[string][]string, valueLength int) error {
return nil
}
+// Truncate the context map to fit in the context value length
+func TruncateContextMap(contextMap map[string][]string, contextValueLen int) ([]*models.MetaItems0, []error) {
+ metas := make([]*models.MetaItems0, 0)
+ errors := make([]error, 0)
+
+ for key, values := range contextMap {
+ if len(values) == 0 {
+ continue
+ }
+
+ valueStr, err := TruncateContext(values, alertContext.ContextValueLen)
+ if err != nil {
+ errors = append(errors, fmt.Errorf("error truncating content for %s: %w", key, err))
+ continue
+ }
+
+ meta := models.MetaItems0{
+ Key: key,
+ Value: valueStr,
+ }
+ metas = append(metas, &meta)
+ }
+ return metas, errors
+}
+
+// Truncate an individual []string to fit in the context value length
func TruncateContext(values []string, contextValueLen int) (string, error) {
valueByte, err := json.Marshal(values)
if err != nil {
@@ -116,61 +149,102 @@ func TruncateContext(values []string, contextValueLen int) (string, error) {
return ret, nil
}
-func EventToContext(events []types.Event) (models.Meta, []error) {
+func EvalAlertContextRules(evt types.Event, match *types.MatchedRule, request *http.Request, tmpContext map[string][]string) []error {
+
var errors []error
- metas := make([]*models.MetaItems0, 0)
- tmpContext := make(map[string][]string)
+ //if we're evaluating context for appsec event, match and request will be present.
+ //otherwise, only evt will be.
+ if match == nil {
+ match = types.NewMatchedRule()
+ }
+ if request == nil {
+ request = &http.Request{}
+ }
- for _, evt := range events {
- for key, values := range alertContext.ContextToSendCompiled {
- if _, ok := tmpContext[key]; !ok {
- tmpContext[key] = make([]string, 0)
- }
+ for key, values := range alertContext.ContextToSendCompiled {
- for _, value := range values {
- var val string
+ if _, ok := tmpContext[key]; !ok {
+ tmpContext[key] = make([]string, 0)
+ }
- output, err := expr.Run(value, map[string]interface{}{"evt": evt})
- if err != nil {
- errors = append(errors, fmt.Errorf("failed to get value for %s: %w", key, err))
- continue
- }
+ for _, value := range values {
+ var val string
- switch out := output.(type) {
- case string:
- val = out
- case int:
- val = strconv.Itoa(out)
- default:
- errors = append(errors, fmt.Errorf("unexpected return type for %s: %T", key, output))
- continue
+ output, err := expr.Run(value, map[string]interface{}{"match": match, "evt": evt, "req": request})
+ if err != nil {
+ errors = append(errors, fmt.Errorf("failed to get value for %s: %w", key, err))
+ continue
+ }
+ switch out := output.(type) {
+ case string:
+ val = out
+ if val != "" && !slices.Contains(tmpContext[key], val) {
+ tmpContext[key] = append(tmpContext[key], val)
}
-
+ case []string:
+ for _, v := range out {
+ if v != "" && !slices.Contains(tmpContext[key], v) {
+ tmpContext[key] = append(tmpContext[key], v)
+ }
+ }
+ case int:
+ val = strconv.Itoa(out)
+ if val != "" && !slices.Contains(tmpContext[key], val) {
+ tmpContext[key] = append(tmpContext[key], val)
+ }
+ case []int:
+ for _, v := range out {
+ val = strconv.Itoa(v)
+ if val != "" && !slices.Contains(tmpContext[key], val) {
+ tmpContext[key] = append(tmpContext[key], val)
+ }
+ }
+ default:
+ val := fmt.Sprintf("%v", output)
if val != "" && !slices.Contains(tmpContext[key], val) {
tmpContext[key] = append(tmpContext[key], val)
}
}
}
}
+ return errors
+}
- for key, values := range tmpContext {
- if len(values) == 0 {
- continue
- }
+// Iterate over the individual appsec matched rules to create the needed alert context.
+func AppsecEventToContext(event types.AppsecEvent, request *http.Request) (models.Meta, []error) {
+ var errors []error
- valueStr, err := TruncateContext(values, alertContext.ContextValueLen)
- if err != nil {
- log.Warning(err.Error())
- }
+ tmpContext := make(map[string][]string)
- meta := models.MetaItems0{
- Key: key,
- Value: valueStr,
- }
- metas = append(metas, &meta)
+ evt := types.MakeEvent(false, types.LOG, false)
+ for _, matched_rule := range event.MatchedRules {
+ tmpErrors := EvalAlertContextRules(evt, &matched_rule, request, tmpContext)
+ errors = append(errors, tmpErrors...)
}
+ metas, truncErrors := TruncateContextMap(tmpContext, alertContext.ContextValueLen)
+ errors = append(errors, truncErrors...)
+
+ ret := models.Meta(metas)
+
+ return ret, errors
+}
+
+// Iterate over the individual events to create the needed alert context.
+func EventToContext(events []types.Event) (models.Meta, []error) {
+ var errors []error
+
+ tmpContext := make(map[string][]string)
+
+ for _, evt := range events {
+ tmpErrors := EvalAlertContextRules(evt, nil, nil, tmpContext)
+ errors = append(errors, tmpErrors...)
+ }
+
+ metas, truncErrors := TruncateContextMap(tmpContext, alertContext.ContextValueLen)
+ errors = append(errors, truncErrors...)
+
ret := models.Meta(metas)
return ret, errors
diff --git a/pkg/alertcontext/alertcontext_test.go b/pkg/alertcontext/alertcontext_test.go
index c111d1bbcfb..284ff451bc2 100644
--- a/pkg/alertcontext/alertcontext_test.go
+++ b/pkg/alertcontext/alertcontext_test.go
@@ -2,6 +2,7 @@ package alertcontext
import (
"fmt"
+ "net/http"
"testing"
"github.com/stretchr/testify/assert"
@@ -9,6 +10,7 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/crowdsecurity/crowdsec/pkg/types"
+ "github.com/crowdsecurity/go-cs-lib/ptr"
)
func TestNewAlertContext(t *testing.T) {
@@ -200,3 +202,162 @@ func TestEventToContext(t *testing.T) {
assert.ElementsMatch(t, test.expectedResult, metas)
}
}
+
+func TestValidateContextExpr(t *testing.T) {
+ tests := []struct {
+ name string
+ key string
+ exprs []string
+ expectedErr *string
+ }{
+ {
+ name: "basic config",
+ key: "source_ip",
+ exprs: []string{
+ "evt.Parsed.source_ip",
+ },
+ expectedErr: nil,
+ },
+ {
+ name: "basic config with non existent field",
+ key: "source_ip",
+ exprs: []string{
+ "evt.invalid.source_ip",
+ },
+ expectedErr: ptr.Of("compilation of 'evt.invalid.source_ip' failed: type types.Event has no field invalid"),
+ },
+ }
+ for _, test := range tests {
+ fmt.Printf("Running test '%s'\n", test.name)
+ err := ValidateContextExpr(test.key, test.exprs)
+ if test.expectedErr == nil {
+ require.NoError(t, err)
+ } else {
+ require.ErrorContains(t, err, *test.expectedErr)
+ }
+ }
+}
+
+func TestAppsecEventToContext(t *testing.T) {
+ tests := []struct {
+ name string
+ contextToSend map[string][]string
+ match types.AppsecEvent
+ req *http.Request
+ expectedResult models.Meta
+ expectedErrLen int
+ }{
+ {
+ name: "basic test on match",
+ contextToSend: map[string][]string{
+ "id": {"match.id"},
+ },
+ match: types.AppsecEvent{
+ MatchedRules: types.MatchedRules{
+ {
+ "id": "test",
+ },
+ },
+ },
+ req: &http.Request{},
+ expectedResult: []*models.MetaItems0{
+ {
+ Key: "id",
+ Value: "[\"test\"]",
+ },
+ },
+ expectedErrLen: 0,
+ },
+ {
+ name: "basic test on req",
+ contextToSend: map[string][]string{
+ "ua": {"req.UserAgent()"},
+ },
+ match: types.AppsecEvent{
+ MatchedRules: types.MatchedRules{
+ {
+ "id": "test",
+ },
+ },
+ },
+ req: &http.Request{
+ Header: map[string][]string{
+ "User-Agent": {"test"},
+ },
+ },
+ expectedResult: []*models.MetaItems0{
+ {
+ Key: "ua",
+ Value: "[\"test\"]",
+ },
+ },
+ expectedErrLen: 0,
+ },
+ {
+ name: "test on req -> []string",
+ contextToSend: map[string][]string{
+ "foobarxx": {"req.Header.Values('Foobar')"},
+ },
+ match: types.AppsecEvent{
+ MatchedRules: types.MatchedRules{
+ {
+ "id": "test",
+ },
+ },
+ },
+ req: &http.Request{
+ Header: map[string][]string{
+ "User-Agent": {"test"},
+ "Foobar": {"test1", "test2"},
+ },
+ },
+ expectedResult: []*models.MetaItems0{
+ {
+ Key: "foobarxx",
+ Value: "[\"test1\",\"test2\"]",
+ },
+ },
+ expectedErrLen: 0,
+ },
+ {
+ name: "test on type int",
+ contextToSend: map[string][]string{
+ "foobarxx": {"len(req.Header.Values('Foobar'))"},
+ },
+ match: types.AppsecEvent{
+ MatchedRules: types.MatchedRules{
+ {
+ "id": "test",
+ },
+ },
+ },
+ req: &http.Request{
+ Header: map[string][]string{
+ "User-Agent": {"test"},
+ "Foobar": {"test1", "test2"},
+ },
+ },
+ expectedResult: []*models.MetaItems0{
+ {
+ Key: "foobarxx",
+ Value: "[\"2\"]",
+ },
+ },
+ expectedErrLen: 0,
+ },
+ }
+
+ for _, test := range tests {
+ //reset cache
+ alertContext = Context{}
+ //compile
+ if err := NewAlertContext(test.contextToSend, 100); err != nil {
+ t.Fatalf("failed to compile %s: %s", test.name, err)
+ }
+ //run
+
+ metas, errors := AppsecEventToContext(test.match, test.req)
+ assert.Len(t, errors, test.expectedErrLen)
+ assert.ElementsMatch(t, test.expectedResult, metas)
+ }
+}
diff --git a/pkg/apiclient/allowlists_service.go b/pkg/apiclient/allowlists_service.go
new file mode 100644
index 00000000000..382f8fcfe5b
--- /dev/null
+++ b/pkg/apiclient/allowlists_service.go
@@ -0,0 +1,91 @@
+package apiclient
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+
+ "github.com/crowdsecurity/crowdsec/pkg/models"
+ qs "github.com/google/go-querystring/query"
+ log "github.com/sirupsen/logrus"
+)
+
+type AllowlistsService service
+
+type AllowlistListOpts struct {
+ WithContent bool `url:"with_content,omitempty"`
+}
+
+func (s *AllowlistsService) List(ctx context.Context, opts AllowlistListOpts) (*models.GetAllowlistsResponse, *Response, error) {
+ u := s.client.URLPrefix + "/allowlists"
+
+ params, err := qs.Values(opts)
+ if err != nil {
+ return nil, nil, fmt.Errorf("building query: %w", err)
+ }
+
+ u += "?" + params.Encode()
+
+ req, err := s.client.NewRequest(http.MethodGet, u, nil)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ allowlists := &models.GetAllowlistsResponse{}
+
+ resp, err := s.client.Do(ctx, req, allowlists)
+ if err != nil {
+ return nil, resp, err
+ }
+
+ return allowlists, resp, nil
+}
+
+type AllowlistGetOpts struct {
+ WithContent bool `url:"with_content,omitempty"`
+}
+
+func (s *AllowlistsService) Get(ctx context.Context, name string, opts AllowlistGetOpts) (*models.GetAllowlistResponse, *Response, error) {
+ u := s.client.URLPrefix + "/allowlists/" + name
+
+ params, err := qs.Values(opts)
+ if err != nil {
+ return nil, nil, fmt.Errorf("building query: %w", err)
+ }
+
+ u += "?" + params.Encode()
+
+ log.Debugf("GET %s", u)
+
+ req, err := s.client.NewRequest(http.MethodGet, u, nil)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ allowlist := &models.GetAllowlistResponse{}
+
+ resp, err := s.client.Do(ctx, req, allowlist)
+ if err != nil {
+ return nil, resp, err
+ }
+
+ return allowlist, resp, nil
+}
+
+func (s *AllowlistsService) CheckIfAllowlisted(ctx context.Context, value string) (bool, *Response, error) {
+ u := s.client.URLPrefix + "/allowlists/check/" + value
+
+ req, err := s.client.NewRequest(http.MethodHead, u, nil)
+ if err != nil {
+ return false, nil, err
+ }
+
+ var discardBody interface{}
+
+ resp, err := s.client.Do(ctx, req, discardBody)
+ if err != nil {
+ return false, resp, err
+ }
+
+ return resp.Response.StatusCode == http.StatusOK, resp, nil
+}
diff --git a/pkg/apiclient/auth_jwt.go b/pkg/apiclient/auth_jwt.go
index 193486ff065..c43e9fc291c 100644
--- a/pkg/apiclient/auth_jwt.go
+++ b/pkg/apiclient/auth_jwt.go
@@ -62,7 +62,6 @@ func (t *JWTTransport) refreshJwtToken() error {
enc := json.NewEncoder(buf)
enc.SetEscapeHTML(false)
err = enc.Encode(auth)
-
if err != nil {
return fmt.Errorf("could not encode jwt auth body: %w", err)
}
@@ -169,7 +168,6 @@ func (t *JWTTransport) prepareRequest(req *http.Request) (*http.Request, error)
// RoundTrip implements the RoundTripper interface.
func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
-
var resp *http.Response
attemptsCount := make(map[int]int)
@@ -229,7 +227,6 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
}
}
return resp, nil
-
}
func (t *JWTTransport) Client() *http.Client {
diff --git a/pkg/apiclient/client.go b/pkg/apiclient/client.go
index 47d97a28344..b6632594a5b 100644
--- a/pkg/apiclient/client.go
+++ b/pkg/apiclient/client.go
@@ -4,12 +4,15 @@ import (
"context"
"crypto/tls"
"crypto/x509"
+ "errors"
"fmt"
"net"
"net/http"
"net/url"
"strings"
+ "time"
+ "github.com/go-openapi/strfmt"
"github.com/golang-jwt/jwt/v4"
"github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent"
@@ -20,6 +23,7 @@ var (
InsecureSkipVerify = false
Cert *tls.Certificate
CaCertPool *x509.CertPool
+ lapiClient *ApiClient
)
type ApiClient struct {
@@ -36,6 +40,7 @@ type ApiClient struct {
Decisions *DecisionsService
DecisionDelete *DecisionDeleteService
Alerts *AlertsService
+ Allowlists *AllowlistsService
Auth *AuthService
Metrics *MetricsService
Signal *SignalService
@@ -66,6 +71,68 @@ type service struct {
client *ApiClient
}
+func InitLAPIClient(ctx context.Context, apiUrl string, papiUrl string, login string, password string, scenarios []string) error {
+
+ if lapiClient != nil {
+ return errors.New("client already initialized")
+ }
+
+ apiURL, err := url.Parse(apiUrl)
+ if err != nil {
+ return fmt.Errorf("parsing api url ('%s'): %w", apiURL, err)
+ }
+
+ papiURL, err := url.Parse(papiUrl)
+ if err != nil {
+ return fmt.Errorf("parsing polling api url ('%s'): %w", papiURL, err)
+ }
+
+ pwd := strfmt.Password(password)
+
+ client, err := NewClient(&Config{
+ MachineID: login,
+ Password: pwd,
+ Scenarios: scenarios,
+ URL: apiURL,
+ PapiURL: papiURL,
+ VersionPrefix: "v1",
+ UpdateScenario: func(_ context.Context) ([]string, error) {
+ return scenarios, nil
+ },
+ })
+ if err != nil {
+ return fmt.Errorf("new client api: %w", err)
+ }
+
+ authResp, _, err := client.Auth.AuthenticateWatcher(ctx, models.WatcherAuthRequest{
+ MachineID: &login,
+ Password: &pwd,
+ Scenarios: scenarios,
+ })
+ if err != nil {
+ return fmt.Errorf("authenticate watcher (%s): %w", login, err)
+ }
+
+ var expiration time.Time
+ if err := expiration.UnmarshalText([]byte(authResp.Expire)); err != nil {
+ return fmt.Errorf("unable to parse jwt expiration: %w", err)
+ }
+
+ client.GetClient().Transport.(*JWTTransport).Token = authResp.Token
+ client.GetClient().Transport.(*JWTTransport).Expiration = expiration
+
+ lapiClient = client
+
+ return nil
+}
+
+func GetLAPIClient() (*ApiClient, error) {
+ if lapiClient == nil {
+ return nil, errors.New("client not initialized")
+ }
+ return lapiClient, nil
+}
+
func NewClient(config *Config) (*ApiClient, error) {
userAgent := config.UserAgent
if userAgent == "" {
@@ -115,6 +182,7 @@ func NewClient(config *Config) (*ApiClient, error) {
c.common.client = c
c.Decisions = (*DecisionsService)(&c.common)
c.Alerts = (*AlertsService)(&c.common)
+ c.Allowlists = (*AllowlistsService)(&c.common)
c.Auth = (*AuthService)(&c.common)
c.Metrics = (*MetricsService)(&c.common)
c.Signal = (*SignalService)(&c.common)
@@ -157,6 +225,7 @@ func NewDefaultClient(URL *url.URL, prefix string, userAgent string, client *htt
c.common.client = c
c.Decisions = (*DecisionsService)(&c.common)
c.Alerts = (*AlertsService)(&c.common)
+ c.Allowlists = (*AllowlistsService)(&c.common)
c.Auth = (*AuthService)(&c.common)
c.Metrics = (*MetricsService)(&c.common)
c.Signal = (*SignalService)(&c.common)
diff --git a/pkg/apiclient/decisions_service.go b/pkg/apiclient/decisions_service.go
index 98f26cad9ae..fea2f39072d 100644
--- a/pkg/apiclient/decisions_service.go
+++ b/pkg/apiclient/decisions_service.go
@@ -31,6 +31,8 @@ type DecisionsListOpts struct {
type DecisionsStreamOpts struct {
Startup bool `url:"startup,omitempty"`
+ CommunityPull bool `url:"community_pull"`
+ AdditionalPull bool `url:"additional_pull"`
Scopes string `url:"scopes,omitempty"`
ScenariosContaining string `url:"scenarios_containing,omitempty"`
ScenariosNotContaining string `url:"scenarios_not_containing,omitempty"`
@@ -43,6 +45,17 @@ func (o *DecisionsStreamOpts) addQueryParamsToURL(url string) (string, error) {
return "", err
}
+ //Those 2 are a bit different
+ //They default to true, and we only want to include them if they are false
+
+ if params.Get("community_pull") == "true" {
+ params.Del("community_pull")
+ }
+
+ if params.Get("additional_pull") == "true" {
+ params.Del("additional_pull")
+ }
+
return fmt.Sprintf("%s?%s", url, params.Encode()), nil
}
diff --git a/pkg/apiclient/decisions_service_test.go b/pkg/apiclient/decisions_service_test.go
index 54c44f43eda..942d14689ff 100644
--- a/pkg/apiclient/decisions_service_test.go
+++ b/pkg/apiclient/decisions_service_test.go
@@ -4,6 +4,7 @@ import (
"context"
"net/http"
"net/url"
+ "strings"
"testing"
log "github.com/sirupsen/logrus"
@@ -87,7 +88,7 @@ func TestDecisionsStream(t *testing.T) {
testMethod(t, r, http.MethodGet)
if r.Method == http.MethodGet {
- if r.URL.RawQuery == "startup=true" {
+ if strings.Contains(r.URL.RawQuery, "startup=true") {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"deleted":null,"new":[{"duration":"3h59m55.756182786s","id":4,"origin":"cscli","scenario":"manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'","scope":"Ip","type":"ban","value":"1.2.3.4"}]}`))
} else {
@@ -160,7 +161,7 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) {
testMethod(t, r, http.MethodGet)
if r.Method == http.MethodGet {
- if r.URL.RawQuery == "startup=true" {
+ if strings.Contains(r.URL.RawQuery, "startup=true") {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"deleted":[{"scope":"ip","decisions":["1.2.3.5"]}],"new":[{"scope":"ip", "scenario": "manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'", "decisions":[{"duration":"3h59m55.756182786s","value":"1.2.3.4"}]}]}`))
} else {
@@ -429,6 +430,8 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) {
Scopes string
ScenariosContaining string
ScenariosNotContaining string
+ CommunityPull bool
+ AdditionalPull bool
}
tests := []struct {
@@ -440,11 +443,17 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) {
{
name: "no filter",
expected: baseURLString + "?",
+ fields: fields{
+ CommunityPull: true,
+ AdditionalPull: true,
+ },
},
{
name: "startup=true",
fields: fields{
- Startup: true,
+ Startup: true,
+ CommunityPull: true,
+ AdditionalPull: true,
},
expected: baseURLString + "?startup=true",
},
@@ -455,9 +464,19 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) {
Scopes: "ip,range",
ScenariosContaining: "ssh",
ScenariosNotContaining: "bf",
+ CommunityPull: true,
+ AdditionalPull: true,
},
expected: baseURLString + "?scenarios_containing=ssh&scenarios_not_containing=bf&scopes=ip%2Crange&startup=true",
},
+ {
+ name: "pull options",
+ fields: fields{
+ CommunityPull: false,
+ AdditionalPull: false,
+ },
+ expected: baseURLString + "?additional_pull=false&community_pull=false",
+ },
}
for _, tt := range tests {
@@ -467,6 +486,8 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) {
Scopes: tt.fields.Scopes,
ScenariosContaining: tt.fields.ScenariosContaining,
ScenariosNotContaining: tt.fields.ScenariosNotContaining,
+ CommunityPull: tt.fields.CommunityPull,
+ AdditionalPull: tt.fields.AdditionalPull,
}
got, err := o.addQueryParamsToURL(baseURLString)
diff --git a/pkg/apiserver/alerts_test.go b/pkg/apiserver/alerts_test.go
index 4cc215c344f..9b79b0dd311 100644
--- a/pkg/apiserver/alerts_test.go
+++ b/pkg/apiserver/alerts_test.go
@@ -16,27 +16,35 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/csplugin"
+ "github.com/crowdsecurity/crowdsec/pkg/database"
"github.com/crowdsecurity/crowdsec/pkg/models"
)
+const (
+ passwordAuthType = "password"
+ apiKeyAuthType = "apikey"
+)
+
type LAPI struct {
router *gin.Engine
loginResp models.WatcherAuthResponse
bouncerKey string
DBConfig *csconfig.DatabaseCfg
+ DBClient *database.Client
}
func SetupLAPITest(t *testing.T, ctx context.Context) LAPI {
t.Helper()
router, loginResp, config := InitMachineTest(t, ctx)
- APIKey := CreateTestBouncer(t, ctx, config.API.Server.DbConfig)
+ APIKey, dbClient := CreateTestBouncer(t, ctx, config.API.Server.DbConfig)
return LAPI{
router: router,
loginResp: loginResp,
bouncerKey: APIKey,
DBConfig: config.API.Server.DbConfig,
+ DBClient: dbClient,
}
}
@@ -51,14 +59,17 @@ func (l *LAPI) RecordResponse(t *testing.T, ctx context.Context, verb string, ur
require.NoError(t, err)
switch authType {
- case "apikey":
+ case apiKeyAuthType:
req.Header.Add("X-Api-Key", l.bouncerKey)
- case "password":
+ case passwordAuthType:
AddAuthHeaders(req, l.loginResp)
default:
t.Fatal("auth type not supported")
}
+ // Port is required for gin to properly parse the client IP
+ req.RemoteAddr = "127.0.0.1:1234"
+
l.router.ServeHTTP(w, req)
return w
@@ -118,14 +129,14 @@ func TestCreateAlert(t *testing.T) {
w := lapi.RecordResponse(t, ctx, http.MethodPost, "/v1/alerts", strings.NewReader("test"), "password")
assert.Equal(t, 400, w.Code)
- assert.Equal(t, `{"message":"invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String())
+ assert.JSONEq(t, `{"message":"invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String())
// Create Alert with invalid input
alertContent := GetAlertReaderFromFile(t, "./tests/invalidAlert_sample.json")
w = lapi.RecordResponse(t, ctx, http.MethodPost, "/v1/alerts", alertContent, "password")
assert.Equal(t, 500, w.Code)
- assert.Equal(t,
+ assert.JSONEq(t,
`{"message":"validation failure list:\n0.scenario in body is required\n0.scenario_hash in body is required\n0.scenario_version in body is required\n0.simulated in body is required\n0.source in body is required"}`,
w.Body.String())
@@ -173,7 +184,7 @@ func TestAlertListFilters(t *testing.T) {
w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?test=test", alertContent, "password")
assert.Equal(t, 500, w.Code)
- assert.Equal(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String())
+ assert.JSONEq(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String())
// get without filters
@@ -239,7 +250,7 @@ func TestAlertListFilters(t *testing.T) {
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?ip=gruueq", emptyBody, "password")
assert.Equal(t, 500, w.Code)
- assert.Equal(t, `{"message":"unable to convert 'gruueq' to int: invalid address: invalid ip address / range"}`, w.Body.String())
+ assert.JSONEq(t, `{"message":"unable to convert 'gruueq' to int: invalid address: invalid ip address / range"}`, w.Body.String())
// test range (ok)
@@ -258,7 +269,7 @@ func TestAlertListFilters(t *testing.T) {
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?range=ratata", emptyBody, "password")
assert.Equal(t, 500, w.Code)
- assert.Equal(t, `{"message":"unable to convert 'ratata' to int: invalid address: invalid ip address / range"}`, w.Body.String())
+ assert.JSONEq(t, `{"message":"unable to convert 'ratata' to int: invalid address: invalid ip address / range"}`, w.Body.String())
// test since (ok)
@@ -329,7 +340,7 @@ func TestAlertListFilters(t *testing.T) {
w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password")
assert.Equal(t, 500, w.Code)
- assert.Equal(t, `{"message":"'ratatqata' is not a boolean: strconv.ParseBool: parsing \"ratatqata\": invalid syntax: unable to parse type"}`, w.Body.String())
+ assert.JSONEq(t, `{"message":"'ratatqata' is not a boolean: strconv.ParseBool: parsing \"ratatqata\": invalid syntax: unable to parse type"}`, w.Body.String())
}
func TestAlertBulkInsert(t *testing.T) {
@@ -351,7 +362,7 @@ func TestListAlert(t *testing.T) {
w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?test=test", emptyBody, "password")
assert.Equal(t, 500, w.Code)
- assert.Equal(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String())
+ assert.JSONEq(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String())
// List Alert
@@ -394,7 +405,7 @@ func TestDeleteAlert(t *testing.T) {
req.RemoteAddr = "127.0.0.2:4242"
lapi.router.ServeHTTP(w, req)
assert.Equal(t, 403, w.Code)
- assert.Equal(t, `{"message":"access forbidden from this IP (127.0.0.2)"}`, w.Body.String())
+ assert.JSONEq(t, `{"message":"access forbidden from this IP (127.0.0.2)"}`, w.Body.String())
// Delete Alert
w = httptest.NewRecorder()
@@ -403,7 +414,7 @@ func TestDeleteAlert(t *testing.T) {
req.RemoteAddr = "127.0.0.1:4242"
lapi.router.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
- assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String())
+ assert.JSONEq(t, `{"nbDeleted":"1"}`, w.Body.String())
}
func TestDeleteAlertByID(t *testing.T) {
@@ -418,7 +429,7 @@ func TestDeleteAlertByID(t *testing.T) {
req.RemoteAddr = "127.0.0.2:4242"
lapi.router.ServeHTTP(w, req)
assert.Equal(t, 403, w.Code)
- assert.Equal(t, `{"message":"access forbidden from this IP (127.0.0.2)"}`, w.Body.String())
+ assert.JSONEq(t, `{"message":"access forbidden from this IP (127.0.0.2)"}`, w.Body.String())
// Delete Alert
w = httptest.NewRecorder()
@@ -427,7 +438,7 @@ func TestDeleteAlertByID(t *testing.T) {
req.RemoteAddr = "127.0.0.1:4242"
lapi.router.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
- assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String())
+ assert.JSONEq(t, `{"nbDeleted":"1"}`, w.Body.String())
}
func TestDeleteAlertTrustedIPS(t *testing.T) {
@@ -472,7 +483,7 @@ func TestDeleteAlertTrustedIPS(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
- assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String())
+ assert.JSONEq(t, `{"nbDeleted":"1"}`, w.Body.String())
}
lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json")
diff --git a/pkg/apiserver/allowlists_test.go b/pkg/apiserver/allowlists_test.go
new file mode 100644
index 00000000000..3575a21897d
--- /dev/null
+++ b/pkg/apiserver/allowlists_test.go
@@ -0,0 +1,119 @@
+package apiserver
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "testing"
+ "time"
+
+ "github.com/crowdsecurity/crowdsec/pkg/models"
+ "github.com/go-openapi/strfmt"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAllowlistList(t *testing.T) {
+ ctx := context.Background()
+ lapi := SetupLAPITest(t, ctx)
+
+ _, err := lapi.DBClient.CreateAllowList(ctx, "test", "test", false)
+
+ require.NoError(t, err)
+
+ w := lapi.RecordResponse(t, ctx, http.MethodGet, "/v1/allowlists", emptyBody, passwordAuthType)
+
+ require.Equal(t, http.StatusOK, w.Code)
+
+ allowlists := models.GetAllowlistsResponse{}
+
+ err = json.Unmarshal(w.Body.Bytes(), &allowlists)
+ require.NoError(t, err)
+
+ require.Len(t, allowlists, 1)
+ require.Equal(t, "test", allowlists[0].Name)
+}
+
+func TestGetAllowlist(t *testing.T) {
+ ctx := context.Background()
+ lapi := SetupLAPITest(t, ctx)
+
+ l, err := lapi.DBClient.CreateAllowList(ctx, "test", "test", false)
+
+ require.NoError(t, err)
+
+ lapi.DBClient.AddToAllowlist(ctx, l, []*models.AllowlistItem{
+ {
+ Value: "1.2.3.4",
+ },
+ {
+ Value: "2.3.4.5",
+ Expiration: strfmt.DateTime(time.Now().Add(-time.Hour)), // expired
+ },
+ })
+
+ w := lapi.RecordResponse(t, ctx, http.MethodGet, "/v1/allowlists/test?with_content=true", emptyBody, passwordAuthType)
+
+ require.Equal(t, http.StatusOK, w.Code)
+
+ allowlist := models.GetAllowlistResponse{}
+
+ err = json.Unmarshal(w.Body.Bytes(), &allowlist)
+ require.NoError(t, err)
+
+ require.Equal(t, "test", allowlist.Name)
+ require.Len(t, allowlist.Items, 1)
+ require.Equal(t, allowlist.Items[0].Value, "1.2.3.4")
+}
+
+func TestCheckInAllowlist(t *testing.T) {
+ ctx := context.Background()
+ lapi := SetupLAPITest(t, ctx)
+
+ l, err := lapi.DBClient.CreateAllowList(ctx, "test", "test", false)
+
+ require.NoError(t, err)
+
+ lapi.DBClient.AddToAllowlist(ctx, l, []*models.AllowlistItem{
+ {
+ Value: "1.2.3.4",
+ },
+ {
+ Value: "2.3.4.5",
+ Expiration: strfmt.DateTime(time.Now().Add(-time.Hour)), // expired
+ },
+ })
+
+ // GET request, should return 200 and status in body
+ w := lapi.RecordResponse(t, ctx, http.MethodGet, "/v1/allowlists/check/1.2.3.4", emptyBody, passwordAuthType)
+
+ require.Equal(t, http.StatusOK, w.Code)
+
+ resp := models.CheckAllowlistResponse{}
+
+ err = json.Unmarshal(w.Body.Bytes(), &resp)
+ require.NoError(t, err)
+
+ require.True(t, resp.Allowlisted)
+
+ // GET request, should return 200 and status in body
+ w = lapi.RecordResponse(t, ctx, http.MethodGet, "/v1/allowlists/check/2.3.4.5", emptyBody, passwordAuthType)
+
+ require.Equal(t, http.StatusOK, w.Code)
+
+ resp = models.CheckAllowlistResponse{}
+
+ err = json.Unmarshal(w.Body.Bytes(), &resp)
+
+ require.NoError(t, err)
+ require.False(t, resp.Allowlisted)
+
+ // HEAD request, should return 200
+ w = lapi.RecordResponse(t, ctx, http.MethodHead, "/v1/allowlists/check/1.2.3.4", emptyBody, passwordAuthType)
+
+ require.Equal(t, http.StatusOK, w.Code)
+
+ // HEAD request, should return 204
+ w = lapi.RecordResponse(t, ctx, http.MethodHead, "/v1/allowlists/check/2.3.4.5", emptyBody, passwordAuthType)
+
+ require.Equal(t, http.StatusNoContent, w.Code)
+}
diff --git a/pkg/apiserver/api_key_test.go b/pkg/apiserver/api_key_test.go
index e6ed68a6e0d..89e37cd3852 100644
--- a/pkg/apiserver/api_key_test.go
+++ b/pkg/apiserver/api_key_test.go
@@ -14,34 +14,80 @@ func TestAPIKey(t *testing.T) {
ctx := context.Background()
router, config := NewAPITest(t, ctx)
- APIKey := CreateTestBouncer(t, ctx, config.API.Server.DbConfig)
+ APIKey, _ := CreateTestBouncer(t, ctx, config.API.Server.DbConfig)
// Login with empty token
w := httptest.NewRecorder()
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader(""))
req.Header.Add("User-Agent", UserAgent)
+ req.RemoteAddr = "127.0.0.1:1234"
router.ServeHTTP(w, req)
- assert.Equal(t, 403, w.Code)
- assert.Equal(t, `{"message":"access forbidden"}`, w.Body.String())
+ assert.Equal(t, http.StatusForbidden, w.Code)
+ assert.JSONEq(t, `{"message":"access forbidden"}`, w.Body.String())
// Login with invalid token
w = httptest.NewRecorder()
req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader(""))
req.Header.Add("User-Agent", UserAgent)
req.Header.Add("X-Api-Key", "a1b2c3d4e5f6")
+ req.RemoteAddr = "127.0.0.1:1234"
router.ServeHTTP(w, req)
- assert.Equal(t, 403, w.Code)
- assert.Equal(t, `{"message":"access forbidden"}`, w.Body.String())
+ assert.Equal(t, http.StatusForbidden, w.Code)
+ assert.JSONEq(t, `{"message":"access forbidden"}`, w.Body.String())
// Login with valid token
w = httptest.NewRecorder()
req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader(""))
req.Header.Add("User-Agent", UserAgent)
req.Header.Add("X-Api-Key", APIKey)
+ req.RemoteAddr = "127.0.0.1:1234"
router.ServeHTTP(w, req)
- assert.Equal(t, 200, w.Code)
+ assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "null", w.Body.String())
+
+ // Login with valid token from another IP
+ w = httptest.NewRecorder()
+ req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader(""))
+ req.Header.Add("User-Agent", UserAgent)
+ req.Header.Add("X-Api-Key", APIKey)
+ req.RemoteAddr = "4.3.2.1:1234"
+ router.ServeHTTP(w, req)
+
+ assert.Equal(t, http.StatusOK, w.Code)
+ assert.Equal(t, "null", w.Body.String())
+
+ // Make the requests multiple times to make sure we only create one
+ w = httptest.NewRecorder()
+ req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader(""))
+ req.Header.Add("User-Agent", UserAgent)
+ req.Header.Add("X-Api-Key", APIKey)
+ req.RemoteAddr = "4.3.2.1:1234"
+ router.ServeHTTP(w, req)
+
+ assert.Equal(t, http.StatusOK, w.Code)
+ assert.Equal(t, "null", w.Body.String())
+
+ // Use the original bouncer again
+ w = httptest.NewRecorder()
+ req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader(""))
+ req.Header.Add("User-Agent", UserAgent)
+ req.Header.Add("X-Api-Key", APIKey)
+ req.RemoteAddr = "127.0.0.1:1234"
+ router.ServeHTTP(w, req)
+
+ assert.Equal(t, http.StatusOK, w.Code)
+ assert.Equal(t, "null", w.Body.String())
+
+ // Check if our second bouncer was properly created
+ bouncers := GetBouncers(t, config.API.Server.DbConfig)
+
+ assert.Len(t, bouncers, 2)
+ assert.Equal(t, "test@4.3.2.1", bouncers[1].Name)
+ assert.Equal(t, bouncers[0].APIKey, bouncers[1].APIKey)
+ assert.Equal(t, bouncers[0].AuthType, bouncers[1].AuthType)
+ assert.False(t, bouncers[0].AutoCreated)
+ assert.True(t, bouncers[1].AutoCreated)
}
diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go
index fff0ebcacbf..d420ef5ac3c 100644
--- a/pkg/apiserver/apic.go
+++ b/pkg/apiserver/apic.go
@@ -1,7 +1,9 @@
package apiserver
import (
+ "bufio"
"context"
+ "encoding/json"
"errors"
"fmt"
"math/rand"
@@ -69,6 +71,10 @@ type apic struct {
consoleConfig *csconfig.ConsoleConfig
isPulling chan bool
whitelists *csconfig.CapiWhitelist
+
+ pullBlocklists bool
+ pullCommunity bool
+ shareSignals bool
}
// randomDuration returns a duration value between d-delta and d+delta
@@ -198,6 +204,9 @@ func NewAPIC(ctx context.Context, config *csconfig.OnlineApiClientCfg, dbClient
usageMetricsIntervalFirst: randomDuration(usageMetricsInterval, usageMetricsIntervalDelta),
isPulling: make(chan bool, 1),
whitelists: apicWhitelist,
+ pullBlocklists: *config.PullConfig.Blocklists,
+ pullCommunity: *config.PullConfig.Community,
+ shareSignals: *config.Sharing,
}
password := strfmt.Password(config.Credentials.Password)
@@ -295,7 +304,7 @@ func (a *apic) Push(ctx context.Context) error {
var signals []*models.AddSignalsRequestItem
for _, alert := range alerts {
- if ok := shouldShareAlert(alert, a.consoleConfig); ok {
+ if ok := shouldShareAlert(alert, a.consoleConfig, a.shareSignals); ok {
signals = append(signals, alertToSignal(alert, getScenarioTrustOfAlert(alert), *a.consoleConfig.ShareContext))
}
}
@@ -324,7 +333,12 @@ func getScenarioTrustOfAlert(alert *models.Alert) string {
return scenarioTrust
}
-func shouldShareAlert(alert *models.Alert, consoleConfig *csconfig.ConsoleConfig) bool {
+func shouldShareAlert(alert *models.Alert, consoleConfig *csconfig.ConsoleConfig, shareSignals bool) bool {
+ if !shareSignals {
+ log.Debugf("sharing signals is disabled")
+ return false
+ }
+
if *alert.Simulated {
log.Debugf("simulation enabled for alert (id:%d), will not be sent to CAPI", alert.ID)
return false
@@ -625,7 +639,9 @@ func (a *apic) PullTop(ctx context.Context, forcePull bool) error {
log.Infof("Starting community-blocklist update")
- data, _, err := a.apiClient.Decisions.GetStreamV3(ctx, apiclient.DecisionsStreamOpts{Startup: a.startup})
+ log.Debugf("Community pull: %t | Blocklist pull: %t", a.pullCommunity, a.pullBlocklists)
+
+ data, _, err := a.apiClient.Decisions.GetStreamV3(ctx, apiclient.DecisionsStreamOpts{Startup: a.startup, CommunityPull: a.pullCommunity, AdditionalPull: a.pullBlocklists})
if err != nil {
return fmt.Errorf("get stream: %w", err)
}
@@ -650,28 +666,40 @@ func (a *apic) PullTop(ctx context.Context, forcePull bool) error {
log.Printf("capi/community-blocklist : %d explicit deletions", nbDeleted)
- if len(data.New) == 0 {
- log.Infof("capi/community-blocklist : received 0 new entries (expected if you just installed crowdsec)")
- return nil
- }
-
- // create one alert for community blocklist using the first decision
- decisions := a.apiClient.Decisions.GetDecisionsFromGroups(data.New)
- // apply APIC specific whitelists
- decisions = a.ApplyApicWhitelists(decisions)
+ if len(data.New) > 0 {
+ // create one alert for community blocklist using the first decision
+ decisions := a.apiClient.Decisions.GetDecisionsFromGroups(data.New)
+ // apply APIC specific whitelists
+ decisions = a.ApplyApicWhitelists(decisions)
- alert := createAlertForDecision(decisions[0])
- alertsFromCapi := []*models.Alert{alert}
- alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, addCounters)
+ alert := createAlertForDecision(decisions[0])
+ alertsFromCapi := []*models.Alert{alert}
+ alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, addCounters)
- err = a.SaveAlerts(ctx, alertsFromCapi, addCounters, deleteCounters)
- if err != nil {
- return fmt.Errorf("while saving alerts: %w", err)
+ err = a.SaveAlerts(ctx, alertsFromCapi, addCounters, deleteCounters)
+ if err != nil {
+ return fmt.Errorf("while saving alerts: %w", err)
+ }
+ } else {
+ if a.pullCommunity {
+ log.Info("capi/community-blocklist : received 0 new entries (expected if you just installed crowdsec)")
+ } else {
+ log.Debug("capi/community-blocklist : community blocklist pull is disabled")
+ }
}
- // update blocklists
- if err := a.UpdateBlocklists(ctx, data.Links, addCounters, forcePull); err != nil {
- return fmt.Errorf("while updating blocklists: %w", err)
+ // update allowlists/blocklists
+ if data.Links != nil {
+ if len(data.Links.Blocklists) > 0 {
+ if err := a.UpdateBlocklists(ctx, data.Links.Blocklists, addCounters, forcePull); err != nil {
+ return fmt.Errorf("while updating blocklists: %w", err)
+ }
+ }
+ if len(data.Links.Allowlists) > 0 {
+ if err := a.UpdateAllowlists(ctx, data.Links.Allowlists, forcePull); err != nil {
+ return fmt.Errorf("while updating allowlists: %w", err)
+ }
+ }
}
return nil
@@ -680,15 +708,94 @@ func (a *apic) PullTop(ctx context.Context, forcePull bool) error {
// we receive a link to a blocklist, we pull the content of the blocklist and we create one alert
func (a *apic) PullBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink, forcePull bool) error {
addCounters, _ := makeAddAndDeleteCounters()
- if err := a.UpdateBlocklists(ctx, &modelscapi.GetDecisionsStreamResponseLinks{
- Blocklists: []*modelscapi.BlocklistLink{blocklist},
- }, addCounters, forcePull); err != nil {
+ if err := a.UpdateBlocklists(ctx, []*modelscapi.BlocklistLink{blocklist}, addCounters, forcePull); err != nil {
return fmt.Errorf("while pulling blocklist: %w", err)
}
return nil
}
+func (a *apic) PullAllowlist(ctx context.Context, allowlist *modelscapi.AllowlistLink, forcePull bool) error {
+ if err := a.UpdateAllowlists(ctx, []*modelscapi.AllowlistLink{allowlist}, forcePull); err != nil {
+ return fmt.Errorf("while pulling allowlist: %w", err)
+ }
+
+ return nil
+}
+
+func (a *apic) UpdateAllowlists(ctx context.Context, allowlistsLinks []*modelscapi.AllowlistLink, forcePull bool) error {
+ if len(allowlistsLinks) == 0 {
+ return nil
+ }
+
+ defaultClient, err := apiclient.NewDefaultClient(a.apiClient.BaseURL, "", "", nil)
+ if err != nil {
+ return fmt.Errorf("while creating default client: %w", err)
+ }
+
+ for _, link := range allowlistsLinks {
+ if link.URL == nil {
+ log.Warningf("allowlist has no URL")
+ continue
+ }
+ if link.Name == nil {
+ log.Warningf("allowlist has no name")
+ continue
+ }
+
+ description := ""
+ if link.Description != nil {
+ description = *link.Description
+ }
+
+ resp, err := defaultClient.GetClient().Get(*link.URL)
+ if err != nil {
+ log.Errorf("while pulling allowlist: %s", err)
+ continue
+ }
+ defer resp.Body.Close()
+
+ scanner := bufio.NewScanner(resp.Body)
+ items := make([]*models.AllowlistItem, 0)
+ for scanner.Scan() {
+ item := scanner.Text()
+ j := &models.AllowlistItem{}
+ if err := json.Unmarshal([]byte(item), j); err != nil {
+ log.Errorf("while unmarshalling allowlist item: %s", err)
+ continue
+ }
+ items = append(items, j)
+ }
+
+ list, err := a.dbClient.GetAllowList(ctx, *link.Name, false)
+
+ if err != nil {
+ if !ent.IsNotFound(err) {
+ log.Errorf("while getting allowlist %s: %s", *link.Name, err)
+ continue
+ }
+ }
+
+ if list == nil {
+ list, err = a.dbClient.CreateAllowList(ctx, *link.Name, description, true)
+ if err != nil {
+ log.Errorf("while creating allowlist %s: %s", *link.Name, err)
+ continue
+ }
+ }
+
+ err = a.dbClient.ReplaceAllowlist(ctx, list, items, true)
+ if err != nil {
+ log.Errorf("while replacing allowlist %s: %s", *link.Name, err)
+ continue
+ }
+
+ log.Infof("Allowlist %s updated", *link.Name)
+ }
+
+ return nil
+}
+
// if decisions is whitelisted: return representation of the whitelist ip or cidr
// if not whitelisted: empty string
func (a *apic) whitelistedBy(decision *models.Decision) string {
@@ -862,14 +969,11 @@ func (a *apic) updateBlocklist(ctx context.Context, client *apiclient.ApiClient,
return nil
}
-func (a *apic) UpdateBlocklists(ctx context.Context, links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error {
- if links == nil {
+func (a *apic) UpdateBlocklists(ctx context.Context, blocklists []*modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error {
+ if len(blocklists) == 0 {
return nil
}
- if links.Blocklists == nil {
- return nil
- }
// we must use a different http client than apiClient's because the transport of apiClient is jwtTransport or here we have signed apis that are incompatibles
// we can use the same baseUrl as the urls are absolute and the parse will take care of it
defaultClient, err := apiclient.NewDefaultClient(a.apiClient.BaseURL, "", "", nil)
@@ -877,7 +981,7 @@ func (a *apic) UpdateBlocklists(ctx context.Context, links *modelscapi.GetDecisi
return fmt.Errorf("while creating default client: %w", err)
}
- for _, blocklist := range links.Blocklists {
+ for _, blocklist := range blocklists {
if err := a.updateBlocklist(ctx, defaultClient, blocklist, addCounters, forcePull); err != nil {
return err
}
diff --git a/pkg/apiserver/apic_metrics.go b/pkg/apiserver/apic_metrics.go
index aa8db3f1c85..fe0dfd55821 100644
--- a/pkg/apiserver/apic_metrics.go
+++ b/pkg/apiserver/apic_metrics.go
@@ -368,10 +368,14 @@ func (a *apic) SendUsageMetrics(ctx context.Context) {
if err != nil {
log.Errorf("unable to send usage metrics: %s", err)
- if resp == nil || resp.Response.StatusCode >= http.StatusBadRequest && resp.Response.StatusCode != http.StatusUnprocessableEntity {
+ if resp == nil || resp.Response == nil {
+ // Most likely a transient network error, it will be retried later
+ continue
+ }
+
+ if resp.Response.StatusCode >= http.StatusBadRequest && resp.Response.StatusCode != http.StatusUnprocessableEntity {
// In case of 422, mark the metrics as sent anyway, the API did not like what we sent,
// and it's unlikely we'll be able to fix it
- // also if resp is nil, we should'nt mark the metrics as sent could be network issue
continue
}
}
diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go
index 99fee6e32bf..a8fbb40c4fa 100644
--- a/pkg/apiserver/apic_test.go
+++ b/pkg/apiserver/apic_test.go
@@ -69,7 +69,10 @@ func getAPIC(t *testing.T, ctx context.Context) *apic {
ShareCustomScenarios: ptr.Of(false),
ShareContext: ptr.Of(false),
},
- isPulling: make(chan bool, 1),
+ isPulling: make(chan bool, 1),
+ shareSignals: true,
+ pullBlocklists: true,
+ pullCommunity: true,
}
}
@@ -200,6 +203,11 @@ func TestNewAPIC(t *testing.T) {
Login: "foo",
Password: "bar",
},
+ Sharing: ptr.Of(true),
+ PullConfig: csconfig.CapiPullConfig{
+ Community: ptr.Of(true),
+ Blocklists: ptr.Of(true),
+ },
}
}
@@ -1193,6 +1201,7 @@ func TestShouldShareAlert(t *testing.T) {
tests := []struct {
name string
consoleConfig *csconfig.ConsoleConfig
+ shareSignals bool
alert *models.Alert
expectedRet bool
expectedTrust string
@@ -1203,6 +1212,7 @@ func TestShouldShareAlert(t *testing.T) {
ShareCustomScenarios: ptr.Of(true),
},
alert: &models.Alert{Simulated: ptr.Of(false)},
+ shareSignals: true,
expectedRet: true,
expectedTrust: "custom",
},
@@ -1212,6 +1222,7 @@ func TestShouldShareAlert(t *testing.T) {
ShareCustomScenarios: ptr.Of(false),
},
alert: &models.Alert{Simulated: ptr.Of(false)},
+ shareSignals: true,
expectedRet: false,
expectedTrust: "custom",
},
@@ -1220,6 +1231,7 @@ func TestShouldShareAlert(t *testing.T) {
consoleConfig: &csconfig.ConsoleConfig{
ShareManualDecisions: ptr.Of(true),
},
+ shareSignals: true,
alert: &models.Alert{
Simulated: ptr.Of(false),
Decisions: []*models.Decision{{Origin: ptr.Of(types.CscliOrigin)}},
@@ -1232,6 +1244,7 @@ func TestShouldShareAlert(t *testing.T) {
consoleConfig: &csconfig.ConsoleConfig{
ShareManualDecisions: ptr.Of(false),
},
+ shareSignals: true,
alert: &models.Alert{
Simulated: ptr.Of(false),
Decisions: []*models.Decision{{Origin: ptr.Of(types.CscliOrigin)}},
@@ -1244,6 +1257,7 @@ func TestShouldShareAlert(t *testing.T) {
consoleConfig: &csconfig.ConsoleConfig{
ShareTaintedScenarios: ptr.Of(true),
},
+ shareSignals: true,
alert: &models.Alert{
Simulated: ptr.Of(false),
ScenarioHash: ptr.Of("whateverHash"),
@@ -1256,6 +1270,7 @@ func TestShouldShareAlert(t *testing.T) {
consoleConfig: &csconfig.ConsoleConfig{
ShareTaintedScenarios: ptr.Of(false),
},
+ shareSignals: true,
alert: &models.Alert{
Simulated: ptr.Of(false),
ScenarioHash: ptr.Of("whateverHash"),
@@ -1263,11 +1278,24 @@ func TestShouldShareAlert(t *testing.T) {
expectedRet: false,
expectedTrust: "tainted",
},
+ {
+ name: "manual alert should not be shared if global sharing is disabled",
+ consoleConfig: &csconfig.ConsoleConfig{
+ ShareManualDecisions: ptr.Of(true),
+ },
+ shareSignals: false,
+ alert: &models.Alert{
+ Simulated: ptr.Of(false),
+ ScenarioHash: ptr.Of("whateverHash"),
+ },
+ expectedRet: false,
+ expectedTrust: "manual",
+ },
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
- ret := shouldShareAlert(tc.alert, tc.consoleConfig)
+ ret := shouldShareAlert(tc.alert, tc.consoleConfig, tc.shareSignals)
assert.Equal(t, tc.expectedRet, ret)
})
}
diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go
index 35f9beaf635..05f9150b037 100644
--- a/pkg/apiserver/apiserver.go
+++ b/pkg/apiserver/apiserver.go
@@ -46,20 +46,11 @@ type APIServer struct {
consoleConfig *csconfig.ConsoleConfig
}
-func recoverFromPanic(c *gin.Context) {
- err := recover()
- if err == nil {
- return
- }
-
- // Check for a broken connection, as it is not really a
- // condition that warrants a panic stack trace.
- brokenPipe := false
-
+func isBrokenConnection(err any) bool {
if ne, ok := err.(*net.OpError); ok {
if se, ok := ne.Err.(*os.SyscallError); ok {
if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") {
- brokenPipe = true
+ return true
}
}
}
@@ -79,11 +70,22 @@ func recoverFromPanic(c *gin.Context) {
errors.Is(strErr, errClosedBody) ||
errors.Is(strErr, errHandlerComplete) ||
errors.Is(strErr, errStreamClosed) {
- brokenPipe = true
+ return true
}
}
- if brokenPipe {
+ return false
+}
+
+func recoverFromPanic(c *gin.Context) {
+ err := recover()
+ if err == nil {
+ return
+ }
+
+ // Check for a broken connection, as it is not really a
+ // condition that warrants a panic stack trace.
+ if isBrokenConnection(err) {
log.Warningf("client %s disconnected: %s", c.ClientIP(), err)
c.Abort()
} else {
diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go
index cdf99462c35..8b1d0d9b401 100644
--- a/pkg/apiserver/apiserver_test.go
+++ b/pkg/apiserver/apiserver_test.go
@@ -24,6 +24,7 @@ import (
middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1"
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/database"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent"
"github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/crowdsecurity/crowdsec/pkg/types"
)
@@ -62,6 +63,7 @@ func LoadTestConfig(t *testing.T) csconfig.Config {
}
apiServerConfig := csconfig.LocalApiServerCfg{
ListenURI: "http://127.0.0.1:8080",
+ LogLevel: ptr.Of(log.DebugLevel),
DbConfig: &dbconfig,
ProfilesPath: "./tests/profiles.yaml",
ConsoleConfig: &csconfig.ConsoleConfig{
@@ -206,6 +208,18 @@ func GetMachineIP(t *testing.T, machineID string, config *csconfig.DatabaseCfg)
return ""
}
+func GetBouncers(t *testing.T, config *csconfig.DatabaseCfg) []*ent.Bouncer {
+ ctx := context.Background()
+
+ dbClient, err := database.NewClient(ctx, config)
+ require.NoError(t, err)
+
+ bouncers, err := dbClient.ListBouncers(ctx)
+ require.NoError(t, err)
+
+ return bouncers
+}
+
func GetAlertReaderFromFile(t *testing.T, path string) *strings.Reader {
alertContentBytes, err := os.ReadFile(path)
require.NoError(t, err)
@@ -283,17 +297,17 @@ func CreateTestMachine(t *testing.T, ctx context.Context, router *gin.Engine, to
return body
}
-func CreateTestBouncer(t *testing.T, ctx context.Context, config *csconfig.DatabaseCfg) string {
+func CreateTestBouncer(t *testing.T, ctx context.Context, config *csconfig.DatabaseCfg) (string, *database.Client) {
dbClient, err := database.NewClient(ctx, config)
require.NoError(t, err)
apiKey, err := middlewares.GenerateAPIKey(keyLength)
require.NoError(t, err)
- _, err = dbClient.CreateBouncer(ctx, "test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType)
+ _, err = dbClient.CreateBouncer(ctx, "test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType, false)
require.NoError(t, err)
- return apiKey
+ return apiKey, dbClient
}
func TestWithWrongDBConfig(t *testing.T) {
diff --git a/pkg/apiserver/controllers/controller.go b/pkg/apiserver/controllers/controller.go
index 719bb231006..58f757b6657 100644
--- a/pkg/apiserver/controllers/controller.go
+++ b/pkg/apiserver/controllers/controller.go
@@ -123,6 +123,11 @@ func (c *Controller) NewV1() error {
jwtAuth.DELETE("/decisions", c.HandlerV1.DeleteDecisions)
jwtAuth.DELETE("/decisions/:decision_id", c.HandlerV1.DeleteDecisionById)
jwtAuth.GET("/heartbeat", c.HandlerV1.HeartBeat)
+ jwtAuth.GET("/allowlists", c.HandlerV1.GetAllowlists)
+ jwtAuth.GET("/allowlists/:allowlist_name", c.HandlerV1.GetAllowlist)
+ jwtAuth.GET("/allowlists/check/:ip_or_range", c.HandlerV1.CheckInAllowlist)
+ jwtAuth.HEAD("/allowlists/check/:ip_or_range", c.HandlerV1.CheckInAllowlist)
+
}
apiKeyAuth := groupV1.Group("")
diff --git a/pkg/apiserver/controllers/v1/alerts.go b/pkg/apiserver/controllers/v1/alerts.go
index d1f93228512..a1582cb05b2 100644
--- a/pkg/apiserver/controllers/v1/alerts.go
+++ b/pkg/apiserver/controllers/v1/alerts.go
@@ -141,6 +141,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) {
}
stopFlush := false
+ alertsToSave := make([]*models.Alert, 0)
for _, alert := range input {
// normalize scope for alert.Source and decisions
@@ -154,6 +155,19 @@ func (c *Controller) CreateAlert(gctx *gin.Context) {
}
}
+ if alert.Source.Scope != nil && (*alert.Source.Scope == types.Ip || *alert.Source.Scope == types.Range) && // Allowlist only works for IP/range
+ alert.Source.Value != nil && // Is this possible ?
+ len(alert.Decisions) == 0 { // If there's no decisions, means it's coming from crowdsec (not cscli), so we can apply allowlist
+ isAllowlisted, err := c.DBClient.IsAllowlisted(ctx, *alert.Source.Value)
+ if err == nil && isAllowlisted {
+ log.Infof("alert source %s is allowlisted, skipping", *alert.Source.Value)
+ continue
+ } else if err != nil {
+ //FIXME: Do we still want to process the alert normally if we can't check the allowlist ?
+ log.Errorf("error while checking allowlist: %s", err)
+ }
+ }
+
alert.MachineID = machineID
// generate uuid here for alert
alert.UUID = uuid.NewString()
@@ -189,6 +203,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) {
stopFlush = true
}
+ alertsToSave = append(alertsToSave, alert)
continue
}
@@ -234,13 +249,15 @@ func (c *Controller) CreateAlert(gctx *gin.Context) {
break
}
}
+
+ alertsToSave = append(alertsToSave, alert)
}
if stopFlush {
c.DBClient.CanFlush = false
}
- alerts, err := c.DBClient.CreateAlert(ctx, machineID, input)
+ alerts, err := c.DBClient.CreateAlert(ctx, machineID, alertsToSave)
c.DBClient.CanFlush = true
if err != nil {
@@ -250,7 +267,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) {
if c.AlertsAddChan != nil {
select {
- case c.AlertsAddChan <- input:
+ case c.AlertsAddChan <- alertsToSave:
log.Debug("alert sent to CAPI channel")
default:
log.Warning("Cannot send alert to Central API channel")
diff --git a/pkg/apiserver/controllers/v1/allowlist.go b/pkg/apiserver/controllers/v1/allowlist.go
new file mode 100644
index 00000000000..64a409e4ecd
--- /dev/null
+++ b/pkg/apiserver/controllers/v1/allowlist.go
@@ -0,0 +1,126 @@
+package v1
+
+import (
+ "net/http"
+ "time"
+
+ "github.com/crowdsecurity/crowdsec/pkg/models"
+ "github.com/gin-gonic/gin"
+ "github.com/go-openapi/strfmt"
+)
+
+func (c *Controller) CheckInAllowlist(gctx *gin.Context) {
+ value := gctx.Param("ip_or_range")
+
+ if value == "" {
+ gctx.JSON(http.StatusBadRequest, gin.H{"message": "value is required"})
+ return
+ }
+
+ allowlisted, err := c.DBClient.IsAllowlisted(gctx.Request.Context(), value)
+
+ if err != nil {
+ c.HandleDBErrors(gctx, err)
+ return
+ }
+
+ if gctx.Request.Method == http.MethodHead {
+ if allowlisted {
+ gctx.Status(http.StatusOK)
+ } else {
+ gctx.Status(http.StatusNoContent)
+ }
+ return
+ }
+
+ resp := models.CheckAllowlistResponse{
+ Allowlisted: allowlisted,
+ }
+
+ gctx.JSON(http.StatusOK, resp)
+}
+
+func (c *Controller) GetAllowlists(gctx *gin.Context) {
+ params := gctx.Request.URL.Query()
+
+ withContent := params.Get("with_content") == "true"
+
+ allowlists, err := c.DBClient.ListAllowLists(gctx.Request.Context(), withContent)
+
+ if err != nil {
+ c.HandleDBErrors(gctx, err)
+ return
+ }
+
+ resp := models.GetAllowlistsResponse{}
+
+ for _, allowlist := range allowlists {
+ items := make([]*models.AllowlistItem, 0)
+ if withContent {
+ for _, item := range allowlist.Edges.AllowlistItems {
+ if !item.ExpiresAt.IsZero() && item.ExpiresAt.Before(time.Now()) {
+ continue
+ }
+ items = append(items, &models.AllowlistItem{
+ CreatedAt: strfmt.DateTime(item.CreatedAt),
+ Description: item.Comment,
+ Expiration: strfmt.DateTime(item.ExpiresAt),
+ Value: item.Value,
+ })
+ }
+ }
+ resp = append(resp, &models.GetAllowlistResponse{
+ AllowlistID: allowlist.AllowlistID,
+ Name: allowlist.Name,
+ Description: allowlist.Description,
+ CreatedAt: strfmt.DateTime(allowlist.CreatedAt),
+ UpdatedAt: strfmt.DateTime(allowlist.UpdatedAt),
+ ConsoleManaged: allowlist.FromConsole,
+ Items: items,
+ })
+ }
+
+ gctx.JSON(http.StatusOK, resp)
+}
+
+func (c *Controller) GetAllowlist(gctx *gin.Context) {
+ allowlist := gctx.Param("allowlist_name")
+
+ params := gctx.Request.URL.Query()
+ withContent := params.Get("with_content") == "true"
+
+ allowlistModel, err := c.DBClient.GetAllowList(gctx.Request.Context(), allowlist, withContent)
+
+ if err != nil {
+ c.HandleDBErrors(gctx, err)
+ return
+ }
+
+ items := make([]*models.AllowlistItem, 0)
+
+ if withContent {
+ for _, item := range allowlistModel.Edges.AllowlistItems {
+ if !item.ExpiresAt.IsZero() && item.ExpiresAt.Before(time.Now()) {
+ continue
+ }
+ items = append(items, &models.AllowlistItem{
+ CreatedAt: strfmt.DateTime(item.CreatedAt),
+ Description: item.Comment,
+ Expiration: strfmt.DateTime(item.ExpiresAt),
+ Value: item.Value,
+ })
+ }
+ }
+
+ resp := models.GetAllowlistResponse{
+ AllowlistID: allowlistModel.AllowlistID,
+ Name: allowlistModel.Name,
+ Description: allowlistModel.Description,
+ CreatedAt: strfmt.DateTime(allowlistModel.CreatedAt),
+ UpdatedAt: strfmt.DateTime(allowlistModel.UpdatedAt),
+ ConsoleManaged: allowlistModel.FromConsole,
+ Items: items,
+ }
+
+ gctx.JSON(http.StatusOK, resp)
+}
diff --git a/pkg/apiserver/decisions_test.go b/pkg/apiserver/decisions_test.go
index a0af6956443..cb5d2e1c4f1 100644
--- a/pkg/apiserver/decisions_test.go
+++ b/pkg/apiserver/decisions_test.go
@@ -22,19 +22,19 @@ func TestDeleteDecisionRange(t *testing.T) {
// delete by ip wrong
w := lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?range=1.2.3.0/24", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
- assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String())
+ assert.JSONEq(t, `{"nbDeleted":"0"}`, w.Body.String())
// delete by range
w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
- assert.Equal(t, `{"nbDeleted":"2"}`, w.Body.String())
+ assert.JSONEq(t, `{"nbDeleted":"2"}`, w.Body.String())
// delete by range : ensure it was already deleted
w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?range=91.121.79.0/24", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
- assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String())
+ assert.JSONEq(t, `{"nbDeleted":"0"}`, w.Body.String())
}
func TestDeleteDecisionFilter(t *testing.T) {
@@ -48,19 +48,19 @@ func TestDeleteDecisionFilter(t *testing.T) {
w := lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?ip=1.2.3.4", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
- assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String())
+ assert.JSONEq(t, `{"nbDeleted":"0"}`, w.Body.String())
// delete by ip good
w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?ip=91.121.79.179", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
- assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String())
+ assert.JSONEq(t, `{"nbDeleted":"1"}`, w.Body.String())
// delete by scope/value
w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?scopes=Ip&value=91.121.79.178", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
- assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String())
+ assert.JSONEq(t, `{"nbDeleted":"1"}`, w.Body.String())
}
func TestDeleteDecisionFilterByScenario(t *testing.T) {
@@ -74,13 +74,13 @@ func TestDeleteDecisionFilterByScenario(t *testing.T) {
w := lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bff", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
- assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String())
+ assert.JSONEq(t, `{"nbDeleted":"0"}`, w.Body.String())
// delete by scenario good
w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bf", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
- assert.Equal(t, `{"nbDeleted":"2"}`, w.Body.String())
+ assert.JSONEq(t, `{"nbDeleted":"2"}`, w.Body.String())
}
func TestGetDecisionFilters(t *testing.T) {
diff --git a/pkg/apiserver/jwt_test.go b/pkg/apiserver/jwt_test.go
index f6f51763975..72ae0302ae4 100644
--- a/pkg/apiserver/jwt_test.go
+++ b/pkg/apiserver/jwt_test.go
@@ -23,7 +23,7 @@ func TestLogin(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code)
- assert.Equal(t, `{"code":401,"message":"machine test not validated"}`, w.Body.String())
+ assert.JSONEq(t, `{"code":401,"message":"machine test not validated"}`, w.Body.String())
// Login with machine not exist
w = httptest.NewRecorder()
@@ -32,7 +32,7 @@ func TestLogin(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code)
- assert.Equal(t, `{"code":401,"message":"ent: machine not found"}`, w.Body.String())
+ assert.JSONEq(t, `{"code":401,"message":"ent: machine not found"}`, w.Body.String())
// Login with invalid body
w = httptest.NewRecorder()
@@ -41,7 +41,7 @@ func TestLogin(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code)
- assert.Equal(t, `{"code":401,"message":"missing: invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String())
+ assert.JSONEq(t, `{"code":401,"message":"missing: invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String())
// Login with invalid format
w = httptest.NewRecorder()
@@ -50,7 +50,7 @@ func TestLogin(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code)
- assert.Equal(t, `{"code":401,"message":"validation failure list:\npassword in body is required"}`, w.Body.String())
+ assert.JSONEq(t, `{"code":401,"message":"validation failure list:\npassword in body is required"}`, w.Body.String())
// Validate machine
ValidateMachine(t, ctx, "test", config.API.Server.DbConfig)
@@ -62,7 +62,7 @@ func TestLogin(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code)
- assert.Equal(t, `{"code":401,"message":"incorrect Username or Password"}`, w.Body.String())
+ assert.JSONEq(t, `{"code":401,"message":"incorrect Username or Password"}`, w.Body.String())
// Login with valid machine
w = httptest.NewRecorder()
diff --git a/pkg/apiserver/machines_test.go b/pkg/apiserver/machines_test.go
index 969f75707d6..57b96f54ddd 100644
--- a/pkg/apiserver/machines_test.go
+++ b/pkg/apiserver/machines_test.go
@@ -25,7 +25,7 @@ func TestCreateMachine(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
- assert.Equal(t, `{"message":"invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String())
+ assert.JSONEq(t, `{"message":"invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String())
// Create machine with invalid input
w = httptest.NewRecorder()
@@ -34,7 +34,7 @@ func TestCreateMachine(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnprocessableEntity, w.Code)
- assert.Equal(t, `{"message":"validation failure list:\nmachine_id in body is required\npassword in body is required"}`, w.Body.String())
+ assert.JSONEq(t, `{"message":"validation failure list:\nmachine_id in body is required\npassword in body is required"}`, w.Body.String())
// Create machine
b, err := json.Marshal(MachineTest)
@@ -144,7 +144,7 @@ func TestCreateMachineAlreadyExist(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusForbidden, w.Code)
- assert.Equal(t, `{"message":"user 'test': user already exist"}`, w.Body.String())
+ assert.JSONEq(t, `{"message":"user 'test': user already exist"}`, w.Body.String())
}
func TestAutoRegistration(t *testing.T) {
diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go
index d438c9b15a4..df2f68930d6 100644
--- a/pkg/apiserver/middlewares/v1/api_key.go
+++ b/pkg/apiserver/middlewares/v1/api_key.go
@@ -89,7 +89,7 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer {
logger.Infof("Creating bouncer %s", bouncerName)
- bouncer, err = a.DbClient.CreateBouncer(ctx, bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType)
+ bouncer, err = a.DbClient.CreateBouncer(ctx, bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType, true)
if err != nil {
logger.Errorf("while creating bouncer db entry: %s", err)
return nil
@@ -114,18 +114,68 @@ func (a *APIKey) authPlain(c *gin.Context, logger *log.Entry) *ent.Bouncer {
return nil
}
+ clientIP := c.ClientIP()
+
ctx := c.Request.Context()
hashStr := HashSHA512(val[0])
- bouncer, err := a.DbClient.SelectBouncer(ctx, hashStr)
+ // Appsec case, we only care if the key is valid
+ // No content is returned, no last_pull update or anything
+ if c.Request.Method == http.MethodHead {
+ bouncer, err := a.DbClient.SelectBouncers(ctx, hashStr, types.ApiKeyAuthType)
+ if err != nil {
+ logger.Errorf("while fetching bouncer info: %s", err)
+ return nil
+ }
+ return bouncer[0]
+ }
+
+ // most common case, check if this specific bouncer exists
+ bouncer, err := a.DbClient.SelectBouncerWithIP(ctx, hashStr, clientIP)
+ if err != nil && !ent.IsNotFound(err) {
+ logger.Errorf("while fetching bouncer info: %s", err)
+ return nil
+ }
+
+ // We found the bouncer with key and IP, we can use it
+ if bouncer != nil {
+ if bouncer.AuthType != types.ApiKeyAuthType {
+ logger.Errorf("bouncer isn't allowed to auth by API key")
+ return nil
+ }
+ return bouncer
+ }
+
+ // We didn't find the bouncer with key and IP, let's try to find it with the key only
+ bouncers, err := a.DbClient.SelectBouncers(ctx, hashStr, types.ApiKeyAuthType)
if err != nil {
logger.Errorf("while fetching bouncer info: %s", err)
return nil
}
- if bouncer.AuthType != types.ApiKeyAuthType {
- logger.Errorf("bouncer %s attempted to login using an API key but it is configured to auth with %s", bouncer.Name, bouncer.AuthType)
+ if len(bouncers) == 0 {
+ logger.Debugf("no bouncer found with this key")
+ return nil
+ }
+
+ logger.Debugf("found %d bouncers with this key", len(bouncers))
+
+ // We only have one bouncer with this key and no IP
+ // This is the first request made by this bouncer, keep this one
+ if len(bouncers) == 1 && bouncers[0].IPAddress == "" {
+ return bouncers[0]
+ }
+
+ // Bouncers are ordered by ID, first one *should* be the manually created one
+ // Can probably get a bit weird if the user deletes the manually created one
+ bouncerName := fmt.Sprintf("%s@%s", bouncers[0].Name, clientIP)
+
+ logger.Infof("Creating bouncer %s", bouncerName)
+
+ bouncer, err = a.DbClient.CreateBouncer(ctx, bouncerName, clientIP, hashStr, types.ApiKeyAuthType, true)
+ if err != nil {
+ logger.Errorf("while creating bouncer db entry: %s", err)
return nil
}
@@ -156,27 +206,20 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
return
}
- logger = logger.WithField("name", bouncer.Name)
-
- if bouncer.IPAddress == "" {
- if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil {
- logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
- c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
- c.Abort()
-
- return
- }
+ // Appsec request, return immediately if we found something
+ if c.Request.Method == http.MethodHead {
+ c.Set(BouncerContextKey, bouncer)
+ return
}
- // Don't update IP on HEAD request, as it's used by the appsec to check the validity of the API key provided
- if bouncer.IPAddress != clientIP && bouncer.IPAddress != "" && c.Request.Method != http.MethodHead {
- log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, clientIP, bouncer.IPAddress)
+ logger = logger.WithField("name", bouncer.Name)
+ // 1st time we see this bouncer, we update its IP
+ if bouncer.IPAddress == "" {
if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil {
logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
-
return
}
}
diff --git a/pkg/apiserver/papi.go b/pkg/apiserver/papi.go
index 7dd6b346aa9..83ba13843b9 100644
--- a/pkg/apiserver/papi.go
+++ b/pkg/apiserver/papi.go
@@ -205,8 +205,8 @@ func reverse(s []longpollclient.Event) []longpollclient.Event {
return a
}
-func (p *Papi) PullOnce(since time.Time, sync bool) error {
- events, err := p.Client.PullOnce(since)
+func (p *Papi) PullOnce(ctx context.Context, since time.Time, sync bool) error {
+ events, err := p.Client.PullOnce(ctx, since)
if err != nil {
return err
}
@@ -261,7 +261,7 @@ func (p *Papi) Pull(ctx context.Context) error {
p.Logger.Infof("Starting PAPI pull (since:%s)", lastTimestamp)
- for event := range p.Client.Start(lastTimestamp) {
+ for event := range p.Client.Start(ctx, lastTimestamp) {
logger := p.Logger.WithField("request-id", event.RequestId)
// update last timestamp in database
newTime := time.Now().UTC()
diff --git a/pkg/apiserver/papi_cmd.go b/pkg/apiserver/papi_cmd.go
index 78f5dc9b0fe..6cef9cb387a 100644
--- a/pkg/apiserver/papi_cmd.go
+++ b/pkg/apiserver/papi_cmd.go
@@ -6,11 +6,13 @@ import (
"fmt"
"time"
+ "github.com/go-openapi/strfmt"
log "github.com/sirupsen/logrus"
"github.com/crowdsecurity/go-cs-lib/ptr"
"github.com/crowdsecurity/crowdsec/pkg/apiclient"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent"
"github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/crowdsecurity/crowdsec/pkg/modelscapi"
"github.com/crowdsecurity/crowdsec/pkg/types"
@@ -21,25 +23,18 @@ type deleteDecisions struct {
Decisions []string `json:"decisions"`
}
-type blocklistLink struct {
- // blocklist name
- Name string `json:"name"`
- // blocklist url
- Url string `json:"url"`
- // blocklist remediation
- Remediation string `json:"remediation"`
- // blocklist scope
- Scope string `json:"scope,omitempty"`
- // blocklist duration
- Duration string `json:"duration,omitempty"`
+type forcePull struct {
+ Blocklist *modelscapi.BlocklistLink `json:"blocklist,omitempty"`
+ Allowlist *modelscapi.AllowlistLink `json:"allowlist,omitempty"`
}
-type forcePull struct {
- Blocklist *blocklistLink `json:"blocklist,omitempty"`
+type blocklistUnsubscribe struct {
+ Name string `json:"name"`
}
-type listUnsubscribe struct {
+type allowlistUnsubscribe struct {
Name string `json:"name"`
+ Id string `json:"id"`
}
func DecisionCmd(message *Message, p *Papi, sync bool) error {
@@ -174,10 +169,10 @@ func AlertCmd(message *Message, p *Papi, sync bool) error {
func ManagementCmd(message *Message, p *Papi, sync bool) error {
ctx := context.TODO()
- if sync {
+ /*if sync {
p.Logger.Infof("Ignoring management command from PAPI in sync mode")
return nil
- }
+ }*/
switch message.Header.OperationCmd {
case "blocklist_unsubscribe":
@@ -186,7 +181,7 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error {
return err
}
- unsubscribeMsg := listUnsubscribe{}
+ unsubscribeMsg := blocklistUnsubscribe{}
if err := json.Unmarshal(data, &unsubscribeMsg); err != nil {
return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err)
}
@@ -224,27 +219,79 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error {
ctx := context.TODO()
- if forcePullMsg.Blocklist == nil {
+ if forcePullMsg.Blocklist == nil && forcePullMsg.Allowlist == nil {
p.Logger.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists")
err = p.apic.PullTop(ctx, true)
if err != nil {
return fmt.Errorf("failed to force pull operation: %w", err)
}
- } else {
- p.Logger.Infof("Received force_pull command from PAPI, pulling blocklist %s", forcePullMsg.Blocklist.Name)
+ } else if forcePullMsg.Blocklist != nil {
+ err = forcePullMsg.Blocklist.Validate(strfmt.Default)
+
+ if err != nil {
+ return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err)
+ }
+
+ p.Logger.Infof("Received blocklist force_pull command from PAPI, pulling blocklist %s", *forcePullMsg.Blocklist.Name)
err = p.apic.PullBlocklist(ctx, &modelscapi.BlocklistLink{
- Name: &forcePullMsg.Blocklist.Name,
- URL: &forcePullMsg.Blocklist.Url,
- Remediation: &forcePullMsg.Blocklist.Remediation,
- Scope: &forcePullMsg.Blocklist.Scope,
- Duration: &forcePullMsg.Blocklist.Duration,
+ Name: forcePullMsg.Blocklist.Name,
+ URL: forcePullMsg.Blocklist.URL,
+ Remediation: forcePullMsg.Blocklist.Remediation,
+ Scope: forcePullMsg.Blocklist.Scope,
+ Duration: forcePullMsg.Blocklist.Duration,
+ }, true)
+ if err != nil {
+ return fmt.Errorf("failed to force pull operation: %w", err)
+ }
+ } else if forcePullMsg.Allowlist != nil {
+ err = forcePullMsg.Allowlist.Validate(strfmt.Default)
+
+ if err != nil {
+ return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err)
+ }
+
+ p.Logger.Infof("Received allowlist force_pull command from PAPI, pulling allowlist %s", *forcePullMsg.Allowlist.Name)
+
+ err = p.apic.PullAllowlist(ctx, &modelscapi.AllowlistLink{
+ Name: forcePullMsg.Allowlist.Name,
+ URL: forcePullMsg.Allowlist.URL,
+ ID: forcePullMsg.Allowlist.ID,
+ CreatedAt: forcePullMsg.Allowlist.CreatedAt,
+ UpdatedAt: forcePullMsg.Allowlist.UpdatedAt,
}, true)
if err != nil {
return fmt.Errorf("failed to force pull operation: %w", err)
}
}
+ case "allowlist_unsubscribe":
+ data, err := json.Marshal(message.Data)
+ if err != nil {
+ return err
+ }
+
+ unsubscribeMsg := allowlistUnsubscribe{}
+
+ if err := json.Unmarshal(data, &unsubscribeMsg); err != nil {
+ return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err)
+ }
+
+ if unsubscribeMsg.Name == "" {
+ return fmt.Errorf("message for '%s' contains bad data format: missing allowlist name", message.Header.OperationType)
+ }
+
+ p.Logger.Infof("Received allowlist_unsubscribe command from PAPI, unsubscribing from allowlist %s", unsubscribeMsg.Name)
+
+ err = p.DBClient.DeleteAllowList(ctx, unsubscribeMsg.Name, true)
+
+ if err != nil {
+ if !ent.IsNotFound(err) {
+ return err
+ }
+ p.Logger.Warningf("Allowlist %s not found", unsubscribeMsg.Name)
+ }
+ return nil
default:
return fmt.Errorf("unknown command '%s' for operation type '%s'", message.Header.OperationCmd, message.Header.OperationType)
}
diff --git a/pkg/appsec/appsec.go b/pkg/appsec/appsec.go
index 30784b23db0..5f01f76d993 100644
--- a/pkg/appsec/appsec.go
+++ b/pkg/appsec/appsec.go
@@ -1,7 +1,6 @@
package appsec
import (
- "errors"
"fmt"
"net/http"
"os"
@@ -150,6 +149,17 @@ func (w *AppsecRuntimeConfig) ClearResponse() {
w.Response.SendAlert = true
}
+func (wc *AppsecConfig) SetUpLogger() {
+ if wc.LogLevel == nil {
+ lvl := wc.Logger.Logger.GetLevel()
+ wc.LogLevel = &lvl
+ }
+
+ /* wc.Name is actually the datasource name.*/
+ wc.Logger = wc.Logger.Dup().WithField("name", wc.Name)
+ wc.Logger.Logger.SetLevel(*wc.LogLevel)
+}
+
func (wc *AppsecConfig) LoadByPath(file string) error {
wc.Logger.Debugf("loading config %s", file)
@@ -157,20 +167,65 @@ func (wc *AppsecConfig) LoadByPath(file string) error {
if err != nil {
return fmt.Errorf("unable to read file %s : %s", file, err)
}
- err = yaml.UnmarshalStrict(yamlFile, wc)
+
+ //as LoadByPath can be called several time, we append rules/hooks, but override other options
+ var tmp AppsecConfig
+
+ err = yaml.UnmarshalStrict(yamlFile, &tmp)
if err != nil {
return fmt.Errorf("unable to parse yaml file %s : %s", file, err)
}
- if wc.Name == "" {
- return errors.New("name cannot be empty")
+ if wc.Name == "" && tmp.Name != "" {
+ wc.Name = tmp.Name
}
- if wc.LogLevel == nil {
- lvl := wc.Logger.Logger.GetLevel()
- wc.LogLevel = &lvl
+
+ //We can append rules/hooks
+ if tmp.OutOfBandRules != nil {
+ wc.OutOfBandRules = append(wc.OutOfBandRules, tmp.OutOfBandRules...)
}
- wc.Logger = wc.Logger.Dup().WithField("name", wc.Name)
- wc.Logger.Logger.SetLevel(*wc.LogLevel)
+ if tmp.InBandRules != nil {
+ wc.InBandRules = append(wc.InBandRules, tmp.InBandRules...)
+ }
+ if tmp.OnLoad != nil {
+ wc.OnLoad = append(wc.OnLoad, tmp.OnLoad...)
+ }
+ if tmp.PreEval != nil {
+ wc.PreEval = append(wc.PreEval, tmp.PreEval...)
+ }
+ if tmp.PostEval != nil {
+ wc.PostEval = append(wc.PostEval, tmp.PostEval...)
+ }
+ if tmp.OnMatch != nil {
+ wc.OnMatch = append(wc.OnMatch, tmp.OnMatch...)
+ }
+ if tmp.VariablesTracking != nil {
+ wc.VariablesTracking = append(wc.VariablesTracking, tmp.VariablesTracking...)
+ }
+
+ //override other options
+ wc.LogLevel = tmp.LogLevel
+
+ wc.DefaultRemediation = tmp.DefaultRemediation
+ wc.DefaultPassAction = tmp.DefaultPassAction
+ wc.BouncerBlockedHTTPCode = tmp.BouncerBlockedHTTPCode
+ wc.BouncerPassedHTTPCode = tmp.BouncerPassedHTTPCode
+ wc.UserBlockedHTTPCode = tmp.UserBlockedHTTPCode
+ wc.UserPassedHTTPCode = tmp.UserPassedHTTPCode
+
+ if tmp.InbandOptions.DisableBodyInspection {
+ wc.InbandOptions.DisableBodyInspection = true
+ }
+ if tmp.InbandOptions.RequestBodyInMemoryLimit != nil {
+ wc.InbandOptions.RequestBodyInMemoryLimit = tmp.InbandOptions.RequestBodyInMemoryLimit
+ }
+ if tmp.OutOfBandOptions.DisableBodyInspection {
+ wc.OutOfBandOptions.DisableBodyInspection = true
+ }
+ if tmp.OutOfBandOptions.RequestBodyInMemoryLimit != nil {
+ wc.OutOfBandOptions.RequestBodyInMemoryLimit = tmp.OutOfBandOptions.RequestBodyInMemoryLimit
+ }
+
return nil
}
diff --git a/pkg/appsec/appsec_rule/appsec_rule.go b/pkg/appsec/appsec_rule/appsec_rule.go
index 136d8b11cb7..9d47c0eed5c 100644
--- a/pkg/appsec/appsec_rule/appsec_rule.go
+++ b/pkg/appsec/appsec_rule/appsec_rule.go
@@ -47,7 +47,6 @@ type CustomRule struct {
}
func (v *CustomRule) Convert(ruleType string, appsecRuleName string) (string, []uint32, error) {
-
if v.Zones == nil && v.And == nil && v.Or == nil {
return "", nil, errors.New("no zones defined")
}
diff --git a/pkg/appsec/appsec_rule/modsec_rule_test.go b/pkg/appsec/appsec_rule/modsec_rule_test.go
index ffb8a15ff1f..74e9b85426e 100644
--- a/pkg/appsec/appsec_rule/modsec_rule_test.go
+++ b/pkg/appsec/appsec_rule/modsec_rule_test.go
@@ -88,7 +88,6 @@ func TestVPatchRuleString(t *testing.T) {
rule: CustomRule{
And: []CustomRule{
{
-
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: Match{Type: "regex", Value: "[^a-zA-Z]"},
@@ -161,7 +160,6 @@ SecRule ARGS_GET:foo "@rx [^a-zA-Z]" "id:1519945803,phase:2,deny,log,msg:'OR AND
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual, _, err := tt.rule.Convert(ModsecurityRuleType, tt.name)
-
if err != nil {
t.Errorf("Error converting rule: %s", err)
}
diff --git a/pkg/appsec/coraza_logger.go b/pkg/appsec/coraza_logger.go
index d2c1612cbd7..93e31be5876 100644
--- a/pkg/appsec/coraza_logger.go
+++ b/pkg/appsec/coraza_logger.go
@@ -124,7 +124,7 @@ func (e *crzLogEvent) Stringer(key string, val fmt.Stringer) dbg.Event {
return e
}
-func (e crzLogEvent) IsEnabled() bool {
+func (e *crzLogEvent) IsEnabled() bool {
return !e.muted
}
diff --git a/pkg/appsec/request_test.go b/pkg/appsec/request_test.go
index f8333e4e5f9..8b457e24dab 100644
--- a/pkg/appsec/request_test.go
+++ b/pkg/appsec/request_test.go
@@ -3,7 +3,6 @@ package appsec
import "testing"
func TestBodyDumper(t *testing.T) {
-
tests := []struct {
name string
req *ParsedRequest
@@ -159,7 +158,6 @@ func TestBodyDumper(t *testing.T) {
}
for idx, test := range tests {
-
t.Run(test.name, func(t *testing.T) {
orig_dr := test.req.DumpRequest()
result := test.filter(orig_dr).GetFilteredRequest()
@@ -177,5 +175,4 @@ func TestBodyDumper(t *testing.T) {
}
})
}
-
}
diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go
index a4e0bd0127a..4da9fd5bf7b 100644
--- a/pkg/cache/cache_test.go
+++ b/pkg/cache/cache_test.go
@@ -5,26 +5,27 @@ import (
"time"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
func TestCreateSetGet(t *testing.T) {
err := CacheInit(CacheCfg{Name: "test", Size: 100, TTL: 1 * time.Second})
- assert.Empty(t, err)
+ require.NoError(t, err)
//set & get
err = SetKey("test", "testkey0", "testvalue1", nil)
- assert.Empty(t, err)
+ require.NoError(t, err)
ret, err := GetKey("test", "testkey0")
assert.Equal(t, "testvalue1", ret)
- assert.Empty(t, err)
+ require.NoError(t, err)
//re-set
err = SetKey("test", "testkey0", "testvalue2", nil)
- assert.Empty(t, err)
+ require.NoError(t, err)
assert.Equal(t, "testvalue1", ret)
- assert.Empty(t, err)
+ require.NoError(t, err)
//expire
time.Sleep(1500 * time.Millisecond)
ret, err = GetKey("test", "testkey0")
assert.Equal(t, "", ret)
- assert.Empty(t, err)
+ require.NoError(t, err)
}
diff --git a/pkg/csconfig/api.go b/pkg/csconfig/api.go
index 3014b729a9e..5f2f8f9248b 100644
--- a/pkg/csconfig/api.go
+++ b/pkg/csconfig/api.go
@@ -38,10 +38,17 @@ type ApiCredentialsCfg struct {
CertPath string `yaml:"cert_path,omitempty"`
}
-/*global api config (for lapi->oapi)*/
+type CapiPullConfig struct {
+ Community *bool `yaml:"community,omitempty"`
+ Blocklists *bool `yaml:"blocklists,omitempty"`
+}
+
+/*global api config (for lapi->capi)*/
type OnlineApiClientCfg struct {
CredentialsFilePath string `yaml:"credentials_path,omitempty"` // credz will be edited by software, store in diff file
Credentials *ApiCredentialsCfg `yaml:"-"`
+ PullConfig CapiPullConfig `yaml:"pull,omitempty"`
+ Sharing *bool `yaml:"sharing,omitempty"`
}
/*local api config (for crowdsec/cscli->lapi)*/
@@ -344,6 +351,21 @@ func (c *Config) LoadAPIServer(inCli bool) error {
log.Printf("push and pull to Central API disabled")
}
+ //Set default values for CAPI push/pull
+ if c.API.Server.OnlineClient != nil {
+ if c.API.Server.OnlineClient.PullConfig.Community == nil {
+ c.API.Server.OnlineClient.PullConfig.Community = ptr.Of(true)
+ }
+
+ if c.API.Server.OnlineClient.PullConfig.Blocklists == nil {
+ c.API.Server.OnlineClient.PullConfig.Blocklists = ptr.Of(true)
+ }
+
+ if c.API.Server.OnlineClient.Sharing == nil {
+ c.API.Server.OnlineClient.Sharing = ptr.Of(true)
+ }
+ }
+
if err := c.LoadDBConfig(inCli); err != nil {
return err
}
diff --git a/pkg/csconfig/api_test.go b/pkg/csconfig/api_test.go
index dff3c3afc8c..17802ba31dd 100644
--- a/pkg/csconfig/api_test.go
+++ b/pkg/csconfig/api_test.go
@@ -212,6 +212,11 @@ func TestLoadAPIServer(t *testing.T) {
Login: "test",
Password: "testpassword",
},
+ Sharing: ptr.Of(true),
+ PullConfig: CapiPullConfig{
+ Community: ptr.Of(true),
+ Blocklists: ptr.Of(true),
+ },
},
Profiles: tmpLAPI.Profiles,
ProfilesPath: "./testdata/profiles.yaml",
diff --git a/pkg/csplugin/broker.go b/pkg/csplugin/broker.go
index e996fa9b68c..f53c831e186 100644
--- a/pkg/csplugin/broker.go
+++ b/pkg/csplugin/broker.go
@@ -91,7 +91,6 @@ func (pb *PluginBroker) Init(ctx context.Context, pluginCfg *csconfig.PluginCfg,
pb.watcher = PluginWatcher{}
pb.watcher.Init(pb.pluginConfigByName, pb.alertsByPluginName)
return nil
-
}
func (pb *PluginBroker) Kill() {
@@ -166,6 +165,7 @@ func (pb *PluginBroker) addProfileAlert(profileAlert ProfileAlert) {
pb.watcher.Inserts <- pluginName
}
}
+
func (pb *PluginBroker) profilesContainPlugin(pluginName string) bool {
for _, profileCfg := range pb.profileConfigs {
for _, name := range profileCfg.Notifications {
@@ -176,6 +176,7 @@ func (pb *PluginBroker) profilesContainPlugin(pluginName string) bool {
}
return false
}
+
func (pb *PluginBroker) loadConfig(path string) error {
files, err := listFilesAtPath(path)
if err != nil {
@@ -277,7 +278,6 @@ func (pb *PluginBroker) loadPlugins(ctx context.Context, path string) error {
}
func (pb *PluginBroker) loadNotificationPlugin(name string, binaryPath string) (protobufs.NotifierServer, error) {
-
handshake, err := getHandshake()
if err != nil {
return nil, err
diff --git a/pkg/cticlient/types.go b/pkg/cticlient/types.go
index 2ad0a6eb34e..954d24641b4 100644
--- a/pkg/cticlient/types.go
+++ b/pkg/cticlient/types.go
@@ -210,7 +210,6 @@ func (c *SmokeItem) GetFalsePositives() []string {
}
func (c *SmokeItem) IsFalsePositive() bool {
-
if c.Classifications.FalsePositives != nil {
if len(c.Classifications.FalsePositives) > 0 {
return true
@@ -284,7 +283,6 @@ func (c *FireItem) GetFalsePositives() []string {
}
func (c *FireItem) IsFalsePositive() bool {
-
if c.Classifications.FalsePositives != nil {
if len(c.Classifications.FalsePositives) > 0 {
return true
diff --git a/pkg/cwversion/component/component.go b/pkg/cwversion/component/component.go
index 4036b63cf00..7ed596525e0 100644
--- a/pkg/cwversion/component/component.go
+++ b/pkg/cwversion/component/component.go
@@ -7,20 +7,21 @@ package component
// Built is a map of all the known components, and whether they are built-in or not.
// This is populated as soon as possible by the respective init() functions
-var Built = map[string]bool {
- "datasource_appsec": false,
- "datasource_cloudwatch": false,
- "datasource_docker": false,
- "datasource_file": false,
- "datasource_journalctl": false,
- "datasource_k8s-audit": false,
- "datasource_kafka": false,
- "datasource_kinesis": false,
- "datasource_loki": false,
- "datasource_s3": false,
- "datasource_syslog": false,
- "datasource_wineventlog":false,
- "cscli_setup": false,
+var Built = map[string]bool{
+ "datasource_appsec": false,
+ "datasource_cloudwatch": false,
+ "datasource_docker": false,
+ "datasource_file": false,
+ "datasource_journalctl": false,
+ "datasource_k8s-audit": false,
+ "datasource_kafka": false,
+ "datasource_kinesis": false,
+ "datasource_loki": false,
+ "datasource_s3": false,
+ "datasource_syslog": false,
+ "datasource_wineventlog": false,
+ "datasource_http": false,
+ "cscli_setup": false,
}
func Register(name string) {
diff --git a/pkg/database/allowlists.go b/pkg/database/allowlists.go
new file mode 100644
index 00000000000..3dc6ea7306f
--- /dev/null
+++ b/pkg/database/allowlists.go
@@ -0,0 +1,255 @@
+package database
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem"
+ "github.com/crowdsecurity/crowdsec/pkg/models"
+ "github.com/crowdsecurity/crowdsec/pkg/types"
+)
+
+func (c *Client) CreateAllowList(ctx context.Context, name string, description string, fromConsole bool) (*ent.AllowList, error) {
+ allowlist, err := c.Ent.AllowList.Create().
+ SetName(name).
+ SetFromConsole(fromConsole).
+ SetDescription(description).
+ Save(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create allowlist: %w", err)
+ }
+
+ return allowlist, nil
+}
+
+func (c *Client) DeleteAllowList(ctx context.Context, name string, fromConsole bool) error {
+
+ nbDeleted, err := c.Ent.AllowListItem.Delete().Where(allowlistitem.HasAllowlistWith(allowlist.NameEQ(name), allowlist.FromConsoleEQ(fromConsole))).Exec(ctx)
+ if err != nil {
+ return fmt.Errorf("unable to delete allowlist items: %w", err)
+ }
+
+ c.Log.Debugf("deleted %d items from allowlist %s", nbDeleted, name)
+
+ nbDeleted, err = c.Ent.AllowList.
+ Delete().
+ Where(allowlist.NameEQ(name), allowlist.FromConsoleEQ(fromConsole)).
+ Exec(ctx)
+ if err != nil {
+ return fmt.Errorf("unable to delete allowlist: %w", err)
+ }
+
+ if nbDeleted == 0 {
+ return fmt.Errorf("allowlist %s not found", name)
+ }
+
+ return nil
+}
+
+func (c *Client) ListAllowLists(ctx context.Context, withContent bool) ([]*ent.AllowList, error) {
+ q := c.Ent.AllowList.Query()
+ if withContent {
+ q = q.WithAllowlistItems()
+ }
+ result, err := q.All(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("unable to list allowlists: %w", err)
+ }
+
+ return result, nil
+}
+
+func (c *Client) GetAllowList(ctx context.Context, name string, withContent bool) (*ent.AllowList, error) {
+ q := c.Ent.AllowList.Query().Where(allowlist.NameEQ(name))
+ if withContent {
+ q = q.WithAllowlistItems()
+ }
+ result, err := q.First(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return result, nil
+}
+
+func (c *Client) AddToAllowlist(ctx context.Context, list *ent.AllowList, items []*models.AllowlistItem) error {
+ successCount := 0
+
+ c.Log.Debugf("adding %d values to allowlist %s", len(items), list.Name)
+ c.Log.Tracef("values: %+v", items)
+
+ for _, item := range items {
+ //FIXME: wrap this in a transaction
+ c.Log.Debugf("adding value %s to allowlist %s", item.Value, list.Name)
+ sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(item.Value)
+ if err != nil {
+ c.Log.Errorf("unable to parse value %s: %s", item.Value, err)
+ continue
+ }
+ query := c.Ent.AllowListItem.Create().
+ SetValue(item.Value).
+ SetIPSize(int64(sz)).
+ SetStartIP(start_ip).
+ SetStartSuffix(start_sfx).
+ SetEndIP(end_ip).
+ SetEndSuffix(end_sfx).
+ SetComment(item.Description)
+
+ if !time.Time(item.Expiration).IsZero() {
+ query = query.SetExpiresAt(time.Time(item.Expiration))
+ }
+
+ content, err := query.Save(ctx)
+ if err != nil {
+ c.Log.Errorf("unable to add value to allowlist: %s", err)
+ }
+
+ c.Log.Debugf("Updating allowlist %s with value %s (exp: %s)", list.Name, item.Value, item.Expiration)
+
+ //We don't have a clean way to handle name conflict from the console, so use id
+ err = c.Ent.AllowList.Update().AddAllowlistItems(content).Where(allowlist.IDEQ(list.ID)).Exec(ctx)
+ if err != nil {
+ c.Log.Errorf("unable to add value to allowlist: %s", err)
+ continue
+ }
+ successCount++
+ }
+
+ c.Log.Infof("added %d values to allowlist %s", successCount, list.Name)
+
+ return nil
+}
+
+func (c *Client) RemoveFromAllowlist(ctx context.Context, list *ent.AllowList, values ...string) (int, error) {
+ c.Log.Debugf("removing %d values from allowlist %s", len(values), list.Name)
+ c.Log.Tracef("values: %v", values)
+
+ nbDeleted, err := c.Ent.AllowListItem.Delete().Where(
+ allowlistitem.HasAllowlistWith(allowlist.IDEQ(list.ID)),
+ allowlistitem.ValueIn(values...),
+ ).Exec(ctx)
+
+ if err != nil {
+ return 0, fmt.Errorf("unable to remove values from allowlist: %w", err)
+ }
+
+ return nbDeleted, nil
+}
+
+func (c *Client) ReplaceAllowlist(ctx context.Context, list *ent.AllowList, items []*models.AllowlistItem, fromConsole bool) error {
+ c.Log.Debugf("replacing values in allowlist %s", list.Name)
+ c.Log.Tracef("items: %+v", items)
+
+ _, err := c.Ent.AllowListItem.Delete().Where(allowlistitem.HasAllowlistWith(allowlist.IDEQ(list.ID))).Exec(ctx)
+
+ if err != nil {
+ return fmt.Errorf("unable to delete allowlist contents: %w", err)
+ }
+
+ err = c.AddToAllowlist(ctx, list, items)
+
+ if err != nil {
+ return fmt.Errorf("unable to add values to allowlist: %w", err)
+ }
+
+ if !list.FromConsole && fromConsole {
+ c.Log.Infof("marking allowlist %s as managed from console", list.Name)
+ err = c.Ent.AllowList.Update().SetFromConsole(fromConsole).Where(allowlist.IDEQ(list.ID)).Exec(ctx)
+ if err != nil {
+ return fmt.Errorf("unable to update allowlist: %w", err)
+ }
+ }
+
+ return nil
+}
+
+func (c *Client) IsAllowlisted(ctx context.Context, value string) (bool, error) {
+ /*
+ Few cases:
+ - value is an IP/range directly is in allowlist
+ - value is an IP/range in a range in allowlist
+ - value is a range and an IP/range belonging to it is in allowlist
+ */
+
+ sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(value)
+ if err != nil {
+ return false, fmt.Errorf("unable to parse value %s: %w", value, err)
+ }
+
+ c.Log.Debugf("checking if %s is allowlisted", value)
+
+ now := time.Now().UTC()
+ query := c.Ent.AllowListItem.Query().Where(
+ allowlistitem.Or(
+ allowlistitem.ExpiresAtGTE(now),
+ allowlistitem.ExpiresAtIsNil(),
+ ),
+ allowlistitem.IPSizeEQ(int64(sz)),
+ )
+
+ if sz == 4 {
+ query = query.Where(
+ allowlistitem.Or(
+ // Value contained inside a range or exact match
+ allowlistitem.And(
+ allowlistitem.StartIPLTE(start_ip),
+ allowlistitem.EndIPGTE(end_ip),
+ ),
+ // Value contains another allowlisted value
+ allowlistitem.And(
+ allowlistitem.StartIPGTE(start_ip),
+ allowlistitem.EndIPLTE(end_ip),
+ ),
+ ))
+ }
+
+ if sz == 16 {
+ query = query.Where(
+ // Value contained inside a range or exact match
+ allowlistitem.Or(
+ allowlistitem.And(
+ allowlistitem.Or(
+ allowlistitem.StartIPLT(start_ip),
+ allowlistitem.And(
+ allowlistitem.StartIPEQ(start_ip),
+ allowlistitem.StartSuffixLTE(start_sfx),
+ )),
+ allowlistitem.Or(
+ allowlistitem.EndIPGT(end_ip),
+ allowlistitem.And(
+ allowlistitem.EndIPEQ(end_ip),
+ allowlistitem.EndSuffixGTE(end_sfx),
+ ),
+ ),
+ ),
+ // Value contains another allowlisted value
+ allowlistitem.And(
+ allowlistitem.Or(
+ allowlistitem.StartIPGT(start_ip),
+ allowlistitem.And(
+ allowlistitem.StartIPEQ(start_ip),
+ allowlistitem.StartSuffixGTE(start_sfx),
+ )),
+ allowlistitem.Or(
+ allowlistitem.EndIPLT(end_ip),
+ allowlistitem.And(
+ allowlistitem.EndIPEQ(end_ip),
+ allowlistitem.EndSuffixLTE(end_sfx),
+ ),
+ ),
+ ),
+ ),
+ )
+ }
+
+ allowed, err := query.Exist(ctx)
+
+ if err != nil {
+ return false, fmt.Errorf("unable to check if value is allowlisted: %w", err)
+ }
+
+ return allowed, nil
+}
diff --git a/pkg/database/allowlists_test.go b/pkg/database/allowlists_test.go
new file mode 100644
index 00000000000..10f1b3aa9ae
--- /dev/null
+++ b/pkg/database/allowlists_test.go
@@ -0,0 +1,99 @@
+package database
+
+import (
+ "context"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/crowdsecurity/crowdsec/pkg/csconfig"
+ "github.com/crowdsecurity/crowdsec/pkg/models"
+ "github.com/go-openapi/strfmt"
+ "github.com/stretchr/testify/require"
+)
+
+func getDBClient(t *testing.T, ctx context.Context) *Client {
+ t.Helper()
+
+ dbPath, err := os.CreateTemp("", "*sqlite")
+ require.NoError(t, err)
+ dbClient, err := NewClient(ctx, &csconfig.DatabaseCfg{
+ Type: "sqlite",
+ DbName: "crowdsec",
+ DbPath: dbPath.Name(),
+ })
+ require.NoError(t, err)
+
+ return dbClient
+}
+
+func TestCheckAllowlist(t *testing.T) {
+
+ ctx := context.Background()
+ dbClient := getDBClient(t, ctx)
+
+ allowlist, err := dbClient.CreateAllowList(ctx, "test", "test", false)
+
+ require.NoError(t, err)
+
+ err = dbClient.AddToAllowlist(ctx, allowlist, []*models.AllowlistItem{
+ {
+ CreatedAt: strfmt.DateTime(time.Now()),
+ Value: "1.2.3.4",
+ },
+ {
+ CreatedAt: strfmt.DateTime(time.Now()),
+ Value: "8.0.0.0/8",
+ },
+ {
+ CreatedAt: strfmt.DateTime(time.Now()),
+ Value: "2001:db8::/32",
+ },
+ {
+ CreatedAt: strfmt.DateTime(time.Now()),
+ Value: "2.3.4.5",
+ Expiration: strfmt.DateTime(time.Now().Add(-time.Hour)), // expired item
+ },
+ {
+ CreatedAt: strfmt.DateTime(time.Now()),
+ Value: "8a95:c186:9f96:4c75:0dad:49c6:ff62:94b8",
+ },
+ })
+
+ require.NoError(t, err)
+
+ // Exatch match
+ allowlisted, err := dbClient.IsAllowlisted(ctx, "1.2.3.4")
+ require.NoError(t, err)
+ require.True(t, allowlisted)
+
+ // CIDR match
+ allowlisted, err = dbClient.IsAllowlisted(ctx, "8.8.8.8")
+ require.NoError(t, err)
+ require.True(t, allowlisted)
+
+ // IPv6 match
+ allowlisted, err = dbClient.IsAllowlisted(ctx, "2001:db8::1")
+ require.NoError(t, err)
+ require.True(t, allowlisted)
+
+ // Expired item
+ allowlisted, err = dbClient.IsAllowlisted(ctx, "2.3.4.5")
+ require.NoError(t, err)
+ require.False(t, allowlisted)
+
+ // Decision on a range that contains an allowlisted value
+ allowlisted, err = dbClient.IsAllowlisted(ctx, "1.2.3.0/24")
+ require.NoError(t, err)
+ require.True(t, allowlisted)
+
+ // No match
+ allowlisted, err = dbClient.IsAllowlisted(ctx, "42.42.42.42")
+ require.NoError(t, err)
+ require.False(t, allowlisted)
+
+ // IPv6 range that contains an allowlisted value
+ allowlisted, err = dbClient.IsAllowlisted(ctx, "8a95:c186:9f96:4c75::/64")
+ require.NoError(t, err)
+ require.True(t, allowlisted)
+}
diff --git a/pkg/database/bouncers.go b/pkg/database/bouncers.go
index 04ef830ae72..f9e62bc6522 100644
--- a/pkg/database/bouncers.go
+++ b/pkg/database/bouncers.go
@@ -41,8 +41,19 @@ func (c *Client) BouncerUpdateBaseMetrics(ctx context.Context, bouncerName strin
return nil
}
-func (c *Client) SelectBouncer(ctx context.Context, apiKeyHash string) (*ent.Bouncer, error) {
- result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash)).First(ctx)
+func (c *Client) SelectBouncers(ctx context.Context, apiKeyHash string, authType string) ([]*ent.Bouncer, error) {
+ //Order by ID so manually created bouncer will be first in the list to use as the base name
+ //when automatically creating a new entry if API keys are shared
+ result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash), bouncer.AuthTypeEQ(authType)).Order(ent.Asc(bouncer.FieldID)).All(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return result, nil
+}
+
+func (c *Client) SelectBouncerWithIP(ctx context.Context, apiKeyHash string, clientIP string) (*ent.Bouncer, error) {
+ result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash), bouncer.IPAddressEQ(clientIP)).First(ctx)
if err != nil {
return nil, err
}
@@ -68,13 +79,15 @@ func (c *Client) ListBouncers(ctx context.Context) ([]*ent.Bouncer, error) {
return result, nil
}
-func (c *Client) CreateBouncer(ctx context.Context, name string, ipAddr string, apiKey string, authType string) (*ent.Bouncer, error) {
+func (c *Client) CreateBouncer(ctx context.Context, name string, ipAddr string, apiKey string, authType string, autoCreated bool) (*ent.Bouncer, error) {
bouncer, err := c.Ent.Bouncer.
Create().
SetName(name).
SetAPIKey(apiKey).
SetRevoked(false).
SetAuthType(authType).
+ SetIPAddress(ipAddr).
+ SetAutoCreated(autoCreated).
Save(ctx)
if err != nil {
if ent.IsConstraintError(err) {
diff --git a/pkg/database/ent/allowlist.go b/pkg/database/ent/allowlist.go
new file mode 100644
index 00000000000..99b36687a7b
--- /dev/null
+++ b/pkg/database/ent/allowlist.go
@@ -0,0 +1,189 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist"
+)
+
+// AllowList is the model entity for the AllowList schema.
+type AllowList struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // Name holds the value of the "name" field.
+ Name string `json:"name,omitempty"`
+ // FromConsole holds the value of the "from_console" field.
+ FromConsole bool `json:"from_console,omitempty"`
+ // Description holds the value of the "description" field.
+ Description string `json:"description,omitempty"`
+ // AllowlistID holds the value of the "allowlist_id" field.
+ AllowlistID string `json:"allowlist_id,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the AllowListQuery when eager-loading is set.
+ Edges AllowListEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// AllowListEdges holds the relations/edges for other nodes in the graph.
+type AllowListEdges struct {
+ // AllowlistItems holds the value of the allowlist_items edge.
+ AllowlistItems []*AllowListItem `json:"allowlist_items,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [1]bool
+}
+
+// AllowlistItemsOrErr returns the AllowlistItems value or an error if the edge
+// was not loaded in eager-loading.
+func (e AllowListEdges) AllowlistItemsOrErr() ([]*AllowListItem, error) {
+ if e.loadedTypes[0] {
+ return e.AllowlistItems, nil
+ }
+ return nil, &NotLoadedError{edge: "allowlist_items"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*AllowList) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case allowlist.FieldFromConsole:
+ values[i] = new(sql.NullBool)
+ case allowlist.FieldID:
+ values[i] = new(sql.NullInt64)
+ case allowlist.FieldName, allowlist.FieldDescription, allowlist.FieldAllowlistID:
+ values[i] = new(sql.NullString)
+ case allowlist.FieldCreatedAt, allowlist.FieldUpdatedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the AllowList fields.
+func (al *AllowList) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case allowlist.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ al.ID = int(value.Int64)
+ case allowlist.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ al.CreatedAt = value.Time
+ }
+ case allowlist.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ al.UpdatedAt = value.Time
+ }
+ case allowlist.FieldName:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field name", values[i])
+ } else if value.Valid {
+ al.Name = value.String
+ }
+ case allowlist.FieldFromConsole:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field from_console", values[i])
+ } else if value.Valid {
+ al.FromConsole = value.Bool
+ }
+ case allowlist.FieldDescription:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field description", values[i])
+ } else if value.Valid {
+ al.Description = value.String
+ }
+ case allowlist.FieldAllowlistID:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field allowlist_id", values[i])
+ } else if value.Valid {
+ al.AllowlistID = value.String
+ }
+ default:
+ al.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the AllowList.
+// This includes values selected through modifiers, order, etc.
+func (al *AllowList) Value(name string) (ent.Value, error) {
+ return al.selectValues.Get(name)
+}
+
+// QueryAllowlistItems queries the "allowlist_items" edge of the AllowList entity.
+func (al *AllowList) QueryAllowlistItems() *AllowListItemQuery {
+ return NewAllowListClient(al.config).QueryAllowlistItems(al)
+}
+
+// Update returns a builder for updating this AllowList.
+// Note that you need to call AllowList.Unwrap() before calling this method if this AllowList
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (al *AllowList) Update() *AllowListUpdateOne {
+ return NewAllowListClient(al.config).UpdateOne(al)
+}
+
+// Unwrap unwraps the AllowList entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (al *AllowList) Unwrap() *AllowList {
+ _tx, ok := al.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: AllowList is not a transactional entity")
+ }
+ al.config.driver = _tx.drv
+ return al
+}
+
+// String implements the fmt.Stringer.
+func (al *AllowList) String() string {
+ var builder strings.Builder
+ builder.WriteString("AllowList(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", al.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(al.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(al.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("name=")
+ builder.WriteString(al.Name)
+ builder.WriteString(", ")
+ builder.WriteString("from_console=")
+ builder.WriteString(fmt.Sprintf("%v", al.FromConsole))
+ builder.WriteString(", ")
+ builder.WriteString("description=")
+ builder.WriteString(al.Description)
+ builder.WriteString(", ")
+ builder.WriteString("allowlist_id=")
+ builder.WriteString(al.AllowlistID)
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// AllowLists is a parsable slice of AllowList.
+type AllowLists []*AllowList
diff --git a/pkg/database/ent/allowlist/allowlist.go b/pkg/database/ent/allowlist/allowlist.go
new file mode 100644
index 00000000000..36cac5c1b21
--- /dev/null
+++ b/pkg/database/ent/allowlist/allowlist.go
@@ -0,0 +1,133 @@
+// Code generated by ent, DO NOT EDIT.
+
+package allowlist
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the allowlist type in the database.
+ Label = "allow_list"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldName holds the string denoting the name field in the database.
+ FieldName = "name"
+ // FieldFromConsole holds the string denoting the from_console field in the database.
+ FieldFromConsole = "from_console"
+ // FieldDescription holds the string denoting the description field in the database.
+ FieldDescription = "description"
+ // FieldAllowlistID holds the string denoting the allowlist_id field in the database.
+ FieldAllowlistID = "allowlist_id"
+ // EdgeAllowlistItems holds the string denoting the allowlist_items edge name in mutations.
+ EdgeAllowlistItems = "allowlist_items"
+ // Table holds the table name of the allowlist in the database.
+ Table = "allow_lists"
+ // AllowlistItemsTable is the table that holds the allowlist_items relation/edge. The primary key declared below.
+ AllowlistItemsTable = "allow_list_allowlist_items"
+ // AllowlistItemsInverseTable is the table name for the AllowListItem entity.
+ // It exists in this package in order to avoid circular dependency with the "allowlistitem" package.
+ AllowlistItemsInverseTable = "allow_list_items"
+)
+
+// Columns holds all SQL columns for allowlist fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldName,
+ FieldFromConsole,
+ FieldDescription,
+ FieldAllowlistID,
+}
+
+var (
+ // AllowlistItemsPrimaryKey and AllowlistItemsColumn2 are the table columns denoting the
+ // primary key for the allowlist_items relation (M2M).
+ AllowlistItemsPrimaryKey = []string{"allow_list_id", "allow_list_item_id"}
+)
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+)
+
+// OrderOption defines the ordering options for the AllowList queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByName orders the results by the name field.
+func ByName(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldName, opts...).ToFunc()
+}
+
+// ByFromConsole orders the results by the from_console field.
+func ByFromConsole(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldFromConsole, opts...).ToFunc()
+}
+
+// ByDescription orders the results by the description field.
+func ByDescription(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldDescription, opts...).ToFunc()
+}
+
+// ByAllowlistID orders the results by the allowlist_id field.
+func ByAllowlistID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldAllowlistID, opts...).ToFunc()
+}
+
+// ByAllowlistItemsCount orders the results by allowlist_items count.
+func ByAllowlistItemsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newAllowlistItemsStep(), opts...)
+ }
+}
+
+// ByAllowlistItems orders the results by allowlist_items terms.
+func ByAllowlistItems(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newAllowlistItemsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+func newAllowlistItemsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(AllowlistItemsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2M, false, AllowlistItemsTable, AllowlistItemsPrimaryKey...),
+ )
+}
diff --git a/pkg/database/ent/allowlist/where.go b/pkg/database/ent/allowlist/where.go
new file mode 100644
index 00000000000..d8b43be2cf9
--- /dev/null
+++ b/pkg/database/ent/allowlist/where.go
@@ -0,0 +1,429 @@
+// Code generated by ent, DO NOT EDIT.
+
+package allowlist
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int) predicate.AllowList {
+ return predicate.AllowList(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int) predicate.AllowList {
+ return predicate.AllowList(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int) predicate.AllowList {
+ return predicate.AllowList(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int) predicate.AllowList {
+ return predicate.AllowList(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int) predicate.AllowList {
+ return predicate.AllowList(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int) predicate.AllowList {
+ return predicate.AllowList(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int) predicate.AllowList {
+ return predicate.AllowList(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int) predicate.AllowList {
+ return predicate.AllowList(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int) predicate.AllowList {
+ return predicate.AllowList(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.AllowList {
+ return predicate.AllowList(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.AllowList {
+ return predicate.AllowList(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// Name applies equality check predicate on the "name" field. It's identical to NameEQ.
+func Name(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldEQ(FieldName, v))
+}
+
+// FromConsole applies equality check predicate on the "from_console" field. It's identical to FromConsoleEQ.
+func FromConsole(v bool) predicate.AllowList {
+ return predicate.AllowList(sql.FieldEQ(FieldFromConsole, v))
+}
+
+// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ.
+func Description(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldEQ(FieldDescription, v))
+}
+
+// AllowlistID applies equality check predicate on the "allowlist_id" field. It's identical to AllowlistIDEQ.
+func AllowlistID(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldEQ(FieldAllowlistID, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.AllowList {
+ return predicate.AllowList(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.AllowList {
+ return predicate.AllowList(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.AllowList {
+ return predicate.AllowList(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.AllowList {
+ return predicate.AllowList(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.AllowList {
+ return predicate.AllowList(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.AllowList {
+ return predicate.AllowList(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.AllowList {
+ return predicate.AllowList(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.AllowList {
+ return predicate.AllowList(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.AllowList {
+ return predicate.AllowList(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.AllowList {
+ return predicate.AllowList(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.AllowList {
+ return predicate.AllowList(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.AllowList {
+ return predicate.AllowList(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.AllowList {
+ return predicate.AllowList(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.AllowList {
+ return predicate.AllowList(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.AllowList {
+ return predicate.AllowList(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.AllowList {
+ return predicate.AllowList(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// NameEQ applies the EQ predicate on the "name" field.
+func NameEQ(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldEQ(FieldName, v))
+}
+
+// NameNEQ applies the NEQ predicate on the "name" field.
+func NameNEQ(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldNEQ(FieldName, v))
+}
+
+// NameIn applies the In predicate on the "name" field.
+func NameIn(vs ...string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldIn(FieldName, vs...))
+}
+
+// NameNotIn applies the NotIn predicate on the "name" field.
+func NameNotIn(vs ...string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldNotIn(FieldName, vs...))
+}
+
+// NameGT applies the GT predicate on the "name" field.
+func NameGT(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldGT(FieldName, v))
+}
+
+// NameGTE applies the GTE predicate on the "name" field.
+func NameGTE(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldGTE(FieldName, v))
+}
+
+// NameLT applies the LT predicate on the "name" field.
+func NameLT(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldLT(FieldName, v))
+}
+
+// NameLTE applies the LTE predicate on the "name" field.
+func NameLTE(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldLTE(FieldName, v))
+}
+
+// NameContains applies the Contains predicate on the "name" field.
+func NameContains(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldContains(FieldName, v))
+}
+
+// NameHasPrefix applies the HasPrefix predicate on the "name" field.
+func NameHasPrefix(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldHasPrefix(FieldName, v))
+}
+
+// NameHasSuffix applies the HasSuffix predicate on the "name" field.
+func NameHasSuffix(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldHasSuffix(FieldName, v))
+}
+
+// NameEqualFold applies the EqualFold predicate on the "name" field.
+func NameEqualFold(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldEqualFold(FieldName, v))
+}
+
+// NameContainsFold applies the ContainsFold predicate on the "name" field.
+func NameContainsFold(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldContainsFold(FieldName, v))
+}
+
+// FromConsoleEQ applies the EQ predicate on the "from_console" field.
+func FromConsoleEQ(v bool) predicate.AllowList {
+ return predicate.AllowList(sql.FieldEQ(FieldFromConsole, v))
+}
+
+// FromConsoleNEQ applies the NEQ predicate on the "from_console" field.
+func FromConsoleNEQ(v bool) predicate.AllowList {
+ return predicate.AllowList(sql.FieldNEQ(FieldFromConsole, v))
+}
+
+// DescriptionEQ applies the EQ predicate on the "description" field.
+func DescriptionEQ(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldEQ(FieldDescription, v))
+}
+
+// DescriptionNEQ applies the NEQ predicate on the "description" field.
+func DescriptionNEQ(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldNEQ(FieldDescription, v))
+}
+
+// DescriptionIn applies the In predicate on the "description" field.
+func DescriptionIn(vs ...string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldIn(FieldDescription, vs...))
+}
+
+// DescriptionNotIn applies the NotIn predicate on the "description" field.
+func DescriptionNotIn(vs ...string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldNotIn(FieldDescription, vs...))
+}
+
+// DescriptionGT applies the GT predicate on the "description" field.
+func DescriptionGT(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldGT(FieldDescription, v))
+}
+
+// DescriptionGTE applies the GTE predicate on the "description" field.
+func DescriptionGTE(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldGTE(FieldDescription, v))
+}
+
+// DescriptionLT applies the LT predicate on the "description" field.
+func DescriptionLT(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldLT(FieldDescription, v))
+}
+
+// DescriptionLTE applies the LTE predicate on the "description" field.
+func DescriptionLTE(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldLTE(FieldDescription, v))
+}
+
+// DescriptionContains applies the Contains predicate on the "description" field.
+func DescriptionContains(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldContains(FieldDescription, v))
+}
+
+// DescriptionHasPrefix applies the HasPrefix predicate on the "description" field.
+func DescriptionHasPrefix(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldHasPrefix(FieldDescription, v))
+}
+
+// DescriptionHasSuffix applies the HasSuffix predicate on the "description" field.
+func DescriptionHasSuffix(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldHasSuffix(FieldDescription, v))
+}
+
+// DescriptionIsNil applies the IsNil predicate on the "description" field.
+func DescriptionIsNil() predicate.AllowList {
+ return predicate.AllowList(sql.FieldIsNull(FieldDescription))
+}
+
+// DescriptionNotNil applies the NotNil predicate on the "description" field.
+func DescriptionNotNil() predicate.AllowList {
+ return predicate.AllowList(sql.FieldNotNull(FieldDescription))
+}
+
+// DescriptionEqualFold applies the EqualFold predicate on the "description" field.
+func DescriptionEqualFold(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldEqualFold(FieldDescription, v))
+}
+
+// DescriptionContainsFold applies the ContainsFold predicate on the "description" field.
+func DescriptionContainsFold(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldContainsFold(FieldDescription, v))
+}
+
+// AllowlistIDEQ applies the EQ predicate on the "allowlist_id" field.
+func AllowlistIDEQ(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldEQ(FieldAllowlistID, v))
+}
+
+// AllowlistIDNEQ applies the NEQ predicate on the "allowlist_id" field.
+func AllowlistIDNEQ(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldNEQ(FieldAllowlistID, v))
+}
+
+// AllowlistIDIn applies the In predicate on the "allowlist_id" field.
+func AllowlistIDIn(vs ...string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldIn(FieldAllowlistID, vs...))
+}
+
+// AllowlistIDNotIn applies the NotIn predicate on the "allowlist_id" field.
+func AllowlistIDNotIn(vs ...string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldNotIn(FieldAllowlistID, vs...))
+}
+
+// AllowlistIDGT applies the GT predicate on the "allowlist_id" field.
+func AllowlistIDGT(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldGT(FieldAllowlistID, v))
+}
+
+// AllowlistIDGTE applies the GTE predicate on the "allowlist_id" field.
+func AllowlistIDGTE(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldGTE(FieldAllowlistID, v))
+}
+
+// AllowlistIDLT applies the LT predicate on the "allowlist_id" field.
+func AllowlistIDLT(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldLT(FieldAllowlistID, v))
+}
+
+// AllowlistIDLTE applies the LTE predicate on the "allowlist_id" field.
+func AllowlistIDLTE(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldLTE(FieldAllowlistID, v))
+}
+
+// AllowlistIDContains applies the Contains predicate on the "allowlist_id" field.
+func AllowlistIDContains(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldContains(FieldAllowlistID, v))
+}
+
+// AllowlistIDHasPrefix applies the HasPrefix predicate on the "allowlist_id" field.
+func AllowlistIDHasPrefix(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldHasPrefix(FieldAllowlistID, v))
+}
+
+// AllowlistIDHasSuffix applies the HasSuffix predicate on the "allowlist_id" field.
+func AllowlistIDHasSuffix(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldHasSuffix(FieldAllowlistID, v))
+}
+
+// AllowlistIDIsNil applies the IsNil predicate on the "allowlist_id" field.
+func AllowlistIDIsNil() predicate.AllowList {
+ return predicate.AllowList(sql.FieldIsNull(FieldAllowlistID))
+}
+
+// AllowlistIDNotNil applies the NotNil predicate on the "allowlist_id" field.
+func AllowlistIDNotNil() predicate.AllowList {
+ return predicate.AllowList(sql.FieldNotNull(FieldAllowlistID))
+}
+
+// AllowlistIDEqualFold applies the EqualFold predicate on the "allowlist_id" field.
+func AllowlistIDEqualFold(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldEqualFold(FieldAllowlistID, v))
+}
+
+// AllowlistIDContainsFold applies the ContainsFold predicate on the "allowlist_id" field.
+func AllowlistIDContainsFold(v string) predicate.AllowList {
+ return predicate.AllowList(sql.FieldContainsFold(FieldAllowlistID, v))
+}
+
+// HasAllowlistItems applies the HasEdge predicate on the "allowlist_items" edge.
+func HasAllowlistItems() predicate.AllowList {
+ return predicate.AllowList(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2M, false, AllowlistItemsTable, AllowlistItemsPrimaryKey...),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasAllowlistItemsWith applies the HasEdge predicate on the "allowlist_items" edge with a given conditions (other predicates).
+func HasAllowlistItemsWith(preds ...predicate.AllowListItem) predicate.AllowList {
+ return predicate.AllowList(func(s *sql.Selector) {
+ step := newAllowlistItemsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.AllowList) predicate.AllowList {
+ return predicate.AllowList(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.AllowList) predicate.AllowList {
+ return predicate.AllowList(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.AllowList) predicate.AllowList {
+ return predicate.AllowList(sql.NotPredicates(p))
+}
diff --git a/pkg/database/ent/allowlist_create.go b/pkg/database/ent/allowlist_create.go
new file mode 100644
index 00000000000..ec9d29b6ae5
--- /dev/null
+++ b/pkg/database/ent/allowlist_create.go
@@ -0,0 +1,321 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem"
+)
+
+// AllowListCreate is the builder for creating a AllowList entity.
+type AllowListCreate struct {
+ config
+ mutation *AllowListMutation
+ hooks []Hook
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (alc *AllowListCreate) SetCreatedAt(t time.Time) *AllowListCreate {
+ alc.mutation.SetCreatedAt(t)
+ return alc
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (alc *AllowListCreate) SetNillableCreatedAt(t *time.Time) *AllowListCreate {
+ if t != nil {
+ alc.SetCreatedAt(*t)
+ }
+ return alc
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (alc *AllowListCreate) SetUpdatedAt(t time.Time) *AllowListCreate {
+ alc.mutation.SetUpdatedAt(t)
+ return alc
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (alc *AllowListCreate) SetNillableUpdatedAt(t *time.Time) *AllowListCreate {
+ if t != nil {
+ alc.SetUpdatedAt(*t)
+ }
+ return alc
+}
+
+// SetName sets the "name" field.
+func (alc *AllowListCreate) SetName(s string) *AllowListCreate {
+ alc.mutation.SetName(s)
+ return alc
+}
+
+// SetFromConsole sets the "from_console" field.
+func (alc *AllowListCreate) SetFromConsole(b bool) *AllowListCreate {
+ alc.mutation.SetFromConsole(b)
+ return alc
+}
+
+// SetDescription sets the "description" field.
+func (alc *AllowListCreate) SetDescription(s string) *AllowListCreate {
+ alc.mutation.SetDescription(s)
+ return alc
+}
+
+// SetNillableDescription sets the "description" field if the given value is not nil.
+func (alc *AllowListCreate) SetNillableDescription(s *string) *AllowListCreate {
+ if s != nil {
+ alc.SetDescription(*s)
+ }
+ return alc
+}
+
+// SetAllowlistID sets the "allowlist_id" field.
+func (alc *AllowListCreate) SetAllowlistID(s string) *AllowListCreate {
+ alc.mutation.SetAllowlistID(s)
+ return alc
+}
+
+// SetNillableAllowlistID sets the "allowlist_id" field if the given value is not nil.
+func (alc *AllowListCreate) SetNillableAllowlistID(s *string) *AllowListCreate {
+ if s != nil {
+ alc.SetAllowlistID(*s)
+ }
+ return alc
+}
+
+// AddAllowlistItemIDs adds the "allowlist_items" edge to the AllowListItem entity by IDs.
+func (alc *AllowListCreate) AddAllowlistItemIDs(ids ...int) *AllowListCreate {
+ alc.mutation.AddAllowlistItemIDs(ids...)
+ return alc
+}
+
+// AddAllowlistItems adds the "allowlist_items" edges to the AllowListItem entity.
+func (alc *AllowListCreate) AddAllowlistItems(a ...*AllowListItem) *AllowListCreate {
+ ids := make([]int, len(a))
+ for i := range a {
+ ids[i] = a[i].ID
+ }
+ return alc.AddAllowlistItemIDs(ids...)
+}
+
+// Mutation returns the AllowListMutation object of the builder.
+func (alc *AllowListCreate) Mutation() *AllowListMutation {
+ return alc.mutation
+}
+
+// Save creates the AllowList in the database.
+func (alc *AllowListCreate) Save(ctx context.Context) (*AllowList, error) {
+ alc.defaults()
+ return withHooks(ctx, alc.sqlSave, alc.mutation, alc.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (alc *AllowListCreate) SaveX(ctx context.Context) *AllowList {
+ v, err := alc.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (alc *AllowListCreate) Exec(ctx context.Context) error {
+ _, err := alc.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (alc *AllowListCreate) ExecX(ctx context.Context) {
+ if err := alc.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (alc *AllowListCreate) defaults() {
+ if _, ok := alc.mutation.CreatedAt(); !ok {
+ v := allowlist.DefaultCreatedAt()
+ alc.mutation.SetCreatedAt(v)
+ }
+ if _, ok := alc.mutation.UpdatedAt(); !ok {
+ v := allowlist.DefaultUpdatedAt()
+ alc.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (alc *AllowListCreate) check() error {
+ if _, ok := alc.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AllowList.created_at"`)}
+ }
+ if _, ok := alc.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AllowList.updated_at"`)}
+ }
+ if _, ok := alc.mutation.Name(); !ok {
+ return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "AllowList.name"`)}
+ }
+ if _, ok := alc.mutation.FromConsole(); !ok {
+ return &ValidationError{Name: "from_console", err: errors.New(`ent: missing required field "AllowList.from_console"`)}
+ }
+ return nil
+}
+
+func (alc *AllowListCreate) sqlSave(ctx context.Context) (*AllowList, error) {
+ if err := alc.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := alc.createSpec()
+ if err := sqlgraph.CreateNode(ctx, alc.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int(id)
+ alc.mutation.id = &_node.ID
+ alc.mutation.done = true
+ return _node, nil
+}
+
+func (alc *AllowListCreate) createSpec() (*AllowList, *sqlgraph.CreateSpec) {
+ var (
+ _node = &AllowList{config: alc.config}
+ _spec = sqlgraph.NewCreateSpec(allowlist.Table, sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt))
+ )
+ if value, ok := alc.mutation.CreatedAt(); ok {
+ _spec.SetField(allowlist.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := alc.mutation.UpdatedAt(); ok {
+ _spec.SetField(allowlist.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := alc.mutation.Name(); ok {
+ _spec.SetField(allowlist.FieldName, field.TypeString, value)
+ _node.Name = value
+ }
+ if value, ok := alc.mutation.FromConsole(); ok {
+ _spec.SetField(allowlist.FieldFromConsole, field.TypeBool, value)
+ _node.FromConsole = value
+ }
+ if value, ok := alc.mutation.Description(); ok {
+ _spec.SetField(allowlist.FieldDescription, field.TypeString, value)
+ _node.Description = value
+ }
+ if value, ok := alc.mutation.AllowlistID(); ok {
+ _spec.SetField(allowlist.FieldAllowlistID, field.TypeString, value)
+ _node.AllowlistID = value
+ }
+ if nodes := alc.mutation.AllowlistItemsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2M,
+ Inverse: false,
+ Table: allowlist.AllowlistItemsTable,
+ Columns: allowlist.AllowlistItemsPrimaryKey,
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// AllowListCreateBulk is the builder for creating many AllowList entities in bulk.
+type AllowListCreateBulk struct {
+ config
+ err error
+ builders []*AllowListCreate
+}
+
+// Save creates the AllowList entities in the database.
+func (alcb *AllowListCreateBulk) Save(ctx context.Context) ([]*AllowList, error) {
+ if alcb.err != nil {
+ return nil, alcb.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(alcb.builders))
+ nodes := make([]*AllowList, len(alcb.builders))
+ mutators := make([]Mutator, len(alcb.builders))
+ for i := range alcb.builders {
+ func(i int, root context.Context) {
+ builder := alcb.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*AllowListMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, alcb.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, alcb.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, alcb.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (alcb *AllowListCreateBulk) SaveX(ctx context.Context) []*AllowList {
+ v, err := alcb.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (alcb *AllowListCreateBulk) Exec(ctx context.Context) error {
+ _, err := alcb.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (alcb *AllowListCreateBulk) ExecX(ctx context.Context) {
+ if err := alcb.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/database/ent/allowlist_delete.go b/pkg/database/ent/allowlist_delete.go
new file mode 100644
index 00000000000..dcfaa214f6f
--- /dev/null
+++ b/pkg/database/ent/allowlist_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate"
+)
+
+// AllowListDelete is the builder for deleting a AllowList entity.
+type AllowListDelete struct {
+ config
+ hooks []Hook
+ mutation *AllowListMutation
+}
+
+// Where appends a list predicates to the AllowListDelete builder.
+func (ald *AllowListDelete) Where(ps ...predicate.AllowList) *AllowListDelete {
+ ald.mutation.Where(ps...)
+ return ald
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (ald *AllowListDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, ald.sqlExec, ald.mutation, ald.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (ald *AllowListDelete) ExecX(ctx context.Context) int {
+ n, err := ald.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (ald *AllowListDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(allowlist.Table, sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt))
+ if ps := ald.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, ald.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ ald.mutation.done = true
+ return affected, err
+}
+
+// AllowListDeleteOne is the builder for deleting a single AllowList entity.
+type AllowListDeleteOne struct {
+ ald *AllowListDelete
+}
+
+// Where appends a list predicates to the AllowListDelete builder.
+func (aldo *AllowListDeleteOne) Where(ps ...predicate.AllowList) *AllowListDeleteOne {
+ aldo.ald.mutation.Where(ps...)
+ return aldo
+}
+
+// Exec executes the deletion query.
+func (aldo *AllowListDeleteOne) Exec(ctx context.Context) error {
+ n, err := aldo.ald.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{allowlist.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (aldo *AllowListDeleteOne) ExecX(ctx context.Context) {
+ if err := aldo.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/database/ent/allowlist_query.go b/pkg/database/ent/allowlist_query.go
new file mode 100644
index 00000000000..511c3c051f7
--- /dev/null
+++ b/pkg/database/ent/allowlist_query.go
@@ -0,0 +1,636 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "database/sql/driver"
+ "fmt"
+ "math"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate"
+)
+
+// AllowListQuery is the builder for querying AllowList entities.
+type AllowListQuery struct {
+ config
+ ctx *QueryContext
+ order []allowlist.OrderOption
+ inters []Interceptor
+ predicates []predicate.AllowList
+ withAllowlistItems *AllowListItemQuery
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the AllowListQuery builder.
+func (alq *AllowListQuery) Where(ps ...predicate.AllowList) *AllowListQuery {
+ alq.predicates = append(alq.predicates, ps...)
+ return alq
+}
+
+// Limit the number of records to be returned by this query.
+func (alq *AllowListQuery) Limit(limit int) *AllowListQuery {
+ alq.ctx.Limit = &limit
+ return alq
+}
+
+// Offset to start from.
+func (alq *AllowListQuery) Offset(offset int) *AllowListQuery {
+ alq.ctx.Offset = &offset
+ return alq
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (alq *AllowListQuery) Unique(unique bool) *AllowListQuery {
+ alq.ctx.Unique = &unique
+ return alq
+}
+
+// Order specifies how the records should be ordered.
+func (alq *AllowListQuery) Order(o ...allowlist.OrderOption) *AllowListQuery {
+ alq.order = append(alq.order, o...)
+ return alq
+}
+
+// QueryAllowlistItems chains the current query on the "allowlist_items" edge.
+func (alq *AllowListQuery) QueryAllowlistItems() *AllowListItemQuery {
+ query := (&AllowListItemClient{config: alq.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := alq.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := alq.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(allowlist.Table, allowlist.FieldID, selector),
+ sqlgraph.To(allowlistitem.Table, allowlistitem.FieldID),
+ sqlgraph.Edge(sqlgraph.M2M, false, allowlist.AllowlistItemsTable, allowlist.AllowlistItemsPrimaryKey...),
+ )
+ fromU = sqlgraph.SetNeighbors(alq.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first AllowList entity from the query.
+// Returns a *NotFoundError when no AllowList was found.
+func (alq *AllowListQuery) First(ctx context.Context) (*AllowList, error) {
+ nodes, err := alq.Limit(1).All(setContextOp(ctx, alq.ctx, "First"))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{allowlist.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (alq *AllowListQuery) FirstX(ctx context.Context) *AllowList {
+ node, err := alq.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first AllowList ID from the query.
+// Returns a *NotFoundError when no AllowList ID was found.
+func (alq *AllowListQuery) FirstID(ctx context.Context) (id int, err error) {
+ var ids []int
+ if ids, err = alq.Limit(1).IDs(setContextOp(ctx, alq.ctx, "FirstID")); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{allowlist.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (alq *AllowListQuery) FirstIDX(ctx context.Context) int {
+ id, err := alq.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single AllowList entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one AllowList entity is found.
+// Returns a *NotFoundError when no AllowList entities are found.
+func (alq *AllowListQuery) Only(ctx context.Context) (*AllowList, error) {
+ nodes, err := alq.Limit(2).All(setContextOp(ctx, alq.ctx, "Only"))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{allowlist.Label}
+ default:
+ return nil, &NotSingularError{allowlist.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (alq *AllowListQuery) OnlyX(ctx context.Context) *AllowList {
+ node, err := alq.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only AllowList ID in the query.
+// Returns a *NotSingularError when more than one AllowList ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (alq *AllowListQuery) OnlyID(ctx context.Context) (id int, err error) {
+ var ids []int
+ if ids, err = alq.Limit(2).IDs(setContextOp(ctx, alq.ctx, "OnlyID")); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{allowlist.Label}
+ default:
+ err = &NotSingularError{allowlist.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (alq *AllowListQuery) OnlyIDX(ctx context.Context) int {
+ id, err := alq.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of AllowLists.
+func (alq *AllowListQuery) All(ctx context.Context) ([]*AllowList, error) {
+ ctx = setContextOp(ctx, alq.ctx, "All")
+ if err := alq.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*AllowList, *AllowListQuery]()
+ return withInterceptors[[]*AllowList](ctx, alq, qr, alq.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (alq *AllowListQuery) AllX(ctx context.Context) []*AllowList {
+ nodes, err := alq.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of AllowList IDs.
+func (alq *AllowListQuery) IDs(ctx context.Context) (ids []int, err error) {
+ if alq.ctx.Unique == nil && alq.path != nil {
+ alq.Unique(true)
+ }
+ ctx = setContextOp(ctx, alq.ctx, "IDs")
+ if err = alq.Select(allowlist.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (alq *AllowListQuery) IDsX(ctx context.Context) []int {
+ ids, err := alq.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (alq *AllowListQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, alq.ctx, "Count")
+ if err := alq.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, alq, querierCount[*AllowListQuery](), alq.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (alq *AllowListQuery) CountX(ctx context.Context) int {
+ count, err := alq.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (alq *AllowListQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, alq.ctx, "Exist")
+ switch _, err := alq.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (alq *AllowListQuery) ExistX(ctx context.Context) bool {
+ exist, err := alq.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the AllowListQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (alq *AllowListQuery) Clone() *AllowListQuery {
+ if alq == nil {
+ return nil
+ }
+ return &AllowListQuery{
+ config: alq.config,
+ ctx: alq.ctx.Clone(),
+ order: append([]allowlist.OrderOption{}, alq.order...),
+ inters: append([]Interceptor{}, alq.inters...),
+ predicates: append([]predicate.AllowList{}, alq.predicates...),
+ withAllowlistItems: alq.withAllowlistItems.Clone(),
+ // clone intermediate query.
+ sql: alq.sql.Clone(),
+ path: alq.path,
+ }
+}
+
+// WithAllowlistItems tells the query-builder to eager-load the nodes that are connected to
+// the "allowlist_items" edge. The optional arguments are used to configure the query builder of the edge.
+func (alq *AllowListQuery) WithAllowlistItems(opts ...func(*AllowListItemQuery)) *AllowListQuery {
+ query := (&AllowListItemClient{config: alq.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ alq.withAllowlistItems = query
+ return alq
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.AllowList.Query().
+// GroupBy(allowlist.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (alq *AllowListQuery) GroupBy(field string, fields ...string) *AllowListGroupBy {
+ alq.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &AllowListGroupBy{build: alq}
+ grbuild.flds = &alq.ctx.Fields
+ grbuild.label = allowlist.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.AllowList.Query().
+// Select(allowlist.FieldCreatedAt).
+// Scan(ctx, &v)
+func (alq *AllowListQuery) Select(fields ...string) *AllowListSelect {
+ alq.ctx.Fields = append(alq.ctx.Fields, fields...)
+ sbuild := &AllowListSelect{AllowListQuery: alq}
+ sbuild.label = allowlist.Label
+ sbuild.flds, sbuild.scan = &alq.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a AllowListSelect configured with the given aggregations.
+func (alq *AllowListQuery) Aggregate(fns ...AggregateFunc) *AllowListSelect {
+ return alq.Select().Aggregate(fns...)
+}
+
+func (alq *AllowListQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range alq.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, alq); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range alq.ctx.Fields {
+ if !allowlist.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if alq.path != nil {
+ prev, err := alq.path(ctx)
+ if err != nil {
+ return err
+ }
+ alq.sql = prev
+ }
+ return nil
+}
+
+func (alq *AllowListQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AllowList, error) {
+ var (
+ nodes = []*AllowList{}
+ _spec = alq.querySpec()
+ loadedTypes = [1]bool{
+ alq.withAllowlistItems != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*AllowList).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &AllowList{config: alq.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, alq.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := alq.withAllowlistItems; query != nil {
+ if err := alq.loadAllowlistItems(ctx, query, nodes,
+ func(n *AllowList) { n.Edges.AllowlistItems = []*AllowListItem{} },
+ func(n *AllowList, e *AllowListItem) { n.Edges.AllowlistItems = append(n.Edges.AllowlistItems, e) }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (alq *AllowListQuery) loadAllowlistItems(ctx context.Context, query *AllowListItemQuery, nodes []*AllowList, init func(*AllowList), assign func(*AllowList, *AllowListItem)) error {
+ edgeIDs := make([]driver.Value, len(nodes))
+ byID := make(map[int]*AllowList)
+ nids := make(map[int]map[*AllowList]struct{})
+ for i, node := range nodes {
+ edgeIDs[i] = node.ID
+ byID[node.ID] = node
+ if init != nil {
+ init(node)
+ }
+ }
+ query.Where(func(s *sql.Selector) {
+ joinT := sql.Table(allowlist.AllowlistItemsTable)
+ s.Join(joinT).On(s.C(allowlistitem.FieldID), joinT.C(allowlist.AllowlistItemsPrimaryKey[1]))
+ s.Where(sql.InValues(joinT.C(allowlist.AllowlistItemsPrimaryKey[0]), edgeIDs...))
+ columns := s.SelectedColumns()
+ s.Select(joinT.C(allowlist.AllowlistItemsPrimaryKey[0]))
+ s.AppendSelect(columns...)
+ s.SetDistinct(false)
+ })
+ if err := query.prepareQuery(ctx); err != nil {
+ return err
+ }
+ qr := QuerierFunc(func(ctx context.Context, q Query) (Value, error) {
+ return query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
+ assign := spec.Assign
+ values := spec.ScanValues
+ spec.ScanValues = func(columns []string) ([]any, error) {
+ values, err := values(columns[1:])
+ if err != nil {
+ return nil, err
+ }
+ return append([]any{new(sql.NullInt64)}, values...), nil
+ }
+ spec.Assign = func(columns []string, values []any) error {
+ outValue := int(values[0].(*sql.NullInt64).Int64)
+ inValue := int(values[1].(*sql.NullInt64).Int64)
+ if nids[inValue] == nil {
+ nids[inValue] = map[*AllowList]struct{}{byID[outValue]: {}}
+ return assign(columns[1:], values[1:])
+ }
+ nids[inValue][byID[outValue]] = struct{}{}
+ return nil
+ }
+ })
+ })
+ neighbors, err := withInterceptors[[]*AllowListItem](ctx, query, qr, query.inters)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected "allowlist_items" node returned %v`, n.ID)
+ }
+ for kn := range nodes {
+ assign(kn, n)
+ }
+ }
+ return nil
+}
+
+func (alq *AllowListQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := alq.querySpec()
+ _spec.Node.Columns = alq.ctx.Fields
+ if len(alq.ctx.Fields) > 0 {
+ _spec.Unique = alq.ctx.Unique != nil && *alq.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, alq.driver, _spec)
+}
+
+func (alq *AllowListQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(allowlist.Table, allowlist.Columns, sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt))
+ _spec.From = alq.sql
+ if unique := alq.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if alq.path != nil {
+ _spec.Unique = true
+ }
+ if fields := alq.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, allowlist.FieldID)
+ for i := range fields {
+ if fields[i] != allowlist.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ }
+ if ps := alq.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := alq.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := alq.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := alq.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (alq *AllowListQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(alq.driver.Dialect())
+ t1 := builder.Table(allowlist.Table)
+ columns := alq.ctx.Fields
+ if len(columns) == 0 {
+ columns = allowlist.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if alq.sql != nil {
+ selector = alq.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if alq.ctx.Unique != nil && *alq.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, p := range alq.predicates {
+ p(selector)
+ }
+ for _, p := range alq.order {
+ p(selector)
+ }
+ if offset := alq.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := alq.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// AllowListGroupBy is the group-by builder for AllowList entities.
+type AllowListGroupBy struct {
+ selector
+ build *AllowListQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (algb *AllowListGroupBy) Aggregate(fns ...AggregateFunc) *AllowListGroupBy {
+ algb.fns = append(algb.fns, fns...)
+ return algb
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (algb *AllowListGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, algb.build.ctx, "GroupBy")
+ if err := algb.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AllowListQuery, *AllowListGroupBy](ctx, algb.build, algb, algb.build.inters, v)
+}
+
+func (algb *AllowListGroupBy) sqlScan(ctx context.Context, root *AllowListQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(algb.fns))
+ for _, fn := range algb.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*algb.flds)+len(algb.fns))
+ for _, f := range *algb.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*algb.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := algb.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// AllowListSelect is the builder for selecting fields of AllowList entities.
+type AllowListSelect struct {
+ *AllowListQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (als *AllowListSelect) Aggregate(fns ...AggregateFunc) *AllowListSelect {
+ als.fns = append(als.fns, fns...)
+ return als
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (als *AllowListSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, als.ctx, "Select")
+ if err := als.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AllowListQuery, *AllowListSelect](ctx, als.AllowListQuery, als, als.inters, v)
+}
+
+func (als *AllowListSelect) sqlScan(ctx context.Context, root *AllowListQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(als.fns))
+ for _, fn := range als.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*als.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := als.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/pkg/database/ent/allowlist_update.go b/pkg/database/ent/allowlist_update.go
new file mode 100644
index 00000000000..b2a6bc65c9d
--- /dev/null
+++ b/pkg/database/ent/allowlist_update.go
@@ -0,0 +1,421 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate"
+)
+
+// AllowListUpdate is the builder for updating AllowList entities.
+type AllowListUpdate struct {
+ config
+ hooks []Hook
+ mutation *AllowListMutation
+}
+
+// Where appends a list predicates to the AllowListUpdate builder.
+func (alu *AllowListUpdate) Where(ps ...predicate.AllowList) *AllowListUpdate {
+ alu.mutation.Where(ps...)
+ return alu
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (alu *AllowListUpdate) SetUpdatedAt(t time.Time) *AllowListUpdate {
+ alu.mutation.SetUpdatedAt(t)
+ return alu
+}
+
+// SetFromConsole sets the "from_console" field.
+func (alu *AllowListUpdate) SetFromConsole(b bool) *AllowListUpdate {
+ alu.mutation.SetFromConsole(b)
+ return alu
+}
+
+// SetNillableFromConsole sets the "from_console" field if the given value is not nil.
+func (alu *AllowListUpdate) SetNillableFromConsole(b *bool) *AllowListUpdate {
+ if b != nil {
+ alu.SetFromConsole(*b)
+ }
+ return alu
+}
+
+// AddAllowlistItemIDs adds the "allowlist_items" edge to the AllowListItem entity by IDs.
+func (alu *AllowListUpdate) AddAllowlistItemIDs(ids ...int) *AllowListUpdate {
+ alu.mutation.AddAllowlistItemIDs(ids...)
+ return alu
+}
+
+// AddAllowlistItems adds the "allowlist_items" edges to the AllowListItem entity.
+func (alu *AllowListUpdate) AddAllowlistItems(a ...*AllowListItem) *AllowListUpdate {
+ ids := make([]int, len(a))
+ for i := range a {
+ ids[i] = a[i].ID
+ }
+ return alu.AddAllowlistItemIDs(ids...)
+}
+
+// Mutation returns the AllowListMutation object of the builder.
+func (alu *AllowListUpdate) Mutation() *AllowListMutation {
+ return alu.mutation
+}
+
+// ClearAllowlistItems clears all "allowlist_items" edges to the AllowListItem entity.
+func (alu *AllowListUpdate) ClearAllowlistItems() *AllowListUpdate {
+ alu.mutation.ClearAllowlistItems()
+ return alu
+}
+
+// RemoveAllowlistItemIDs removes the "allowlist_items" edge to AllowListItem entities by IDs.
+func (alu *AllowListUpdate) RemoveAllowlistItemIDs(ids ...int) *AllowListUpdate {
+ alu.mutation.RemoveAllowlistItemIDs(ids...)
+ return alu
+}
+
+// RemoveAllowlistItems removes "allowlist_items" edges to AllowListItem entities.
+func (alu *AllowListUpdate) RemoveAllowlistItems(a ...*AllowListItem) *AllowListUpdate {
+ ids := make([]int, len(a))
+ for i := range a {
+ ids[i] = a[i].ID
+ }
+ return alu.RemoveAllowlistItemIDs(ids...)
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (alu *AllowListUpdate) Save(ctx context.Context) (int, error) {
+ alu.defaults()
+ return withHooks(ctx, alu.sqlSave, alu.mutation, alu.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (alu *AllowListUpdate) SaveX(ctx context.Context) int {
+ affected, err := alu.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (alu *AllowListUpdate) Exec(ctx context.Context) error {
+ _, err := alu.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (alu *AllowListUpdate) ExecX(ctx context.Context) {
+ if err := alu.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (alu *AllowListUpdate) defaults() {
+ if _, ok := alu.mutation.UpdatedAt(); !ok {
+ v := allowlist.UpdateDefaultUpdatedAt()
+ alu.mutation.SetUpdatedAt(v)
+ }
+}
+
+func (alu *AllowListUpdate) sqlSave(ctx context.Context) (n int, err error) {
+ _spec := sqlgraph.NewUpdateSpec(allowlist.Table, allowlist.Columns, sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt))
+ if ps := alu.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := alu.mutation.UpdatedAt(); ok {
+ _spec.SetField(allowlist.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := alu.mutation.FromConsole(); ok {
+ _spec.SetField(allowlist.FieldFromConsole, field.TypeBool, value)
+ }
+ if alu.mutation.DescriptionCleared() {
+ _spec.ClearField(allowlist.FieldDescription, field.TypeString)
+ }
+ if alu.mutation.AllowlistIDCleared() {
+ _spec.ClearField(allowlist.FieldAllowlistID, field.TypeString)
+ }
+ if alu.mutation.AllowlistItemsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2M,
+ Inverse: false,
+ Table: allowlist.AllowlistItemsTable,
+ Columns: allowlist.AllowlistItemsPrimaryKey,
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := alu.mutation.RemovedAllowlistItemsIDs(); len(nodes) > 0 && !alu.mutation.AllowlistItemsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2M,
+ Inverse: false,
+ Table: allowlist.AllowlistItemsTable,
+ Columns: allowlist.AllowlistItemsPrimaryKey,
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := alu.mutation.AllowlistItemsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2M,
+ Inverse: false,
+ Table: allowlist.AllowlistItemsTable,
+ Columns: allowlist.AllowlistItemsPrimaryKey,
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if n, err = sqlgraph.UpdateNodes(ctx, alu.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{allowlist.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ alu.mutation.done = true
+ return n, nil
+}
+
+// AllowListUpdateOne is the builder for updating a single AllowList entity.
+type AllowListUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *AllowListMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (aluo *AllowListUpdateOne) SetUpdatedAt(t time.Time) *AllowListUpdateOne {
+ aluo.mutation.SetUpdatedAt(t)
+ return aluo
+}
+
+// SetFromConsole sets the "from_console" field.
+func (aluo *AllowListUpdateOne) SetFromConsole(b bool) *AllowListUpdateOne {
+ aluo.mutation.SetFromConsole(b)
+ return aluo
+}
+
+// SetNillableFromConsole sets the "from_console" field if the given value is not nil.
+func (aluo *AllowListUpdateOne) SetNillableFromConsole(b *bool) *AllowListUpdateOne {
+ if b != nil {
+ aluo.SetFromConsole(*b)
+ }
+ return aluo
+}
+
+// AddAllowlistItemIDs adds the "allowlist_items" edge to the AllowListItem entity by IDs.
+func (aluo *AllowListUpdateOne) AddAllowlistItemIDs(ids ...int) *AllowListUpdateOne {
+ aluo.mutation.AddAllowlistItemIDs(ids...)
+ return aluo
+}
+
+// AddAllowlistItems adds the "allowlist_items" edges to the AllowListItem entity.
+func (aluo *AllowListUpdateOne) AddAllowlistItems(a ...*AllowListItem) *AllowListUpdateOne {
+ ids := make([]int, len(a))
+ for i := range a {
+ ids[i] = a[i].ID
+ }
+ return aluo.AddAllowlistItemIDs(ids...)
+}
+
+// Mutation returns the AllowListMutation object of the builder.
+func (aluo *AllowListUpdateOne) Mutation() *AllowListMutation {
+ return aluo.mutation
+}
+
+// ClearAllowlistItems clears all "allowlist_items" edges to the AllowListItem entity.
+func (aluo *AllowListUpdateOne) ClearAllowlistItems() *AllowListUpdateOne {
+ aluo.mutation.ClearAllowlistItems()
+ return aluo
+}
+
+// RemoveAllowlistItemIDs removes the "allowlist_items" edge to AllowListItem entities by IDs.
+func (aluo *AllowListUpdateOne) RemoveAllowlistItemIDs(ids ...int) *AllowListUpdateOne {
+ aluo.mutation.RemoveAllowlistItemIDs(ids...)
+ return aluo
+}
+
+// RemoveAllowlistItems removes "allowlist_items" edges to AllowListItem entities.
+func (aluo *AllowListUpdateOne) RemoveAllowlistItems(a ...*AllowListItem) *AllowListUpdateOne {
+ ids := make([]int, len(a))
+ for i := range a {
+ ids[i] = a[i].ID
+ }
+ return aluo.RemoveAllowlistItemIDs(ids...)
+}
+
+// Where appends a list predicates to the AllowListUpdate builder.
+func (aluo *AllowListUpdateOne) Where(ps ...predicate.AllowList) *AllowListUpdateOne {
+ aluo.mutation.Where(ps...)
+ return aluo
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (aluo *AllowListUpdateOne) Select(field string, fields ...string) *AllowListUpdateOne {
+ aluo.fields = append([]string{field}, fields...)
+ return aluo
+}
+
+// Save executes the query and returns the updated AllowList entity.
+func (aluo *AllowListUpdateOne) Save(ctx context.Context) (*AllowList, error) {
+ aluo.defaults()
+ return withHooks(ctx, aluo.sqlSave, aluo.mutation, aluo.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (aluo *AllowListUpdateOne) SaveX(ctx context.Context) *AllowList {
+ node, err := aluo.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (aluo *AllowListUpdateOne) Exec(ctx context.Context) error {
+ _, err := aluo.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (aluo *AllowListUpdateOne) ExecX(ctx context.Context) {
+ if err := aluo.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (aluo *AllowListUpdateOne) defaults() {
+ if _, ok := aluo.mutation.UpdatedAt(); !ok {
+ v := allowlist.UpdateDefaultUpdatedAt()
+ aluo.mutation.SetUpdatedAt(v)
+ }
+}
+
+func (aluo *AllowListUpdateOne) sqlSave(ctx context.Context) (_node *AllowList, err error) {
+ _spec := sqlgraph.NewUpdateSpec(allowlist.Table, allowlist.Columns, sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt))
+ id, ok := aluo.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AllowList.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := aluo.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, allowlist.FieldID)
+ for _, f := range fields {
+ if !allowlist.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != allowlist.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := aluo.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := aluo.mutation.UpdatedAt(); ok {
+ _spec.SetField(allowlist.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := aluo.mutation.FromConsole(); ok {
+ _spec.SetField(allowlist.FieldFromConsole, field.TypeBool, value)
+ }
+ if aluo.mutation.DescriptionCleared() {
+ _spec.ClearField(allowlist.FieldDescription, field.TypeString)
+ }
+ if aluo.mutation.AllowlistIDCleared() {
+ _spec.ClearField(allowlist.FieldAllowlistID, field.TypeString)
+ }
+ if aluo.mutation.AllowlistItemsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2M,
+ Inverse: false,
+ Table: allowlist.AllowlistItemsTable,
+ Columns: allowlist.AllowlistItemsPrimaryKey,
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := aluo.mutation.RemovedAllowlistItemsIDs(); len(nodes) > 0 && !aluo.mutation.AllowlistItemsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2M,
+ Inverse: false,
+ Table: allowlist.AllowlistItemsTable,
+ Columns: allowlist.AllowlistItemsPrimaryKey,
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := aluo.mutation.AllowlistItemsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2M,
+ Inverse: false,
+ Table: allowlist.AllowlistItemsTable,
+ Columns: allowlist.AllowlistItemsPrimaryKey,
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &AllowList{config: aluo.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, aluo.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{allowlist.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ aluo.mutation.done = true
+ return _node, nil
+}
diff --git a/pkg/database/ent/allowlistitem.go b/pkg/database/ent/allowlistitem.go
new file mode 100644
index 00000000000..2c0d997b1d7
--- /dev/null
+++ b/pkg/database/ent/allowlistitem.go
@@ -0,0 +1,231 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem"
+)
+
+// AllowListItem is the model entity for the AllowListItem schema.
+type AllowListItem struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // ExpiresAt holds the value of the "expires_at" field.
+ ExpiresAt time.Time `json:"expires_at,omitempty"`
+ // Comment holds the value of the "comment" field.
+ Comment string `json:"comment,omitempty"`
+ // Value holds the value of the "value" field.
+ Value string `json:"value,omitempty"`
+ // StartIP holds the value of the "start_ip" field.
+ StartIP int64 `json:"start_ip,omitempty"`
+ // EndIP holds the value of the "end_ip" field.
+ EndIP int64 `json:"end_ip,omitempty"`
+ // StartSuffix holds the value of the "start_suffix" field.
+ StartSuffix int64 `json:"start_suffix,omitempty"`
+ // EndSuffix holds the value of the "end_suffix" field.
+ EndSuffix int64 `json:"end_suffix,omitempty"`
+ // IPSize holds the value of the "ip_size" field.
+ IPSize int64 `json:"ip_size,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the AllowListItemQuery when eager-loading is set.
+ Edges AllowListItemEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// AllowListItemEdges holds the relations/edges for other nodes in the graph.
+type AllowListItemEdges struct {
+ // Allowlist holds the value of the allowlist edge.
+ Allowlist []*AllowList `json:"allowlist,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [1]bool
+}
+
+// AllowlistOrErr returns the Allowlist value or an error if the edge
+// was not loaded in eager-loading.
+func (e AllowListItemEdges) AllowlistOrErr() ([]*AllowList, error) {
+ if e.loadedTypes[0] {
+ return e.Allowlist, nil
+ }
+ return nil, &NotLoadedError{edge: "allowlist"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*AllowListItem) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case allowlistitem.FieldID, allowlistitem.FieldStartIP, allowlistitem.FieldEndIP, allowlistitem.FieldStartSuffix, allowlistitem.FieldEndSuffix, allowlistitem.FieldIPSize:
+ values[i] = new(sql.NullInt64)
+ case allowlistitem.FieldComment, allowlistitem.FieldValue:
+ values[i] = new(sql.NullString)
+ case allowlistitem.FieldCreatedAt, allowlistitem.FieldUpdatedAt, allowlistitem.FieldExpiresAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the AllowListItem fields.
+func (ali *AllowListItem) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case allowlistitem.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ ali.ID = int(value.Int64)
+ case allowlistitem.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ ali.CreatedAt = value.Time
+ }
+ case allowlistitem.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ ali.UpdatedAt = value.Time
+ }
+ case allowlistitem.FieldExpiresAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field expires_at", values[i])
+ } else if value.Valid {
+ ali.ExpiresAt = value.Time
+ }
+ case allowlistitem.FieldComment:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field comment", values[i])
+ } else if value.Valid {
+ ali.Comment = value.String
+ }
+ case allowlistitem.FieldValue:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field value", values[i])
+ } else if value.Valid {
+ ali.Value = value.String
+ }
+ case allowlistitem.FieldStartIP:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field start_ip", values[i])
+ } else if value.Valid {
+ ali.StartIP = value.Int64
+ }
+ case allowlistitem.FieldEndIP:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field end_ip", values[i])
+ } else if value.Valid {
+ ali.EndIP = value.Int64
+ }
+ case allowlistitem.FieldStartSuffix:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field start_suffix", values[i])
+ } else if value.Valid {
+ ali.StartSuffix = value.Int64
+ }
+ case allowlistitem.FieldEndSuffix:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field end_suffix", values[i])
+ } else if value.Valid {
+ ali.EndSuffix = value.Int64
+ }
+ case allowlistitem.FieldIPSize:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field ip_size", values[i])
+ } else if value.Valid {
+ ali.IPSize = value.Int64
+ }
+ default:
+ ali.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// GetValue returns the ent.Value that was dynamically selected and assigned to the AllowListItem.
+// This includes values selected through modifiers, order, etc.
+func (ali *AllowListItem) GetValue(name string) (ent.Value, error) {
+ return ali.selectValues.Get(name)
+}
+
+// QueryAllowlist queries the "allowlist" edge of the AllowListItem entity.
+func (ali *AllowListItem) QueryAllowlist() *AllowListQuery {
+ return NewAllowListItemClient(ali.config).QueryAllowlist(ali)
+}
+
+// Update returns a builder for updating this AllowListItem.
+// Note that you need to call AllowListItem.Unwrap() before calling this method if this AllowListItem
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (ali *AllowListItem) Update() *AllowListItemUpdateOne {
+ return NewAllowListItemClient(ali.config).UpdateOne(ali)
+}
+
+// Unwrap unwraps the AllowListItem entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (ali *AllowListItem) Unwrap() *AllowListItem {
+ _tx, ok := ali.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: AllowListItem is not a transactional entity")
+ }
+ ali.config.driver = _tx.drv
+ return ali
+}
+
+// String implements the fmt.Stringer.
+func (ali *AllowListItem) String() string {
+ var builder strings.Builder
+ builder.WriteString("AllowListItem(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", ali.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(ali.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(ali.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("expires_at=")
+ builder.WriteString(ali.ExpiresAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("comment=")
+ builder.WriteString(ali.Comment)
+ builder.WriteString(", ")
+ builder.WriteString("value=")
+ builder.WriteString(ali.Value)
+ builder.WriteString(", ")
+ builder.WriteString("start_ip=")
+ builder.WriteString(fmt.Sprintf("%v", ali.StartIP))
+ builder.WriteString(", ")
+ builder.WriteString("end_ip=")
+ builder.WriteString(fmt.Sprintf("%v", ali.EndIP))
+ builder.WriteString(", ")
+ builder.WriteString("start_suffix=")
+ builder.WriteString(fmt.Sprintf("%v", ali.StartSuffix))
+ builder.WriteString(", ")
+ builder.WriteString("end_suffix=")
+ builder.WriteString(fmt.Sprintf("%v", ali.EndSuffix))
+ builder.WriteString(", ")
+ builder.WriteString("ip_size=")
+ builder.WriteString(fmt.Sprintf("%v", ali.IPSize))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// AllowListItems is a parsable slice of AllowListItem.
+type AllowListItems []*AllowListItem
diff --git a/pkg/database/ent/allowlistitem/allowlistitem.go b/pkg/database/ent/allowlistitem/allowlistitem.go
new file mode 100644
index 00000000000..5474763eac3
--- /dev/null
+++ b/pkg/database/ent/allowlistitem/allowlistitem.go
@@ -0,0 +1,165 @@
+// Code generated by ent, DO NOT EDIT.
+
+package allowlistitem
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the allowlistitem type in the database.
+ Label = "allow_list_item"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldExpiresAt holds the string denoting the expires_at field in the database.
+ FieldExpiresAt = "expires_at"
+ // FieldComment holds the string denoting the comment field in the database.
+ FieldComment = "comment"
+ // FieldValue holds the string denoting the value field in the database.
+ FieldValue = "value"
+ // FieldStartIP holds the string denoting the start_ip field in the database.
+ FieldStartIP = "start_ip"
+ // FieldEndIP holds the string denoting the end_ip field in the database.
+ FieldEndIP = "end_ip"
+ // FieldStartSuffix holds the string denoting the start_suffix field in the database.
+ FieldStartSuffix = "start_suffix"
+ // FieldEndSuffix holds the string denoting the end_suffix field in the database.
+ FieldEndSuffix = "end_suffix"
+ // FieldIPSize holds the string denoting the ip_size field in the database.
+ FieldIPSize = "ip_size"
+ // EdgeAllowlist holds the string denoting the allowlist edge name in mutations.
+ EdgeAllowlist = "allowlist"
+ // Table holds the table name of the allowlistitem in the database.
+ Table = "allow_list_items"
+ // AllowlistTable is the table that holds the allowlist relation/edge. The primary key declared below.
+ AllowlistTable = "allow_list_allowlist_items"
+ // AllowlistInverseTable is the table name for the AllowList entity.
+ // It exists in this package in order to avoid circular dependency with the "allowlist" package.
+ AllowlistInverseTable = "allow_lists"
+)
+
+// Columns holds all SQL columns for allowlistitem fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldExpiresAt,
+ FieldComment,
+ FieldValue,
+ FieldStartIP,
+ FieldEndIP,
+ FieldStartSuffix,
+ FieldEndSuffix,
+ FieldIPSize,
+}
+
+var (
+ // AllowlistPrimaryKey and AllowlistColumn2 are the table columns denoting the
+ // primary key for the allowlist relation (M2M).
+ AllowlistPrimaryKey = []string{"allow_list_id", "allow_list_item_id"}
+)
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+)
+
+// OrderOption defines the ordering options for the AllowListItem queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByExpiresAt orders the results by the expires_at field.
+func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldExpiresAt, opts...).ToFunc()
+}
+
+// ByComment orders the results by the comment field.
+func ByComment(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldComment, opts...).ToFunc()
+}
+
+// ByValue orders the results by the value field.
+func ByValue(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldValue, opts...).ToFunc()
+}
+
+// ByStartIP orders the results by the start_ip field.
+func ByStartIP(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldStartIP, opts...).ToFunc()
+}
+
+// ByEndIP orders the results by the end_ip field.
+func ByEndIP(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldEndIP, opts...).ToFunc()
+}
+
+// ByStartSuffix orders the results by the start_suffix field.
+func ByStartSuffix(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldStartSuffix, opts...).ToFunc()
+}
+
+// ByEndSuffix orders the results by the end_suffix field.
+func ByEndSuffix(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldEndSuffix, opts...).ToFunc()
+}
+
+// ByIPSize orders the results by the ip_size field.
+func ByIPSize(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIPSize, opts...).ToFunc()
+}
+
+// ByAllowlistCount orders the results by allowlist count.
+func ByAllowlistCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newAllowlistStep(), opts...)
+ }
+}
+
+// ByAllowlist orders the results by allowlist terms.
+func ByAllowlist(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newAllowlistStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+func newAllowlistStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(AllowlistInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2M, true, AllowlistTable, AllowlistPrimaryKey...),
+ )
+}
diff --git a/pkg/database/ent/allowlistitem/where.go b/pkg/database/ent/allowlistitem/where.go
new file mode 100644
index 00000000000..32a10d77c22
--- /dev/null
+++ b/pkg/database/ent/allowlistitem/where.go
@@ -0,0 +1,664 @@
+// Code generated by ent, DO NOT EDIT.
+
+package allowlistitem
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ.
+func ExpiresAt(v time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEQ(FieldExpiresAt, v))
+}
+
+// Comment applies equality check predicate on the "comment" field. It's identical to CommentEQ.
+func Comment(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEQ(FieldComment, v))
+}
+
+// Value applies equality check predicate on the "value" field. It's identical to ValueEQ.
+func Value(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEQ(FieldValue, v))
+}
+
+// StartIP applies equality check predicate on the "start_ip" field. It's identical to StartIPEQ.
+func StartIP(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEQ(FieldStartIP, v))
+}
+
+// EndIP applies equality check predicate on the "end_ip" field. It's identical to EndIPEQ.
+func EndIP(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEQ(FieldEndIP, v))
+}
+
+// StartSuffix applies equality check predicate on the "start_suffix" field. It's identical to StartSuffixEQ.
+func StartSuffix(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEQ(FieldStartSuffix, v))
+}
+
+// EndSuffix applies equality check predicate on the "end_suffix" field. It's identical to EndSuffixEQ.
+func EndSuffix(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEQ(FieldEndSuffix, v))
+}
+
+// IPSize applies equality check predicate on the "ip_size" field. It's identical to IPSizeEQ.
+func IPSize(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEQ(FieldIPSize, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// ExpiresAtEQ applies the EQ predicate on the "expires_at" field.
+func ExpiresAtEQ(v time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEQ(FieldExpiresAt, v))
+}
+
+// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field.
+func ExpiresAtNEQ(v time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNEQ(FieldExpiresAt, v))
+}
+
+// ExpiresAtIn applies the In predicate on the "expires_at" field.
+func ExpiresAtIn(vs ...time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldIn(FieldExpiresAt, vs...))
+}
+
+// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field.
+func ExpiresAtNotIn(vs ...time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNotIn(FieldExpiresAt, vs...))
+}
+
+// ExpiresAtGT applies the GT predicate on the "expires_at" field.
+func ExpiresAtGT(v time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldGT(FieldExpiresAt, v))
+}
+
+// ExpiresAtGTE applies the GTE predicate on the "expires_at" field.
+func ExpiresAtGTE(v time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldGTE(FieldExpiresAt, v))
+}
+
+// ExpiresAtLT applies the LT predicate on the "expires_at" field.
+func ExpiresAtLT(v time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldLT(FieldExpiresAt, v))
+}
+
+// ExpiresAtLTE applies the LTE predicate on the "expires_at" field.
+func ExpiresAtLTE(v time.Time) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldLTE(FieldExpiresAt, v))
+}
+
+// ExpiresAtIsNil applies the IsNil predicate on the "expires_at" field.
+func ExpiresAtIsNil() predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldIsNull(FieldExpiresAt))
+}
+
+// ExpiresAtNotNil applies the NotNil predicate on the "expires_at" field.
+func ExpiresAtNotNil() predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNotNull(FieldExpiresAt))
+}
+
+// CommentEQ applies the EQ predicate on the "comment" field.
+func CommentEQ(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEQ(FieldComment, v))
+}
+
+// CommentNEQ applies the NEQ predicate on the "comment" field.
+func CommentNEQ(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNEQ(FieldComment, v))
+}
+
+// CommentIn applies the In predicate on the "comment" field.
+func CommentIn(vs ...string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldIn(FieldComment, vs...))
+}
+
+// CommentNotIn applies the NotIn predicate on the "comment" field.
+func CommentNotIn(vs ...string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNotIn(FieldComment, vs...))
+}
+
+// CommentGT applies the GT predicate on the "comment" field.
+func CommentGT(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldGT(FieldComment, v))
+}
+
+// CommentGTE applies the GTE predicate on the "comment" field.
+func CommentGTE(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldGTE(FieldComment, v))
+}
+
+// CommentLT applies the LT predicate on the "comment" field.
+func CommentLT(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldLT(FieldComment, v))
+}
+
+// CommentLTE applies the LTE predicate on the "comment" field.
+func CommentLTE(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldLTE(FieldComment, v))
+}
+
+// CommentContains applies the Contains predicate on the "comment" field.
+func CommentContains(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldContains(FieldComment, v))
+}
+
+// CommentHasPrefix applies the HasPrefix predicate on the "comment" field.
+func CommentHasPrefix(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldHasPrefix(FieldComment, v))
+}
+
+// CommentHasSuffix applies the HasSuffix predicate on the "comment" field.
+func CommentHasSuffix(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldHasSuffix(FieldComment, v))
+}
+
+// CommentIsNil applies the IsNil predicate on the "comment" field.
+func CommentIsNil() predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldIsNull(FieldComment))
+}
+
+// CommentNotNil applies the NotNil predicate on the "comment" field.
+func CommentNotNil() predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNotNull(FieldComment))
+}
+
+// CommentEqualFold applies the EqualFold predicate on the "comment" field.
+func CommentEqualFold(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEqualFold(FieldComment, v))
+}
+
+// CommentContainsFold applies the ContainsFold predicate on the "comment" field.
+func CommentContainsFold(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldContainsFold(FieldComment, v))
+}
+
+// ValueEQ applies the EQ predicate on the "value" field.
+func ValueEQ(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEQ(FieldValue, v))
+}
+
+// ValueNEQ applies the NEQ predicate on the "value" field.
+func ValueNEQ(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNEQ(FieldValue, v))
+}
+
+// ValueIn applies the In predicate on the "value" field.
+func ValueIn(vs ...string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldIn(FieldValue, vs...))
+}
+
+// ValueNotIn applies the NotIn predicate on the "value" field.
+func ValueNotIn(vs ...string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNotIn(FieldValue, vs...))
+}
+
+// ValueGT applies the GT predicate on the "value" field.
+func ValueGT(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldGT(FieldValue, v))
+}
+
+// ValueGTE applies the GTE predicate on the "value" field.
+func ValueGTE(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldGTE(FieldValue, v))
+}
+
+// ValueLT applies the LT predicate on the "value" field.
+func ValueLT(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldLT(FieldValue, v))
+}
+
+// ValueLTE applies the LTE predicate on the "value" field.
+func ValueLTE(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldLTE(FieldValue, v))
+}
+
+// ValueContains applies the Contains predicate on the "value" field.
+func ValueContains(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldContains(FieldValue, v))
+}
+
+// ValueHasPrefix applies the HasPrefix predicate on the "value" field.
+func ValueHasPrefix(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldHasPrefix(FieldValue, v))
+}
+
+// ValueHasSuffix applies the HasSuffix predicate on the "value" field.
+func ValueHasSuffix(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldHasSuffix(FieldValue, v))
+}
+
+// ValueEqualFold applies the EqualFold predicate on the "value" field.
+func ValueEqualFold(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEqualFold(FieldValue, v))
+}
+
+// ValueContainsFold applies the ContainsFold predicate on the "value" field.
+func ValueContainsFold(v string) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldContainsFold(FieldValue, v))
+}
+
+// StartIPEQ applies the EQ predicate on the "start_ip" field.
+func StartIPEQ(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEQ(FieldStartIP, v))
+}
+
+// StartIPNEQ applies the NEQ predicate on the "start_ip" field.
+func StartIPNEQ(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNEQ(FieldStartIP, v))
+}
+
+// StartIPIn applies the In predicate on the "start_ip" field.
+func StartIPIn(vs ...int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldIn(FieldStartIP, vs...))
+}
+
+// StartIPNotIn applies the NotIn predicate on the "start_ip" field.
+func StartIPNotIn(vs ...int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNotIn(FieldStartIP, vs...))
+}
+
+// StartIPGT applies the GT predicate on the "start_ip" field.
+func StartIPGT(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldGT(FieldStartIP, v))
+}
+
+// StartIPGTE applies the GTE predicate on the "start_ip" field.
+func StartIPGTE(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldGTE(FieldStartIP, v))
+}
+
+// StartIPLT applies the LT predicate on the "start_ip" field.
+func StartIPLT(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldLT(FieldStartIP, v))
+}
+
+// StartIPLTE applies the LTE predicate on the "start_ip" field.
+func StartIPLTE(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldLTE(FieldStartIP, v))
+}
+
+// StartIPIsNil applies the IsNil predicate on the "start_ip" field.
+func StartIPIsNil() predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldIsNull(FieldStartIP))
+}
+
+// StartIPNotNil applies the NotNil predicate on the "start_ip" field.
+func StartIPNotNil() predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNotNull(FieldStartIP))
+}
+
+// EndIPEQ applies the EQ predicate on the "end_ip" field.
+func EndIPEQ(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEQ(FieldEndIP, v))
+}
+
+// EndIPNEQ applies the NEQ predicate on the "end_ip" field.
+func EndIPNEQ(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNEQ(FieldEndIP, v))
+}
+
+// EndIPIn applies the In predicate on the "end_ip" field.
+func EndIPIn(vs ...int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldIn(FieldEndIP, vs...))
+}
+
+// EndIPNotIn applies the NotIn predicate on the "end_ip" field.
+func EndIPNotIn(vs ...int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNotIn(FieldEndIP, vs...))
+}
+
+// EndIPGT applies the GT predicate on the "end_ip" field.
+func EndIPGT(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldGT(FieldEndIP, v))
+}
+
+// EndIPGTE applies the GTE predicate on the "end_ip" field.
+func EndIPGTE(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldGTE(FieldEndIP, v))
+}
+
+// EndIPLT applies the LT predicate on the "end_ip" field.
+func EndIPLT(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldLT(FieldEndIP, v))
+}
+
+// EndIPLTE applies the LTE predicate on the "end_ip" field.
+func EndIPLTE(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldLTE(FieldEndIP, v))
+}
+
+// EndIPIsNil applies the IsNil predicate on the "end_ip" field.
+func EndIPIsNil() predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldIsNull(FieldEndIP))
+}
+
+// EndIPNotNil applies the NotNil predicate on the "end_ip" field.
+func EndIPNotNil() predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNotNull(FieldEndIP))
+}
+
+// StartSuffixEQ applies the EQ predicate on the "start_suffix" field.
+func StartSuffixEQ(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEQ(FieldStartSuffix, v))
+}
+
+// StartSuffixNEQ applies the NEQ predicate on the "start_suffix" field.
+func StartSuffixNEQ(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNEQ(FieldStartSuffix, v))
+}
+
+// StartSuffixIn applies the In predicate on the "start_suffix" field.
+func StartSuffixIn(vs ...int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldIn(FieldStartSuffix, vs...))
+}
+
+// StartSuffixNotIn applies the NotIn predicate on the "start_suffix" field.
+func StartSuffixNotIn(vs ...int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNotIn(FieldStartSuffix, vs...))
+}
+
+// StartSuffixGT applies the GT predicate on the "start_suffix" field.
+func StartSuffixGT(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldGT(FieldStartSuffix, v))
+}
+
+// StartSuffixGTE applies the GTE predicate on the "start_suffix" field.
+func StartSuffixGTE(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldGTE(FieldStartSuffix, v))
+}
+
+// StartSuffixLT applies the LT predicate on the "start_suffix" field.
+func StartSuffixLT(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldLT(FieldStartSuffix, v))
+}
+
+// StartSuffixLTE applies the LTE predicate on the "start_suffix" field.
+func StartSuffixLTE(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldLTE(FieldStartSuffix, v))
+}
+
+// StartSuffixIsNil applies the IsNil predicate on the "start_suffix" field.
+func StartSuffixIsNil() predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldIsNull(FieldStartSuffix))
+}
+
+// StartSuffixNotNil applies the NotNil predicate on the "start_suffix" field.
+func StartSuffixNotNil() predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNotNull(FieldStartSuffix))
+}
+
+// EndSuffixEQ applies the EQ predicate on the "end_suffix" field.
+func EndSuffixEQ(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEQ(FieldEndSuffix, v))
+}
+
+// EndSuffixNEQ applies the NEQ predicate on the "end_suffix" field.
+func EndSuffixNEQ(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNEQ(FieldEndSuffix, v))
+}
+
+// EndSuffixIn applies the In predicate on the "end_suffix" field.
+func EndSuffixIn(vs ...int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldIn(FieldEndSuffix, vs...))
+}
+
+// EndSuffixNotIn applies the NotIn predicate on the "end_suffix" field.
+func EndSuffixNotIn(vs ...int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNotIn(FieldEndSuffix, vs...))
+}
+
+// EndSuffixGT applies the GT predicate on the "end_suffix" field.
+func EndSuffixGT(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldGT(FieldEndSuffix, v))
+}
+
+// EndSuffixGTE applies the GTE predicate on the "end_suffix" field.
+func EndSuffixGTE(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldGTE(FieldEndSuffix, v))
+}
+
+// EndSuffixLT applies the LT predicate on the "end_suffix" field.
+func EndSuffixLT(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldLT(FieldEndSuffix, v))
+}
+
+// EndSuffixLTE applies the LTE predicate on the "end_suffix" field.
+func EndSuffixLTE(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldLTE(FieldEndSuffix, v))
+}
+
+// EndSuffixIsNil applies the IsNil predicate on the "end_suffix" field.
+func EndSuffixIsNil() predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldIsNull(FieldEndSuffix))
+}
+
+// EndSuffixNotNil applies the NotNil predicate on the "end_suffix" field.
+func EndSuffixNotNil() predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNotNull(FieldEndSuffix))
+}
+
+// IPSizeEQ applies the EQ predicate on the "ip_size" field.
+func IPSizeEQ(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldEQ(FieldIPSize, v))
+}
+
+// IPSizeNEQ applies the NEQ predicate on the "ip_size" field.
+func IPSizeNEQ(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNEQ(FieldIPSize, v))
+}
+
+// IPSizeIn applies the In predicate on the "ip_size" field.
+func IPSizeIn(vs ...int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldIn(FieldIPSize, vs...))
+}
+
+// IPSizeNotIn applies the NotIn predicate on the "ip_size" field.
+func IPSizeNotIn(vs ...int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNotIn(FieldIPSize, vs...))
+}
+
+// IPSizeGT applies the GT predicate on the "ip_size" field.
+func IPSizeGT(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldGT(FieldIPSize, v))
+}
+
+// IPSizeGTE applies the GTE predicate on the "ip_size" field.
+func IPSizeGTE(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldGTE(FieldIPSize, v))
+}
+
+// IPSizeLT applies the LT predicate on the "ip_size" field.
+func IPSizeLT(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldLT(FieldIPSize, v))
+}
+
+// IPSizeLTE applies the LTE predicate on the "ip_size" field.
+func IPSizeLTE(v int64) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldLTE(FieldIPSize, v))
+}
+
+// IPSizeIsNil applies the IsNil predicate on the "ip_size" field.
+func IPSizeIsNil() predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldIsNull(FieldIPSize))
+}
+
+// IPSizeNotNil applies the NotNil predicate on the "ip_size" field.
+func IPSizeNotNil() predicate.AllowListItem {
+ return predicate.AllowListItem(sql.FieldNotNull(FieldIPSize))
+}
+
+// HasAllowlist applies the HasEdge predicate on the "allowlist" edge.
+func HasAllowlist() predicate.AllowListItem {
+ return predicate.AllowListItem(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2M, true, AllowlistTable, AllowlistPrimaryKey...),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasAllowlistWith applies the HasEdge predicate on the "allowlist" edge with a given conditions (other predicates).
+func HasAllowlistWith(preds ...predicate.AllowList) predicate.AllowListItem {
+ return predicate.AllowListItem(func(s *sql.Selector) {
+ step := newAllowlistStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.AllowListItem) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.AllowListItem) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.AllowListItem) predicate.AllowListItem {
+ return predicate.AllowListItem(sql.NotPredicates(p))
+}
diff --git a/pkg/database/ent/allowlistitem_create.go b/pkg/database/ent/allowlistitem_create.go
new file mode 100644
index 00000000000..502cec11db7
--- /dev/null
+++ b/pkg/database/ent/allowlistitem_create.go
@@ -0,0 +1,398 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem"
+)
+
+// AllowListItemCreate is the builder for creating a AllowListItem entity.
+type AllowListItemCreate struct {
+ config
+ mutation *AllowListItemMutation
+ hooks []Hook
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (alic *AllowListItemCreate) SetCreatedAt(t time.Time) *AllowListItemCreate {
+ alic.mutation.SetCreatedAt(t)
+ return alic
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (alic *AllowListItemCreate) SetNillableCreatedAt(t *time.Time) *AllowListItemCreate {
+ if t != nil {
+ alic.SetCreatedAt(*t)
+ }
+ return alic
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (alic *AllowListItemCreate) SetUpdatedAt(t time.Time) *AllowListItemCreate {
+ alic.mutation.SetUpdatedAt(t)
+ return alic
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (alic *AllowListItemCreate) SetNillableUpdatedAt(t *time.Time) *AllowListItemCreate {
+ if t != nil {
+ alic.SetUpdatedAt(*t)
+ }
+ return alic
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (alic *AllowListItemCreate) SetExpiresAt(t time.Time) *AllowListItemCreate {
+ alic.mutation.SetExpiresAt(t)
+ return alic
+}
+
+// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
+func (alic *AllowListItemCreate) SetNillableExpiresAt(t *time.Time) *AllowListItemCreate {
+ if t != nil {
+ alic.SetExpiresAt(*t)
+ }
+ return alic
+}
+
+// SetComment sets the "comment" field.
+func (alic *AllowListItemCreate) SetComment(s string) *AllowListItemCreate {
+ alic.mutation.SetComment(s)
+ return alic
+}
+
+// SetNillableComment sets the "comment" field if the given value is not nil.
+func (alic *AllowListItemCreate) SetNillableComment(s *string) *AllowListItemCreate {
+ if s != nil {
+ alic.SetComment(*s)
+ }
+ return alic
+}
+
+// SetValue sets the "value" field.
+func (alic *AllowListItemCreate) SetValue(s string) *AllowListItemCreate {
+ alic.mutation.SetValue(s)
+ return alic
+}
+
+// SetStartIP sets the "start_ip" field.
+func (alic *AllowListItemCreate) SetStartIP(i int64) *AllowListItemCreate {
+ alic.mutation.SetStartIP(i)
+ return alic
+}
+
+// SetNillableStartIP sets the "start_ip" field if the given value is not nil.
+func (alic *AllowListItemCreate) SetNillableStartIP(i *int64) *AllowListItemCreate {
+ if i != nil {
+ alic.SetStartIP(*i)
+ }
+ return alic
+}
+
+// SetEndIP sets the "end_ip" field.
+func (alic *AllowListItemCreate) SetEndIP(i int64) *AllowListItemCreate {
+ alic.mutation.SetEndIP(i)
+ return alic
+}
+
+// SetNillableEndIP sets the "end_ip" field if the given value is not nil.
+func (alic *AllowListItemCreate) SetNillableEndIP(i *int64) *AllowListItemCreate {
+ if i != nil {
+ alic.SetEndIP(*i)
+ }
+ return alic
+}
+
+// SetStartSuffix sets the "start_suffix" field.
+func (alic *AllowListItemCreate) SetStartSuffix(i int64) *AllowListItemCreate {
+ alic.mutation.SetStartSuffix(i)
+ return alic
+}
+
+// SetNillableStartSuffix sets the "start_suffix" field if the given value is not nil.
+func (alic *AllowListItemCreate) SetNillableStartSuffix(i *int64) *AllowListItemCreate {
+ if i != nil {
+ alic.SetStartSuffix(*i)
+ }
+ return alic
+}
+
+// SetEndSuffix sets the "end_suffix" field.
+func (alic *AllowListItemCreate) SetEndSuffix(i int64) *AllowListItemCreate {
+ alic.mutation.SetEndSuffix(i)
+ return alic
+}
+
+// SetNillableEndSuffix sets the "end_suffix" field if the given value is not nil.
+func (alic *AllowListItemCreate) SetNillableEndSuffix(i *int64) *AllowListItemCreate {
+ if i != nil {
+ alic.SetEndSuffix(*i)
+ }
+ return alic
+}
+
+// SetIPSize sets the "ip_size" field.
+func (alic *AllowListItemCreate) SetIPSize(i int64) *AllowListItemCreate {
+ alic.mutation.SetIPSize(i)
+ return alic
+}
+
+// SetNillableIPSize sets the "ip_size" field if the given value is not nil.
+func (alic *AllowListItemCreate) SetNillableIPSize(i *int64) *AllowListItemCreate {
+ if i != nil {
+ alic.SetIPSize(*i)
+ }
+ return alic
+}
+
+// AddAllowlistIDs adds the "allowlist" edge to the AllowList entity by IDs.
+func (alic *AllowListItemCreate) AddAllowlistIDs(ids ...int) *AllowListItemCreate {
+ alic.mutation.AddAllowlistIDs(ids...)
+ return alic
+}
+
+// AddAllowlist adds the "allowlist" edges to the AllowList entity.
+func (alic *AllowListItemCreate) AddAllowlist(a ...*AllowList) *AllowListItemCreate {
+ ids := make([]int, len(a))
+ for i := range a {
+ ids[i] = a[i].ID
+ }
+ return alic.AddAllowlistIDs(ids...)
+}
+
+// Mutation returns the AllowListItemMutation object of the builder.
+func (alic *AllowListItemCreate) Mutation() *AllowListItemMutation {
+ return alic.mutation
+}
+
+// Save creates the AllowListItem in the database.
+func (alic *AllowListItemCreate) Save(ctx context.Context) (*AllowListItem, error) {
+ alic.defaults()
+ return withHooks(ctx, alic.sqlSave, alic.mutation, alic.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (alic *AllowListItemCreate) SaveX(ctx context.Context) *AllowListItem {
+ v, err := alic.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (alic *AllowListItemCreate) Exec(ctx context.Context) error {
+ _, err := alic.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (alic *AllowListItemCreate) ExecX(ctx context.Context) {
+ if err := alic.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (alic *AllowListItemCreate) defaults() {
+ if _, ok := alic.mutation.CreatedAt(); !ok {
+ v := allowlistitem.DefaultCreatedAt()
+ alic.mutation.SetCreatedAt(v)
+ }
+ if _, ok := alic.mutation.UpdatedAt(); !ok {
+ v := allowlistitem.DefaultUpdatedAt()
+ alic.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (alic *AllowListItemCreate) check() error {
+ if _, ok := alic.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AllowListItem.created_at"`)}
+ }
+ if _, ok := alic.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AllowListItem.updated_at"`)}
+ }
+ if _, ok := alic.mutation.Value(); !ok {
+ return &ValidationError{Name: "value", err: errors.New(`ent: missing required field "AllowListItem.value"`)}
+ }
+ return nil
+}
+
+func (alic *AllowListItemCreate) sqlSave(ctx context.Context) (*AllowListItem, error) {
+ if err := alic.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := alic.createSpec()
+ if err := sqlgraph.CreateNode(ctx, alic.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int(id)
+ alic.mutation.id = &_node.ID
+ alic.mutation.done = true
+ return _node, nil
+}
+
+func (alic *AllowListItemCreate) createSpec() (*AllowListItem, *sqlgraph.CreateSpec) {
+ var (
+ _node = &AllowListItem{config: alic.config}
+ _spec = sqlgraph.NewCreateSpec(allowlistitem.Table, sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt))
+ )
+ if value, ok := alic.mutation.CreatedAt(); ok {
+ _spec.SetField(allowlistitem.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := alic.mutation.UpdatedAt(); ok {
+ _spec.SetField(allowlistitem.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := alic.mutation.ExpiresAt(); ok {
+ _spec.SetField(allowlistitem.FieldExpiresAt, field.TypeTime, value)
+ _node.ExpiresAt = value
+ }
+ if value, ok := alic.mutation.Comment(); ok {
+ _spec.SetField(allowlistitem.FieldComment, field.TypeString, value)
+ _node.Comment = value
+ }
+ if value, ok := alic.mutation.Value(); ok {
+ _spec.SetField(allowlistitem.FieldValue, field.TypeString, value)
+ _node.Value = value
+ }
+ if value, ok := alic.mutation.StartIP(); ok {
+ _spec.SetField(allowlistitem.FieldStartIP, field.TypeInt64, value)
+ _node.StartIP = value
+ }
+ if value, ok := alic.mutation.EndIP(); ok {
+ _spec.SetField(allowlistitem.FieldEndIP, field.TypeInt64, value)
+ _node.EndIP = value
+ }
+ if value, ok := alic.mutation.StartSuffix(); ok {
+ _spec.SetField(allowlistitem.FieldStartSuffix, field.TypeInt64, value)
+ _node.StartSuffix = value
+ }
+ if value, ok := alic.mutation.EndSuffix(); ok {
+ _spec.SetField(allowlistitem.FieldEndSuffix, field.TypeInt64, value)
+ _node.EndSuffix = value
+ }
+ if value, ok := alic.mutation.IPSize(); ok {
+ _spec.SetField(allowlistitem.FieldIPSize, field.TypeInt64, value)
+ _node.IPSize = value
+ }
+ if nodes := alic.mutation.AllowlistIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2M,
+ Inverse: true,
+ Table: allowlistitem.AllowlistTable,
+ Columns: allowlistitem.AllowlistPrimaryKey,
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// AllowListItemCreateBulk is the builder for creating many AllowListItem entities in bulk.
+type AllowListItemCreateBulk struct {
+ config
+ err error
+ builders []*AllowListItemCreate
+}
+
+// Save creates the AllowListItem entities in the database.
+func (alicb *AllowListItemCreateBulk) Save(ctx context.Context) ([]*AllowListItem, error) {
+ if alicb.err != nil {
+ return nil, alicb.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(alicb.builders))
+ nodes := make([]*AllowListItem, len(alicb.builders))
+ mutators := make([]Mutator, len(alicb.builders))
+ for i := range alicb.builders {
+ func(i int, root context.Context) {
+ builder := alicb.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*AllowListItemMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, alicb.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, alicb.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, alicb.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (alicb *AllowListItemCreateBulk) SaveX(ctx context.Context) []*AllowListItem {
+ v, err := alicb.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (alicb *AllowListItemCreateBulk) Exec(ctx context.Context) error {
+ _, err := alicb.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (alicb *AllowListItemCreateBulk) ExecX(ctx context.Context) {
+ if err := alicb.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/database/ent/allowlistitem_delete.go b/pkg/database/ent/allowlistitem_delete.go
new file mode 100644
index 00000000000..87b340012df
--- /dev/null
+++ b/pkg/database/ent/allowlistitem_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate"
+)
+
+// AllowListItemDelete is the builder for deleting a AllowListItem entity.
+type AllowListItemDelete struct {
+ config
+ hooks []Hook
+ mutation *AllowListItemMutation
+}
+
+// Where appends a list predicates to the AllowListItemDelete builder.
+func (alid *AllowListItemDelete) Where(ps ...predicate.AllowListItem) *AllowListItemDelete {
+ alid.mutation.Where(ps...)
+ return alid
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (alid *AllowListItemDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, alid.sqlExec, alid.mutation, alid.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (alid *AllowListItemDelete) ExecX(ctx context.Context) int {
+ n, err := alid.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (alid *AllowListItemDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(allowlistitem.Table, sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt))
+ if ps := alid.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, alid.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ alid.mutation.done = true
+ return affected, err
+}
+
+// AllowListItemDeleteOne is the builder for deleting a single AllowListItem entity.
+type AllowListItemDeleteOne struct {
+ alid *AllowListItemDelete
+}
+
+// Where appends a list predicates to the AllowListItemDelete builder.
+func (alido *AllowListItemDeleteOne) Where(ps ...predicate.AllowListItem) *AllowListItemDeleteOne {
+ alido.alid.mutation.Where(ps...)
+ return alido
+}
+
+// Exec executes the deletion query.
+func (alido *AllowListItemDeleteOne) Exec(ctx context.Context) error {
+ n, err := alido.alid.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{allowlistitem.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (alido *AllowListItemDeleteOne) ExecX(ctx context.Context) {
+ if err := alido.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/database/ent/allowlistitem_query.go b/pkg/database/ent/allowlistitem_query.go
new file mode 100644
index 00000000000..628b680a27c
--- /dev/null
+++ b/pkg/database/ent/allowlistitem_query.go
@@ -0,0 +1,636 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "database/sql/driver"
+ "fmt"
+ "math"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate"
+)
+
+// AllowListItemQuery is the builder for querying AllowListItem entities.
+type AllowListItemQuery struct {
+ config
+ ctx *QueryContext
+ order []allowlistitem.OrderOption
+ inters []Interceptor
+ predicates []predicate.AllowListItem
+ withAllowlist *AllowListQuery
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the AllowListItemQuery builder.
+func (aliq *AllowListItemQuery) Where(ps ...predicate.AllowListItem) *AllowListItemQuery {
+ aliq.predicates = append(aliq.predicates, ps...)
+ return aliq
+}
+
+// Limit the number of records to be returned by this query.
+func (aliq *AllowListItemQuery) Limit(limit int) *AllowListItemQuery {
+ aliq.ctx.Limit = &limit
+ return aliq
+}
+
+// Offset to start from.
+func (aliq *AllowListItemQuery) Offset(offset int) *AllowListItemQuery {
+ aliq.ctx.Offset = &offset
+ return aliq
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (aliq *AllowListItemQuery) Unique(unique bool) *AllowListItemQuery {
+ aliq.ctx.Unique = &unique
+ return aliq
+}
+
+// Order specifies how the records should be ordered.
+func (aliq *AllowListItemQuery) Order(o ...allowlistitem.OrderOption) *AllowListItemQuery {
+ aliq.order = append(aliq.order, o...)
+ return aliq
+}
+
+// QueryAllowlist chains the current query on the "allowlist" edge.
+func (aliq *AllowListItemQuery) QueryAllowlist() *AllowListQuery {
+ query := (&AllowListClient{config: aliq.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := aliq.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := aliq.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(allowlistitem.Table, allowlistitem.FieldID, selector),
+ sqlgraph.To(allowlist.Table, allowlist.FieldID),
+ sqlgraph.Edge(sqlgraph.M2M, true, allowlistitem.AllowlistTable, allowlistitem.AllowlistPrimaryKey...),
+ )
+ fromU = sqlgraph.SetNeighbors(aliq.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first AllowListItem entity from the query.
+// Returns a *NotFoundError when no AllowListItem was found.
+func (aliq *AllowListItemQuery) First(ctx context.Context) (*AllowListItem, error) {
+ nodes, err := aliq.Limit(1).All(setContextOp(ctx, aliq.ctx, "First"))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{allowlistitem.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (aliq *AllowListItemQuery) FirstX(ctx context.Context) *AllowListItem {
+ node, err := aliq.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first AllowListItem ID from the query.
+// Returns a *NotFoundError when no AllowListItem ID was found.
+func (aliq *AllowListItemQuery) FirstID(ctx context.Context) (id int, err error) {
+ var ids []int
+ if ids, err = aliq.Limit(1).IDs(setContextOp(ctx, aliq.ctx, "FirstID")); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{allowlistitem.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (aliq *AllowListItemQuery) FirstIDX(ctx context.Context) int {
+ id, err := aliq.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single AllowListItem entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one AllowListItem entity is found.
+// Returns a *NotFoundError when no AllowListItem entities are found.
+func (aliq *AllowListItemQuery) Only(ctx context.Context) (*AllowListItem, error) {
+ nodes, err := aliq.Limit(2).All(setContextOp(ctx, aliq.ctx, "Only"))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{allowlistitem.Label}
+ default:
+ return nil, &NotSingularError{allowlistitem.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (aliq *AllowListItemQuery) OnlyX(ctx context.Context) *AllowListItem {
+ node, err := aliq.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only AllowListItem ID in the query.
+// Returns a *NotSingularError when more than one AllowListItem ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (aliq *AllowListItemQuery) OnlyID(ctx context.Context) (id int, err error) {
+ var ids []int
+ if ids, err = aliq.Limit(2).IDs(setContextOp(ctx, aliq.ctx, "OnlyID")); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{allowlistitem.Label}
+ default:
+ err = &NotSingularError{allowlistitem.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (aliq *AllowListItemQuery) OnlyIDX(ctx context.Context) int {
+ id, err := aliq.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of AllowListItems.
+func (aliq *AllowListItemQuery) All(ctx context.Context) ([]*AllowListItem, error) {
+ ctx = setContextOp(ctx, aliq.ctx, "All")
+ if err := aliq.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*AllowListItem, *AllowListItemQuery]()
+ return withInterceptors[[]*AllowListItem](ctx, aliq, qr, aliq.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (aliq *AllowListItemQuery) AllX(ctx context.Context) []*AllowListItem {
+ nodes, err := aliq.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of AllowListItem IDs.
+func (aliq *AllowListItemQuery) IDs(ctx context.Context) (ids []int, err error) {
+ if aliq.ctx.Unique == nil && aliq.path != nil {
+ aliq.Unique(true)
+ }
+ ctx = setContextOp(ctx, aliq.ctx, "IDs")
+ if err = aliq.Select(allowlistitem.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (aliq *AllowListItemQuery) IDsX(ctx context.Context) []int {
+ ids, err := aliq.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (aliq *AllowListItemQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, aliq.ctx, "Count")
+ if err := aliq.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, aliq, querierCount[*AllowListItemQuery](), aliq.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (aliq *AllowListItemQuery) CountX(ctx context.Context) int {
+ count, err := aliq.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (aliq *AllowListItemQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, aliq.ctx, "Exist")
+ switch _, err := aliq.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (aliq *AllowListItemQuery) ExistX(ctx context.Context) bool {
+ exist, err := aliq.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the AllowListItemQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (aliq *AllowListItemQuery) Clone() *AllowListItemQuery {
+ if aliq == nil {
+ return nil
+ }
+ return &AllowListItemQuery{
+ config: aliq.config,
+ ctx: aliq.ctx.Clone(),
+ order: append([]allowlistitem.OrderOption{}, aliq.order...),
+ inters: append([]Interceptor{}, aliq.inters...),
+ predicates: append([]predicate.AllowListItem{}, aliq.predicates...),
+ withAllowlist: aliq.withAllowlist.Clone(),
+ // clone intermediate query.
+ sql: aliq.sql.Clone(),
+ path: aliq.path,
+ }
+}
+
+// WithAllowlist tells the query-builder to eager-load the nodes that are connected to
+// the "allowlist" edge. The optional arguments are used to configure the query builder of the edge.
+func (aliq *AllowListItemQuery) WithAllowlist(opts ...func(*AllowListQuery)) *AllowListItemQuery {
+ query := (&AllowListClient{config: aliq.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ aliq.withAllowlist = query
+ return aliq
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.AllowListItem.Query().
+// GroupBy(allowlistitem.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (aliq *AllowListItemQuery) GroupBy(field string, fields ...string) *AllowListItemGroupBy {
+ aliq.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &AllowListItemGroupBy{build: aliq}
+ grbuild.flds = &aliq.ctx.Fields
+ grbuild.label = allowlistitem.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.AllowListItem.Query().
+// Select(allowlistitem.FieldCreatedAt).
+// Scan(ctx, &v)
+func (aliq *AllowListItemQuery) Select(fields ...string) *AllowListItemSelect {
+ aliq.ctx.Fields = append(aliq.ctx.Fields, fields...)
+ sbuild := &AllowListItemSelect{AllowListItemQuery: aliq}
+ sbuild.label = allowlistitem.Label
+ sbuild.flds, sbuild.scan = &aliq.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a AllowListItemSelect configured with the given aggregations.
+func (aliq *AllowListItemQuery) Aggregate(fns ...AggregateFunc) *AllowListItemSelect {
+ return aliq.Select().Aggregate(fns...)
+}
+
+func (aliq *AllowListItemQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range aliq.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, aliq); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range aliq.ctx.Fields {
+ if !allowlistitem.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if aliq.path != nil {
+ prev, err := aliq.path(ctx)
+ if err != nil {
+ return err
+ }
+ aliq.sql = prev
+ }
+ return nil
+}
+
+func (aliq *AllowListItemQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AllowListItem, error) {
+ var (
+ nodes = []*AllowListItem{}
+ _spec = aliq.querySpec()
+ loadedTypes = [1]bool{
+ aliq.withAllowlist != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*AllowListItem).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &AllowListItem{config: aliq.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, aliq.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := aliq.withAllowlist; query != nil {
+ if err := aliq.loadAllowlist(ctx, query, nodes,
+ func(n *AllowListItem) { n.Edges.Allowlist = []*AllowList{} },
+ func(n *AllowListItem, e *AllowList) { n.Edges.Allowlist = append(n.Edges.Allowlist, e) }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (aliq *AllowListItemQuery) loadAllowlist(ctx context.Context, query *AllowListQuery, nodes []*AllowListItem, init func(*AllowListItem), assign func(*AllowListItem, *AllowList)) error {
+ edgeIDs := make([]driver.Value, len(nodes))
+ byID := make(map[int]*AllowListItem)
+ nids := make(map[int]map[*AllowListItem]struct{})
+ for i, node := range nodes {
+ edgeIDs[i] = node.ID
+ byID[node.ID] = node
+ if init != nil {
+ init(node)
+ }
+ }
+ query.Where(func(s *sql.Selector) {
+ joinT := sql.Table(allowlistitem.AllowlistTable)
+ s.Join(joinT).On(s.C(allowlist.FieldID), joinT.C(allowlistitem.AllowlistPrimaryKey[0]))
+ s.Where(sql.InValues(joinT.C(allowlistitem.AllowlistPrimaryKey[1]), edgeIDs...))
+ columns := s.SelectedColumns()
+ s.Select(joinT.C(allowlistitem.AllowlistPrimaryKey[1]))
+ s.AppendSelect(columns...)
+ s.SetDistinct(false)
+ })
+ if err := query.prepareQuery(ctx); err != nil {
+ return err
+ }
+ qr := QuerierFunc(func(ctx context.Context, q Query) (Value, error) {
+ return query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
+ assign := spec.Assign
+ values := spec.ScanValues
+ spec.ScanValues = func(columns []string) ([]any, error) {
+ values, err := values(columns[1:])
+ if err != nil {
+ return nil, err
+ }
+ return append([]any{new(sql.NullInt64)}, values...), nil
+ }
+ spec.Assign = func(columns []string, values []any) error {
+ outValue := int(values[0].(*sql.NullInt64).Int64)
+ inValue := int(values[1].(*sql.NullInt64).Int64)
+ if nids[inValue] == nil {
+ nids[inValue] = map[*AllowListItem]struct{}{byID[outValue]: {}}
+ return assign(columns[1:], values[1:])
+ }
+ nids[inValue][byID[outValue]] = struct{}{}
+ return nil
+ }
+ })
+ })
+ neighbors, err := withInterceptors[[]*AllowList](ctx, query, qr, query.inters)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected "allowlist" node returned %v`, n.ID)
+ }
+ for kn := range nodes {
+ assign(kn, n)
+ }
+ }
+ return nil
+}
+
+func (aliq *AllowListItemQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := aliq.querySpec()
+ _spec.Node.Columns = aliq.ctx.Fields
+ if len(aliq.ctx.Fields) > 0 {
+ _spec.Unique = aliq.ctx.Unique != nil && *aliq.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, aliq.driver, _spec)
+}
+
+func (aliq *AllowListItemQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(allowlistitem.Table, allowlistitem.Columns, sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt))
+ _spec.From = aliq.sql
+ if unique := aliq.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if aliq.path != nil {
+ _spec.Unique = true
+ }
+ if fields := aliq.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, allowlistitem.FieldID)
+ for i := range fields {
+ if fields[i] != allowlistitem.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ }
+ if ps := aliq.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := aliq.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := aliq.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := aliq.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (aliq *AllowListItemQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(aliq.driver.Dialect())
+ t1 := builder.Table(allowlistitem.Table)
+ columns := aliq.ctx.Fields
+ if len(columns) == 0 {
+ columns = allowlistitem.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if aliq.sql != nil {
+ selector = aliq.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if aliq.ctx.Unique != nil && *aliq.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, p := range aliq.predicates {
+ p(selector)
+ }
+ for _, p := range aliq.order {
+ p(selector)
+ }
+ if offset := aliq.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := aliq.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// AllowListItemGroupBy is the group-by builder for AllowListItem entities.
+type AllowListItemGroupBy struct {
+ selector
+ build *AllowListItemQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (aligb *AllowListItemGroupBy) Aggregate(fns ...AggregateFunc) *AllowListItemGroupBy {
+ aligb.fns = append(aligb.fns, fns...)
+ return aligb
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (aligb *AllowListItemGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, aligb.build.ctx, "GroupBy")
+ if err := aligb.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AllowListItemQuery, *AllowListItemGroupBy](ctx, aligb.build, aligb, aligb.build.inters, v)
+}
+
+func (aligb *AllowListItemGroupBy) sqlScan(ctx context.Context, root *AllowListItemQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(aligb.fns))
+ for _, fn := range aligb.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*aligb.flds)+len(aligb.fns))
+ for _, f := range *aligb.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*aligb.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := aligb.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// AllowListItemSelect is the builder for selecting fields of AllowListItem entities.
+type AllowListItemSelect struct {
+ *AllowListItemQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (alis *AllowListItemSelect) Aggregate(fns ...AggregateFunc) *AllowListItemSelect {
+ alis.fns = append(alis.fns, fns...)
+ return alis
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (alis *AllowListItemSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, alis.ctx, "Select")
+ if err := alis.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AllowListItemQuery, *AllowListItemSelect](ctx, alis.AllowListItemQuery, alis, alis.inters, v)
+}
+
+func (alis *AllowListItemSelect) sqlScan(ctx context.Context, root *AllowListItemQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(alis.fns))
+ for _, fn := range alis.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*alis.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := alis.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/pkg/database/ent/allowlistitem_update.go b/pkg/database/ent/allowlistitem_update.go
new file mode 100644
index 00000000000..e6878955afe
--- /dev/null
+++ b/pkg/database/ent/allowlistitem_update.go
@@ -0,0 +1,463 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate"
+)
+
+// AllowListItemUpdate is the builder for updating AllowListItem entities.
+type AllowListItemUpdate struct {
+ config
+ hooks []Hook
+ mutation *AllowListItemMutation
+}
+
+// Where appends a list predicates to the AllowListItemUpdate builder.
+func (aliu *AllowListItemUpdate) Where(ps ...predicate.AllowListItem) *AllowListItemUpdate {
+ aliu.mutation.Where(ps...)
+ return aliu
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (aliu *AllowListItemUpdate) SetUpdatedAt(t time.Time) *AllowListItemUpdate {
+ aliu.mutation.SetUpdatedAt(t)
+ return aliu
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (aliu *AllowListItemUpdate) SetExpiresAt(t time.Time) *AllowListItemUpdate {
+ aliu.mutation.SetExpiresAt(t)
+ return aliu
+}
+
+// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
+func (aliu *AllowListItemUpdate) SetNillableExpiresAt(t *time.Time) *AllowListItemUpdate {
+ if t != nil {
+ aliu.SetExpiresAt(*t)
+ }
+ return aliu
+}
+
+// ClearExpiresAt clears the value of the "expires_at" field.
+func (aliu *AllowListItemUpdate) ClearExpiresAt() *AllowListItemUpdate {
+ aliu.mutation.ClearExpiresAt()
+ return aliu
+}
+
+// AddAllowlistIDs adds the "allowlist" edge to the AllowList entity by IDs.
+func (aliu *AllowListItemUpdate) AddAllowlistIDs(ids ...int) *AllowListItemUpdate {
+ aliu.mutation.AddAllowlistIDs(ids...)
+ return aliu
+}
+
+// AddAllowlist adds the "allowlist" edges to the AllowList entity.
+func (aliu *AllowListItemUpdate) AddAllowlist(a ...*AllowList) *AllowListItemUpdate {
+ ids := make([]int, len(a))
+ for i := range a {
+ ids[i] = a[i].ID
+ }
+ return aliu.AddAllowlistIDs(ids...)
+}
+
+// Mutation returns the AllowListItemMutation object of the builder.
+func (aliu *AllowListItemUpdate) Mutation() *AllowListItemMutation {
+ return aliu.mutation
+}
+
+// ClearAllowlist clears all "allowlist" edges to the AllowList entity.
+func (aliu *AllowListItemUpdate) ClearAllowlist() *AllowListItemUpdate {
+ aliu.mutation.ClearAllowlist()
+ return aliu
+}
+
+// RemoveAllowlistIDs removes the "allowlist" edge to AllowList entities by IDs.
+func (aliu *AllowListItemUpdate) RemoveAllowlistIDs(ids ...int) *AllowListItemUpdate {
+ aliu.mutation.RemoveAllowlistIDs(ids...)
+ return aliu
+}
+
+// RemoveAllowlist removes "allowlist" edges to AllowList entities.
+func (aliu *AllowListItemUpdate) RemoveAllowlist(a ...*AllowList) *AllowListItemUpdate {
+ ids := make([]int, len(a))
+ for i := range a {
+ ids[i] = a[i].ID
+ }
+ return aliu.RemoveAllowlistIDs(ids...)
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (aliu *AllowListItemUpdate) Save(ctx context.Context) (int, error) {
+ aliu.defaults()
+ return withHooks(ctx, aliu.sqlSave, aliu.mutation, aliu.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (aliu *AllowListItemUpdate) SaveX(ctx context.Context) int {
+ affected, err := aliu.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (aliu *AllowListItemUpdate) Exec(ctx context.Context) error {
+ _, err := aliu.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (aliu *AllowListItemUpdate) ExecX(ctx context.Context) {
+ if err := aliu.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (aliu *AllowListItemUpdate) defaults() {
+ if _, ok := aliu.mutation.UpdatedAt(); !ok {
+ v := allowlistitem.UpdateDefaultUpdatedAt()
+ aliu.mutation.SetUpdatedAt(v)
+ }
+}
+
+func (aliu *AllowListItemUpdate) sqlSave(ctx context.Context) (n int, err error) {
+ _spec := sqlgraph.NewUpdateSpec(allowlistitem.Table, allowlistitem.Columns, sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt))
+ if ps := aliu.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := aliu.mutation.UpdatedAt(); ok {
+ _spec.SetField(allowlistitem.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := aliu.mutation.ExpiresAt(); ok {
+ _spec.SetField(allowlistitem.FieldExpiresAt, field.TypeTime, value)
+ }
+ if aliu.mutation.ExpiresAtCleared() {
+ _spec.ClearField(allowlistitem.FieldExpiresAt, field.TypeTime)
+ }
+ if aliu.mutation.CommentCleared() {
+ _spec.ClearField(allowlistitem.FieldComment, field.TypeString)
+ }
+ if aliu.mutation.StartIPCleared() {
+ _spec.ClearField(allowlistitem.FieldStartIP, field.TypeInt64)
+ }
+ if aliu.mutation.EndIPCleared() {
+ _spec.ClearField(allowlistitem.FieldEndIP, field.TypeInt64)
+ }
+ if aliu.mutation.StartSuffixCleared() {
+ _spec.ClearField(allowlistitem.FieldStartSuffix, field.TypeInt64)
+ }
+ if aliu.mutation.EndSuffixCleared() {
+ _spec.ClearField(allowlistitem.FieldEndSuffix, field.TypeInt64)
+ }
+ if aliu.mutation.IPSizeCleared() {
+ _spec.ClearField(allowlistitem.FieldIPSize, field.TypeInt64)
+ }
+ if aliu.mutation.AllowlistCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2M,
+ Inverse: true,
+ Table: allowlistitem.AllowlistTable,
+ Columns: allowlistitem.AllowlistPrimaryKey,
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := aliu.mutation.RemovedAllowlistIDs(); len(nodes) > 0 && !aliu.mutation.AllowlistCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2M,
+ Inverse: true,
+ Table: allowlistitem.AllowlistTable,
+ Columns: allowlistitem.AllowlistPrimaryKey,
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := aliu.mutation.AllowlistIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2M,
+ Inverse: true,
+ Table: allowlistitem.AllowlistTable,
+ Columns: allowlistitem.AllowlistPrimaryKey,
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if n, err = sqlgraph.UpdateNodes(ctx, aliu.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{allowlistitem.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ aliu.mutation.done = true
+ return n, nil
+}
+
+// AllowListItemUpdateOne is the builder for updating a single AllowListItem entity.
+type AllowListItemUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *AllowListItemMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (aliuo *AllowListItemUpdateOne) SetUpdatedAt(t time.Time) *AllowListItemUpdateOne {
+ aliuo.mutation.SetUpdatedAt(t)
+ return aliuo
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (aliuo *AllowListItemUpdateOne) SetExpiresAt(t time.Time) *AllowListItemUpdateOne {
+ aliuo.mutation.SetExpiresAt(t)
+ return aliuo
+}
+
+// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
+func (aliuo *AllowListItemUpdateOne) SetNillableExpiresAt(t *time.Time) *AllowListItemUpdateOne {
+ if t != nil {
+ aliuo.SetExpiresAt(*t)
+ }
+ return aliuo
+}
+
+// ClearExpiresAt clears the value of the "expires_at" field.
+func (aliuo *AllowListItemUpdateOne) ClearExpiresAt() *AllowListItemUpdateOne {
+ aliuo.mutation.ClearExpiresAt()
+ return aliuo
+}
+
+// AddAllowlistIDs adds the "allowlist" edge to the AllowList entity by IDs.
+func (aliuo *AllowListItemUpdateOne) AddAllowlistIDs(ids ...int) *AllowListItemUpdateOne {
+ aliuo.mutation.AddAllowlistIDs(ids...)
+ return aliuo
+}
+
+// AddAllowlist adds the "allowlist" edges to the AllowList entity.
+func (aliuo *AllowListItemUpdateOne) AddAllowlist(a ...*AllowList) *AllowListItemUpdateOne {
+ ids := make([]int, len(a))
+ for i := range a {
+ ids[i] = a[i].ID
+ }
+ return aliuo.AddAllowlistIDs(ids...)
+}
+
+// Mutation returns the AllowListItemMutation object of the builder.
+func (aliuo *AllowListItemUpdateOne) Mutation() *AllowListItemMutation {
+ return aliuo.mutation
+}
+
+// ClearAllowlist clears all "allowlist" edges to the AllowList entity.
+func (aliuo *AllowListItemUpdateOne) ClearAllowlist() *AllowListItemUpdateOne {
+ aliuo.mutation.ClearAllowlist()
+ return aliuo
+}
+
+// RemoveAllowlistIDs removes the "allowlist" edge to AllowList entities by IDs.
+func (aliuo *AllowListItemUpdateOne) RemoveAllowlistIDs(ids ...int) *AllowListItemUpdateOne {
+ aliuo.mutation.RemoveAllowlistIDs(ids...)
+ return aliuo
+}
+
+// RemoveAllowlist removes "allowlist" edges to AllowList entities.
+func (aliuo *AllowListItemUpdateOne) RemoveAllowlist(a ...*AllowList) *AllowListItemUpdateOne {
+ ids := make([]int, len(a))
+ for i := range a {
+ ids[i] = a[i].ID
+ }
+ return aliuo.RemoveAllowlistIDs(ids...)
+}
+
+// Where appends a list predicates to the AllowListItemUpdate builder.
+func (aliuo *AllowListItemUpdateOne) Where(ps ...predicate.AllowListItem) *AllowListItemUpdateOne {
+ aliuo.mutation.Where(ps...)
+ return aliuo
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (aliuo *AllowListItemUpdateOne) Select(field string, fields ...string) *AllowListItemUpdateOne {
+ aliuo.fields = append([]string{field}, fields...)
+ return aliuo
+}
+
+// Save executes the query and returns the updated AllowListItem entity.
+func (aliuo *AllowListItemUpdateOne) Save(ctx context.Context) (*AllowListItem, error) {
+ aliuo.defaults()
+ return withHooks(ctx, aliuo.sqlSave, aliuo.mutation, aliuo.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (aliuo *AllowListItemUpdateOne) SaveX(ctx context.Context) *AllowListItem {
+ node, err := aliuo.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (aliuo *AllowListItemUpdateOne) Exec(ctx context.Context) error {
+ _, err := aliuo.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (aliuo *AllowListItemUpdateOne) ExecX(ctx context.Context) {
+ if err := aliuo.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (aliuo *AllowListItemUpdateOne) defaults() {
+ if _, ok := aliuo.mutation.UpdatedAt(); !ok {
+ v := allowlistitem.UpdateDefaultUpdatedAt()
+ aliuo.mutation.SetUpdatedAt(v)
+ }
+}
+
+func (aliuo *AllowListItemUpdateOne) sqlSave(ctx context.Context) (_node *AllowListItem, err error) {
+ _spec := sqlgraph.NewUpdateSpec(allowlistitem.Table, allowlistitem.Columns, sqlgraph.NewFieldSpec(allowlistitem.FieldID, field.TypeInt))
+ id, ok := aliuo.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AllowListItem.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := aliuo.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, allowlistitem.FieldID)
+ for _, f := range fields {
+ if !allowlistitem.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != allowlistitem.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := aliuo.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := aliuo.mutation.UpdatedAt(); ok {
+ _spec.SetField(allowlistitem.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := aliuo.mutation.ExpiresAt(); ok {
+ _spec.SetField(allowlistitem.FieldExpiresAt, field.TypeTime, value)
+ }
+ if aliuo.mutation.ExpiresAtCleared() {
+ _spec.ClearField(allowlistitem.FieldExpiresAt, field.TypeTime)
+ }
+ if aliuo.mutation.CommentCleared() {
+ _spec.ClearField(allowlistitem.FieldComment, field.TypeString)
+ }
+ if aliuo.mutation.StartIPCleared() {
+ _spec.ClearField(allowlistitem.FieldStartIP, field.TypeInt64)
+ }
+ if aliuo.mutation.EndIPCleared() {
+ _spec.ClearField(allowlistitem.FieldEndIP, field.TypeInt64)
+ }
+ if aliuo.mutation.StartSuffixCleared() {
+ _spec.ClearField(allowlistitem.FieldStartSuffix, field.TypeInt64)
+ }
+ if aliuo.mutation.EndSuffixCleared() {
+ _spec.ClearField(allowlistitem.FieldEndSuffix, field.TypeInt64)
+ }
+ if aliuo.mutation.IPSizeCleared() {
+ _spec.ClearField(allowlistitem.FieldIPSize, field.TypeInt64)
+ }
+ if aliuo.mutation.AllowlistCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2M,
+ Inverse: true,
+ Table: allowlistitem.AllowlistTable,
+ Columns: allowlistitem.AllowlistPrimaryKey,
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := aliuo.mutation.RemovedAllowlistIDs(); len(nodes) > 0 && !aliuo.mutation.AllowlistCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2M,
+ Inverse: true,
+ Table: allowlistitem.AllowlistTable,
+ Columns: allowlistitem.AllowlistPrimaryKey,
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := aliuo.mutation.AllowlistIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2M,
+ Inverse: true,
+ Table: allowlistitem.AllowlistTable,
+ Columns: allowlistitem.AllowlistPrimaryKey,
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(allowlist.FieldID, field.TypeInt),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &AllowListItem{config: aliuo.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, aliuo.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{allowlistitem.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ aliuo.mutation.done = true
+ return _node, nil
+}
diff --git a/pkg/database/ent/bouncer.go b/pkg/database/ent/bouncer.go
index 3b4d619e384..197f61cde19 100644
--- a/pkg/database/ent/bouncer.go
+++ b/pkg/database/ent/bouncer.go
@@ -43,6 +43,8 @@ type Bouncer struct {
Osversion string `json:"osversion,omitempty"`
// Featureflags holds the value of the "featureflags" field.
Featureflags string `json:"featureflags,omitempty"`
+ // AutoCreated holds the value of the "auto_created" field.
+ AutoCreated bool `json:"auto_created"`
selectValues sql.SelectValues
}
@@ -51,7 +53,7 @@ func (*Bouncer) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
- case bouncer.FieldRevoked:
+ case bouncer.FieldRevoked, bouncer.FieldAutoCreated:
values[i] = new(sql.NullBool)
case bouncer.FieldID:
values[i] = new(sql.NullInt64)
@@ -159,6 +161,12 @@ func (b *Bouncer) assignValues(columns []string, values []any) error {
} else if value.Valid {
b.Featureflags = value.String
}
+ case bouncer.FieldAutoCreated:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field auto_created", values[i])
+ } else if value.Valid {
+ b.AutoCreated = value.Bool
+ }
default:
b.selectValues.Set(columns[i], values[i])
}
@@ -234,6 +242,9 @@ func (b *Bouncer) String() string {
builder.WriteString(", ")
builder.WriteString("featureflags=")
builder.WriteString(b.Featureflags)
+ builder.WriteString(", ")
+ builder.WriteString("auto_created=")
+ builder.WriteString(fmt.Sprintf("%v", b.AutoCreated))
builder.WriteByte(')')
return builder.String()
}
diff --git a/pkg/database/ent/bouncer/bouncer.go b/pkg/database/ent/bouncer/bouncer.go
index a6f62aeadd5..f25b5a5815a 100644
--- a/pkg/database/ent/bouncer/bouncer.go
+++ b/pkg/database/ent/bouncer/bouncer.go
@@ -39,6 +39,8 @@ const (
FieldOsversion = "osversion"
// FieldFeatureflags holds the string denoting the featureflags field in the database.
FieldFeatureflags = "featureflags"
+ // FieldAutoCreated holds the string denoting the auto_created field in the database.
+ FieldAutoCreated = "auto_created"
// Table holds the table name of the bouncer in the database.
Table = "bouncers"
)
@@ -59,6 +61,7 @@ var Columns = []string{
FieldOsname,
FieldOsversion,
FieldFeatureflags,
+ FieldAutoCreated,
}
// ValidColumn reports if the column name is valid (part of the table columns).
@@ -82,6 +85,8 @@ var (
DefaultIPAddress string
// DefaultAuthType holds the default value on creation for the "auth_type" field.
DefaultAuthType string
+ // DefaultAutoCreated holds the default value on creation for the "auto_created" field.
+ DefaultAutoCreated bool
)
// OrderOption defines the ordering options for the Bouncer queries.
@@ -156,3 +161,8 @@ func ByOsversion(opts ...sql.OrderTermOption) OrderOption {
func ByFeatureflags(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldFeatureflags, opts...).ToFunc()
}
+
+// ByAutoCreated orders the results by the auto_created field.
+func ByAutoCreated(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldAutoCreated, opts...).ToFunc()
+}
diff --git a/pkg/database/ent/bouncer/where.go b/pkg/database/ent/bouncer/where.go
index e02199bc0a9..79b8999354f 100644
--- a/pkg/database/ent/bouncer/where.go
+++ b/pkg/database/ent/bouncer/where.go
@@ -119,6 +119,11 @@ func Featureflags(v string) predicate.Bouncer {
return predicate.Bouncer(sql.FieldEQ(FieldFeatureflags, v))
}
+// AutoCreated applies equality check predicate on the "auto_created" field. It's identical to AutoCreatedEQ.
+func AutoCreated(v bool) predicate.Bouncer {
+ return predicate.Bouncer(sql.FieldEQ(FieldAutoCreated, v))
+}
+
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.Bouncer {
return predicate.Bouncer(sql.FieldEQ(FieldCreatedAt, v))
@@ -904,6 +909,16 @@ func FeatureflagsContainsFold(v string) predicate.Bouncer {
return predicate.Bouncer(sql.FieldContainsFold(FieldFeatureflags, v))
}
+// AutoCreatedEQ applies the EQ predicate on the "auto_created" field.
+func AutoCreatedEQ(v bool) predicate.Bouncer {
+ return predicate.Bouncer(sql.FieldEQ(FieldAutoCreated, v))
+}
+
+// AutoCreatedNEQ applies the NEQ predicate on the "auto_created" field.
+func AutoCreatedNEQ(v bool) predicate.Bouncer {
+ return predicate.Bouncer(sql.FieldNEQ(FieldAutoCreated, v))
+}
+
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.Bouncer) predicate.Bouncer {
return predicate.Bouncer(sql.AndPredicates(predicates...))
diff --git a/pkg/database/ent/bouncer_create.go b/pkg/database/ent/bouncer_create.go
index 29b23f87cf1..9ff4c0e0820 100644
--- a/pkg/database/ent/bouncer_create.go
+++ b/pkg/database/ent/bouncer_create.go
@@ -178,6 +178,20 @@ func (bc *BouncerCreate) SetNillableFeatureflags(s *string) *BouncerCreate {
return bc
}
+// SetAutoCreated sets the "auto_created" field.
+func (bc *BouncerCreate) SetAutoCreated(b bool) *BouncerCreate {
+ bc.mutation.SetAutoCreated(b)
+ return bc
+}
+
+// SetNillableAutoCreated sets the "auto_created" field if the given value is not nil.
+func (bc *BouncerCreate) SetNillableAutoCreated(b *bool) *BouncerCreate {
+ if b != nil {
+ bc.SetAutoCreated(*b)
+ }
+ return bc
+}
+
// Mutation returns the BouncerMutation object of the builder.
func (bc *BouncerCreate) Mutation() *BouncerMutation {
return bc.mutation
@@ -229,6 +243,10 @@ func (bc *BouncerCreate) defaults() {
v := bouncer.DefaultAuthType
bc.mutation.SetAuthType(v)
}
+ if _, ok := bc.mutation.AutoCreated(); !ok {
+ v := bouncer.DefaultAutoCreated
+ bc.mutation.SetAutoCreated(v)
+ }
}
// check runs all checks and user-defined validators on the builder.
@@ -251,6 +269,9 @@ func (bc *BouncerCreate) check() error {
if _, ok := bc.mutation.AuthType(); !ok {
return &ValidationError{Name: "auth_type", err: errors.New(`ent: missing required field "Bouncer.auth_type"`)}
}
+ if _, ok := bc.mutation.AutoCreated(); !ok {
+ return &ValidationError{Name: "auto_created", err: errors.New(`ent: missing required field "Bouncer.auto_created"`)}
+ }
return nil
}
@@ -329,6 +350,10 @@ func (bc *BouncerCreate) createSpec() (*Bouncer, *sqlgraph.CreateSpec) {
_spec.SetField(bouncer.FieldFeatureflags, field.TypeString, value)
_node.Featureflags = value
}
+ if value, ok := bc.mutation.AutoCreated(); ok {
+ _spec.SetField(bouncer.FieldAutoCreated, field.TypeBool, value)
+ _node.AutoCreated = value
+ }
return _node, _spec
}
diff --git a/pkg/database/ent/client.go b/pkg/database/ent/client.go
index 59686102ebe..bc7c0330459 100644
--- a/pkg/database/ent/client.go
+++ b/pkg/database/ent/client.go
@@ -16,6 +16,8 @@ import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/alert"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/decision"
@@ -33,6 +35,10 @@ type Client struct {
Schema *migrate.Schema
// Alert is the client for interacting with the Alert builders.
Alert *AlertClient
+ // AllowList is the client for interacting with the AllowList builders.
+ AllowList *AllowListClient
+ // AllowListItem is the client for interacting with the AllowListItem builders.
+ AllowListItem *AllowListItemClient
// Bouncer is the client for interacting with the Bouncer builders.
Bouncer *BouncerClient
// ConfigItem is the client for interacting with the ConfigItem builders.
@@ -61,6 +67,8 @@ func NewClient(opts ...Option) *Client {
func (c *Client) init() {
c.Schema = migrate.NewSchema(c.driver)
c.Alert = NewAlertClient(c.config)
+ c.AllowList = NewAllowListClient(c.config)
+ c.AllowListItem = NewAllowListItemClient(c.config)
c.Bouncer = NewBouncerClient(c.config)
c.ConfigItem = NewConfigItemClient(c.config)
c.Decision = NewDecisionClient(c.config)
@@ -159,17 +167,19 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
cfg := c.config
cfg.driver = tx
return &Tx{
- ctx: ctx,
- config: cfg,
- Alert: NewAlertClient(cfg),
- Bouncer: NewBouncerClient(cfg),
- ConfigItem: NewConfigItemClient(cfg),
- Decision: NewDecisionClient(cfg),
- Event: NewEventClient(cfg),
- Lock: NewLockClient(cfg),
- Machine: NewMachineClient(cfg),
- Meta: NewMetaClient(cfg),
- Metric: NewMetricClient(cfg),
+ ctx: ctx,
+ config: cfg,
+ Alert: NewAlertClient(cfg),
+ AllowList: NewAllowListClient(cfg),
+ AllowListItem: NewAllowListItemClient(cfg),
+ Bouncer: NewBouncerClient(cfg),
+ ConfigItem: NewConfigItemClient(cfg),
+ Decision: NewDecisionClient(cfg),
+ Event: NewEventClient(cfg),
+ Lock: NewLockClient(cfg),
+ Machine: NewMachineClient(cfg),
+ Meta: NewMetaClient(cfg),
+ Metric: NewMetricClient(cfg),
}, nil
}
@@ -187,17 +197,19 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
cfg := c.config
cfg.driver = &txDriver{tx: tx, drv: c.driver}
return &Tx{
- ctx: ctx,
- config: cfg,
- Alert: NewAlertClient(cfg),
- Bouncer: NewBouncerClient(cfg),
- ConfigItem: NewConfigItemClient(cfg),
- Decision: NewDecisionClient(cfg),
- Event: NewEventClient(cfg),
- Lock: NewLockClient(cfg),
- Machine: NewMachineClient(cfg),
- Meta: NewMetaClient(cfg),
- Metric: NewMetricClient(cfg),
+ ctx: ctx,
+ config: cfg,
+ Alert: NewAlertClient(cfg),
+ AllowList: NewAllowListClient(cfg),
+ AllowListItem: NewAllowListItemClient(cfg),
+ Bouncer: NewBouncerClient(cfg),
+ ConfigItem: NewConfigItemClient(cfg),
+ Decision: NewDecisionClient(cfg),
+ Event: NewEventClient(cfg),
+ Lock: NewLockClient(cfg),
+ Machine: NewMachineClient(cfg),
+ Meta: NewMetaClient(cfg),
+ Metric: NewMetricClient(cfg),
}, nil
}
@@ -227,8 +239,8 @@ func (c *Client) Close() error {
// In order to add hooks to a specific client, call: `client.Node.Use(...)`.
func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{
- c.Alert, c.Bouncer, c.ConfigItem, c.Decision, c.Event, c.Lock, c.Machine,
- c.Meta, c.Metric,
+ c.Alert, c.AllowList, c.AllowListItem, c.Bouncer, c.ConfigItem, c.Decision,
+ c.Event, c.Lock, c.Machine, c.Meta, c.Metric,
} {
n.Use(hooks...)
}
@@ -238,8 +250,8 @@ func (c *Client) Use(hooks ...Hook) {
// In order to add interceptors to a specific client, call: `client.Node.Intercept(...)`.
func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{
- c.Alert, c.Bouncer, c.ConfigItem, c.Decision, c.Event, c.Lock, c.Machine,
- c.Meta, c.Metric,
+ c.Alert, c.AllowList, c.AllowListItem, c.Bouncer, c.ConfigItem, c.Decision,
+ c.Event, c.Lock, c.Machine, c.Meta, c.Metric,
} {
n.Intercept(interceptors...)
}
@@ -250,6 +262,10 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
switch m := m.(type) {
case *AlertMutation:
return c.Alert.mutate(ctx, m)
+ case *AllowListMutation:
+ return c.AllowList.mutate(ctx, m)
+ case *AllowListItemMutation:
+ return c.AllowListItem.mutate(ctx, m)
case *BouncerMutation:
return c.Bouncer.mutate(ctx, m)
case *ConfigItemMutation:
@@ -468,6 +484,304 @@ func (c *AlertClient) mutate(ctx context.Context, m *AlertMutation) (Value, erro
}
}
+// AllowListClient is a client for the AllowList schema.
+type AllowListClient struct {
+ config
+}
+
+// NewAllowListClient returns a client for the AllowList from the given config.
+func NewAllowListClient(c config) *AllowListClient {
+ return &AllowListClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `allowlist.Hooks(f(g(h())))`.
+func (c *AllowListClient) Use(hooks ...Hook) {
+ c.hooks.AllowList = append(c.hooks.AllowList, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `allowlist.Intercept(f(g(h())))`.
+func (c *AllowListClient) Intercept(interceptors ...Interceptor) {
+ c.inters.AllowList = append(c.inters.AllowList, interceptors...)
+}
+
+// Create returns a builder for creating a AllowList entity.
+func (c *AllowListClient) Create() *AllowListCreate {
+ mutation := newAllowListMutation(c.config, OpCreate)
+ return &AllowListCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of AllowList entities.
+func (c *AllowListClient) CreateBulk(builders ...*AllowListCreate) *AllowListCreateBulk {
+ return &AllowListCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *AllowListClient) MapCreateBulk(slice any, setFunc func(*AllowListCreate, int)) *AllowListCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &AllowListCreateBulk{err: fmt.Errorf("calling to AllowListClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*AllowListCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &AllowListCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for AllowList.
+func (c *AllowListClient) Update() *AllowListUpdate {
+ mutation := newAllowListMutation(c.config, OpUpdate)
+ return &AllowListUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *AllowListClient) UpdateOne(al *AllowList) *AllowListUpdateOne {
+ mutation := newAllowListMutation(c.config, OpUpdateOne, withAllowList(al))
+ return &AllowListUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *AllowListClient) UpdateOneID(id int) *AllowListUpdateOne {
+ mutation := newAllowListMutation(c.config, OpUpdateOne, withAllowListID(id))
+ return &AllowListUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for AllowList.
+func (c *AllowListClient) Delete() *AllowListDelete {
+ mutation := newAllowListMutation(c.config, OpDelete)
+ return &AllowListDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *AllowListClient) DeleteOne(al *AllowList) *AllowListDeleteOne {
+ return c.DeleteOneID(al.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *AllowListClient) DeleteOneID(id int) *AllowListDeleteOne {
+ builder := c.Delete().Where(allowlist.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &AllowListDeleteOne{builder}
+}
+
+// Query returns a query builder for AllowList.
+func (c *AllowListClient) Query() *AllowListQuery {
+ return &AllowListQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeAllowList},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a AllowList entity by its id.
+func (c *AllowListClient) Get(ctx context.Context, id int) (*AllowList, error) {
+ return c.Query().Where(allowlist.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *AllowListClient) GetX(ctx context.Context, id int) *AllowList {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryAllowlistItems queries the allowlist_items edge of a AllowList.
+func (c *AllowListClient) QueryAllowlistItems(al *AllowList) *AllowListItemQuery {
+ query := (&AllowListItemClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := al.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(allowlist.Table, allowlist.FieldID, id),
+ sqlgraph.To(allowlistitem.Table, allowlistitem.FieldID),
+ sqlgraph.Edge(sqlgraph.M2M, false, allowlist.AllowlistItemsTable, allowlist.AllowlistItemsPrimaryKey...),
+ )
+ fromV = sqlgraph.Neighbors(al.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *AllowListClient) Hooks() []Hook {
+ return c.hooks.AllowList
+}
+
+// Interceptors returns the client interceptors.
+func (c *AllowListClient) Interceptors() []Interceptor {
+ return c.inters.AllowList
+}
+
+func (c *AllowListClient) mutate(ctx context.Context, m *AllowListMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&AllowListCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&AllowListUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&AllowListUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&AllowListDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown AllowList mutation op: %q", m.Op())
+ }
+}
+
+// AllowListItemClient is a client for the AllowListItem schema.
+type AllowListItemClient struct {
+ config
+}
+
+// NewAllowListItemClient returns a client for the AllowListItem from the given config.
+func NewAllowListItemClient(c config) *AllowListItemClient {
+ return &AllowListItemClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `allowlistitem.Hooks(f(g(h())))`.
+func (c *AllowListItemClient) Use(hooks ...Hook) {
+ c.hooks.AllowListItem = append(c.hooks.AllowListItem, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `allowlistitem.Intercept(f(g(h())))`.
+func (c *AllowListItemClient) Intercept(interceptors ...Interceptor) {
+ c.inters.AllowListItem = append(c.inters.AllowListItem, interceptors...)
+}
+
+// Create returns a builder for creating a AllowListItem entity.
+func (c *AllowListItemClient) Create() *AllowListItemCreate {
+ mutation := newAllowListItemMutation(c.config, OpCreate)
+ return &AllowListItemCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of AllowListItem entities.
+func (c *AllowListItemClient) CreateBulk(builders ...*AllowListItemCreate) *AllowListItemCreateBulk {
+ return &AllowListItemCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *AllowListItemClient) MapCreateBulk(slice any, setFunc func(*AllowListItemCreate, int)) *AllowListItemCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &AllowListItemCreateBulk{err: fmt.Errorf("calling to AllowListItemClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*AllowListItemCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &AllowListItemCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for AllowListItem.
+func (c *AllowListItemClient) Update() *AllowListItemUpdate {
+ mutation := newAllowListItemMutation(c.config, OpUpdate)
+ return &AllowListItemUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *AllowListItemClient) UpdateOne(ali *AllowListItem) *AllowListItemUpdateOne {
+ mutation := newAllowListItemMutation(c.config, OpUpdateOne, withAllowListItem(ali))
+ return &AllowListItemUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *AllowListItemClient) UpdateOneID(id int) *AllowListItemUpdateOne {
+ mutation := newAllowListItemMutation(c.config, OpUpdateOne, withAllowListItemID(id))
+ return &AllowListItemUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for AllowListItem.
+func (c *AllowListItemClient) Delete() *AllowListItemDelete {
+ mutation := newAllowListItemMutation(c.config, OpDelete)
+ return &AllowListItemDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *AllowListItemClient) DeleteOne(ali *AllowListItem) *AllowListItemDeleteOne {
+ return c.DeleteOneID(ali.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *AllowListItemClient) DeleteOneID(id int) *AllowListItemDeleteOne {
+ builder := c.Delete().Where(allowlistitem.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &AllowListItemDeleteOne{builder}
+}
+
+// Query returns a query builder for AllowListItem.
+func (c *AllowListItemClient) Query() *AllowListItemQuery {
+ return &AllowListItemQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeAllowListItem},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a AllowListItem entity by its id.
+func (c *AllowListItemClient) Get(ctx context.Context, id int) (*AllowListItem, error) {
+ return c.Query().Where(allowlistitem.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *AllowListItemClient) GetX(ctx context.Context, id int) *AllowListItem {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryAllowlist queries the allowlist edge of a AllowListItem.
+func (c *AllowListItemClient) QueryAllowlist(ali *AllowListItem) *AllowListQuery {
+ query := (&AllowListClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := ali.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(allowlistitem.Table, allowlistitem.FieldID, id),
+ sqlgraph.To(allowlist.Table, allowlist.FieldID),
+ sqlgraph.Edge(sqlgraph.M2M, true, allowlistitem.AllowlistTable, allowlistitem.AllowlistPrimaryKey...),
+ )
+ fromV = sqlgraph.Neighbors(ali.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *AllowListItemClient) Hooks() []Hook {
+ return c.hooks.AllowListItem
+}
+
+// Interceptors returns the client interceptors.
+func (c *AllowListItemClient) Interceptors() []Interceptor {
+ return c.inters.AllowListItem
+}
+
+func (c *AllowListItemClient) mutate(ctx context.Context, m *AllowListItemMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&AllowListItemCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&AllowListItemUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&AllowListItemUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&AllowListItemDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown AllowListItem mutation op: %q", m.Op())
+ }
+}
+
// BouncerClient is a client for the Bouncer schema.
type BouncerClient struct {
config
@@ -1599,11 +1913,11 @@ func (c *MetricClient) mutate(ctx context.Context, m *MetricMutation) (Value, er
// hooks and interceptors per client, for fast access.
type (
hooks struct {
- Alert, Bouncer, ConfigItem, Decision, Event, Lock, Machine, Meta,
- Metric []ent.Hook
+ Alert, AllowList, AllowListItem, Bouncer, ConfigItem, Decision, Event, Lock,
+ Machine, Meta, Metric []ent.Hook
}
inters struct {
- Alert, Bouncer, ConfigItem, Decision, Event, Lock, Machine, Meta,
- Metric []ent.Interceptor
+ Alert, AllowList, AllowListItem, Bouncer, ConfigItem, Decision, Event, Lock,
+ Machine, Meta, Metric []ent.Interceptor
}
)
diff --git a/pkg/database/ent/ent.go b/pkg/database/ent/ent.go
index 2a5ad188197..e38db54aa59 100644
--- a/pkg/database/ent/ent.go
+++ b/pkg/database/ent/ent.go
@@ -13,6 +13,8 @@ import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/alert"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/decision"
@@ -81,15 +83,17 @@ var (
func checkColumn(table, column string) error {
initCheck.Do(func() {
columnCheck = sql.NewColumnCheck(map[string]func(string) bool{
- alert.Table: alert.ValidColumn,
- bouncer.Table: bouncer.ValidColumn,
- configitem.Table: configitem.ValidColumn,
- decision.Table: decision.ValidColumn,
- event.Table: event.ValidColumn,
- lock.Table: lock.ValidColumn,
- machine.Table: machine.ValidColumn,
- meta.Table: meta.ValidColumn,
- metric.Table: metric.ValidColumn,
+ alert.Table: alert.ValidColumn,
+ allowlist.Table: allowlist.ValidColumn,
+ allowlistitem.Table: allowlistitem.ValidColumn,
+ bouncer.Table: bouncer.ValidColumn,
+ configitem.Table: configitem.ValidColumn,
+ decision.Table: decision.ValidColumn,
+ event.Table: event.ValidColumn,
+ lock.Table: lock.ValidColumn,
+ machine.Table: machine.ValidColumn,
+ meta.Table: meta.ValidColumn,
+ metric.Table: metric.ValidColumn,
})
})
return columnCheck(table, column)
diff --git a/pkg/database/ent/hook/hook.go b/pkg/database/ent/hook/hook.go
index 62cc07820d0..b5ddfc81290 100644
--- a/pkg/database/ent/hook/hook.go
+++ b/pkg/database/ent/hook/hook.go
@@ -21,6 +21,30 @@ func (f AlertFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AlertMutation", m)
}
+// The AllowListFunc type is an adapter to allow the use of ordinary
+// function as AllowList mutator.
+type AllowListFunc func(context.Context, *ent.AllowListMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f AllowListFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.AllowListMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AllowListMutation", m)
+}
+
+// The AllowListItemFunc type is an adapter to allow the use of ordinary
+// function as AllowListItem mutator.
+type AllowListItemFunc func(context.Context, *ent.AllowListItemMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f AllowListItemFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.AllowListItemMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AllowListItemMutation", m)
+}
+
// The BouncerFunc type is an adapter to allow the use of ordinary
// function as Bouncer mutator.
type BouncerFunc func(context.Context, *ent.BouncerMutation) (ent.Value, error)
diff --git a/pkg/database/ent/migrate/schema.go b/pkg/database/ent/migrate/schema.go
index 986f5bc8c67..932c27dd7a6 100644
--- a/pkg/database/ent/migrate/schema.go
+++ b/pkg/database/ent/migrate/schema.go
@@ -58,6 +58,66 @@ var (
},
},
}
+ // AllowListsColumns holds the columns for the "allow_lists" table.
+ AllowListsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt, Increment: true},
+ {Name: "created_at", Type: field.TypeTime},
+ {Name: "updated_at", Type: field.TypeTime},
+ {Name: "name", Type: field.TypeString},
+ {Name: "from_console", Type: field.TypeBool},
+ {Name: "description", Type: field.TypeString, Nullable: true},
+ {Name: "allowlist_id", Type: field.TypeString, Nullable: true},
+ }
+ // AllowListsTable holds the schema information for the "allow_lists" table.
+ AllowListsTable = &schema.Table{
+ Name: "allow_lists",
+ Columns: AllowListsColumns,
+ PrimaryKey: []*schema.Column{AllowListsColumns[0]},
+ Indexes: []*schema.Index{
+ {
+ Name: "allowlist_id",
+ Unique: true,
+ Columns: []*schema.Column{AllowListsColumns[0]},
+ },
+ {
+ Name: "allowlist_name",
+ Unique: true,
+ Columns: []*schema.Column{AllowListsColumns[3]},
+ },
+ },
+ }
+ // AllowListItemsColumns holds the columns for the "allow_list_items" table.
+ AllowListItemsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt, Increment: true},
+ {Name: "created_at", Type: field.TypeTime},
+ {Name: "updated_at", Type: field.TypeTime},
+ {Name: "expires_at", Type: field.TypeTime, Nullable: true},
+ {Name: "comment", Type: field.TypeString, Nullable: true},
+ {Name: "value", Type: field.TypeString},
+ {Name: "start_ip", Type: field.TypeInt64, Nullable: true},
+ {Name: "end_ip", Type: field.TypeInt64, Nullable: true},
+ {Name: "start_suffix", Type: field.TypeInt64, Nullable: true},
+ {Name: "end_suffix", Type: field.TypeInt64, Nullable: true},
+ {Name: "ip_size", Type: field.TypeInt64, Nullable: true},
+ }
+ // AllowListItemsTable holds the schema information for the "allow_list_items" table.
+ AllowListItemsTable = &schema.Table{
+ Name: "allow_list_items",
+ Columns: AllowListItemsColumns,
+ PrimaryKey: []*schema.Column{AllowListItemsColumns[0]},
+ Indexes: []*schema.Index{
+ {
+ Name: "allowlistitem_id",
+ Unique: false,
+ Columns: []*schema.Column{AllowListItemsColumns[0]},
+ },
+ {
+ Name: "allowlistitem_start_ip_end_ip",
+ Unique: false,
+ Columns: []*schema.Column{AllowListItemsColumns[6], AllowListItemsColumns[7]},
+ },
+ },
+ }
// BouncersColumns holds the columns for the "bouncers" table.
BouncersColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt, Increment: true},
@@ -74,6 +134,7 @@ var (
{Name: "osname", Type: field.TypeString, Nullable: true},
{Name: "osversion", Type: field.TypeString, Nullable: true},
{Name: "featureflags", Type: field.TypeString, Nullable: true},
+ {Name: "auto_created", Type: field.TypeBool, Default: false},
}
// BouncersTable holds the schema information for the "bouncers" table.
BouncersTable = &schema.Table{
@@ -264,9 +325,36 @@ var (
Columns: MetricsColumns,
PrimaryKey: []*schema.Column{MetricsColumns[0]},
}
+ // AllowListAllowlistItemsColumns holds the columns for the "allow_list_allowlist_items" table.
+ AllowListAllowlistItemsColumns = []*schema.Column{
+ {Name: "allow_list_id", Type: field.TypeInt},
+ {Name: "allow_list_item_id", Type: field.TypeInt},
+ }
+ // AllowListAllowlistItemsTable holds the schema information for the "allow_list_allowlist_items" table.
+ AllowListAllowlistItemsTable = &schema.Table{
+ Name: "allow_list_allowlist_items",
+ Columns: AllowListAllowlistItemsColumns,
+ PrimaryKey: []*schema.Column{AllowListAllowlistItemsColumns[0], AllowListAllowlistItemsColumns[1]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "allow_list_allowlist_items_allow_list_id",
+ Columns: []*schema.Column{AllowListAllowlistItemsColumns[0]},
+ RefColumns: []*schema.Column{AllowListsColumns[0]},
+ OnDelete: schema.Cascade,
+ },
+ {
+ Symbol: "allow_list_allowlist_items_allow_list_item_id",
+ Columns: []*schema.Column{AllowListAllowlistItemsColumns[1]},
+ RefColumns: []*schema.Column{AllowListItemsColumns[0]},
+ OnDelete: schema.Cascade,
+ },
+ },
+ }
// Tables holds all the tables in the schema.
Tables = []*schema.Table{
AlertsTable,
+ AllowListsTable,
+ AllowListItemsTable,
BouncersTable,
ConfigItemsTable,
DecisionsTable,
@@ -275,6 +363,7 @@ var (
MachinesTable,
MetaTable,
MetricsTable,
+ AllowListAllowlistItemsTable,
}
)
@@ -283,4 +372,6 @@ func init() {
DecisionsTable.ForeignKeys[0].RefTable = AlertsTable
EventsTable.ForeignKeys[0].RefTable = AlertsTable
MetaTable.ForeignKeys[0].RefTable = AlertsTable
+ AllowListAllowlistItemsTable.ForeignKeys[0].RefTable = AllowListsTable
+ AllowListAllowlistItemsTable.ForeignKeys[1].RefTable = AllowListItemsTable
}
diff --git a/pkg/database/ent/mutation.go b/pkg/database/ent/mutation.go
index 5c6596f3db4..f45bd47a5fb 100644
--- a/pkg/database/ent/mutation.go
+++ b/pkg/database/ent/mutation.go
@@ -12,6 +12,8 @@ import (
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/alert"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/decision"
@@ -33,15 +35,17 @@ const (
OpUpdateOne = ent.OpUpdateOne
// Node types.
- TypeAlert = "Alert"
- TypeBouncer = "Bouncer"
- TypeConfigItem = "ConfigItem"
- TypeDecision = "Decision"
- TypeEvent = "Event"
- TypeLock = "Lock"
- TypeMachine = "Machine"
- TypeMeta = "Meta"
- TypeMetric = "Metric"
+ TypeAlert = "Alert"
+ TypeAllowList = "AllowList"
+ TypeAllowListItem = "AllowListItem"
+ TypeBouncer = "Bouncer"
+ TypeConfigItem = "ConfigItem"
+ TypeDecision = "Decision"
+ TypeEvent = "Event"
+ TypeLock = "Lock"
+ TypeMachine = "Machine"
+ TypeMeta = "Meta"
+ TypeMetric = "Metric"
)
// AlertMutation represents an operation that mutates the Alert nodes in the graph.
@@ -2452,6 +2456,1950 @@ func (m *AlertMutation) ResetEdge(name string) error {
return fmt.Errorf("unknown Alert edge %s", name)
}
+// AllowListMutation represents an operation that mutates the AllowList nodes in the graph.
+type AllowListMutation struct {
+ config
+ op Op
+ typ string
+ id *int
+ created_at *time.Time
+ updated_at *time.Time
+ name *string
+ from_console *bool
+ description *string
+ allowlist_id *string
+ clearedFields map[string]struct{}
+ allowlist_items map[int]struct{}
+ removedallowlist_items map[int]struct{}
+ clearedallowlist_items bool
+ done bool
+ oldValue func(context.Context) (*AllowList, error)
+ predicates []predicate.AllowList
+}
+
+var _ ent.Mutation = (*AllowListMutation)(nil)
+
+// allowlistOption allows management of the mutation configuration using functional options.
+type allowlistOption func(*AllowListMutation)
+
+// newAllowListMutation creates new mutation for the AllowList entity.
+func newAllowListMutation(c config, op Op, opts ...allowlistOption) *AllowListMutation {
+ m := &AllowListMutation{
+ config: c,
+ op: op,
+ typ: TypeAllowList,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withAllowListID sets the ID field of the mutation.
+func withAllowListID(id int) allowlistOption {
+ return func(m *AllowListMutation) {
+ var (
+ err error
+ once sync.Once
+ value *AllowList
+ )
+ m.oldValue = func(ctx context.Context) (*AllowList, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().AllowList.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withAllowList sets the old AllowList of the mutation.
+func withAllowList(node *AllowList) allowlistOption {
+ return func(m *AllowListMutation) {
+ m.oldValue = func(context.Context) (*AllowList, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m AllowListMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m AllowListMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *AllowListMutation) ID() (id int, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *AllowListMutation) IDs(ctx context.Context) ([]int, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().AllowList.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *AllowListMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *AllowListMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the AllowList entity.
+// If the AllowList object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AllowListMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *AllowListMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *AllowListMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *AllowListMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the AllowList entity.
+// If the AllowList object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AllowListMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *AllowListMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetName sets the "name" field.
+func (m *AllowListMutation) SetName(s string) {
+ m.name = &s
+}
+
+// Name returns the value of the "name" field in the mutation.
+func (m *AllowListMutation) Name() (r string, exists bool) {
+ v := m.name
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldName returns the old "name" field's value of the AllowList entity.
+// If the AllowList object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AllowListMutation) OldName(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldName is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldName requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldName: %w", err)
+ }
+ return oldValue.Name, nil
+}
+
+// ResetName resets all changes to the "name" field.
+func (m *AllowListMutation) ResetName() {
+ m.name = nil
+}
+
+// SetFromConsole sets the "from_console" field.
+func (m *AllowListMutation) SetFromConsole(b bool) {
+ m.from_console = &b
+}
+
+// FromConsole returns the value of the "from_console" field in the mutation.
+func (m *AllowListMutation) FromConsole() (r bool, exists bool) {
+ v := m.from_console
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldFromConsole returns the old "from_console" field's value of the AllowList entity.
+// If the AllowList object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AllowListMutation) OldFromConsole(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldFromConsole is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldFromConsole requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldFromConsole: %w", err)
+ }
+ return oldValue.FromConsole, nil
+}
+
+// ResetFromConsole resets all changes to the "from_console" field.
+func (m *AllowListMutation) ResetFromConsole() {
+ m.from_console = nil
+}
+
+// SetDescription sets the "description" field.
+func (m *AllowListMutation) SetDescription(s string) {
+ m.description = &s
+}
+
+// Description returns the value of the "description" field in the mutation.
+func (m *AllowListMutation) Description() (r string, exists bool) {
+ v := m.description
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldDescription returns the old "description" field's value of the AllowList entity.
+// If the AllowList object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AllowListMutation) OldDescription(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldDescription is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldDescription requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldDescription: %w", err)
+ }
+ return oldValue.Description, nil
+}
+
+// ClearDescription clears the value of the "description" field.
+func (m *AllowListMutation) ClearDescription() {
+ m.description = nil
+ m.clearedFields[allowlist.FieldDescription] = struct{}{}
+}
+
+// DescriptionCleared returns if the "description" field was cleared in this mutation.
+func (m *AllowListMutation) DescriptionCleared() bool {
+ _, ok := m.clearedFields[allowlist.FieldDescription]
+ return ok
+}
+
+// ResetDescription resets all changes to the "description" field.
+func (m *AllowListMutation) ResetDescription() {
+ m.description = nil
+ delete(m.clearedFields, allowlist.FieldDescription)
+}
+
+// SetAllowlistID sets the "allowlist_id" field.
+func (m *AllowListMutation) SetAllowlistID(s string) {
+ m.allowlist_id = &s
+}
+
+// AllowlistID returns the value of the "allowlist_id" field in the mutation.
+func (m *AllowListMutation) AllowlistID() (r string, exists bool) {
+ v := m.allowlist_id
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldAllowlistID returns the old "allowlist_id" field's value of the AllowList entity.
+// If the AllowList object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AllowListMutation) OldAllowlistID(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldAllowlistID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldAllowlistID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldAllowlistID: %w", err)
+ }
+ return oldValue.AllowlistID, nil
+}
+
+// ClearAllowlistID clears the value of the "allowlist_id" field.
+func (m *AllowListMutation) ClearAllowlistID() {
+ m.allowlist_id = nil
+ m.clearedFields[allowlist.FieldAllowlistID] = struct{}{}
+}
+
+// AllowlistIDCleared returns if the "allowlist_id" field was cleared in this mutation.
+func (m *AllowListMutation) AllowlistIDCleared() bool {
+ _, ok := m.clearedFields[allowlist.FieldAllowlistID]
+ return ok
+}
+
+// ResetAllowlistID resets all changes to the "allowlist_id" field.
+func (m *AllowListMutation) ResetAllowlistID() {
+ m.allowlist_id = nil
+ delete(m.clearedFields, allowlist.FieldAllowlistID)
+}
+
+// AddAllowlistItemIDs adds the "allowlist_items" edge to the AllowListItem entity by ids.
+func (m *AllowListMutation) AddAllowlistItemIDs(ids ...int) {
+ if m.allowlist_items == nil {
+ m.allowlist_items = make(map[int]struct{})
+ }
+ for i := range ids {
+ m.allowlist_items[ids[i]] = struct{}{}
+ }
+}
+
+// ClearAllowlistItems clears the "allowlist_items" edge to the AllowListItem entity.
+func (m *AllowListMutation) ClearAllowlistItems() {
+ m.clearedallowlist_items = true
+}
+
+// AllowlistItemsCleared reports if the "allowlist_items" edge to the AllowListItem entity was cleared.
+func (m *AllowListMutation) AllowlistItemsCleared() bool {
+ return m.clearedallowlist_items
+}
+
+// RemoveAllowlistItemIDs removes the "allowlist_items" edge to the AllowListItem entity by IDs.
+func (m *AllowListMutation) RemoveAllowlistItemIDs(ids ...int) {
+ if m.removedallowlist_items == nil {
+ m.removedallowlist_items = make(map[int]struct{})
+ }
+ for i := range ids {
+ delete(m.allowlist_items, ids[i])
+ m.removedallowlist_items[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedAllowlistItems returns the removed IDs of the "allowlist_items" edge to the AllowListItem entity.
+func (m *AllowListMutation) RemovedAllowlistItemsIDs() (ids []int) {
+ for id := range m.removedallowlist_items {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// AllowlistItemsIDs returns the "allowlist_items" edge IDs in the mutation.
+func (m *AllowListMutation) AllowlistItemsIDs() (ids []int) {
+ for id := range m.allowlist_items {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetAllowlistItems resets all changes to the "allowlist_items" edge.
+func (m *AllowListMutation) ResetAllowlistItems() {
+ m.allowlist_items = nil
+ m.clearedallowlist_items = false
+ m.removedallowlist_items = nil
+}
+
+// Where appends a list predicates to the AllowListMutation builder.
+func (m *AllowListMutation) Where(ps ...predicate.AllowList) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the AllowListMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *AllowListMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.AllowList, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *AllowListMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *AllowListMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (AllowList).
+func (m *AllowListMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *AllowListMutation) Fields() []string {
+ fields := make([]string, 0, 6)
+ if m.created_at != nil {
+ fields = append(fields, allowlist.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, allowlist.FieldUpdatedAt)
+ }
+ if m.name != nil {
+ fields = append(fields, allowlist.FieldName)
+ }
+ if m.from_console != nil {
+ fields = append(fields, allowlist.FieldFromConsole)
+ }
+ if m.description != nil {
+ fields = append(fields, allowlist.FieldDescription)
+ }
+ if m.allowlist_id != nil {
+ fields = append(fields, allowlist.FieldAllowlistID)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *AllowListMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case allowlist.FieldCreatedAt:
+ return m.CreatedAt()
+ case allowlist.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case allowlist.FieldName:
+ return m.Name()
+ case allowlist.FieldFromConsole:
+ return m.FromConsole()
+ case allowlist.FieldDescription:
+ return m.Description()
+ case allowlist.FieldAllowlistID:
+ return m.AllowlistID()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *AllowListMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case allowlist.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case allowlist.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case allowlist.FieldName:
+ return m.OldName(ctx)
+ case allowlist.FieldFromConsole:
+ return m.OldFromConsole(ctx)
+ case allowlist.FieldDescription:
+ return m.OldDescription(ctx)
+ case allowlist.FieldAllowlistID:
+ return m.OldAllowlistID(ctx)
+ }
+ return nil, fmt.Errorf("unknown AllowList field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AllowListMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case allowlist.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case allowlist.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case allowlist.FieldName:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetName(v)
+ return nil
+ case allowlist.FieldFromConsole:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetFromConsole(v)
+ return nil
+ case allowlist.FieldDescription:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetDescription(v)
+ return nil
+ case allowlist.FieldAllowlistID:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAllowlistID(v)
+ return nil
+ }
+ return fmt.Errorf("unknown AllowList field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *AllowListMutation) AddedFields() []string {
+ return nil
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *AllowListMutation) AddedField(name string) (ent.Value, bool) {
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AllowListMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown AllowList numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *AllowListMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(allowlist.FieldDescription) {
+ fields = append(fields, allowlist.FieldDescription)
+ }
+ if m.FieldCleared(allowlist.FieldAllowlistID) {
+ fields = append(fields, allowlist.FieldAllowlistID)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *AllowListMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *AllowListMutation) ClearField(name string) error {
+ switch name {
+ case allowlist.FieldDescription:
+ m.ClearDescription()
+ return nil
+ case allowlist.FieldAllowlistID:
+ m.ClearAllowlistID()
+ return nil
+ }
+ return fmt.Errorf("unknown AllowList nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *AllowListMutation) ResetField(name string) error {
+ switch name {
+ case allowlist.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case allowlist.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case allowlist.FieldName:
+ m.ResetName()
+ return nil
+ case allowlist.FieldFromConsole:
+ m.ResetFromConsole()
+ return nil
+ case allowlist.FieldDescription:
+ m.ResetDescription()
+ return nil
+ case allowlist.FieldAllowlistID:
+ m.ResetAllowlistID()
+ return nil
+ }
+ return fmt.Errorf("unknown AllowList field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *AllowListMutation) AddedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.allowlist_items != nil {
+ edges = append(edges, allowlist.EdgeAllowlistItems)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *AllowListMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case allowlist.EdgeAllowlistItems:
+ ids := make([]ent.Value, 0, len(m.allowlist_items))
+ for id := range m.allowlist_items {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *AllowListMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.removedallowlist_items != nil {
+ edges = append(edges, allowlist.EdgeAllowlistItems)
+ }
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *AllowListMutation) RemovedIDs(name string) []ent.Value {
+ switch name {
+ case allowlist.EdgeAllowlistItems:
+ ids := make([]ent.Value, 0, len(m.removedallowlist_items))
+ for id := range m.removedallowlist_items {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *AllowListMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.clearedallowlist_items {
+ edges = append(edges, allowlist.EdgeAllowlistItems)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *AllowListMutation) EdgeCleared(name string) bool {
+ switch name {
+ case allowlist.EdgeAllowlistItems:
+ return m.clearedallowlist_items
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *AllowListMutation) ClearEdge(name string) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown AllowList unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *AllowListMutation) ResetEdge(name string) error {
+ switch name {
+ case allowlist.EdgeAllowlistItems:
+ m.ResetAllowlistItems()
+ return nil
+ }
+ return fmt.Errorf("unknown AllowList edge %s", name)
+}
+
+// AllowListItemMutation represents an operation that mutates the AllowListItem nodes in the graph.
+type AllowListItemMutation struct {
+ config
+ op Op
+ typ string
+ id *int
+ created_at *time.Time
+ updated_at *time.Time
+ expires_at *time.Time
+ comment *string
+ value *string
+ start_ip *int64
+ addstart_ip *int64
+ end_ip *int64
+ addend_ip *int64
+ start_suffix *int64
+ addstart_suffix *int64
+ end_suffix *int64
+ addend_suffix *int64
+ ip_size *int64
+ addip_size *int64
+ clearedFields map[string]struct{}
+ allowlist map[int]struct{}
+ removedallowlist map[int]struct{}
+ clearedallowlist bool
+ done bool
+ oldValue func(context.Context) (*AllowListItem, error)
+ predicates []predicate.AllowListItem
+}
+
+var _ ent.Mutation = (*AllowListItemMutation)(nil)
+
+// allowlistitemOption allows management of the mutation configuration using functional options.
+type allowlistitemOption func(*AllowListItemMutation)
+
+// newAllowListItemMutation creates new mutation for the AllowListItem entity.
+func newAllowListItemMutation(c config, op Op, opts ...allowlistitemOption) *AllowListItemMutation {
+ m := &AllowListItemMutation{
+ config: c,
+ op: op,
+ typ: TypeAllowListItem,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withAllowListItemID sets the ID field of the mutation.
+func withAllowListItemID(id int) allowlistitemOption {
+ return func(m *AllowListItemMutation) {
+ var (
+ err error
+ once sync.Once
+ value *AllowListItem
+ )
+ m.oldValue = func(ctx context.Context) (*AllowListItem, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().AllowListItem.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withAllowListItem sets the old AllowListItem of the mutation.
+func withAllowListItem(node *AllowListItem) allowlistitemOption {
+ return func(m *AllowListItemMutation) {
+ m.oldValue = func(context.Context) (*AllowListItem, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m AllowListItemMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m AllowListItemMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *AllowListItemMutation) ID() (id int, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *AllowListItemMutation) IDs(ctx context.Context) ([]int, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().AllowListItem.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *AllowListItemMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *AllowListItemMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the AllowListItem entity.
+// If the AllowListItem object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AllowListItemMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *AllowListItemMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *AllowListItemMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *AllowListItemMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the AllowListItem entity.
+// If the AllowListItem object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AllowListItemMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *AllowListItemMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (m *AllowListItemMutation) SetExpiresAt(t time.Time) {
+ m.expires_at = &t
+}
+
+// ExpiresAt returns the value of the "expires_at" field in the mutation.
+func (m *AllowListItemMutation) ExpiresAt() (r time.Time, exists bool) {
+ v := m.expires_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldExpiresAt returns the old "expires_at" field's value of the AllowListItem entity.
+// If the AllowListItem object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AllowListItemMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldExpiresAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err)
+ }
+ return oldValue.ExpiresAt, nil
+}
+
+// ClearExpiresAt clears the value of the "expires_at" field.
+func (m *AllowListItemMutation) ClearExpiresAt() {
+ m.expires_at = nil
+ m.clearedFields[allowlistitem.FieldExpiresAt] = struct{}{}
+}
+
+// ExpiresAtCleared returns if the "expires_at" field was cleared in this mutation.
+func (m *AllowListItemMutation) ExpiresAtCleared() bool {
+ _, ok := m.clearedFields[allowlistitem.FieldExpiresAt]
+ return ok
+}
+
+// ResetExpiresAt resets all changes to the "expires_at" field.
+func (m *AllowListItemMutation) ResetExpiresAt() {
+ m.expires_at = nil
+ delete(m.clearedFields, allowlistitem.FieldExpiresAt)
+}
+
+// SetComment sets the "comment" field.
+func (m *AllowListItemMutation) SetComment(s string) {
+ m.comment = &s
+}
+
+// Comment returns the value of the "comment" field in the mutation.
+func (m *AllowListItemMutation) Comment() (r string, exists bool) {
+ v := m.comment
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldComment returns the old "comment" field's value of the AllowListItem entity.
+// If the AllowListItem object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AllowListItemMutation) OldComment(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldComment is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldComment requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldComment: %w", err)
+ }
+ return oldValue.Comment, nil
+}
+
+// ClearComment clears the value of the "comment" field.
+func (m *AllowListItemMutation) ClearComment() {
+ m.comment = nil
+ m.clearedFields[allowlistitem.FieldComment] = struct{}{}
+}
+
+// CommentCleared returns if the "comment" field was cleared in this mutation.
+func (m *AllowListItemMutation) CommentCleared() bool {
+ _, ok := m.clearedFields[allowlistitem.FieldComment]
+ return ok
+}
+
+// ResetComment resets all changes to the "comment" field.
+func (m *AllowListItemMutation) ResetComment() {
+ m.comment = nil
+ delete(m.clearedFields, allowlistitem.FieldComment)
+}
+
+// SetValue sets the "value" field.
+func (m *AllowListItemMutation) SetValue(s string) {
+ m.value = &s
+}
+
+// Value returns the value of the "value" field in the mutation.
+func (m *AllowListItemMutation) Value() (r string, exists bool) {
+ v := m.value
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldValue returns the old "value" field's value of the AllowListItem entity.
+// If the AllowListItem object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AllowListItemMutation) OldValue(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldValue is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldValue requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldValue: %w", err)
+ }
+ return oldValue.Value, nil
+}
+
+// ResetValue resets all changes to the "value" field.
+func (m *AllowListItemMutation) ResetValue() {
+ m.value = nil
+}
+
+// SetStartIP sets the "start_ip" field.
+func (m *AllowListItemMutation) SetStartIP(i int64) {
+ m.start_ip = &i
+ m.addstart_ip = nil
+}
+
+// StartIP returns the value of the "start_ip" field in the mutation.
+func (m *AllowListItemMutation) StartIP() (r int64, exists bool) {
+ v := m.start_ip
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldStartIP returns the old "start_ip" field's value of the AllowListItem entity.
+// If the AllowListItem object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AllowListItemMutation) OldStartIP(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldStartIP is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldStartIP requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldStartIP: %w", err)
+ }
+ return oldValue.StartIP, nil
+}
+
+// AddStartIP adds i to the "start_ip" field.
+func (m *AllowListItemMutation) AddStartIP(i int64) {
+ if m.addstart_ip != nil {
+ *m.addstart_ip += i
+ } else {
+ m.addstart_ip = &i
+ }
+}
+
+// AddedStartIP returns the value that was added to the "start_ip" field in this mutation.
+func (m *AllowListItemMutation) AddedStartIP() (r int64, exists bool) {
+ v := m.addstart_ip
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearStartIP clears the value of the "start_ip" field.
+func (m *AllowListItemMutation) ClearStartIP() {
+ m.start_ip = nil
+ m.addstart_ip = nil
+ m.clearedFields[allowlistitem.FieldStartIP] = struct{}{}
+}
+
+// StartIPCleared returns if the "start_ip" field was cleared in this mutation.
+func (m *AllowListItemMutation) StartIPCleared() bool {
+ _, ok := m.clearedFields[allowlistitem.FieldStartIP]
+ return ok
+}
+
+// ResetStartIP resets all changes to the "start_ip" field.
+func (m *AllowListItemMutation) ResetStartIP() {
+ m.start_ip = nil
+ m.addstart_ip = nil
+ delete(m.clearedFields, allowlistitem.FieldStartIP)
+}
+
+// SetEndIP sets the "end_ip" field.
+func (m *AllowListItemMutation) SetEndIP(i int64) {
+ m.end_ip = &i
+ m.addend_ip = nil
+}
+
+// EndIP returns the value of the "end_ip" field in the mutation.
+func (m *AllowListItemMutation) EndIP() (r int64, exists bool) {
+ v := m.end_ip
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldEndIP returns the old "end_ip" field's value of the AllowListItem entity.
+// If the AllowListItem object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AllowListItemMutation) OldEndIP(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldEndIP is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldEndIP requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldEndIP: %w", err)
+ }
+ return oldValue.EndIP, nil
+}
+
+// AddEndIP adds i to the "end_ip" field.
+func (m *AllowListItemMutation) AddEndIP(i int64) {
+ if m.addend_ip != nil {
+ *m.addend_ip += i
+ } else {
+ m.addend_ip = &i
+ }
+}
+
+// AddedEndIP returns the value that was added to the "end_ip" field in this mutation.
+func (m *AllowListItemMutation) AddedEndIP() (r int64, exists bool) {
+ v := m.addend_ip
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearEndIP clears the value of the "end_ip" field.
+func (m *AllowListItemMutation) ClearEndIP() {
+ m.end_ip = nil
+ m.addend_ip = nil
+ m.clearedFields[allowlistitem.FieldEndIP] = struct{}{}
+}
+
+// EndIPCleared returns if the "end_ip" field was cleared in this mutation.
+func (m *AllowListItemMutation) EndIPCleared() bool {
+ _, ok := m.clearedFields[allowlistitem.FieldEndIP]
+ return ok
+}
+
+// ResetEndIP resets all changes to the "end_ip" field.
+func (m *AllowListItemMutation) ResetEndIP() {
+ m.end_ip = nil
+ m.addend_ip = nil
+ delete(m.clearedFields, allowlistitem.FieldEndIP)
+}
+
+// SetStartSuffix sets the "start_suffix" field.
+func (m *AllowListItemMutation) SetStartSuffix(i int64) {
+ m.start_suffix = &i
+ m.addstart_suffix = nil
+}
+
+// StartSuffix returns the value of the "start_suffix" field in the mutation.
+func (m *AllowListItemMutation) StartSuffix() (r int64, exists bool) {
+ v := m.start_suffix
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldStartSuffix returns the old "start_suffix" field's value of the AllowListItem entity.
+// If the AllowListItem object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AllowListItemMutation) OldStartSuffix(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldStartSuffix is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldStartSuffix requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldStartSuffix: %w", err)
+ }
+ return oldValue.StartSuffix, nil
+}
+
+// AddStartSuffix adds i to the "start_suffix" field.
+func (m *AllowListItemMutation) AddStartSuffix(i int64) {
+ if m.addstart_suffix != nil {
+ *m.addstart_suffix += i
+ } else {
+ m.addstart_suffix = &i
+ }
+}
+
+// AddedStartSuffix returns the value that was added to the "start_suffix" field in this mutation.
+func (m *AllowListItemMutation) AddedStartSuffix() (r int64, exists bool) {
+ v := m.addstart_suffix
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearStartSuffix clears the value of the "start_suffix" field.
+func (m *AllowListItemMutation) ClearStartSuffix() {
+ m.start_suffix = nil
+ m.addstart_suffix = nil
+ m.clearedFields[allowlistitem.FieldStartSuffix] = struct{}{}
+}
+
+// StartSuffixCleared returns if the "start_suffix" field was cleared in this mutation.
+func (m *AllowListItemMutation) StartSuffixCleared() bool {
+ _, ok := m.clearedFields[allowlistitem.FieldStartSuffix]
+ return ok
+}
+
+// ResetStartSuffix resets all changes to the "start_suffix" field.
+func (m *AllowListItemMutation) ResetStartSuffix() {
+ m.start_suffix = nil
+ m.addstart_suffix = nil
+ delete(m.clearedFields, allowlistitem.FieldStartSuffix)
+}
+
+// SetEndSuffix sets the "end_suffix" field.
+func (m *AllowListItemMutation) SetEndSuffix(i int64) {
+ m.end_suffix = &i
+ m.addend_suffix = nil
+}
+
+// EndSuffix returns the value of the "end_suffix" field in the mutation.
+func (m *AllowListItemMutation) EndSuffix() (r int64, exists bool) {
+ v := m.end_suffix
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldEndSuffix returns the old "end_suffix" field's value of the AllowListItem entity.
+// If the AllowListItem object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AllowListItemMutation) OldEndSuffix(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldEndSuffix is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldEndSuffix requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldEndSuffix: %w", err)
+ }
+ return oldValue.EndSuffix, nil
+}
+
+// AddEndSuffix adds i to the "end_suffix" field.
+func (m *AllowListItemMutation) AddEndSuffix(i int64) {
+ if m.addend_suffix != nil {
+ *m.addend_suffix += i
+ } else {
+ m.addend_suffix = &i
+ }
+}
+
+// AddedEndSuffix returns the value that was added to the "end_suffix" field in this mutation.
+func (m *AllowListItemMutation) AddedEndSuffix() (r int64, exists bool) {
+ v := m.addend_suffix
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearEndSuffix clears the value of the "end_suffix" field.
+func (m *AllowListItemMutation) ClearEndSuffix() {
+ m.end_suffix = nil
+ m.addend_suffix = nil
+ m.clearedFields[allowlistitem.FieldEndSuffix] = struct{}{}
+}
+
+// EndSuffixCleared returns if the "end_suffix" field was cleared in this mutation.
+func (m *AllowListItemMutation) EndSuffixCleared() bool {
+ _, ok := m.clearedFields[allowlistitem.FieldEndSuffix]
+ return ok
+}
+
+// ResetEndSuffix resets all changes to the "end_suffix" field.
+func (m *AllowListItemMutation) ResetEndSuffix() {
+ m.end_suffix = nil
+ m.addend_suffix = nil
+ delete(m.clearedFields, allowlistitem.FieldEndSuffix)
+}
+
+// SetIPSize sets the "ip_size" field.
+func (m *AllowListItemMutation) SetIPSize(i int64) {
+ m.ip_size = &i
+ m.addip_size = nil
+}
+
+// IPSize returns the value of the "ip_size" field in the mutation.
+func (m *AllowListItemMutation) IPSize() (r int64, exists bool) {
+ v := m.ip_size
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldIPSize returns the old "ip_size" field's value of the AllowListItem entity.
+// If the AllowListItem object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AllowListItemMutation) OldIPSize(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldIPSize is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldIPSize requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldIPSize: %w", err)
+ }
+ return oldValue.IPSize, nil
+}
+
+// AddIPSize adds i to the "ip_size" field.
+func (m *AllowListItemMutation) AddIPSize(i int64) {
+ if m.addip_size != nil {
+ *m.addip_size += i
+ } else {
+ m.addip_size = &i
+ }
+}
+
+// AddedIPSize returns the value that was added to the "ip_size" field in this mutation.
+func (m *AllowListItemMutation) AddedIPSize() (r int64, exists bool) {
+ v := m.addip_size
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearIPSize clears the value of the "ip_size" field.
+func (m *AllowListItemMutation) ClearIPSize() {
+ m.ip_size = nil
+ m.addip_size = nil
+ m.clearedFields[allowlistitem.FieldIPSize] = struct{}{}
+}
+
+// IPSizeCleared returns if the "ip_size" field was cleared in this mutation.
+func (m *AllowListItemMutation) IPSizeCleared() bool {
+ _, ok := m.clearedFields[allowlistitem.FieldIPSize]
+ return ok
+}
+
+// ResetIPSize resets all changes to the "ip_size" field.
+func (m *AllowListItemMutation) ResetIPSize() {
+ m.ip_size = nil
+ m.addip_size = nil
+ delete(m.clearedFields, allowlistitem.FieldIPSize)
+}
+
+// AddAllowlistIDs adds the "allowlist" edge to the AllowList entity by ids.
+func (m *AllowListItemMutation) AddAllowlistIDs(ids ...int) {
+ if m.allowlist == nil {
+ m.allowlist = make(map[int]struct{})
+ }
+ for i := range ids {
+ m.allowlist[ids[i]] = struct{}{}
+ }
+}
+
+// ClearAllowlist clears the "allowlist" edge to the AllowList entity.
+func (m *AllowListItemMutation) ClearAllowlist() {
+ m.clearedallowlist = true
+}
+
+// AllowlistCleared reports if the "allowlist" edge to the AllowList entity was cleared.
+func (m *AllowListItemMutation) AllowlistCleared() bool {
+ return m.clearedallowlist
+}
+
+// RemoveAllowlistIDs removes the "allowlist" edge to the AllowList entity by IDs.
+func (m *AllowListItemMutation) RemoveAllowlistIDs(ids ...int) {
+ if m.removedallowlist == nil {
+ m.removedallowlist = make(map[int]struct{})
+ }
+ for i := range ids {
+ delete(m.allowlist, ids[i])
+ m.removedallowlist[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedAllowlist returns the removed IDs of the "allowlist" edge to the AllowList entity.
+func (m *AllowListItemMutation) RemovedAllowlistIDs() (ids []int) {
+ for id := range m.removedallowlist {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// AllowlistIDs returns the "allowlist" edge IDs in the mutation.
+func (m *AllowListItemMutation) AllowlistIDs() (ids []int) {
+ for id := range m.allowlist {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetAllowlist resets all changes to the "allowlist" edge.
+func (m *AllowListItemMutation) ResetAllowlist() {
+ m.allowlist = nil
+ m.clearedallowlist = false
+ m.removedallowlist = nil
+}
+
+// Where appends a list predicates to the AllowListItemMutation builder.
+func (m *AllowListItemMutation) Where(ps ...predicate.AllowListItem) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the AllowListItemMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *AllowListItemMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.AllowListItem, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *AllowListItemMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *AllowListItemMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (AllowListItem).
+func (m *AllowListItemMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *AllowListItemMutation) Fields() []string {
+ fields := make([]string, 0, 10)
+ if m.created_at != nil {
+ fields = append(fields, allowlistitem.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, allowlistitem.FieldUpdatedAt)
+ }
+ if m.expires_at != nil {
+ fields = append(fields, allowlistitem.FieldExpiresAt)
+ }
+ if m.comment != nil {
+ fields = append(fields, allowlistitem.FieldComment)
+ }
+ if m.value != nil {
+ fields = append(fields, allowlistitem.FieldValue)
+ }
+ if m.start_ip != nil {
+ fields = append(fields, allowlistitem.FieldStartIP)
+ }
+ if m.end_ip != nil {
+ fields = append(fields, allowlistitem.FieldEndIP)
+ }
+ if m.start_suffix != nil {
+ fields = append(fields, allowlistitem.FieldStartSuffix)
+ }
+ if m.end_suffix != nil {
+ fields = append(fields, allowlistitem.FieldEndSuffix)
+ }
+ if m.ip_size != nil {
+ fields = append(fields, allowlistitem.FieldIPSize)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *AllowListItemMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case allowlistitem.FieldCreatedAt:
+ return m.CreatedAt()
+ case allowlistitem.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case allowlistitem.FieldExpiresAt:
+ return m.ExpiresAt()
+ case allowlistitem.FieldComment:
+ return m.Comment()
+ case allowlistitem.FieldValue:
+ return m.Value()
+ case allowlistitem.FieldStartIP:
+ return m.StartIP()
+ case allowlistitem.FieldEndIP:
+ return m.EndIP()
+ case allowlistitem.FieldStartSuffix:
+ return m.StartSuffix()
+ case allowlistitem.FieldEndSuffix:
+ return m.EndSuffix()
+ case allowlistitem.FieldIPSize:
+ return m.IPSize()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *AllowListItemMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case allowlistitem.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case allowlistitem.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case allowlistitem.FieldExpiresAt:
+ return m.OldExpiresAt(ctx)
+ case allowlistitem.FieldComment:
+ return m.OldComment(ctx)
+ case allowlistitem.FieldValue:
+ return m.OldValue(ctx)
+ case allowlistitem.FieldStartIP:
+ return m.OldStartIP(ctx)
+ case allowlistitem.FieldEndIP:
+ return m.OldEndIP(ctx)
+ case allowlistitem.FieldStartSuffix:
+ return m.OldStartSuffix(ctx)
+ case allowlistitem.FieldEndSuffix:
+ return m.OldEndSuffix(ctx)
+ case allowlistitem.FieldIPSize:
+ return m.OldIPSize(ctx)
+ }
+ return nil, fmt.Errorf("unknown AllowListItem field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AllowListItemMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case allowlistitem.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case allowlistitem.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case allowlistitem.FieldExpiresAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetExpiresAt(v)
+ return nil
+ case allowlistitem.FieldComment:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetComment(v)
+ return nil
+ case allowlistitem.FieldValue:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetValue(v)
+ return nil
+ case allowlistitem.FieldStartIP:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetStartIP(v)
+ return nil
+ case allowlistitem.FieldEndIP:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetEndIP(v)
+ return nil
+ case allowlistitem.FieldStartSuffix:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetStartSuffix(v)
+ return nil
+ case allowlistitem.FieldEndSuffix:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetEndSuffix(v)
+ return nil
+ case allowlistitem.FieldIPSize:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetIPSize(v)
+ return nil
+ }
+ return fmt.Errorf("unknown AllowListItem field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *AllowListItemMutation) AddedFields() []string {
+ var fields []string
+ if m.addstart_ip != nil {
+ fields = append(fields, allowlistitem.FieldStartIP)
+ }
+ if m.addend_ip != nil {
+ fields = append(fields, allowlistitem.FieldEndIP)
+ }
+ if m.addstart_suffix != nil {
+ fields = append(fields, allowlistitem.FieldStartSuffix)
+ }
+ if m.addend_suffix != nil {
+ fields = append(fields, allowlistitem.FieldEndSuffix)
+ }
+ if m.addip_size != nil {
+ fields = append(fields, allowlistitem.FieldIPSize)
+ }
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *AllowListItemMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ case allowlistitem.FieldStartIP:
+ return m.AddedStartIP()
+ case allowlistitem.FieldEndIP:
+ return m.AddedEndIP()
+ case allowlistitem.FieldStartSuffix:
+ return m.AddedStartSuffix()
+ case allowlistitem.FieldEndSuffix:
+ return m.AddedEndSuffix()
+ case allowlistitem.FieldIPSize:
+ return m.AddedIPSize()
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AllowListItemMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ case allowlistitem.FieldStartIP:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddStartIP(v)
+ return nil
+ case allowlistitem.FieldEndIP:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddEndIP(v)
+ return nil
+ case allowlistitem.FieldStartSuffix:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddStartSuffix(v)
+ return nil
+ case allowlistitem.FieldEndSuffix:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddEndSuffix(v)
+ return nil
+ case allowlistitem.FieldIPSize:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddIPSize(v)
+ return nil
+ }
+ return fmt.Errorf("unknown AllowListItem numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *AllowListItemMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(allowlistitem.FieldExpiresAt) {
+ fields = append(fields, allowlistitem.FieldExpiresAt)
+ }
+ if m.FieldCleared(allowlistitem.FieldComment) {
+ fields = append(fields, allowlistitem.FieldComment)
+ }
+ if m.FieldCleared(allowlistitem.FieldStartIP) {
+ fields = append(fields, allowlistitem.FieldStartIP)
+ }
+ if m.FieldCleared(allowlistitem.FieldEndIP) {
+ fields = append(fields, allowlistitem.FieldEndIP)
+ }
+ if m.FieldCleared(allowlistitem.FieldStartSuffix) {
+ fields = append(fields, allowlistitem.FieldStartSuffix)
+ }
+ if m.FieldCleared(allowlistitem.FieldEndSuffix) {
+ fields = append(fields, allowlistitem.FieldEndSuffix)
+ }
+ if m.FieldCleared(allowlistitem.FieldIPSize) {
+ fields = append(fields, allowlistitem.FieldIPSize)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *AllowListItemMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *AllowListItemMutation) ClearField(name string) error {
+ switch name {
+ case allowlistitem.FieldExpiresAt:
+ m.ClearExpiresAt()
+ return nil
+ case allowlistitem.FieldComment:
+ m.ClearComment()
+ return nil
+ case allowlistitem.FieldStartIP:
+ m.ClearStartIP()
+ return nil
+ case allowlistitem.FieldEndIP:
+ m.ClearEndIP()
+ return nil
+ case allowlistitem.FieldStartSuffix:
+ m.ClearStartSuffix()
+ return nil
+ case allowlistitem.FieldEndSuffix:
+ m.ClearEndSuffix()
+ return nil
+ case allowlistitem.FieldIPSize:
+ m.ClearIPSize()
+ return nil
+ }
+ return fmt.Errorf("unknown AllowListItem nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *AllowListItemMutation) ResetField(name string) error {
+ switch name {
+ case allowlistitem.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case allowlistitem.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case allowlistitem.FieldExpiresAt:
+ m.ResetExpiresAt()
+ return nil
+ case allowlistitem.FieldComment:
+ m.ResetComment()
+ return nil
+ case allowlistitem.FieldValue:
+ m.ResetValue()
+ return nil
+ case allowlistitem.FieldStartIP:
+ m.ResetStartIP()
+ return nil
+ case allowlistitem.FieldEndIP:
+ m.ResetEndIP()
+ return nil
+ case allowlistitem.FieldStartSuffix:
+ m.ResetStartSuffix()
+ return nil
+ case allowlistitem.FieldEndSuffix:
+ m.ResetEndSuffix()
+ return nil
+ case allowlistitem.FieldIPSize:
+ m.ResetIPSize()
+ return nil
+ }
+ return fmt.Errorf("unknown AllowListItem field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *AllowListItemMutation) AddedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.allowlist != nil {
+ edges = append(edges, allowlistitem.EdgeAllowlist)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *AllowListItemMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case allowlistitem.EdgeAllowlist:
+ ids := make([]ent.Value, 0, len(m.allowlist))
+ for id := range m.allowlist {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *AllowListItemMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.removedallowlist != nil {
+ edges = append(edges, allowlistitem.EdgeAllowlist)
+ }
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *AllowListItemMutation) RemovedIDs(name string) []ent.Value {
+ switch name {
+ case allowlistitem.EdgeAllowlist:
+ ids := make([]ent.Value, 0, len(m.removedallowlist))
+ for id := range m.removedallowlist {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *AllowListItemMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.clearedallowlist {
+ edges = append(edges, allowlistitem.EdgeAllowlist)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *AllowListItemMutation) EdgeCleared(name string) bool {
+ switch name {
+ case allowlistitem.EdgeAllowlist:
+ return m.clearedallowlist
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *AllowListItemMutation) ClearEdge(name string) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown AllowListItem unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *AllowListItemMutation) ResetEdge(name string) error {
+ switch name {
+ case allowlistitem.EdgeAllowlist:
+ m.ResetAllowlist()
+ return nil
+ }
+ return fmt.Errorf("unknown AllowListItem edge %s", name)
+}
+
// BouncerMutation represents an operation that mutates the Bouncer nodes in the graph.
type BouncerMutation struct {
config
@@ -2471,6 +4419,7 @@ type BouncerMutation struct {
osname *string
osversion *string
featureflags *string
+ auto_created *bool
clearedFields map[string]struct{}
done bool
oldValue func(context.Context) (*Bouncer, error)
@@ -3134,6 +5083,42 @@ func (m *BouncerMutation) ResetFeatureflags() {
delete(m.clearedFields, bouncer.FieldFeatureflags)
}
+// SetAutoCreated sets the "auto_created" field.
+func (m *BouncerMutation) SetAutoCreated(b bool) {
+ m.auto_created = &b
+}
+
+// AutoCreated returns the value of the "auto_created" field in the mutation.
+func (m *BouncerMutation) AutoCreated() (r bool, exists bool) {
+ v := m.auto_created
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldAutoCreated returns the old "auto_created" field's value of the Bouncer entity.
+// If the Bouncer object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *BouncerMutation) OldAutoCreated(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldAutoCreated is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldAutoCreated requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldAutoCreated: %w", err)
+ }
+ return oldValue.AutoCreated, nil
+}
+
+// ResetAutoCreated resets all changes to the "auto_created" field.
+func (m *BouncerMutation) ResetAutoCreated() {
+ m.auto_created = nil
+}
+
// Where appends a list predicates to the BouncerMutation builder.
func (m *BouncerMutation) Where(ps ...predicate.Bouncer) {
m.predicates = append(m.predicates, ps...)
@@ -3168,7 +5153,7 @@ func (m *BouncerMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *BouncerMutation) Fields() []string {
- fields := make([]string, 0, 13)
+ fields := make([]string, 0, 14)
if m.created_at != nil {
fields = append(fields, bouncer.FieldCreatedAt)
}
@@ -3208,6 +5193,9 @@ func (m *BouncerMutation) Fields() []string {
if m.featureflags != nil {
fields = append(fields, bouncer.FieldFeatureflags)
}
+ if m.auto_created != nil {
+ fields = append(fields, bouncer.FieldAutoCreated)
+ }
return fields
}
@@ -3242,6 +5230,8 @@ func (m *BouncerMutation) Field(name string) (ent.Value, bool) {
return m.Osversion()
case bouncer.FieldFeatureflags:
return m.Featureflags()
+ case bouncer.FieldAutoCreated:
+ return m.AutoCreated()
}
return nil, false
}
@@ -3277,6 +5267,8 @@ func (m *BouncerMutation) OldField(ctx context.Context, name string) (ent.Value,
return m.OldOsversion(ctx)
case bouncer.FieldFeatureflags:
return m.OldFeatureflags(ctx)
+ case bouncer.FieldAutoCreated:
+ return m.OldAutoCreated(ctx)
}
return nil, fmt.Errorf("unknown Bouncer field %s", name)
}
@@ -3377,6 +5369,13 @@ func (m *BouncerMutation) SetField(name string, value ent.Value) error {
}
m.SetFeatureflags(v)
return nil
+ case bouncer.FieldAutoCreated:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAutoCreated(v)
+ return nil
}
return fmt.Errorf("unknown Bouncer field %s", name)
}
@@ -3510,6 +5509,9 @@ func (m *BouncerMutation) ResetField(name string) error {
case bouncer.FieldFeatureflags:
m.ResetFeatureflags()
return nil
+ case bouncer.FieldAutoCreated:
+ m.ResetAutoCreated()
+ return nil
}
return fmt.Errorf("unknown Bouncer field %s", name)
}
diff --git a/pkg/database/ent/predicate/predicate.go b/pkg/database/ent/predicate/predicate.go
index 8ad03e2fc48..97e574aa167 100644
--- a/pkg/database/ent/predicate/predicate.go
+++ b/pkg/database/ent/predicate/predicate.go
@@ -9,6 +9,12 @@ import (
// Alert is the predicate function for alert builders.
type Alert func(*sql.Selector)
+// AllowList is the predicate function for allowlist builders.
+type AllowList func(*sql.Selector)
+
+// AllowListItem is the predicate function for allowlistitem builders.
+type AllowListItem func(*sql.Selector)
+
// Bouncer is the predicate function for bouncer builders.
type Bouncer func(*sql.Selector)
diff --git a/pkg/database/ent/runtime.go b/pkg/database/ent/runtime.go
index 15413490633..989e67fda7d 100644
--- a/pkg/database/ent/runtime.go
+++ b/pkg/database/ent/runtime.go
@@ -6,6 +6,8 @@ import (
"time"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/alert"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlist"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/decision"
@@ -56,6 +58,30 @@ func init() {
alertDescSimulated := alertFields[21].Descriptor()
// alert.DefaultSimulated holds the default value on creation for the simulated field.
alert.DefaultSimulated = alertDescSimulated.Default.(bool)
+ allowlistFields := schema.AllowList{}.Fields()
+ _ = allowlistFields
+ // allowlistDescCreatedAt is the schema descriptor for created_at field.
+ allowlistDescCreatedAt := allowlistFields[0].Descriptor()
+ // allowlist.DefaultCreatedAt holds the default value on creation for the created_at field.
+ allowlist.DefaultCreatedAt = allowlistDescCreatedAt.Default.(func() time.Time)
+ // allowlistDescUpdatedAt is the schema descriptor for updated_at field.
+ allowlistDescUpdatedAt := allowlistFields[1].Descriptor()
+ // allowlist.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ allowlist.DefaultUpdatedAt = allowlistDescUpdatedAt.Default.(func() time.Time)
+ // allowlist.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ allowlist.UpdateDefaultUpdatedAt = allowlistDescUpdatedAt.UpdateDefault.(func() time.Time)
+ allowlistitemFields := schema.AllowListItem{}.Fields()
+ _ = allowlistitemFields
+ // allowlistitemDescCreatedAt is the schema descriptor for created_at field.
+ allowlistitemDescCreatedAt := allowlistitemFields[0].Descriptor()
+ // allowlistitem.DefaultCreatedAt holds the default value on creation for the created_at field.
+ allowlistitem.DefaultCreatedAt = allowlistitemDescCreatedAt.Default.(func() time.Time)
+ // allowlistitemDescUpdatedAt is the schema descriptor for updated_at field.
+ allowlistitemDescUpdatedAt := allowlistitemFields[1].Descriptor()
+ // allowlistitem.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ allowlistitem.DefaultUpdatedAt = allowlistitemDescUpdatedAt.Default.(func() time.Time)
+ // allowlistitem.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ allowlistitem.UpdateDefaultUpdatedAt = allowlistitemDescUpdatedAt.UpdateDefault.(func() time.Time)
bouncerFields := schema.Bouncer{}.Fields()
_ = bouncerFields
// bouncerDescCreatedAt is the schema descriptor for created_at field.
@@ -76,6 +102,10 @@ func init() {
bouncerDescAuthType := bouncerFields[9].Descriptor()
// bouncer.DefaultAuthType holds the default value on creation for the auth_type field.
bouncer.DefaultAuthType = bouncerDescAuthType.Default.(string)
+ // bouncerDescAutoCreated is the schema descriptor for auto_created field.
+ bouncerDescAutoCreated := bouncerFields[13].Descriptor()
+ // bouncer.DefaultAutoCreated holds the default value on creation for the auto_created field.
+ bouncer.DefaultAutoCreated = bouncerDescAutoCreated.Default.(bool)
configitemFields := schema.ConfigItem{}.Fields()
_ = configitemFields
// configitemDescCreatedAt is the schema descriptor for created_at field.
diff --git a/pkg/database/ent/schema/allowlist.go b/pkg/database/ent/schema/allowlist.go
new file mode 100644
index 00000000000..1df878ac87a
--- /dev/null
+++ b/pkg/database/ent/schema/allowlist.go
@@ -0,0 +1,44 @@
+package schema
+
+import (
+ "entgo.io/ent"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+
+ "github.com/crowdsecurity/crowdsec/pkg/types"
+)
+
+// Alert holds the schema definition for the Alert entity.
+type AllowList struct {
+ ent.Schema
+}
+
+// Fields of the Alert.
+func (AllowList) Fields() []ent.Field {
+ return []ent.Field{
+ field.Time("created_at").
+ Default(types.UtcNow).
+ Immutable(),
+ field.Time("updated_at").
+ Default(types.UtcNow).
+ UpdateDefault(types.UtcNow),
+ field.String("name").Immutable(),
+ field.Bool("from_console"),
+ field.String("description").Optional().Immutable(),
+ field.String("allowlist_id").Optional().Immutable(),
+ }
+}
+
+func (AllowList) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("id").Unique(),
+ index.Fields("name").Unique(),
+ }
+}
+
+func (AllowList) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.To("allowlist_items", AllowListItem.Type),
+ }
+}
diff --git a/pkg/database/ent/schema/allowlist_item.go b/pkg/database/ent/schema/allowlist_item.go
new file mode 100644
index 00000000000..907b8a44729
--- /dev/null
+++ b/pkg/database/ent/schema/allowlist_item.go
@@ -0,0 +1,51 @@
+package schema
+
+import (
+ "entgo.io/ent"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+
+ "github.com/crowdsecurity/crowdsec/pkg/types"
+)
+
+// AllowListItem holds the schema definition for the AllowListItem entity.
+type AllowListItem struct {
+ ent.Schema
+}
+
+// Fields of the AllowListItem.
+func (AllowListItem) Fields() []ent.Field {
+ return []ent.Field{
+ field.Time("created_at").
+ Default(types.UtcNow).
+ Immutable(),
+ field.Time("updated_at").
+ Default(types.UtcNow).
+ UpdateDefault(types.UtcNow),
+ field.Time("expires_at").
+ Optional(),
+ field.String("comment").Optional().Immutable(),
+ field.String("value").Immutable(), //For textual representation of the IP/range
+ //Use the same fields as the decision table
+ field.Int64("start_ip").Optional().Immutable(),
+ field.Int64("end_ip").Optional().Immutable(),
+ field.Int64("start_suffix").Optional().Immutable(),
+ field.Int64("end_suffix").Optional().Immutable(),
+ field.Int64("ip_size").Optional().Immutable(),
+ }
+}
+
+func (AllowListItem) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("id"),
+ index.Fields("start_ip", "end_ip"),
+ }
+}
+
+func (AllowListItem) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("allowlist", AllowList.Type).
+ Ref("allowlist_items"),
+ }
+}
diff --git a/pkg/database/ent/schema/bouncer.go b/pkg/database/ent/schema/bouncer.go
index 599c4c404fc..c176bf0f766 100644
--- a/pkg/database/ent/schema/bouncer.go
+++ b/pkg/database/ent/schema/bouncer.go
@@ -33,6 +33,8 @@ func (Bouncer) Fields() []ent.Field {
field.String("osname").Optional(),
field.String("osversion").Optional(),
field.String("featureflags").Optional(),
+ // Old auto-created TLS bouncers will have a wrong value for this field
+ field.Bool("auto_created").StructTag(`json:"auto_created"`).Default(false).Immutable(),
}
}
diff --git a/pkg/database/ent/tx.go b/pkg/database/ent/tx.go
index bf8221ce4a5..69983beebc5 100644
--- a/pkg/database/ent/tx.go
+++ b/pkg/database/ent/tx.go
@@ -14,6 +14,10 @@ type Tx struct {
config
// Alert is the client for interacting with the Alert builders.
Alert *AlertClient
+ // AllowList is the client for interacting with the AllowList builders.
+ AllowList *AllowListClient
+ // AllowListItem is the client for interacting with the AllowListItem builders.
+ AllowListItem *AllowListItemClient
// Bouncer is the client for interacting with the Bouncer builders.
Bouncer *BouncerClient
// ConfigItem is the client for interacting with the ConfigItem builders.
@@ -162,6 +166,8 @@ func (tx *Tx) Client() *Client {
func (tx *Tx) init() {
tx.Alert = NewAlertClient(tx.config)
+ tx.AllowList = NewAllowListClient(tx.config)
+ tx.AllowListItem = NewAllowListItemClient(tx.config)
tx.Bouncer = NewBouncerClient(tx.config)
tx.ConfigItem = NewConfigItemClient(tx.config)
tx.Decision = NewDecisionClient(tx.config)
diff --git a/pkg/database/flush.go b/pkg/database/flush.go
index 8f646ddc961..e66e47a7635 100644
--- a/pkg/database/flush.go
+++ b/pkg/database/flush.go
@@ -13,6 +13,7 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/alert"
+ "github.com/crowdsecurity/crowdsec/pkg/database/ent/allowlistitem"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/decision"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/event"
@@ -115,6 +116,13 @@ func (c *Client) StartFlushScheduler(ctx context.Context, config *csconfig.Flush
metricsJob.SingletonMode()
+ allowlistsJob, err := scheduler.Every(flushInterval).Do(c.flushAllowlists, ctx)
+ if err != nil {
+ return nil, fmt.Errorf("while starting FlushAllowlists scheduler: %w", err)
+ }
+
+ allowlistsJob.SingletonMode()
+
scheduler.StartAsync()
return scheduler, nil
@@ -309,3 +317,17 @@ func (c *Client) FlushAlerts(ctx context.Context, MaxAge string, MaxItems int) e
return nil
}
+
+func (c *Client) flushAllowlists(ctx context.Context) {
+ deleted, err := c.Ent.AllowListItem.Delete().Where(
+ allowlistitem.ExpiresAtLTE(time.Now().UTC()),
+ ).Exec(ctx)
+ if err != nil {
+ c.Log.Errorf("while flushing allowlists: %s", err)
+ return
+ }
+
+ if deleted > 0 {
+ c.Log.Debugf("flushed %d allowlists", deleted)
+ }
+}
diff --git a/pkg/exprhelpers/crowdsec_cti.go b/pkg/exprhelpers/crowdsec_cti.go
index ccd67b27a49..9b9eac4b95c 100644
--- a/pkg/exprhelpers/crowdsec_cti.go
+++ b/pkg/exprhelpers/crowdsec_cti.go
@@ -12,16 +12,20 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/types"
)
-var CTIUrl = "https://cti.api.crowdsec.net"
-var CTIUrlSuffix = "/v2/smoke/"
-var CTIApiKey = ""
+var (
+ CTIUrl = "https://cti.api.crowdsec.net"
+ CTIUrlSuffix = "/v2/smoke/"
+ CTIApiKey = ""
+)
// this is set for non-recoverable errors, such as 403 when querying API or empty API key
var CTIApiEnabled = false
// when hitting quotas or auth errors, we temporarily disable the API
-var CTIBackOffUntil time.Time
-var CTIBackOffDuration = 5 * time.Minute
+var (
+ CTIBackOffUntil time.Time
+ CTIBackOffDuration = 5 * time.Minute
+)
var ctiClient *cticlient.CrowdsecCTIClient
@@ -62,8 +66,10 @@ func ShutdownCrowdsecCTI() {
}
// Cache for responses
-var CTICache gcache.Cache
-var CacheExpiration time.Duration
+var (
+ CTICache gcache.Cache
+ CacheExpiration time.Duration
+)
func CrowdsecCTIInitCache(size int, ttl time.Duration) {
CTICache = gcache.New(size).LRU().Build()
diff --git a/pkg/exprhelpers/geoip.go b/pkg/exprhelpers/geoip.go
index fb0c344d884..6d8813dc0ad 100644
--- a/pkg/exprhelpers/geoip.go
+++ b/pkg/exprhelpers/geoip.go
@@ -14,7 +14,6 @@ func GeoIPEnrich(params ...any) (any, error) {
parsedIP := net.ParseIP(ip)
city, err := geoIPCityReader.City(parsedIP)
-
if err != nil {
return nil, err
}
@@ -31,7 +30,6 @@ func GeoIPASNEnrich(params ...any) (any, error) {
parsedIP := net.ParseIP(ip)
asn, err := geoIPASNReader.ASN(parsedIP)
-
if err != nil {
return nil, err
}
@@ -50,7 +48,6 @@ func GeoIPRangeEnrich(params ...any) (any, error) {
parsedIP := net.ParseIP(ip)
rangeIP, ok, err := geoIPRangeReader.LookupNetwork(parsedIP, &dummy)
-
if err != nil {
return nil, err
}
diff --git a/pkg/fflag/crowdsec.go b/pkg/fflag/crowdsec.go
index d42d6a05ef6..ea397bfe5bc 100644
--- a/pkg/fflag/crowdsec.go
+++ b/pkg/fflag/crowdsec.go
@@ -2,12 +2,14 @@ package fflag
var Crowdsec = FeatureRegister{EnvPrefix: "CROWDSEC_FEATURE_"}
-var CscliSetup = &Feature{Name: "cscli_setup", Description: "Enable cscli setup command (service detection)"}
-var DisableHttpRetryBackoff = &Feature{Name: "disable_http_retry_backoff", Description: "Disable http retry backoff"}
-var ChunkedDecisionsStream = &Feature{Name: "chunked_decisions_stream", Description: "Enable chunked decisions stream"}
-var PapiClient = &Feature{Name: "papi_client", Description: "Enable Polling API client", State: DeprecatedState}
-var Re2GrokSupport = &Feature{Name: "re2_grok_support", Description: "Enable RE2 support for GROK patterns"}
-var Re2RegexpInfileSupport = &Feature{Name: "re2_regexp_in_file_support", Description: "Enable RE2 support for RegexpInFile expr helper"}
+var (
+ CscliSetup = &Feature{Name: "cscli_setup", Description: "Enable cscli setup command (service detection)"}
+ DisableHttpRetryBackoff = &Feature{Name: "disable_http_retry_backoff", Description: "Disable http retry backoff"}
+ ChunkedDecisionsStream = &Feature{Name: "chunked_decisions_stream", Description: "Enable chunked decisions stream"}
+ PapiClient = &Feature{Name: "papi_client", Description: "Enable Polling API client", State: DeprecatedState}
+ Re2GrokSupport = &Feature{Name: "re2_grok_support", Description: "Enable RE2 support for GROK patterns"}
+ Re2RegexpInfileSupport = &Feature{Name: "re2_regexp_in_file_support", Description: "Enable RE2 support for RegexpInFile expr helper"}
+)
func RegisterAllFeatures() error {
err := Crowdsec.RegisterFeature(CscliSetup)
diff --git a/pkg/leakybucket/blackhole.go b/pkg/leakybucket/blackhole.go
index b12f169acd9..bda2e7c9ed1 100644
--- a/pkg/leakybucket/blackhole.go
+++ b/pkg/leakybucket/blackhole.go
@@ -49,7 +49,6 @@ func (bl *Blackhole) OnBucketOverflow(bucketFactory *BucketFactory) func(*Leaky,
tmp = append(tmp, element)
} else {
leaky.logger.Debugf("%s left blackhole %s ago", element.key, leaky.Ovflw_ts.Sub(element.expiration))
-
}
}
bl.hiddenKeys = tmp
@@ -64,5 +63,4 @@ func (bl *Blackhole) OnBucketOverflow(bucketFactory *BucketFactory) func(*Leaky,
leaky.logger.Debugf("Adding overflow to blackhole (%s)", leaky.First_ts)
return alert, queue
}
-
}
diff --git a/pkg/leakybucket/bucket.go b/pkg/leakybucket/bucket.go
index e981551af8f..bc81a505925 100644
--- a/pkg/leakybucket/bucket.go
+++ b/pkg/leakybucket/bucket.go
@@ -204,7 +204,6 @@ func FromFactory(bucketFactory BucketFactory) *Leaky {
/* for now mimic a leak routine */
//LeakRoutine us the life of a bucket. It dies when the bucket underflows or overflows
func LeakRoutine(leaky *Leaky) error {
-
var (
durationTickerChan = make(<-chan time.Time)
durationTicker *time.Ticker
diff --git a/pkg/leakybucket/buckets.go b/pkg/leakybucket/buckets.go
index cfe8d7c302e..72948da1ad7 100644
--- a/pkg/leakybucket/buckets.go
+++ b/pkg/leakybucket/buckets.go
@@ -25,5 +25,4 @@ func NewBuckets() *Buckets {
func GetKey(bucketCfg BucketFactory, stackkey string) string {
return fmt.Sprintf("%x", sha1.Sum([]byte(bucketCfg.Filter+stackkey+bucketCfg.Name)))
-
}
diff --git a/pkg/leakybucket/conditional.go b/pkg/leakybucket/conditional.go
index a203a639743..b3a84b07c21 100644
--- a/pkg/leakybucket/conditional.go
+++ b/pkg/leakybucket/conditional.go
@@ -11,8 +11,10 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/types"
)
-var conditionalExprCache map[string]vm.Program
-var conditionalExprCacheLock sync.Mutex
+var (
+ conditionalExprCache map[string]vm.Program
+ conditionalExprCacheLock sync.Mutex
+)
type ConditionalOverflow struct {
ConditionalFilter string
diff --git a/pkg/leakybucket/manager_load_test.go b/pkg/leakybucket/manager_load_test.go
index 513f11ff373..9d207da164e 100644
--- a/pkg/leakybucket/manager_load_test.go
+++ b/pkg/leakybucket/manager_load_test.go
@@ -51,93 +51,86 @@ func TestBadBucketsConfig(t *testing.T) {
}
func TestLeakyBucketsConfig(t *testing.T) {
- var CfgTests = []cfgTest{
- //leaky with bad capacity
+ CfgTests := []cfgTest{
+ // leaky with bad capacity
{BucketFactory{Name: "test", Description: "test1", Type: "leaky", Capacity: 0}, false, false},
- //leaky with empty leakspeed
+ // leaky with empty leakspeed
{BucketFactory{Name: "test", Description: "test1", Type: "leaky", Capacity: 1}, false, false},
- //leaky with missing filter
+ // leaky with missing filter
{BucketFactory{Name: "test", Description: "test1", Type: "leaky", Capacity: 1, LeakSpeed: "1s"}, false, true},
- //leaky with invalid leakspeed
+ // leaky with invalid leakspeed
{BucketFactory{Name: "test", Description: "test1", Type: "leaky", Capacity: 1, LeakSpeed: "abs", Filter: "true"}, false, false},
- //leaky with valid filter
+ // leaky with valid filter
{BucketFactory{Name: "test", Description: "test1", Type: "leaky", Capacity: 1, LeakSpeed: "1s", Filter: "true"}, true, true},
- //leaky with invalid filter
+ // leaky with invalid filter
{BucketFactory{Name: "test", Description: "test1", Type: "leaky", Capacity: 1, LeakSpeed: "1s", Filter: "xu"}, false, true},
- //leaky with valid filter
+ // leaky with valid filter
{BucketFactory{Name: "test", Description: "test1", Type: "leaky", Capacity: 1, LeakSpeed: "1s", Filter: "true"}, true, true},
- //leaky with bad overflow filter
+ // leaky with bad overflow filter
{BucketFactory{Name: "test", Description: "test1", Type: "leaky", Capacity: 1, LeakSpeed: "1s", Filter: "true", OverflowFilter: "xu"}, false, true},
}
if err := runTest(CfgTests); err != nil {
t.Fatalf("%s", err)
}
-
}
func TestBlackholeConfig(t *testing.T) {
- var CfgTests = []cfgTest{
- //basic bh
+ CfgTests := []cfgTest{
+ // basic bh
{BucketFactory{Name: "test", Description: "test1", Type: "trigger", Filter: "true", Blackhole: "15s"}, true, true},
- //bad bh
+ // bad bh
{BucketFactory{Name: "test", Description: "test1", Type: "trigger", Filter: "true", Blackhole: "abc"}, false, true},
}
if err := runTest(CfgTests); err != nil {
t.Fatalf("%s", err)
}
-
}
func TestTriggerBucketsConfig(t *testing.T) {
- var CfgTests = []cfgTest{
- //basic valid counter
+ CfgTests := []cfgTest{
+ // basic valid counter
{BucketFactory{Name: "test", Description: "test1", Type: "trigger", Filter: "true"}, true, true},
}
if err := runTest(CfgTests); err != nil {
t.Fatalf("%s", err)
}
-
}
func TestCounterBucketsConfig(t *testing.T) {
- var CfgTests = []cfgTest{
-
- //basic valid counter
+ CfgTests := []cfgTest{
+ // basic valid counter
{BucketFactory{Name: "test", Description: "test1", Type: "counter", Capacity: -1, Duration: "5s", Filter: "true"}, true, true},
- //missing duration
+ // missing duration
{BucketFactory{Name: "test", Description: "test1", Type: "counter", Capacity: -1, Filter: "true"}, false, false},
- //bad duration
+ // bad duration
{BucketFactory{Name: "test", Description: "test1", Type: "counter", Capacity: -1, Duration: "abc", Filter: "true"}, false, false},
- //capacity must be -1
+ // capacity must be -1
{BucketFactory{Name: "test", Description: "test1", Type: "counter", Capacity: 0, Duration: "5s", Filter: "true"}, false, false},
}
if err := runTest(CfgTests); err != nil {
t.Fatalf("%s", err)
}
-
}
func TestBayesianBucketsConfig(t *testing.T) {
- var CfgTests = []cfgTest{
-
- //basic valid counter
+ CfgTests := []cfgTest{
+ // basic valid counter
{BucketFactory{Name: "test", Description: "test1", Type: "bayesian", Capacity: -1, Filter: "true", BayesianPrior: 0.5, BayesianThreshold: 0.5, BayesianConditions: []RawBayesianCondition{{ConditionalFilterName: "true", ProbGivenEvil: 0.5, ProbGivenBenign: 0.5}}}, true, true},
- //bad capacity
+ // bad capacity
{BucketFactory{Name: "test", Description: "test1", Type: "bayesian", Capacity: 1, Filter: "true", BayesianPrior: 0.5, BayesianThreshold: 0.5, BayesianConditions: []RawBayesianCondition{{ConditionalFilterName: "true", ProbGivenEvil: 0.5, ProbGivenBenign: 0.5}}}, false, false},
- //missing prior
+ // missing prior
{BucketFactory{Name: "test", Description: "test1", Type: "bayesian", Capacity: -1, Filter: "true", BayesianThreshold: 0.5, BayesianConditions: []RawBayesianCondition{{ConditionalFilterName: "true", ProbGivenEvil: 0.5, ProbGivenBenign: 0.5}}}, false, false},
- //missing threshold
+ // missing threshold
{BucketFactory{Name: "test", Description: "test1", Type: "bayesian", Capacity: -1, Filter: "true", BayesianPrior: 0.5, BayesianConditions: []RawBayesianCondition{{ConditionalFilterName: "true", ProbGivenEvil: 0.5, ProbGivenBenign: 0.5}}}, false, false},
- //bad prior
+ // bad prior
{BucketFactory{Name: "test", Description: "test1", Type: "bayesian", Capacity: -1, Filter: "true", BayesianPrior: 1.5, BayesianThreshold: 0.5, BayesianConditions: []RawBayesianCondition{{ConditionalFilterName: "true", ProbGivenEvil: 0.5, ProbGivenBenign: 0.5}}}, false, false},
- //bad threshold
+ // bad threshold
{BucketFactory{Name: "test", Description: "test1", Type: "bayesian", Capacity: -1, Filter: "true", BayesianPrior: 0.5, BayesianThreshold: 1.5, BayesianConditions: []RawBayesianCondition{{ConditionalFilterName: "true", ProbGivenEvil: 0.5, ProbGivenBenign: 0.5}}}, false, false},
}
if err := runTest(CfgTests); err != nil {
t.Fatalf("%s", err)
}
-
}
diff --git a/pkg/leakybucket/manager_run.go b/pkg/leakybucket/manager_run.go
index 2858d8b5635..e6712e6e47e 100644
--- a/pkg/leakybucket/manager_run.go
+++ b/pkg/leakybucket/manager_run.go
@@ -17,9 +17,11 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/types"
)
-var serialized map[string]Leaky
-var BucketPourCache map[string][]types.Event
-var BucketPourTrack bool
+var (
+ serialized map[string]Leaky
+ BucketPourCache map[string][]types.Event
+ BucketPourTrack bool
+)
/*
The leaky routines lifecycle are based on "real" time.
@@ -243,7 +245,6 @@ func PourItemToBucket(bucket *Leaky, holder BucketFactory, buckets *Buckets, par
}
func LoadOrStoreBucketFromHolder(partitionKey string, buckets *Buckets, holder BucketFactory, expectMode int) (*Leaky, error) {
-
biface, ok := buckets.Bucket_map.Load(partitionKey)
/* the bucket doesn't exist, create it !*/
@@ -283,9 +284,7 @@ func LoadOrStoreBucketFromHolder(partitionKey string, buckets *Buckets, holder B
var orderEvent map[string]*sync.WaitGroup
func PourItemToHolders(parsed types.Event, holders []BucketFactory, buckets *Buckets) (bool, error) {
- var (
- ok, condition, poured bool
- )
+ var ok, condition, poured bool
if BucketPourTrack {
if BucketPourCache == nil {
diff --git a/pkg/leakybucket/overflows.go b/pkg/leakybucket/overflows.go
index 39b0e6a0ec4..126bcd05685 100644
--- a/pkg/leakybucket/overflows.go
+++ b/pkg/leakybucket/overflows.go
@@ -198,22 +198,24 @@ func eventSources(evt types.Event, leaky *Leaky) (map[string]models.Source, erro
func EventsFromQueue(queue *types.Queue) []*models.Event {
events := []*models.Event{}
- for _, evt := range queue.Queue {
- if evt.Meta == nil {
+ qEvents := queue.GetQueue()
+
+ for idx := range qEvents {
+ if qEvents[idx].Meta == nil {
continue
}
meta := models.Meta{}
// we want consistence
- skeys := make([]string, 0, len(evt.Meta))
- for k := range evt.Meta {
+ skeys := make([]string, 0, len(qEvents[idx].Meta))
+ for k := range qEvents[idx].Meta {
skeys = append(skeys, k)
}
sort.Strings(skeys)
for _, k := range skeys {
- v := evt.Meta[k]
+ v := qEvents[idx].Meta[k]
subMeta := models.MetaItems0{Key: k, Value: v}
meta = append(meta, &subMeta)
}
@@ -223,15 +225,15 @@ func EventsFromQueue(queue *types.Queue) []*models.Event {
Meta: meta,
}
// either MarshaledTime is present and is extracted from log
- if evt.MarshaledTime != "" {
- tmpTimeStamp := evt.MarshaledTime
+ if qEvents[idx].MarshaledTime != "" {
+ tmpTimeStamp := qEvents[idx].MarshaledTime
ovflwEvent.Timestamp = &tmpTimeStamp
- } else if !evt.Time.IsZero() { // or .Time has been set during parse as time.Now().UTC()
+ } else if !qEvents[idx].Time.IsZero() { // or .Time has been set during parse as time.Now().UTC()
ovflwEvent.Timestamp = new(string)
- raw, err := evt.Time.MarshalText()
+ raw, err := qEvents[idx].Time.MarshalText()
if err != nil {
- log.Warningf("while serializing time '%s' : %s", evt.Time.String(), err)
+ log.Warningf("while serializing time '%s' : %s", qEvents[idx].Time.String(), err)
} else {
*ovflwEvent.Timestamp = string(raw)
}
@@ -253,8 +255,9 @@ func alertFormatSource(leaky *Leaky, queue *types.Queue) (map[string]models.Sour
log.Debugf("Formatting (%s) - scope Info : scope_type:%s / scope_filter:%s", leaky.Name, leaky.scopeType.Scope, leaky.scopeType.Filter)
- for _, evt := range queue.Queue {
- srcs, err := SourceFromEvent(evt, leaky)
+ qEvents := queue.GetQueue()
+ for idx := range qEvents {
+ srcs, err := SourceFromEvent(qEvents[idx], leaky)
if err != nil {
return nil, "", fmt.Errorf("while extracting scope from bucket %s: %w", leaky.Name, err)
}
diff --git a/pkg/leakybucket/processor.go b/pkg/leakybucket/processor.go
index 81af3000c1c..dc5330a612e 100644
--- a/pkg/leakybucket/processor.go
+++ b/pkg/leakybucket/processor.go
@@ -10,8 +10,7 @@ type Processor interface {
AfterBucketPour(Bucket *BucketFactory) func(types.Event, *Leaky) *types.Event
}
-type DumbProcessor struct {
-}
+type DumbProcessor struct{}
func (d *DumbProcessor) OnBucketInit(bucketFactory *BucketFactory) error {
return nil
diff --git a/pkg/leakybucket/reset_filter.go b/pkg/leakybucket/reset_filter.go
index 452ccc085b1..3b9b876aff4 100644
--- a/pkg/leakybucket/reset_filter.go
+++ b/pkg/leakybucket/reset_filter.go
@@ -23,10 +23,12 @@ type CancelOnFilter struct {
Debug bool
}
-var cancelExprCacheLock sync.Mutex
-var cancelExprCache map[string]struct {
- CancelOnFilter *vm.Program
-}
+var (
+ cancelExprCacheLock sync.Mutex
+ cancelExprCache map[string]struct {
+ CancelOnFilter *vm.Program
+ }
+)
func (u *CancelOnFilter) OnBucketPour(bucketFactory *BucketFactory) func(types.Event, *Leaky) *types.Event {
return func(msg types.Event, leaky *Leaky) *types.Event {
diff --git a/pkg/leakybucket/uniq.go b/pkg/leakybucket/uniq.go
index 0cc0583390b..3a4683ae309 100644
--- a/pkg/leakybucket/uniq.go
+++ b/pkg/leakybucket/uniq.go
@@ -16,8 +16,10 @@ import (
// on overflow
// on leak
-var uniqExprCache map[string]vm.Program
-var uniqExprCacheLock sync.Mutex
+var (
+ uniqExprCache map[string]vm.Program
+ uniqExprCacheLock sync.Mutex
+)
type Uniq struct {
DistinctCompiled *vm.Program
diff --git a/pkg/longpollclient/client.go b/pkg/longpollclient/client.go
index 5a7af0bfa63..5c395185b20 100644
--- a/pkg/longpollclient/client.go
+++ b/pkg/longpollclient/client.go
@@ -1,6 +1,7 @@
package longpollclient
import (
+ "context"
"encoding/json"
"errors"
"fmt"
@@ -50,7 +51,7 @@ var errUnauthorized = errors.New("user is not authorized to use PAPI")
const timeoutMessage = "no events before timeout"
-func (c *LongPollClient) doQuery() (*http.Response, error) {
+func (c *LongPollClient) doQuery(ctx context.Context) (*http.Response, error) {
logger := c.logger.WithField("method", "doQuery")
query := c.url.Query()
query.Set("since_time", fmt.Sprintf("%d", c.since))
@@ -59,7 +60,7 @@ func (c *LongPollClient) doQuery() (*http.Response, error) {
logger.Debugf("Query parameters: %s", c.url.RawQuery)
- req, err := http.NewRequest(http.MethodGet, c.url.String(), nil)
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url.String(), nil)
if err != nil {
logger.Errorf("failed to create request: %s", err)
return nil, err
@@ -73,10 +74,10 @@ func (c *LongPollClient) doQuery() (*http.Response, error) {
return resp, nil
}
-func (c *LongPollClient) poll() error {
+func (c *LongPollClient) poll(ctx context.Context) error {
logger := c.logger.WithField("method", "poll")
- resp, err := c.doQuery()
+ resp, err := c.doQuery(ctx)
if err != nil {
return err
}
@@ -146,7 +147,7 @@ func (c *LongPollClient) poll() error {
}
}
-func (c *LongPollClient) pollEvents() error {
+func (c *LongPollClient) pollEvents(ctx context.Context) error {
for {
select {
case <-c.t.Dying():
@@ -154,7 +155,7 @@ func (c *LongPollClient) pollEvents() error {
return nil
default:
c.logger.Debug("Polling PAPI")
- err := c.poll()
+ err := c.poll(ctx)
if err != nil {
c.logger.Errorf("failed to poll: %s", err)
if errors.Is(err, errUnauthorized) {
@@ -168,12 +169,12 @@ func (c *LongPollClient) pollEvents() error {
}
}
-func (c *LongPollClient) Start(since time.Time) chan Event {
+func (c *LongPollClient) Start(ctx context.Context, since time.Time) chan Event {
c.logger.Infof("starting polling client")
c.c = make(chan Event)
c.since = since.Unix() * 1000
c.timeout = "45"
- c.t.Go(c.pollEvents)
+ c.t.Go(func() error {return c.pollEvents(ctx)})
return c.c
}
@@ -182,11 +183,11 @@ func (c *LongPollClient) Stop() error {
return nil
}
-func (c *LongPollClient) PullOnce(since time.Time) ([]Event, error) {
+func (c *LongPollClient) PullOnce(ctx context.Context, since time.Time) ([]Event, error) {
c.logger.Debug("Pulling PAPI once")
c.since = since.Unix() * 1000
c.timeout = "1"
- resp, err := c.doQuery()
+ resp, err := c.doQuery(ctx)
if err != nil {
return nil, err
}
diff --git a/pkg/models/allowlist_item.go b/pkg/models/allowlist_item.go
new file mode 100644
index 00000000000..3d688d52e5d
--- /dev/null
+++ b/pkg/models/allowlist_item.go
@@ -0,0 +1,100 @@
+// Code generated by go-swagger; DO NOT EDIT.
+
+package models
+
+// This file was generated by the swagger tool.
+// Editing this file might prove futile when you re-run the swagger generate command
+
+import (
+ "context"
+
+ "github.com/go-openapi/errors"
+ "github.com/go-openapi/strfmt"
+ "github.com/go-openapi/swag"
+ "github.com/go-openapi/validate"
+)
+
+// AllowlistItem AllowlistItem
+//
+// swagger:model AllowlistItem
+type AllowlistItem struct {
+
+ // creation date of the allowlist item
+ // Format: date-time
+ CreatedAt strfmt.DateTime `json:"created_at,omitempty"`
+
+ // description of the allowlist item
+ Description string `json:"description,omitempty"`
+
+ // expiration date of the allowlist item
+ // Format: date-time
+ Expiration strfmt.DateTime `json:"expiration,omitempty"`
+
+ // value of the allowlist item
+ Value string `json:"value,omitempty"`
+}
+
+// Validate validates this allowlist item
+func (m *AllowlistItem) Validate(formats strfmt.Registry) error {
+ var res []error
+
+ if err := m.validateCreatedAt(formats); err != nil {
+ res = append(res, err)
+ }
+
+ if err := m.validateExpiration(formats); err != nil {
+ res = append(res, err)
+ }
+
+ if len(res) > 0 {
+ return errors.CompositeValidationError(res...)
+ }
+ return nil
+}
+
+func (m *AllowlistItem) validateCreatedAt(formats strfmt.Registry) error {
+ if swag.IsZero(m.CreatedAt) { // not required
+ return nil
+ }
+
+ if err := validate.FormatOf("created_at", "body", "date-time", m.CreatedAt.String(), formats); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (m *AllowlistItem) validateExpiration(formats strfmt.Registry) error {
+ if swag.IsZero(m.Expiration) { // not required
+ return nil
+ }
+
+ if err := validate.FormatOf("expiration", "body", "date-time", m.Expiration.String(), formats); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// ContextValidate validates this allowlist item based on context it is used
+func (m *AllowlistItem) ContextValidate(ctx context.Context, formats strfmt.Registry) error {
+ return nil
+}
+
+// MarshalBinary interface implementation
+func (m *AllowlistItem) MarshalBinary() ([]byte, error) {
+ if m == nil {
+ return nil, nil
+ }
+ return swag.WriteJSON(m)
+}
+
+// UnmarshalBinary interface implementation
+func (m *AllowlistItem) UnmarshalBinary(b []byte) error {
+ var res AllowlistItem
+ if err := swag.ReadJSON(b, &res); err != nil {
+ return err
+ }
+ *m = res
+ return nil
+}
diff --git a/pkg/models/check_allowlist_response.go b/pkg/models/check_allowlist_response.go
new file mode 100644
index 00000000000..ab57d4fea71
--- /dev/null
+++ b/pkg/models/check_allowlist_response.go
@@ -0,0 +1,50 @@
+// Code generated by go-swagger; DO NOT EDIT.
+
+package models
+
+// This file was generated by the swagger tool.
+// Editing this file might prove futile when you re-run the swagger generate command
+
+import (
+ "context"
+
+ "github.com/go-openapi/strfmt"
+ "github.com/go-openapi/swag"
+)
+
+// CheckAllowlistResponse CheckAllowlistResponse
+//
+// swagger:model CheckAllowlistResponse
+type CheckAllowlistResponse struct {
+
+ // true if the IP or range is in the allowlist
+ Allowlisted bool `json:"allowlisted,omitempty"`
+}
+
+// Validate validates this check allowlist response
+func (m *CheckAllowlistResponse) Validate(formats strfmt.Registry) error {
+ return nil
+}
+
+// ContextValidate validates this check allowlist response based on context it is used
+func (m *CheckAllowlistResponse) ContextValidate(ctx context.Context, formats strfmt.Registry) error {
+ return nil
+}
+
+// MarshalBinary interface implementation
+func (m *CheckAllowlistResponse) MarshalBinary() ([]byte, error) {
+ if m == nil {
+ return nil, nil
+ }
+ return swag.WriteJSON(m)
+}
+
+// UnmarshalBinary interface implementation
+func (m *CheckAllowlistResponse) UnmarshalBinary(b []byte) error {
+ var res CheckAllowlistResponse
+ if err := swag.ReadJSON(b, &res); err != nil {
+ return err
+ }
+ *m = res
+ return nil
+}
diff --git a/pkg/models/get_allowlist_response.go b/pkg/models/get_allowlist_response.go
new file mode 100644
index 00000000000..4459457ecb3
--- /dev/null
+++ b/pkg/models/get_allowlist_response.go
@@ -0,0 +1,174 @@
+// Code generated by go-swagger; DO NOT EDIT.
+
+package models
+
+// This file was generated by the swagger tool.
+// Editing this file might prove futile when you re-run the swagger generate command
+
+import (
+ "context"
+ "strconv"
+
+ "github.com/go-openapi/errors"
+ "github.com/go-openapi/strfmt"
+ "github.com/go-openapi/swag"
+ "github.com/go-openapi/validate"
+)
+
+// GetAllowlistResponse GetAllowlistResponse
+//
+// swagger:model GetAllowlistResponse
+type GetAllowlistResponse struct {
+
+ // id of the allowlist
+ AllowlistID string `json:"allowlist_id,omitempty"`
+
+ // true if the allowlist is managed by the console
+ ConsoleManaged bool `json:"console_managed,omitempty"`
+
+ // creation date of the allowlist
+ // Format: date-time
+ CreatedAt strfmt.DateTime `json:"created_at,omitempty"`
+
+ // description of the allowlist
+ Description string `json:"description,omitempty"`
+
+ // items in the allowlist
+ Items []*AllowlistItem `json:"items"`
+
+ // name of the allowlist
+ Name string `json:"name,omitempty"`
+
+ // last update date of the allowlist
+ // Format: date-time
+ UpdatedAt strfmt.DateTime `json:"updated_at,omitempty"`
+}
+
+// Validate validates this get allowlist response
+func (m *GetAllowlistResponse) Validate(formats strfmt.Registry) error {
+ var res []error
+
+ if err := m.validateCreatedAt(formats); err != nil {
+ res = append(res, err)
+ }
+
+ if err := m.validateItems(formats); err != nil {
+ res = append(res, err)
+ }
+
+ if err := m.validateUpdatedAt(formats); err != nil {
+ res = append(res, err)
+ }
+
+ if len(res) > 0 {
+ return errors.CompositeValidationError(res...)
+ }
+ return nil
+}
+
+func (m *GetAllowlistResponse) validateCreatedAt(formats strfmt.Registry) error {
+ if swag.IsZero(m.CreatedAt) { // not required
+ return nil
+ }
+
+ if err := validate.FormatOf("created_at", "body", "date-time", m.CreatedAt.String(), formats); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (m *GetAllowlistResponse) validateItems(formats strfmt.Registry) error {
+ if swag.IsZero(m.Items) { // not required
+ return nil
+ }
+
+ for i := 0; i < len(m.Items); i++ {
+ if swag.IsZero(m.Items[i]) { // not required
+ continue
+ }
+
+ if m.Items[i] != nil {
+ if err := m.Items[i].Validate(formats); err != nil {
+ if ve, ok := err.(*errors.Validation); ok {
+ return ve.ValidateName("items" + "." + strconv.Itoa(i))
+ } else if ce, ok := err.(*errors.CompositeError); ok {
+ return ce.ValidateName("items" + "." + strconv.Itoa(i))
+ }
+ return err
+ }
+ }
+
+ }
+
+ return nil
+}
+
+func (m *GetAllowlistResponse) validateUpdatedAt(formats strfmt.Registry) error {
+ if swag.IsZero(m.UpdatedAt) { // not required
+ return nil
+ }
+
+ if err := validate.FormatOf("updated_at", "body", "date-time", m.UpdatedAt.String(), formats); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// ContextValidate validate this get allowlist response based on the context it is used
+func (m *GetAllowlistResponse) ContextValidate(ctx context.Context, formats strfmt.Registry) error {
+ var res []error
+
+ if err := m.contextValidateItems(ctx, formats); err != nil {
+ res = append(res, err)
+ }
+
+ if len(res) > 0 {
+ return errors.CompositeValidationError(res...)
+ }
+ return nil
+}
+
+func (m *GetAllowlistResponse) contextValidateItems(ctx context.Context, formats strfmt.Registry) error {
+
+ for i := 0; i < len(m.Items); i++ {
+
+ if m.Items[i] != nil {
+
+ if swag.IsZero(m.Items[i]) { // not required
+ return nil
+ }
+
+ if err := m.Items[i].ContextValidate(ctx, formats); err != nil {
+ if ve, ok := err.(*errors.Validation); ok {
+ return ve.ValidateName("items" + "." + strconv.Itoa(i))
+ } else if ce, ok := err.(*errors.CompositeError); ok {
+ return ce.ValidateName("items" + "." + strconv.Itoa(i))
+ }
+ return err
+ }
+ }
+
+ }
+
+ return nil
+}
+
+// MarshalBinary interface implementation
+func (m *GetAllowlistResponse) MarshalBinary() ([]byte, error) {
+ if m == nil {
+ return nil, nil
+ }
+ return swag.WriteJSON(m)
+}
+
+// UnmarshalBinary interface implementation
+func (m *GetAllowlistResponse) UnmarshalBinary(b []byte) error {
+ var res GetAllowlistResponse
+ if err := swag.ReadJSON(b, &res); err != nil {
+ return err
+ }
+ *m = res
+ return nil
+}
diff --git a/pkg/models/get_allowlists_response.go b/pkg/models/get_allowlists_response.go
new file mode 100644
index 00000000000..dd6a80918c6
--- /dev/null
+++ b/pkg/models/get_allowlists_response.go
@@ -0,0 +1,78 @@
+// Code generated by go-swagger; DO NOT EDIT.
+
+package models
+
+// This file was generated by the swagger tool.
+// Editing this file might prove futile when you re-run the swagger generate command
+
+import (
+ "context"
+ "strconv"
+
+ "github.com/go-openapi/errors"
+ "github.com/go-openapi/strfmt"
+ "github.com/go-openapi/swag"
+)
+
+// GetAllowlistsResponse GetAllowlistsResponse
+//
+// swagger:model GetAllowlistsResponse
+type GetAllowlistsResponse []*GetAllowlistResponse
+
+// Validate validates this get allowlists response
+func (m GetAllowlistsResponse) Validate(formats strfmt.Registry) error {
+ var res []error
+
+ for i := 0; i < len(m); i++ {
+ if swag.IsZero(m[i]) { // not required
+ continue
+ }
+
+ if m[i] != nil {
+ if err := m[i].Validate(formats); err != nil {
+ if ve, ok := err.(*errors.Validation); ok {
+ return ve.ValidateName(strconv.Itoa(i))
+ } else if ce, ok := err.(*errors.CompositeError); ok {
+ return ce.ValidateName(strconv.Itoa(i))
+ }
+ return err
+ }
+ }
+
+ }
+
+ if len(res) > 0 {
+ return errors.CompositeValidationError(res...)
+ }
+ return nil
+}
+
+// ContextValidate validate this get allowlists response based on the context it is used
+func (m GetAllowlistsResponse) ContextValidate(ctx context.Context, formats strfmt.Registry) error {
+ var res []error
+
+ for i := 0; i < len(m); i++ {
+
+ if m[i] != nil {
+
+ if swag.IsZero(m[i]) { // not required
+ return nil
+ }
+
+ if err := m[i].ContextValidate(ctx, formats); err != nil {
+ if ve, ok := err.(*errors.Validation); ok {
+ return ve.ValidateName(strconv.Itoa(i))
+ } else if ce, ok := err.(*errors.CompositeError); ok {
+ return ce.ValidateName(strconv.Itoa(i))
+ }
+ return err
+ }
+ }
+
+ }
+
+ if len(res) > 0 {
+ return errors.CompositeValidationError(res...)
+ }
+ return nil
+}
diff --git a/pkg/models/localapi_swagger.yaml b/pkg/models/localapi_swagger.yaml
index 01bbe6f8bde..5d156e5791f 100644
--- a/pkg/models/localapi_swagger.yaml
+++ b/pkg/models/localapi_swagger.yaml
@@ -719,6 +719,120 @@ paths:
security:
- APIKeyAuthorizer: []
- JWTAuthorizer: []
+ /allowlists:
+ get:
+ description: Get a list of all allowlists
+ summary: getAllowlists
+ tags:
+ - watchers
+ operationId: getAllowlists
+ produces:
+ - application/json
+ responses:
+ '200':
+ description: successful operation
+ schema:
+ $ref: '#/definitions/GetAllowlistsResponse'
+ headers: {}
+ /allowlists/{allowlist_name}:
+ get:
+ description: Get a specific allowlist
+ summary: getAllowlist
+ tags:
+ - watchers
+ operationId: getAllowlist
+ produces:
+ - application/json
+ parameters:
+ - name: allowlist_name
+ in: path
+ required: true
+ type: string
+ description: ''
+ responses:
+ '200':
+ description: successful operation
+ schema:
+ $ref: '#/definitions/GetAllowlistResponse'
+ headers: {}
+ '404':
+ description: "404 response"
+ schema:
+ $ref: "#/definitions/ErrorResponse"
+ head:
+ description: Get a specific allowlist
+ summary: getAllowlist
+ tags:
+ - watchers
+ operationId: headAllowlist
+ produces:
+ - application/json
+ parameters:
+ - name: allowlist_name
+ in: path
+ required: true
+ type: string
+ description: ''
+ - name: with_content
+ in: query
+ required: false
+ type: boolean
+ description: 'if true, the content of the allowlist will be returned as well'
+ responses:
+ '200':
+ description: successful operation
+ headers: {}
+ '404':
+ description: "404 response"
+ /allowlists/check/{ip_or_range}:
+ get:
+ description: Check if an IP or range is in an allowlist
+ summary: checkAllowlist
+ tags:
+ - watchers
+ operationId: checkAllowlist
+ produces:
+ - application/json
+ parameters:
+ - name: ip_or_range
+ in: path
+ required: true
+ type: string
+ description: ''
+ responses:
+ '200':
+ description: successful operation
+ schema:
+ $ref: '#/definitions/CheckAllowlistResponse'
+ headers: {}
+ '400':
+ description: "missing ip_or_range"
+ schema:
+ $ref: "#/definitions/ErrorResponse"
+ head:
+ description: Check if an IP or range is in an allowlist
+ summary: checkAllowlist
+ tags:
+ - watchers
+ operationId: headCheckAllowlist
+ produces:
+ - application/json
+ parameters:
+ - name: ip_or_range
+ in: path
+ required: true
+ type: string
+ description: ''
+ responses:
+ '200':
+ description: IP or range is in an allowlist
+ headers: {}
+ '204':
+ description: "IP or range is not in an allowlist"
+ '400':
+ description: "missing ip_or_range"
+ schema:
+ $ref: "#/definitions/ErrorResponse"
definitions:
WatcherRegistrationRequest:
title: WatcherRegistrationRequest
@@ -1220,6 +1334,65 @@ definitions:
status:
type: string
description: status of the hub item (official, custom, tainted, etc.)
+ GetAllowlistsResponse:
+ title: GetAllowlistsResponse
+ type: array
+ items:
+ $ref: '#/definitions/GetAllowlistResponse'
+ GetAllowlistResponse:
+ title: GetAllowlistResponse
+ type: object
+ properties:
+ name:
+ type: string
+ description: name of the allowlist
+ allowlist_id:
+ type: string
+ description: id of the allowlist
+ description:
+ type: string
+ description: description of the allowlist
+ items:
+ type: array
+ items:
+ $ref: '#/definitions/AllowlistItem'
+ description: items in the allowlist
+ created_at:
+ type: string
+ format: date-time
+ description: creation date of the allowlist
+ updated_at:
+ type: string
+ format: date-time
+ description: last update date of the allowlist
+ console_managed:
+ type: boolean
+ description: true if the allowlist is managed by the console
+ AllowlistItem:
+ title: AllowlistItem
+ type: object
+ properties:
+ value:
+ type: string
+ description: value of the allowlist item
+ description:
+ type: string
+ description: description of the allowlist item
+ created_at:
+ type: string
+ format: date-time
+ description: creation date of the allowlist item
+ expiration:
+ type: string
+ format: date-time
+ description: expiration date of the allowlist item
+ CheckAllowlistResponse:
+ title: CheckAllowlistResponse
+ type: object
+ properties:
+ allowlisted:
+ type: boolean
+ description: 'true if the IP or range is in the allowlist'
ErrorResponse:
type: "object"
required:
diff --git a/pkg/modelscapi/allowlist_link.go b/pkg/modelscapi/allowlist_link.go
new file mode 100644
index 00000000000..ce9fce17357
--- /dev/null
+++ b/pkg/modelscapi/allowlist_link.go
@@ -0,0 +1,166 @@
+// Code generated by go-swagger; DO NOT EDIT.
+
+package modelscapi
+
+// This file was generated by the swagger tool.
+// Editing this file might prove futile when you re-run the swagger generate command
+
+import (
+ "context"
+
+ "github.com/go-openapi/errors"
+ "github.com/go-openapi/strfmt"
+ "github.com/go-openapi/swag"
+ "github.com/go-openapi/validate"
+)
+
+// AllowlistLink allowlist link
+//
+// swagger:model AllowlistLink
+type AllowlistLink struct {
+
+ // the creation date of the allowlist
+ // Required: true
+ // Format: date-time
+ CreatedAt *strfmt.DateTime `json:"created_at"`
+
+ // the description of the allowlist
+ // Required: true
+ Description *string `json:"description"`
+
+ // the id of the allowlist
+ // Required: true
+ ID *string `json:"id"`
+
+ // the name of the allowlist
+ // Required: true
+ Name *string `json:"name"`
+
+ // the last update date of the allowlist
+ // Required: true
+ // Format: date-time
+ UpdatedAt *strfmt.DateTime `json:"updated_at"`
+
+ // the url from which the allowlist content can be downloaded
+ // Required: true
+ URL *string `json:"url"`
+}
+
+// Validate validates this allowlist link
+func (m *AllowlistLink) Validate(formats strfmt.Registry) error {
+ var res []error
+
+ if err := m.validateCreatedAt(formats); err != nil {
+ res = append(res, err)
+ }
+
+ if err := m.validateDescription(formats); err != nil {
+ res = append(res, err)
+ }
+
+ if err := m.validateID(formats); err != nil {
+ res = append(res, err)
+ }
+
+ if err := m.validateName(formats); err != nil {
+ res = append(res, err)
+ }
+
+ if err := m.validateUpdatedAt(formats); err != nil {
+ res = append(res, err)
+ }
+
+ if err := m.validateURL(formats); err != nil {
+ res = append(res, err)
+ }
+
+ if len(res) > 0 {
+ return errors.CompositeValidationError(res...)
+ }
+ return nil
+}
+
+func (m *AllowlistLink) validateCreatedAt(formats strfmt.Registry) error {
+
+ if err := validate.Required("created_at", "body", m.CreatedAt); err != nil {
+ return err
+ }
+
+ if err := validate.FormatOf("created_at", "body", "date-time", m.CreatedAt.String(), formats); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (m *AllowlistLink) validateDescription(formats strfmt.Registry) error {
+
+ if err := validate.Required("description", "body", m.Description); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (m *AllowlistLink) validateID(formats strfmt.Registry) error {
+
+ if err := validate.Required("id", "body", m.ID); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (m *AllowlistLink) validateName(formats strfmt.Registry) error {
+
+ if err := validate.Required("name", "body", m.Name); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (m *AllowlistLink) validateUpdatedAt(formats strfmt.Registry) error {
+
+ if err := validate.Required("updated_at", "body", m.UpdatedAt); err != nil {
+ return err
+ }
+
+ if err := validate.FormatOf("updated_at", "body", "date-time", m.UpdatedAt.String(), formats); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (m *AllowlistLink) validateURL(formats strfmt.Registry) error {
+
+ if err := validate.Required("url", "body", m.URL); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// ContextValidate validates this allowlist link based on context it is used
+func (m *AllowlistLink) ContextValidate(ctx context.Context, formats strfmt.Registry) error {
+ return nil
+}
+
+// MarshalBinary interface implementation
+func (m *AllowlistLink) MarshalBinary() ([]byte, error) {
+ if m == nil {
+ return nil, nil
+ }
+ return swag.WriteJSON(m)
+}
+
+// UnmarshalBinary interface implementation
+func (m *AllowlistLink) UnmarshalBinary(b []byte) error {
+ var res AllowlistLink
+ if err := swag.ReadJSON(b, &res); err != nil {
+ return err
+ }
+ *m = res
+ return nil
+}
diff --git a/pkg/modelscapi/centralapi_swagger.yaml b/pkg/modelscapi/centralapi_swagger.yaml
index bd695894f2b..6a830ee3820 100644
--- a/pkg/modelscapi/centralapi_swagger.yaml
+++ b/pkg/modelscapi/centralapi_swagger.yaml
@@ -55,6 +55,19 @@ paths:
description: "returns list of top decisions to add or delete"
produces:
- "application/json"
+ parameters:
+ - in: query
+ name: "community_pull"
+ type: "boolean"
+ default: true
+ required: false
+ description: "Fetch the community blocklist content"
+ - in: query
+ name: "additional_pull"
+ type: "boolean"
+ default: true
+ required: false
+ description: "Fetch additional blocklists content"
responses:
"200":
description: "200 response"
@@ -535,6 +548,37 @@ definitions:
description: "the scope of decisions in the blocklist"
duration:
type: string
+ AllowlistLink:
+ type: object
+ required:
+ - name
+ - description
+ - url
+ - id
+ - created_at
+ - updated_at
+ properties:
+ name:
+ type: string
+ description: "the name of the allowlist"
+ description:
+ type: string
+ description: "the description of the allowlist"
+ url:
+ type: string
+ description: "the url from which the allowlist content can be downloaded"
+ id:
+ type: string
+ description: "the id of the allowlist"
+ created_at:
+ type: string
+ format: date-time
+ description: "the creation date of the allowlist"
+ updated_at:
+ type: string
+ format: date-time
+ description: "the last update date of the allowlist"
+
AddSignalsRequestItemDecisionsItem:
type: "object"
required:
@@ -872,4 +916,8 @@ definitions:
type: array
items:
$ref: "#/definitions/BlocklistLink"
+ allowlists:
+ type: array
+ items:
+ $ref: "#/definitions/AllowlistLink"
diff --git a/pkg/modelscapi/get_decisions_stream_response_links.go b/pkg/modelscapi/get_decisions_stream_response_links.go
index 6b9054574f1..f9e320aee38 100644
--- a/pkg/modelscapi/get_decisions_stream_response_links.go
+++ b/pkg/modelscapi/get_decisions_stream_response_links.go
@@ -19,6 +19,9 @@ import (
// swagger:model GetDecisionsStreamResponseLinks
type GetDecisionsStreamResponseLinks struct {
+ // allowlists
+ Allowlists []*AllowlistLink `json:"allowlists"`
+
// blocklists
Blocklists []*BlocklistLink `json:"blocklists"`
}
@@ -27,6 +30,10 @@ type GetDecisionsStreamResponseLinks struct {
func (m *GetDecisionsStreamResponseLinks) Validate(formats strfmt.Registry) error {
var res []error
+ if err := m.validateAllowlists(formats); err != nil {
+ res = append(res, err)
+ }
+
if err := m.validateBlocklists(formats); err != nil {
res = append(res, err)
}
@@ -37,6 +44,32 @@ func (m *GetDecisionsStreamResponseLinks) Validate(formats strfmt.Registry) erro
return nil
}
+func (m *GetDecisionsStreamResponseLinks) validateAllowlists(formats strfmt.Registry) error {
+ if swag.IsZero(m.Allowlists) { // not required
+ return nil
+ }
+
+ for i := 0; i < len(m.Allowlists); i++ {
+ if swag.IsZero(m.Allowlists[i]) { // not required
+ continue
+ }
+
+ if m.Allowlists[i] != nil {
+ if err := m.Allowlists[i].Validate(formats); err != nil {
+ if ve, ok := err.(*errors.Validation); ok {
+ return ve.ValidateName("allowlists" + "." + strconv.Itoa(i))
+ } else if ce, ok := err.(*errors.CompositeError); ok {
+ return ce.ValidateName("allowlists" + "." + strconv.Itoa(i))
+ }
+ return err
+ }
+ }
+
+ }
+
+ return nil
+}
+
func (m *GetDecisionsStreamResponseLinks) validateBlocklists(formats strfmt.Registry) error {
if swag.IsZero(m.Blocklists) { // not required
return nil
@@ -67,6 +100,10 @@ func (m *GetDecisionsStreamResponseLinks) validateBlocklists(formats strfmt.Regi
func (m *GetDecisionsStreamResponseLinks) ContextValidate(ctx context.Context, formats strfmt.Registry) error {
var res []error
+ if err := m.contextValidateAllowlists(ctx, formats); err != nil {
+ res = append(res, err)
+ }
+
if err := m.contextValidateBlocklists(ctx, formats); err != nil {
res = append(res, err)
}
@@ -77,6 +114,31 @@ func (m *GetDecisionsStreamResponseLinks) ContextValidate(ctx context.Context, f
return nil
}
+func (m *GetDecisionsStreamResponseLinks) contextValidateAllowlists(ctx context.Context, formats strfmt.Registry) error {
+
+ for i := 0; i < len(m.Allowlists); i++ {
+
+ if m.Allowlists[i] != nil {
+
+ if swag.IsZero(m.Allowlists[i]) { // not required
+ return nil
+ }
+
+ if err := m.Allowlists[i].ContextValidate(ctx, formats); err != nil {
+ if ve, ok := err.(*errors.Validation); ok {
+ return ve.ValidateName("allowlists" + "." + strconv.Itoa(i))
+ } else if ce, ok := err.(*errors.CompositeError); ok {
+ return ce.ValidateName("allowlists" + "." + strconv.Itoa(i))
+ }
+ return err
+ }
+ }
+
+ }
+
+ return nil
+}
+
func (m *GetDecisionsStreamResponseLinks) contextValidateBlocklists(ctx context.Context, formats strfmt.Registry) error {
for i := 0; i < len(m.Blocklists); i++ {
diff --git a/pkg/parser/enrich.go b/pkg/parser/enrich.go
index 661410d20d3..a69cd963813 100644
--- a/pkg/parser/enrich.go
+++ b/pkg/parser/enrich.go
@@ -7,8 +7,10 @@ import (
)
/* should be part of a package shared with enrich/geoip.go */
-type EnrichFunc func(string, *types.Event, *log.Entry) (map[string]string, error)
-type InitFunc func(map[string]string) (interface{}, error)
+type (
+ EnrichFunc func(string, *types.Event, *log.Entry) (map[string]string, error)
+ InitFunc func(map[string]string) (interface{}, error)
+)
type EnricherCtx struct {
Registered map[string]*Enricher
diff --git a/pkg/parser/enrich_geoip.go b/pkg/parser/enrich_geoip.go
index 1756927bc4b..79a70077283 100644
--- a/pkg/parser/enrich_geoip.go
+++ b/pkg/parser/enrich_geoip.go
@@ -18,7 +18,6 @@ func IpToRange(field string, p *types.Event, plog *log.Entry) (map[string]string
}
r, err := exprhelpers.GeoIPRangeEnrich(field)
-
if err != nil {
plog.Errorf("Unable to enrich ip '%s'", field)
return nil, nil //nolint:nilerr
@@ -47,7 +46,6 @@ func GeoIpASN(field string, p *types.Event, plog *log.Entry) (map[string]string,
}
r, err := exprhelpers.GeoIPASNEnrich(field)
-
if err != nil {
plog.Debugf("Unable to enrich ip '%s'", field)
return nil, nil //nolint:nilerr
@@ -81,7 +79,6 @@ func GeoIpCity(field string, p *types.Event, plog *log.Entry) (map[string]string
}
r, err := exprhelpers.GeoIPEnrich(field)
-
if err != nil {
plog.Debugf("Unable to enrich ip '%s'", field)
return nil, nil //nolint:nilerr
diff --git a/pkg/parser/node.go b/pkg/parser/node.go
index 26046ae4fd6..62a1ff6c4e2 100644
--- a/pkg/parser/node.go
+++ b/pkg/parser/node.go
@@ -3,6 +3,7 @@ package parser
import (
"errors"
"fmt"
+ "strconv"
"strings"
"time"
@@ -236,7 +237,7 @@ func (n *Node) processGrok(p *types.Event, cachedExprEnv map[string]any) (bool,
case string:
gstr = out
case int:
- gstr = fmt.Sprintf("%d", out)
+ gstr = strconv.Itoa(out)
case float64, float32:
gstr = fmt.Sprintf("%f", out)
default:
@@ -357,16 +358,17 @@ func (n *Node) process(p *types.Event, ctx UnixParserCtx, expressionEnv map[stri
}
// Iterate on leafs
- for _, leaf := range n.LeavesNodes {
- ret, err := leaf.process(p, ctx, cachedExprEnv)
+ leaves := n.LeavesNodes
+ for idx := range leaves {
+ ret, err := leaves[idx].process(p, ctx, cachedExprEnv)
if err != nil {
- clog.Tracef("\tNode (%s) failed : %v", leaf.rn, err)
+ clog.Tracef("\tNode (%s) failed : %v", leaves[idx].rn, err)
clog.Debugf("Event leaving node : ko")
return false, err
}
- clog.Tracef("\tsub-node (%s) ret : %v (strategy:%s)", leaf.rn, ret, n.OnSuccess)
+ clog.Tracef("\tsub-node (%s) ret : %v (strategy:%s)", leaves[idx].rn, ret, n.OnSuccess)
if ret {
NodeState = true
@@ -593,7 +595,7 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error {
/* compile leafs if present */
for idx := range n.LeavesNodes {
if n.LeavesNodes[idx].Name == "" {
- n.LeavesNodes[idx].Name = fmt.Sprintf("child-%s", n.Name)
+ n.LeavesNodes[idx].Name = "child-" + n.Name
}
/*propagate debug/stats to child nodes*/
if !n.LeavesNodes[idx].Debug && n.Debug {
diff --git a/pkg/parser/parsing_test.go b/pkg/parser/parsing_test.go
index 269d51a1ba2..5f6f924e7df 100644
--- a/pkg/parser/parsing_test.go
+++ b/pkg/parser/parsing_test.go
@@ -151,7 +151,7 @@ func testOneParser(pctx *UnixParserCtx, ectx EnricherCtx, dir string, b *testing
b.ResetTimer()
}
- for range(count) {
+ for range count {
if !testFile(tests, *pctx, pnodes) {
return errors.New("test failed")
}
@@ -285,7 +285,7 @@ func matchEvent(expected types.Event, out types.Event, debug bool) ([]string, bo
valid = true
- for mapIdx := range(len(expectMaps)) {
+ for mapIdx := range len(expectMaps) {
for expKey, expVal := range expectMaps[mapIdx] {
outVal, ok := outMaps[mapIdx][expKey]
if !ok {
diff --git a/pkg/parser/runtime.go b/pkg/parser/runtime.go
index 8068690b68f..7af82a71535 100644
--- a/pkg/parser/runtime.go
+++ b/pkg/parser/runtime.go
@@ -29,10 +29,11 @@ func SetTargetByName(target string, value string, evt *types.Event) bool {
return false
}
- //it's a hack, we do it for the user
+ // it's a hack, we do it for the user
target = strings.TrimPrefix(target, "evt.")
log.Debugf("setting target %s to %s", target, value)
+
defer func() {
if r := recover(); r != nil {
log.Errorf("Runtime error while trying to set '%s': %+v", target, r)
@@ -46,6 +47,7 @@ func SetTargetByName(target string, value string, evt *types.Event) bool {
//event is nil
return false
}
+
for _, f := range strings.Split(target, ".") {
/*
** According to current Event layout we only have to handle struct and map
@@ -57,7 +59,9 @@ func SetTargetByName(target string, value string, evt *types.Event) bool {
if (tmp == reflect.Value{}) || tmp.IsZero() {
log.Debugf("map entry is zero in '%s'", target)
}
+
iter.SetMapIndex(reflect.ValueOf(f), reflect.ValueOf(value))
+
return true
case reflect.Struct:
tmp := iter.FieldByName(f)
@@ -65,9 +69,11 @@ func SetTargetByName(target string, value string, evt *types.Event) bool {
log.Debugf("'%s' is not a valid target because '%s' is not valid", target, f)
return false
}
+
if tmp.Kind() == reflect.Ptr {
tmp = reflect.Indirect(tmp)
}
+
iter = tmp
case reflect.Ptr:
tmp := iter.Elem()
@@ -82,11 +88,14 @@ func SetTargetByName(target string, value string, evt *types.Event) bool {
log.Errorf("'%s' can't be set", target)
return false
}
+
if iter.Kind() != reflect.String {
log.Errorf("Expected string, got %v when handling '%s'", iter.Kind(), target)
return false
}
+
iter.Set(reflect.ValueOf(value))
+
return true
}
@@ -248,14 +257,18 @@ func stageidx(stage string, stages []string) int {
return -1
}
-var ParseDump bool
-var DumpFolder string
+var (
+ ParseDump bool
+ DumpFolder string
+)
-var StageParseCache dumps.ParserResults
-var StageParseMutex sync.Mutex
+var (
+ StageParseCache dumps.ParserResults
+ StageParseMutex sync.Mutex
+)
func Parse(ctx UnixParserCtx, xp types.Event, nodes []Node) (types.Event, error) {
- var event = xp
+ event := xp
/* the stage is undefined, probably line is freshly acquired, set to first stage !*/
if event.Stage == "" && len(ctx.Stages) > 0 {
@@ -317,46 +330,46 @@ func Parse(ctx UnixParserCtx, xp types.Event, nodes []Node) (types.Event, error)
}
isStageOK := false
- for idx, node := range nodes {
+ for idx := range nodes {
//Only process current stage's nodes
- if event.Stage != node.Stage {
+ if event.Stage != nodes[idx].Stage {
continue
}
clog := log.WithFields(log.Fields{
- "node-name": node.rn,
+ "node-name": nodes[idx].rn,
"stage": event.Stage,
})
- clog.Tracef("Processing node %d/%d -> %s", idx, len(nodes), node.rn)
+ clog.Tracef("Processing node %d/%d -> %s", idx, len(nodes), nodes[idx].rn)
if ctx.Profiling {
- node.Profiling = true
+ nodes[idx].Profiling = true
}
- ret, err := node.process(&event, ctx, map[string]interface{}{"evt": &event})
+ ret, err := nodes[idx].process(&event, ctx, map[string]interface{}{"evt": &event})
if err != nil {
clog.Errorf("Error while processing node : %v", err)
return event, err
}
- clog.Tracef("node (%s) ret : %v", node.rn, ret)
+ clog.Tracef("node (%s) ret : %v", nodes[idx].rn, ret)
if ParseDump {
var parserIdxInStage int
StageParseMutex.Lock()
- if len(StageParseCache[stage][node.Name]) == 0 {
- StageParseCache[stage][node.Name] = make([]dumps.ParserResult, 0)
+ if len(StageParseCache[stage][nodes[idx].Name]) == 0 {
+ StageParseCache[stage][nodes[idx].Name] = make([]dumps.ParserResult, 0)
parserIdxInStage = len(StageParseCache[stage])
} else {
- parserIdxInStage = StageParseCache[stage][node.Name][0].Idx
+ parserIdxInStage = StageParseCache[stage][nodes[idx].Name][0].Idx
}
StageParseMutex.Unlock()
evtcopy := deepcopy.Copy(event)
parserInfo := dumps.ParserResult{Evt: evtcopy.(types.Event), Success: ret, Idx: parserIdxInStage}
StageParseMutex.Lock()
- StageParseCache[stage][node.Name] = append(StageParseCache[stage][node.Name], parserInfo)
+ StageParseCache[stage][nodes[idx].Name] = append(StageParseCache[stage][nodes[idx].Name], parserInfo)
StageParseMutex.Unlock()
}
if ret {
isStageOK = true
}
- if ret && node.OnSuccess == "next_stage" {
+ if ret && nodes[idx].OnSuccess == "next_stage" {
clog.Debugf("node successful, stop end stage %s", stage)
break
}
diff --git a/pkg/parser/unix_parser.go b/pkg/parser/unix_parser.go
index 351de8ade56..f0f26a06645 100644
--- a/pkg/parser/unix_parser.go
+++ b/pkg/parser/unix_parser.go
@@ -43,7 +43,7 @@ func Init(c map[string]interface{}) (*UnixParserCtx, error) {
}
r.DataFolder = c["data"].(string)
for _, f := range files {
- if strings.Contains(f.Name(), ".") {
+ if strings.Contains(f.Name(), ".") || f.IsDir() {
continue
}
if err := r.Grok.AddFromFile(filepath.Join(c["patterns"].(string), f.Name())); err != nil {
diff --git a/pkg/types/appsec_event.go b/pkg/types/appsec_event.go
index dc81c63b344..54163f53fef 100644
--- a/pkg/types/appsec_event.go
+++ b/pkg/types/appsec_event.go
@@ -18,7 +18,9 @@ len(evt.Waf.ByTagRx("*CVE*").ByConfidence("high").ByAction("block")) > 1
*/
-type MatchedRules []map[string]interface{}
+type MatchedRules []MatchedRule
+
+type MatchedRule map[string]interface{}
type AppsecEvent struct {
HasInBandMatches, HasOutBandMatches bool
@@ -45,6 +47,10 @@ const (
Kind Field = "kind"
)
+func NewMatchedRule() *MatchedRule {
+ return &MatchedRule{}
+}
+
func (w AppsecEvent) GetVar(varName string) string {
if w.Vars == nil {
return ""
@@ -54,7 +60,6 @@ func (w AppsecEvent) GetVar(varName string) string {
}
log.Infof("var %s not found. Available variables: %+v", varName, w.Vars)
return ""
-
}
// getters
diff --git a/pkg/types/constants.go b/pkg/types/constants.go
index acb5b5bfacf..2421b076b97 100644
--- a/pkg/types/constants.go
+++ b/pkg/types/constants.go
@@ -1,23 +1,29 @@
package types
-const ApiKeyAuthType = "api-key"
-const TlsAuthType = "tls"
-const PasswordAuthType = "password"
+const (
+ ApiKeyAuthType = "api-key"
+ TlsAuthType = "tls"
+ PasswordAuthType = "password"
+)
-const PAPIBaseURL = "https://papi.api.crowdsec.net/"
-const PAPIVersion = "v1"
-const PAPIPollUrl = "/decisions/stream/poll"
-const PAPIPermissionsUrl = "/permissions"
+const (
+ PAPIBaseURL = "https://papi.api.crowdsec.net/"
+ PAPIVersion = "v1"
+ PAPIPollUrl = "/decisions/stream/poll"
+ PAPIPermissionsUrl = "/permissions"
+)
const CAPIBaseURL = "https://api.crowdsec.net/"
-const CscliOrigin = "cscli"
-const CrowdSecOrigin = "crowdsec"
-const ConsoleOrigin = "console"
-const CscliImportOrigin = "cscli-import"
-const ListOrigin = "lists"
-const CAPIOrigin = "CAPI"
-const CommunityBlocklistPullSourceScope = "crowdsecurity/community-blocklist"
+const (
+ CscliOrigin = "cscli"
+ CrowdSecOrigin = "crowdsec"
+ ConsoleOrigin = "console"
+ CscliImportOrigin = "cscli-import"
+ ListOrigin = "lists"
+ CAPIOrigin = "CAPI"
+ CommunityBlocklistPullSourceScope = "crowdsecurity/community-blocklist"
+)
const DecisionTypeBan = "ban"
diff --git a/pkg/types/event.go b/pkg/types/event.go
index e016d0294c4..0b09bf7cbdf 100644
--- a/pkg/types/event.go
+++ b/pkg/types/event.go
@@ -47,6 +47,23 @@ type Event struct {
Meta map[string]string `yaml:"Meta,omitempty" json:"Meta,omitempty"`
}
+func MakeEvent(timeMachine bool, evtType int, process bool) Event {
+ evt := Event{
+ Parsed: make(map[string]string),
+ Meta: make(map[string]string),
+ Unmarshaled: make(map[string]interface{}),
+ Enriched: make(map[string]string),
+ ExpectMode: LIVE,
+ Process: process,
+ Type: evtType,
+ }
+ if timeMachine {
+ evt.ExpectMode = TIMEMACHINE
+ }
+
+ return evt
+}
+
func (e *Event) SetMeta(key string, value string) bool {
if e.Meta == nil {
e.Meta = make(map[string]string)
@@ -81,8 +98,9 @@ func (e *Event) GetType() string {
func (e *Event) GetMeta(key string) string {
if e.Type == OVFLW {
- for _, alert := range e.Overflow.APIAlerts {
- for _, event := range alert.Events {
+ alerts := e.Overflow.APIAlerts
+ for idx := range alerts {
+ for _, event := range alerts[idx].Events {
if event.GetMeta(key) != "" {
return event.GetMeta(key)
}
diff --git a/pkg/types/event_test.go b/pkg/types/event_test.go
index 97b13f96d9a..638e42fe757 100644
--- a/pkg/types/event_test.go
+++ b/pkg/types/event_test.go
@@ -46,7 +46,6 @@ func TestSetParsed(t *testing.T) {
assert.Equal(t, tt.value, tt.evt.Parsed[tt.key])
})
}
-
}
func TestSetMeta(t *testing.T) {
@@ -86,7 +85,6 @@ func TestSetMeta(t *testing.T) {
assert.Equal(t, tt.value, tt.evt.GetMeta(tt.key))
})
}
-
}
func TestParseIPSources(t *testing.T) {
diff --git a/pkg/types/getfstype.go b/pkg/types/getfstype.go
index 728e986bed0..c16fe86ec9c 100644
--- a/pkg/types/getfstype.go
+++ b/pkg/types/getfstype.go
@@ -100,7 +100,6 @@ func GetFSType(path string) (string, error) {
var buf unix.Statfs_t
err := unix.Statfs(path, &buf)
-
if err != nil {
return "", err
}
diff --git a/pkg/types/ip.go b/pkg/types/ip.go
index 9d08afd8809..47fb3fc83a5 100644
--- a/pkg/types/ip.go
+++ b/pkg/types/ip.go
@@ -23,7 +23,8 @@ func LastAddress(n net.IPNet) net.IP {
ip[6] | ^n.Mask[6], ip[7] | ^n.Mask[7], ip[8] | ^n.Mask[8],
ip[9] | ^n.Mask[9], ip[10] | ^n.Mask[10], ip[11] | ^n.Mask[11],
ip[12] | ^n.Mask[12], ip[13] | ^n.Mask[13], ip[14] | ^n.Mask[14],
- ip[15] | ^n.Mask[15]}
+ ip[15] | ^n.Mask[15],
+ }
}
return net.IPv4(
diff --git a/pkg/types/ip_test.go b/pkg/types/ip_test.go
index f8c14b12e3c..b9298ba487f 100644
--- a/pkg/types/ip_test.go
+++ b/pkg/types/ip_test.go
@@ -8,21 +8,20 @@ import (
)
func TestIP2Int(t *testing.T) {
-
tEmpty := net.IP{}
_, _, _, err := IP2Ints(tEmpty)
if !strings.Contains(err.Error(), "unexpected len 0 for ") {
t.Fatalf("unexpected: %s", err)
}
}
+
func TestRange2Int(t *testing.T) {
tEmpty := net.IPNet{}
- //empty item
+ // empty item
_, _, _, _, _, err := Range2Ints(tEmpty)
if !strings.Contains(err.Error(), "converting first ip in range") {
t.Fatalf("unexpected: %s", err)
}
-
}
func TestAdd2Int(t *testing.T) {
diff --git a/pkg/types/utils.go b/pkg/types/utils.go
index 712d44ba12d..3e1ae4f7547 100644
--- a/pkg/types/utils.go
+++ b/pkg/types/utils.go
@@ -10,9 +10,11 @@ import (
"gopkg.in/natefinch/lumberjack.v2"
)
-var logFormatter log.Formatter
-var LogOutput *lumberjack.Logger //io.Writer
-var logLevel log.Level
+var (
+ logFormatter log.Formatter
+ LogOutput *lumberjack.Logger // io.Writer
+ logLevel log.Level
+)
func SetDefaultLoggerConfig(cfgMode string, cfgFolder string, cfgLevel log.Level, maxSize int, maxFiles int, maxAge int, compress *bool, forceColors bool) error {
/*Configure logs*/
diff --git a/rpm/SPECS/crowdsec.spec b/rpm/SPECS/crowdsec.spec
index ab71b650d11..ac438ad0c14 100644
--- a/rpm/SPECS/crowdsec.spec
+++ b/rpm/SPECS/crowdsec.spec
@@ -12,7 +12,7 @@ Patch0: user.patch
BuildRoot: %{_tmppath}/%{name}-%{version}-%{release}-root-%(%{__id_u} -n)
BuildRequires: systemd
-Requires: crontabs
+Requires: (crontabs or cron)
%{?fc33:BuildRequires: systemd-rpm-macros}
%{?fc34:BuildRequires: systemd-rpm-macros}
%{?fc35:BuildRequires: systemd-rpm-macros}
diff --git a/test/ansible/vagrant/fedora-40/Vagrantfile b/test/ansible/vagrant/fedora-40/Vagrantfile
index ec03661fe39..5541d453acf 100644
--- a/test/ansible/vagrant/fedora-40/Vagrantfile
+++ b/test/ansible/vagrant/fedora-40/Vagrantfile
@@ -1,7 +1,7 @@
# frozen_string_literal: true
Vagrant.configure('2') do |config|
- config.vm.box = "fedora/39-cloud-base"
+ config.vm.box = "fedora/40-cloud-base"
config.vm.provision "shell", inline: <<-SHELL
SHELL
end
diff --git a/test/ansible/vagrant/fedora-41/Vagrantfile b/test/ansible/vagrant/fedora-41/Vagrantfile
new file mode 100644
index 00000000000..3f905f51671
--- /dev/null
+++ b/test/ansible/vagrant/fedora-41/Vagrantfile
@@ -0,0 +1,13 @@
+# frozen_string_literal: true
+
+Vagrant.configure('2') do |config|
+ config.vm.box = "fedora/40-cloud-base"
+ config.vm.provision "shell", inline: <<-SHELL
+ SHELL
+ config.vm.provision "shell" do |s|
+ s.inline = "sudo dnf upgrade --refresh -y && sudo dnf install dnf-plugin-system-upgrade -y && sudo dnf system-upgrade download --releasever=41 -y && sudo dnf system-upgrade reboot -y"
+ end
+end
+
+common = '../common'
+load common if File.exist?(common)
diff --git a/test/ansible/vagrant/fedora-41/skip b/test/ansible/vagrant/fedora-41/skip
new file mode 100644
index 00000000000..4f1a9063d2b
--- /dev/null
+++ b/test/ansible/vagrant/fedora-41/skip
@@ -0,0 +1,9 @@
+#!/bin/sh
+
+die() {
+ echo "$@" >&2
+ exit 1
+}
+
+[ "${DB_BACKEND}" = "mysql" ] && die "mysql role does not support this distribution"
+exit 0
diff --git a/test/ansible/vagrant/opensuse-leap-15/Vagrantfile b/test/ansible/vagrant/opensuse-leap-15/Vagrantfile
new file mode 100644
index 00000000000..d10e68a50a7
--- /dev/null
+++ b/test/ansible/vagrant/opensuse-leap-15/Vagrantfile
@@ -0,0 +1,10 @@
+# frozen_string_literal: true
+
+Vagrant.configure('2') do |config|
+ config.vm.box = "opensuse/Leap-15.6.x86_64"
+ config.vm.provision "shell", inline: <<-SHELL
+ SHELL
+end
+
+common = '../common'
+load common if File.exist?(common)
diff --git a/test/ansible/vagrant/opensuse-leap-15/skip b/test/ansible/vagrant/opensuse-leap-15/skip
new file mode 100644
index 00000000000..4f1a9063d2b
--- /dev/null
+++ b/test/ansible/vagrant/opensuse-leap-15/skip
@@ -0,0 +1,9 @@
+#!/bin/sh
+
+die() {
+ echo "$@" >&2
+ exit 1
+}
+
+[ "${DB_BACKEND}" = "mysql" ] && die "mysql role does not support this distribution"
+exit 0
diff --git a/test/bats/10_bouncers.bats b/test/bats/10_bouncers.bats
index f99913dcee5..b1c90116dd2 100644
--- a/test/bats/10_bouncers.bats
+++ b/test/bats/10_bouncers.bats
@@ -63,7 +63,7 @@ teardown() {
@test "delete non-existent bouncer" {
# this is a fatal error, which is not consistent with "machines delete"
rune -1 cscli bouncers delete something
- assert_stderr --partial "unable to delete bouncer: 'something' does not exist"
+ assert_stderr --partial "unable to delete bouncer something: ent: bouncer not found"
rune -0 cscli bouncers delete something --ignore-missing
refute_stderr
}
@@ -144,3 +144,56 @@ teardown() {
rune -0 cscli bouncers prune
assert_output 'No bouncers to prune.'
}
+
+curl_localhost() {
+ [[ -z "$API_KEY" ]] && { fail "${FUNCNAME[0]}: missing API_KEY"; }
+ local path=$1
+ shift
+ curl "localhost:8080$path" -sS --fail-with-body -H "X-Api-Key: $API_KEY" "$@"
+}
+
+# We can't use curl-with-key here, as we want to query localhost, not 127.0.0.1
+@test "multiple bouncers sharing api key" {
+ export API_KEY=bouncerkey
+
+ # crowdsec needs to listen on all interfaces
+ rune -0 ./instance-crowdsec stop
+ rune -0 config_set 'del(.api.server.listen_socket) | del(.api.server.listen_uri)'
+ echo "{'api':{'server':{'listen_uri':0.0.0.0:8080}}}" >"${CONFIG_YAML}.local"
+
+ rune -0 ./instance-crowdsec start
+
+ # add a decision for our bouncers
+ rune -0 cscli decisions add -i '1.2.3.5'
+
+ rune -0 cscli bouncers add test-auto -k "$API_KEY"
+
+ # query with 127.0.0.1 as source ip
+ rune -0 curl_localhost "/v1/decisions/stream" -4
+ rune -0 jq -r '.new' <(output)
+ assert_output --partial '1.2.3.5'
+
+ # now with ::1, we should get the same IP, even though we are using the same key
+ rune -0 curl_localhost "/v1/decisions/stream" -6
+ rune -0 jq -r '.new' <(output)
+ assert_output --partial '1.2.3.5'
+
+ rune -0 cscli bouncers list -o json
+ rune -0 jq -c '[.[] | [.name,.revoked,.ip_address,.auto_created]]' <(output)
+ assert_json '[["test-auto",false,"127.0.0.1",false],["test-auto@::1",false,"::1",true]]'
+
+ # check the 2nd bouncer was created automatically
+ rune -0 cscli bouncers inspect "test-auto@::1" -o json
+ rune -0 jq -r '.ip_address' <(output)
+ assert_output --partial '::1'
+
+ # attempt to delete the auto-created bouncer, it should fail
+ rune -0 cscli bouncers delete 'test-auto@::1'
+ assert_stderr --partial 'cannot be deleted'
+
+ # delete the "real" bouncer, it should delete both
+ rune -0 cscli bouncers delete 'test-auto'
+
+ rune -0 cscli bouncers list -o json
+ assert_json []
+}
diff --git a/test/lib/init/crowdsec-daemon b/test/lib/init/crowdsec-daemon
index a232f344b6a..ba8e98992db 100755
--- a/test/lib/init/crowdsec-daemon
+++ b/test/lib/init/crowdsec-daemon
@@ -51,7 +51,11 @@ stop() {
PGID="$(ps -o pgid= -p "$(cat "${DAEMON_PID}")" | tr -d ' ')"
# ps above should work on linux, freebsd, busybox..
if [[ -n "${PGID}" ]]; then
- kill -- "-${PGID}"
+ kill -- "-${PGID}"
+
+ while pgrep -g "${PGID}" >/dev/null; do
+ sleep .05
+ done
fi
rm -f -- "${DAEMON_PID}"