forked from elastic/detection-rules
-
Notifications
You must be signed in to change notification settings - Fork 0
/
optimizer.py
129 lines (98 loc) · 4.5 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License;
# you may not use this file except in compliance with the Elastic License.
import functools
from eql import Walker, DepthFirstWalker
from .ast import AndValues, NotValue, Value, OrValues, NotExpr, FieldComparison
class Optimizer(DepthFirstWalker):
def flat_optimize(self, tree):
return Walker.walk(self, tree)
def _walk_default(self, tree, *args, **kwargs):
return tree
def group_fields(self, tree, value_cls): # type: (List, type) -> KqlNode
cls = type(tree)
field_groups = {}
ungrouped = []
for term in tree.items:
# move a `not` inwards before grouping
if isinstance(term, NotExpr) and isinstance(term.expr, FieldComparison):
term = FieldComparison(term.expr.field, NotValue(term.expr.value))
if isinstance(term, FieldComparison):
if term.field.name in field_groups:
existing_checks = field_groups[term.field.name]
existing_checks.append(term)
continue
else:
field_groups[term.field.name] = [term]
ungrouped.append(term)
for term in ungrouped:
if isinstance(term, FieldComparison):
term.value = self.flat_optimize(value_cls([t.value for t in field_groups[term.field.name]]))
ungrouped = [self.flat_optimize(u) for u in ungrouped]
return cls(ungrouped) if len(ungrouped) > 1 else ungrouped[0]
@staticmethod
def sort_key(a, b):
if isinstance(a, Value) and not isinstance(b, Value):
return -1
if not isinstance(a, Value) and isinstance(b, Value):
return +1
if isinstance(a, Value) and isinstance(b, Value):
t_a = type(a)
t_b = type(b)
if t_a == t_b:
return (a.value > b.value) - (a.value < b.value)
else:
return (t_a.__name__ > t_b.__name__) - (t_a.__name__ < t_b.__name__)
else:
# unable to compare
return 0
def _walk_field_comparison(self, tree): # type: (FieldComparison) -> KqlNode
# if there's a single `not`, then pull it out of the expression
if isinstance(tree.value, NotValue):
return NotExpr(FieldComparison(tree.field, tree.value.value))
return tree
def flatten(self, tree): # type: (List) -> List
cls = type(tree)
flattened = []
for node in tree.items:
if isinstance(node, cls):
flattened.extend(node.items)
else:
flattened.append(node)
flattened = [self.flat_optimize(t) for t in flattened]
return cls(flattened)
def flatten_values(self, tree, dual_cls): # type: (List, type) -> List
cls = type(tree)
flattened = []
not_term = None
for term in self.flatten(tree).items:
if isinstance(term, NotValue) and isinstance(term.value, Value):
# create a copy to leave the source tree unaltered
term = NotValue(term.value)
if not_term is None:
not_term = term
else:
not_term.value = dual_cls([not_term.value, term.value])
continue
flattened.append(term)
if not_term is not None:
not_term.value = self.flat_optimize(not_term.value)
flattened = [self.flat_optimize(t) for t in flattened]
flattened.sort(key=functools.cmp_to_key(self.sort_key))
return cls(flattened) if len(flattened) > 1 else flattened[0]
def _walk_not_value(self, tree): # type: (NotValue) -> KqlNode
if isinstance(tree.value, NotValue):
return tree.value.value
return tree
def _walk_or_values(self, tree):
return self.flatten_values(tree, AndValues)
def _walk_and_values(self, tree):
return self.flatten_values(tree, OrValues)
def _walk_not_expr(self, tree): # type: (NotExpr) -> KqlNode
if isinstance(tree.expr, NotExpr):
return tree.expr.expr
return tree
def _walk_and_expr(self, tree): # type: (AndExpr) -> KqlNode
return self.group_fields(self.flatten(tree), value_cls=AndValues)
def _walk_or_expr(self, tree): # type: (OrExpr) -> KqlNode
return self.group_fields(self.flatten(tree), value_cls=OrValues)