Skip to content

Commit

Permalink
updating to account for symbols
Browse files Browse the repository at this point in the history
  • Loading branch information
acx1729 committed Dec 24, 2024
1 parent d527a6c commit 22fe669
Showing 1 changed file with 52 additions and 69 deletions.
121 changes: 52 additions & 69 deletions pkg/opengovernance-es-sdk/elasticsearch.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ func BuildFilterWithDefaultFieldName(ctx context.Context, queryContext *plugin.Q
filters = append(filters, NewTermFilter(fieldName, qualValue(qual.GetValue())))
}
}
// Range operators
if oprStr == ">" {
filters = append(filters, NewRangeFilter(fieldName, qualValue(qual.GetValue()), "", "", ""))
}
Expand Down Expand Up @@ -299,7 +298,6 @@ func BuildFilterWithDefaultFieldName(ctx context.Context, queryContext *plugin.Q
} else {
esResourceGroupFilters := make([]BoolFilter, 0, len(resourceGroupFilters)+1)

// If clientType is "compliance", add a MustFilter for tagless resource types
if clientType != nil && len(*clientType) > 0 && *clientType == "compliance" {
taglessTypes := make([]string, 0, len(awsTaglessResourceTypes)+len(azureTaglessResourceTypes))
for _, awsTaglessResourceType := range awsTaglessResourceTypes {
Expand All @@ -312,11 +310,9 @@ func BuildFilterWithDefaultFieldName(ctx context.Context, queryContext *plugin.Q
esResourceGroupFilters = append(esResourceGroupFilters, NewBoolMustFilter(taglessTermsFilter))
}

// Build filters for each resourceGroupFilter
for _, resourceGroupFilter := range resourceGroupFilters {
andFilters := make([]BoolFilter, 0, 5)

// Various "TermsFilter" for Connectors, AccountIDs, ResourceTypes, etc.
if len(resourceGroupFilter.Connectors) > 0 {
andFilters = append(andFilters, NewTermsFilter("source_type", resourceGroupFilter.Connectors))
}
Expand All @@ -327,17 +323,15 @@ func BuildFilterWithDefaultFieldName(ctx context.Context, queryContext *plugin.Q
andFilters = append(andFilters, NewTermsFilter("metadata.ResourceType", resourceGroupFilter.ResourceTypes))
}

// Regions or Locations (AWS vs. Azure) => bool should
if len(resourceGroupFilter.Regions) > 0 {
andFilters = append(andFilters,
NewBoolShouldFilter(
NewTermsFilter("metadata.Region", resourceGroupFilter.Regions), // AWS
NewTermsFilter("metadata.Location", resourceGroupFilter.Regions), // Azure
NewTermsFilter("metadata.Region", resourceGroupFilter.Regions),
NewTermsFilter("metadata.Location", resourceGroupFilter.Regions),
),
)
}

// For each Tag key/value => nested filter on canonical_tags
if len(resourceGroupFilter.Tags) > 0 {
for k, v := range resourceGroupFilter.Tags {
k := strings.ToLower(k)
Expand All @@ -353,30 +347,24 @@ func BuildFilterWithDefaultFieldName(ctx context.Context, queryContext *plugin.Q
}
}

// If we have any subfilters, combine them in a MustFilter
if len(andFilters) > 0 {
esResourceGroupFilters = append(esResourceGroupFilters, NewBoolMustFilter(andFilters...))
}
}

// If we built any resourceGroupFilters, wrap them in a ShouldFilter
// so that at least one of them must match
if len(esResourceGroupFilters) > 0 {
filters = append(filters, NewBoolShouldFilter(esResourceGroupFilters...))
}
}
}
}

// Log the final filters for debug
jsonFilters, _ := json.Marshal(filters)
plugin.Logger(ctx).Trace("BuildFilter", "filters", filters, "jsonFilters", string(jsonFilters))

return filters
}

