Merge pull request #827 from mandiant/perf/short-circuit

perf: short circuit logic nodes when appropriate
This commit is contained in:
Willi Ballenthin
2021-11-09 16:10:20 -07:00
committed by GitHub
5 changed files with 135 additions and 31 deletions

View File

@@ -4,6 +4,8 @@
### New Features
- engine: short circuit logic nodes for better performance #824 @williballenthin
### Breaking Changes
### New Rules (3)

View File

@@ -46,9 +46,12 @@ class Statement:
def __repr__(self):
return str(self)
def evaluate(self, features: FeatureSet) -> Result:
def evaluate(self, features: FeatureSet, short_circuit=True) -> Result:
"""
classes that inherit `Statement` must implement `evaluate`
args:
short_circuit (bool): if true, then statements like and/or/some may short circuit.
"""
raise NotImplementedError()
@@ -73,33 +76,67 @@ class Statement:
class And(Statement):
"""match if all of the children evaluate to True."""
"""
match if all of the children evaluate to True.
the order of evaluation is dictated by the property
`And.children` (type: List[Statement|Feature]).
a query optimizer may safely manipulate the order of these children.
"""
def __init__(self, children, description=None):
super(And, self).__init__(description=description)
self.children = children
def evaluate(self, ctx):
def evaluate(self, ctx, short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.and"] += 1
results = [child.evaluate(ctx) for child in self.children]
if short_circuit:
results = []
for child in self.children:
result = child.evaluate(ctx, short_circuit=short_circuit)
results.append(result)
if not result:
# short circuit
return Result(False, self, results)
return Result(True, self, results)
else:
results = [child.evaluate(ctx, short_circuit=short_circuit) for child in self.children]
success = all(results)
return Result(success, self, results)
class Or(Statement):
"""match if any of the children evaluate to True."""
"""
match if any of the children evaluate to True.
the order of evaluation is dictated by the property
`Or.children` (type: List[Statement|Feature]).
a query optimizer may safely manipulate the order of these children.
"""
def __init__(self, children, description=None):
super(Or, self).__init__(description=description)
self.children = children
def evaluate(self, ctx):
def evaluate(self, ctx, short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.or"] += 1
results = [child.evaluate(ctx) for child in self.children]
if short_circuit:
results = []
for child in self.children:
result = child.evaluate(ctx, short_circuit=short_circuit)
results.append(result)
if result:
# short circuit as soon as we hit one match
return Result(True, self, results)
return Result(False, self, results)
else:
results = [child.evaluate(ctx, short_circuit=short_circuit) for child in self.children]
success = any(results)
return Result(success, self, results)
@@ -111,28 +148,49 @@ class Not(Statement):
super(Not, self).__init__(description=description)
self.child = child
def evaluate(self, ctx):
def evaluate(self, ctx, short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.not"] += 1
results = [self.child.evaluate(ctx)]
results = [self.child.evaluate(ctx, short_circuit=short_circuit)]
success = not results[0]
return Result(success, self, results)
class Some(Statement):
"""match if at least N of the children evaluate to True."""
"""
match if at least N of the children evaluate to True.
the order of evaluation is dictated by the property
`Some.children` (type: List[Statement|Feature]).
a query optimizer may safely manipulate the order of these children.
"""
def __init__(self, count, children, description=None):
super(Some, self).__init__(description=description)
self.count = count
self.children = children
def evaluate(self, ctx):
def evaluate(self, ctx, short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.some"] += 1
results = [child.evaluate(ctx) for child in self.children]
if short_circuit:
results = []
satisfied_children_count = 0
for child in self.children:
result = child.evaluate(ctx, short_circuit=short_circuit)
results.append(result)
if result:
satisfied_children_count += 1
if satisfied_children_count >= self.count:
# short circuit as soon as we hit the threshold
return Result(True, self, results)
return Result(False, self, results)
else:
results = [child.evaluate(ctx, short_circuit=short_circuit) for child in self.children]
# note that here we cast the child result as a bool
# because we've overridden `__bool__` above.
#
@@ -150,7 +208,7 @@ class Range(Statement):
self.min = min if min is not None else 0
self.max = max if max is not None else (1 << 64 - 1)
def evaluate(self, ctx):
def evaluate(self, ctx, **kwargs):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.range"] += 1
@@ -178,7 +236,7 @@ class Subscope(Statement):
self.scope = scope
self.child = child
def evaluate(self, ctx):
def evaluate(self, ctx, **kwargs):
raise ValueError("cannot evaluate a subscope directly!")
@@ -236,8 +294,18 @@ def match(rules: List["capa.rules.Rule"], features: FeatureSet, va: int) -> Tupl
features = collections.defaultdict(set, copy.copy(features))
for rule in rules:
res = rule.evaluate(features)
res = rule.evaluate(features, short_circuit=True)
if res:
# we first matched the rule with short circuiting enabled.
# this is much faster than without short circuiting.
# however, we want to collect all results thoroughly,
# so once we've found a match quickly,
# go back and capture results without short circuiting.
res = rule.evaluate(features, short_circuit=False)
# sanity check
assert bool(res) is True
results[rule.name].append((va, res))
# we need to update the current `features`
# because subsequent iterations of this loop may use newly added features,

View File

@@ -146,7 +146,7 @@ class Feature:
def __repr__(self):
return str(self)
def evaluate(self, ctx: Dict["Feature", Set[int]]) -> Result:
def evaluate(self, ctx: Dict["Feature", Set[int]], **kwargs) -> Result:
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature." + self.name] += 1
return Result(self in ctx, self, [], locations=ctx.get(self, []))
@@ -192,7 +192,7 @@ class Substring(String):
super(Substring, self).__init__(value, description=description)
self.value = value
def evaluate(self, ctx):
def evaluate(self, ctx, short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.substring"] += 1
@@ -210,6 +210,10 @@ class Substring(String):
if self.value in feature.value:
matches[feature.value].extend(locations)
if short_circuit:
# we found one matching string, thats sufficient to match.
# don't collect other matching strings in this mode.
break
if matches:
# finalize: defaultdict -> dict
@@ -280,7 +284,7 @@ class Regex(String):
"invalid regular expression: %s it should use Python syntax, try it at https://pythex.org" % value
)
def evaluate(self, ctx):
def evaluate(self, ctx, short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.regex"] += 1
@@ -302,6 +306,10 @@ class Regex(String):
# so that they don't have to prefix/suffix their terms like: /.*foo.*/.
if self.re.search(feature.value):
matches[feature.value].extend(locations)
if short_circuit:
# we found one matching string, thats sufficient to match.
# don't collect other matching strings in this mode.
break
if matches:
# finalize: defaultdict -> dict
@@ -366,7 +374,7 @@ class Bytes(Feature):
super(Bytes, self).__init__(value, description=description)
self.value = value
def evaluate(self, ctx):
def evaluate(self, ctx, **kwargs):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.bytes"] += 1

View File

@@ -620,10 +620,10 @@ class Rule:
for new_rule in self._extract_subscope_rules_rec(self.statement):
yield new_rule
def evaluate(self, features: FeatureSet):
def evaluate(self, features: FeatureSet, short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.rule"] += 1
return self.statement.evaluate(features)
return self.statement.evaluate(features, short_circuit=short_circuit)
@classmethod
def from_dict(cls, d, definition):

View File

@@ -533,3 +533,29 @@ def test_render_offset():
assert str(capa.features.insn.Offset(1)) == "offset(0x1)"
assert str(capa.features.insn.Offset(1, bitness=capa.features.common.BITNESS_X32)) == "offset/x32(0x1)"
assert str(capa.features.insn.Offset(1, bitness=capa.features.common.BITNESS_X64)) == "offset/x64(0x1)"
def test_short_circuit():
assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}) == True
# with short circuiting, only the children up until the first satisfied child are captured.
assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {1}}, short_circuit=True).children) == 1
assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {1}}, short_circuit=False).children) == 2
def test_eval_order():
# base cases.
assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}) == True
assert Or([Number(1), Number(2)]).evaluate({Number(2): {1}}) == True
# with short circuiting, only the children up until the first satisfied child are captured.
assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {1}}).children) == 1
assert len(Or([Number(1), Number(2)]).evaluate({Number(2): {1}}).children) == 2
assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {1}, Number(2): {1}}).children) == 1
# and its guaranteed that children are evaluated in order.
assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}).children[0].statement == Number(1)
assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}).children[0].statement != Number(2)
assert Or([Number(1), Number(2)]).evaluate({Number(2): {1}}).children[1].statement == Number(2)
assert Or([Number(1), Number(2)]).evaluate({Number(2): {1}}).children[1].statement != Number(1)