// qualValue extracts the actual string value from a proto.QualValue (Steampipe representation)
// and converts it to a simple string for ES queries.
func qualValue(qual *proto.QualValue) string {
var valStr string
val := qual.Value
Expand All @@ -399,6 +387,15 @@ func qualValue(qual *proto.QualValue) string {
return valStr
}

// Helper function: checks for common special symbols that might cause tokenization
// or partial matching in the text field. If found, we also match on .keyword.
func containsSpecialSymbol(val string) bool {
// Consider these special chars: / \ < > , - _ ( ) [ ] =
// You can expand if needed
specialChars := "/\\<>,-_()[]="
return strings.ContainsAny(val, specialChars)
}

// TermFilter represents a "term" query in Elasticsearch, e.g.:
// { "term": { "<field>": "<value>" } }
type TermFilter struct {
Expand All @@ -414,9 +411,40 @@ func NewTermFilter(field, value string) BoolFilter {
}
}

// MarshalJSON is called automatically when building the ES query JSON.
// Produces { "term": { "<field>": "<value>" } }.
// MarshalJSON automatically checks for special symbols in the value. If found,
// we generate a bool.should with "field" and "field.keyword" OR. Otherwise, a normal term filter.
func (t TermFilter) MarshalJSON() ([]byte, error) {
if containsSpecialSymbol(t.value) {
// Produce:
// {
// "bool": {
// "should": [
// { "term": { "<field>": "<value>" } },
// { "term": { "<field>.keyword": "<value>" } }
// ],
// "minimum_should_match": 1
// }
// }
return json.Marshal(map[string]any{
"bool": map[string]any{
"should": []map[string]any{
{
"term": map[string]any{
t.field: t.value,
},
},
{
"term": map[string]any{
t.field + ".keyword": t.value,
},
},
},
"minimum_should_match": 1,
},
})
}

// Otherwise, standard single-term query
return json.Marshal(map[string]any{
"term": map[string]string{
t.field: t.value,
Expand All @@ -425,22 +453,22 @@ func (t TermFilter) MarshalJSON() ([]byte, error) {
}
func (t TermFilter) IsBoolFilter() {}

// TermsFilter represents a "terms" query in Elasticsearch, e.g.:
// { "terms": { "<field>": [ "<val1>", "<val2>", ... ] } }
// TermsFilter, TermsSetMatchAllFilter, etc. remain unchanged ...
// (no modifications below this for other filters)

// TermsFilter ...
type TermsFilter struct {
field string
values []string
}

// NewTermsFilter constructs a filter for matching any of the listed values in a field.
func NewTermsFilter(field string, values []string) BoolFilter {
return TermsFilter{
field: field,
values: values,
}
}

// MarshalJSON produces { "terms": { "<field>": [ ... ] } }.
func (t TermsFilter) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{
"terms": map[string][]string{
Expand All @@ -450,22 +478,18 @@ func (t TermsFilter) MarshalJSON() ([]byte, error) {
}
func (t TermsFilter) IsBoolFilter() {}

// TermsSetMatchAllFilter is used for matching all provided terms in an array field.
// Creates a "terms_set" query with "minimum_should_match_script": "params.num_terms".
type TermsSetMatchAllFilter struct {
field string
values []string
}

// NewTermsSetMatchAllFilter is a specialized version of TermsSet for "match all" behavior.
func NewTermsSetMatchAllFilter(field string, values []string) BoolFilter {
return TermsSetMatchAllFilter{
field: field,
values: values,
}
}

// MarshalJSON produces a "terms_set" query with a script ensuring all terms must match.
func (t TermsSetMatchAllFilter) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{
"terms_set": map[string]any{
Expand All @@ -480,8 +504,6 @@ func (t TermsSetMatchAllFilter) MarshalJSON() ([]byte, error) {
}
func (t TermsSetMatchAllFilter) IsBoolFilter() {}

// RangeFilter represents a "range" query (>, >=, <, <=).
// Only the relevant keys (gt/gte/lt/lte) are set based on user input.
type RangeFilter struct {
field string
gt string
Expand All @@ -490,7 +512,6 @@ type RangeFilter struct {
lte string
}

// NewRangeFilter constructs a range filter for a single field with optional gt/gte/lt/lte.
func NewRangeFilter(field, gt, gte, lt, lte string) BoolFilter {
return RangeFilter{
field: field,
Expand All @@ -501,18 +522,6 @@ func NewRangeFilter(field, gt, gte, lt, lte string) BoolFilter {
}
}

// MarshalJSON produces something like:
//
// {
// "range": {
// "<field>": {
// "gt": "...",
// "gte": "...",
// "lt": "...",
// "lte": "..."
// }
// }
// }
func (t RangeFilter) MarshalJSON() ([]byte, error) {
fieldMap := map[string]interface{}{}
if len(t.gt) > 0 {
Expand All @@ -536,20 +545,16 @@ func (t RangeFilter) MarshalJSON() ([]byte, error) {
}
func (t RangeFilter) IsBoolFilter() {}

// BoolShouldFilter is a logical OR for multiple subfilters. Equivalent to "should" in ES.
type BoolShouldFilter struct {
filters []BoolFilter
}

// NewBoolShouldFilter returns a filter that any of the subfilters must match.
func NewBoolShouldFilter(filters ...BoolFilter) BoolFilter {
return BoolShouldFilter{
filters: filters,
}
}

// MarshalJSON creates a structure like:
// { "bool": { "should": [ <subfilters> ] } }
func (t BoolShouldFilter) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{
"bool": map[string][]BoolFilter{
Expand All @@ -559,20 +564,16 @@ func (t BoolShouldFilter) MarshalJSON() ([]byte, error) {
}
func (t BoolShouldFilter) IsBoolFilter() {}

// BoolMustFilter is a logical AND for multiple subfilters. Equivalent to "must" in ES.
type BoolMustFilter struct {
filters []BoolFilter
}

// NewBoolMustFilter groups subfilters so all must match (logical AND).
func NewBoolMustFilter(filters ...BoolFilter) BoolFilter {
return BoolMustFilter{
filters: filters,
}
}

// MarshalJSON produces:
// { "bool": { "must": [ <subfilters> ] } }
func (t BoolMustFilter) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{
"bool": map[string][]BoolFilter{
Expand All @@ -582,20 +583,16 @@ func (t BoolMustFilter) MarshalJSON() ([]byte, error) {
}
func (t BoolMustFilter) IsBoolFilter() {}

// BoolMustNotFilter is the logical NOT for subfilters. Equivalent to "must_not" in ES.
type BoolMustNotFilter struct {
filters []BoolFilter
}

// NewBoolMustNotFilter groups subfilters under "must_not".
func NewBoolMustNotFilter(filters ...BoolFilter) BoolFilter {
return BoolMustNotFilter{
filters: filters,
}
}

// MarshalJSON yields:
// { "bool": { "must_not": [ <subfilters> ] } }
func (t BoolMustNotFilter) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{
"bool": map[string][]BoolFilter{
Expand All @@ -605,29 +602,18 @@ func (t BoolMustNotFilter) MarshalJSON() ([]byte, error) {
}
func (t BoolMustNotFilter) IsBoolFilter() {}

// NestedFilter is used for queries on nested fields (Elasticsearch "nested" type).
// Example: { "nested": { "path": "...", "query": <subfilter> } }
type NestedFilter struct {
path string
query BoolFilter
}

// NewNestedFilter constructs a nested query referencing the provided path.
func NewNestedFilter(path string, query BoolFilter) BoolFilter {
return NestedFilter{
path: path,
query: query,
}
}

// MarshalJSON produces JSON like:
//
// {
// "nested": {
// "path": "<path>",
// "query": <subfilter>
// }
// }
func (t NestedFilter) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{
"nested": map[string]any{
Expand All @@ -638,7 +624,7 @@ func (t NestedFilter) MarshalJSON() ([]byte, error) {
}
func (t NestedFilter) IsBoolFilter() {}

// Healthcheck pings the cluster Health API to verify status is "green" or "yellow."
// Healthcheck checks cluster health, returning an error if "red" or otherwise failing.
func (c Client) Healthcheck(ctx context.Context) error {
opts := []func(request *opensearchapi.ClusterHealthRequest){
c.es.Cluster.Health.WithContext(ctx),
Expand Down Expand Up @@ -666,15 +652,14 @@ func (c Client) Healthcheck(ctx context.Context) error {
return fmt.Errorf("failed to unmarshal due to %v", err)
}

// We consider "green" or "yellow" as acceptable. "red" is unhealthy.
if js["status"] != "green" && js["status"] != "yellow" {
return errors.New("unhealthy")
}

return nil
}

// CreateIndexTemplate creates an index template in OpenSearch using the provided name and JSON body.
// CreateIndexTemplate sets up an index template in OpenSearch.
func (c Client) CreateIndexTemplate(ctx context.Context, name string, body string) error {
opts := []func(request *opensearchapi.IndicesPutIndexTemplateRequest){
c.es.Indices.PutIndexTemplate.WithContext(ctx),
Expand All @@ -695,7 +680,7 @@ func (c Client) CreateIndexTemplate(ctx context.Context, name string, body strin
return nil
}

// CreateComponentTemplate sets up a component template in OpenSearch with the given name and body.
// CreateComponentTemplate sets up a component template in OpenSearch.
func (c Client) CreateComponentTemplate(ctx context.Context, name string, body string) error {
opts := []func(request *opensearchapi.ClusterPutComponentTemplateRequest){
c.es.Cluster.PutComponentTemplate.WithContext(ctx),
Expand All @@ -716,7 +701,6 @@ func (c Client) CreateComponentTemplate(ctx context.Context, name string, body s
return nil
}

// DeleteByQueryResponse matches the JSON shape returned by the _delete_by_query API.
type DeleteByQueryResponse struct {
Took int `json:"took"`
TimedOut bool `json:"timed_out"`
Expand All @@ -735,7 +719,7 @@ type DeleteByQueryResponse struct {
Failures []any `json:"failures"`
}

// DeleteByQuery runs an ES "delete_by_query" request on the specified indices using the provided JSON query.
// DeleteByQuery runs _delete_by_query on the specified indices using the provided JSON query.
func DeleteByQuery(ctx context.Context, es *opensearch.Client, indices []string, query any, opts ...func(*opensearchapi.DeleteByQueryRequest)) (DeleteByQueryResponse, error) {
defaultOpts := []func(*opensearchapi.DeleteByQueryRequest){
es.DeleteByQuery.WithContext(ctx),
Expand All @@ -752,7 +736,6 @@ func DeleteByQuery(ctx context.Context, es *opensearch.Client, indices []string,
return DeleteByQueryResponse{}, err
} else if err := CheckError(resp); err != nil {
if IsIndexNotFoundErr(err) {
// If the index doesn't exist, just return an empty response.
return DeleteByQueryResponse{}, nil
}
return DeleteByQueryResponse{}, err
Expand Down

0 comments on commit 22fe669

Please sign in to comment.