diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c7c3ac6..665c3c1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ ### New Features +- engine: short circuit logic nodes for better performance #824 @williballenthin + ### Breaking Changes ### New Rules (3) diff --git a/capa/engine.py b/capa/engine.py index 97f6d295..29c0dc65 100644 --- a/capa/engine.py +++ b/capa/engine.py @@ -46,9 +46,12 @@ class Statement: def __repr__(self): return str(self) - def evaluate(self, features: FeatureSet) -> Result: + def evaluate(self, features: FeatureSet, short_circuit=True) -> Result: """ classes that inherit `Statement` must implement `evaluate` + + args: + short_circuit (bool): if true, then statements like and/or/some may short circuit. """ raise NotImplementedError() @@ -73,35 +76,69 @@ class Statement: class And(Statement): - """match if all of the children evaluate to True.""" + """ + match if all of the children evaluate to True. + + the order of evaluation is dictated by the property + `And.children` (type: List[Statement|Feature]). + a query optimizer may safely manipulate the order of these children. + """ def __init__(self, children, description=None): super(And, self).__init__(description=description) self.children = children - def evaluate(self, ctx): + def evaluate(self, ctx, short_circuit=True): capa.perf.counters["evaluate.feature"] += 1 capa.perf.counters["evaluate.feature.and"] += 1 - results = [child.evaluate(ctx) for child in self.children] - success = all(results) - return Result(success, self, results) + if short_circuit: + results = [] + for child in self.children: + result = child.evaluate(ctx, short_circuit=short_circuit) + results.append(result) + if not result: + # short circuit + return Result(False, self, results) + + return Result(True, self, results) + else: + results = [child.evaluate(ctx, short_circuit=short_circuit) for child in self.children] + success = all(results) + return Result(success, self, results) class Or(Statement): - """match if any of the children evaluate to True.""" + """ + match if any of the children evaluate to True. + + the order of evaluation is dictated by the property + `Or.children` (type: List[Statement|Feature]). + a query optimizer may safely manipulate the order of these children. + """ def __init__(self, children, description=None): super(Or, self).__init__(description=description) self.children = children - def evaluate(self, ctx): + def evaluate(self, ctx, short_circuit=True): capa.perf.counters["evaluate.feature"] += 1 capa.perf.counters["evaluate.feature.or"] += 1 - results = [child.evaluate(ctx) for child in self.children] - success = any(results) - return Result(success, self, results) + if short_circuit: + results = [] + for child in self.children: + result = child.evaluate(ctx, short_circuit=short_circuit) + results.append(result) + if result: + # short circuit as soon as we hit one match + return Result(True, self, results) + + return Result(False, self, results) + else: + results = [child.evaluate(ctx, short_circuit=short_circuit) for child in self.children] + success = any(results) + return Result(success, self, results) class Not(Statement): @@ -111,34 +148,55 @@ class Not(Statement): super(Not, self).__init__(description=description) self.child = child - def evaluate(self, ctx): + def evaluate(self, ctx, short_circuit=True): capa.perf.counters["evaluate.feature"] += 1 capa.perf.counters["evaluate.feature.not"] += 1 - results = [self.child.evaluate(ctx)] + results = [self.child.evaluate(ctx, short_circuit=short_circuit)] success = not results[0] return Result(success, self, results) class Some(Statement): - """match if at least N of the children evaluate to True.""" + """ + match if at least N of the children evaluate to True. + + the order of evaluation is dictated by the property + `Some.children` (type: List[Statement|Feature]). + a query optimizer may safely manipulate the order of these children. + """ def __init__(self, count, children, description=None): super(Some, self).__init__(description=description) self.count = count self.children = children - def evaluate(self, ctx): + def evaluate(self, ctx, short_circuit=True): capa.perf.counters["evaluate.feature"] += 1 capa.perf.counters["evaluate.feature.some"] += 1 - results = [child.evaluate(ctx) for child in self.children] - # note that here we cast the child result as a bool - # because we've overridden `__bool__` above. - # - # we can't use `if child is True` because the instance is not True. - success = sum([1 for child in results if bool(child) is True]) >= self.count - return Result(success, self, results) + if short_circuit: + results = [] + satisfied_children_count = 0 + for child in self.children: + result = child.evaluate(ctx, short_circuit=short_circuit) + results.append(result) + if result: + satisfied_children_count += 1 + + if satisfied_children_count >= self.count: + # short circuit as soon as we hit the threshold + return Result(True, self, results) + + return Result(False, self, results) + else: + results = [child.evaluate(ctx, short_circuit=short_circuit) for child in self.children] + # note that here we cast the child result as a bool + # because we've overridden `__bool__` above. + # + # we can't use `if child is True` because the instance is not True. + success = sum([1 for child in results if bool(child) is True]) >= self.count + return Result(success, self, results) class Range(Statement): @@ -150,7 +208,7 @@ class Range(Statement): self.min = min if min is not None else 0 self.max = max if max is not None else (1 << 64 - 1) - def evaluate(self, ctx): + def evaluate(self, ctx, **kwargs): capa.perf.counters["evaluate.feature"] += 1 capa.perf.counters["evaluate.feature.range"] += 1 @@ -178,7 +236,7 @@ class Subscope(Statement): self.scope = scope self.child = child - def evaluate(self, ctx): + def evaluate(self, ctx, **kwargs): raise ValueError("cannot evaluate a subscope directly!") @@ -236,8 +294,18 @@ def match(rules: List["capa.rules.Rule"], features: FeatureSet, va: int) -> Tupl features = collections.defaultdict(set, copy.copy(features)) for rule in rules: - res = rule.evaluate(features) + res = rule.evaluate(features, short_circuit=True) if res: + # we first matched the rule with short circuiting enabled. + # this is much faster than without short circuiting. + # however, we want to collect all results thoroughly, + # so once we've found a match quickly, + # go back and capture results without short circuiting. + res = rule.evaluate(features, short_circuit=False) + + # sanity check + assert bool(res) is True + results[rule.name].append((va, res)) # we need to update the current `features` # because subsequent iterations of this loop may use newly added features, diff --git a/capa/features/common.py b/capa/features/common.py index 3a4e71e9..6b867766 100644 --- a/capa/features/common.py +++ b/capa/features/common.py @@ -146,7 +146,7 @@ class Feature: def __repr__(self): return str(self) - def evaluate(self, ctx: Dict["Feature", Set[int]]) -> Result: + def evaluate(self, ctx: Dict["Feature", Set[int]], **kwargs) -> Result: capa.perf.counters["evaluate.feature"] += 1 capa.perf.counters["evaluate.feature." + self.name] += 1 return Result(self in ctx, self, [], locations=ctx.get(self, [])) @@ -192,7 +192,7 @@ class Substring(String): super(Substring, self).__init__(value, description=description) self.value = value - def evaluate(self, ctx): + def evaluate(self, ctx, short_circuit=True): capa.perf.counters["evaluate.feature"] += 1 capa.perf.counters["evaluate.feature.substring"] += 1 @@ -210,6 +210,10 @@ class Substring(String): if self.value in feature.value: matches[feature.value].extend(locations) + if short_circuit: + # we found one matching string, thats sufficient to match. + # don't collect other matching strings in this mode. + break if matches: # finalize: defaultdict -> dict @@ -280,7 +284,7 @@ class Regex(String): "invalid regular expression: %s it should use Python syntax, try it at https://pythex.org" % value ) - def evaluate(self, ctx): + def evaluate(self, ctx, short_circuit=True): capa.perf.counters["evaluate.feature"] += 1 capa.perf.counters["evaluate.feature.regex"] += 1 @@ -302,6 +306,10 @@ class Regex(String): # so that they don't have to prefix/suffix their terms like: /.*foo.*/. if self.re.search(feature.value): matches[feature.value].extend(locations) + if short_circuit: + # we found one matching string, thats sufficient to match. + # don't collect other matching strings in this mode. + break if matches: # finalize: defaultdict -> dict @@ -366,7 +374,7 @@ class Bytes(Feature): super(Bytes, self).__init__(value, description=description) self.value = value - def evaluate(self, ctx): + def evaluate(self, ctx, **kwargs): capa.perf.counters["evaluate.feature"] += 1 capa.perf.counters["evaluate.feature.bytes"] += 1 diff --git a/capa/rules.py b/capa/rules.py index 2753f19d..00dc0837 100644 --- a/capa/rules.py +++ b/capa/rules.py @@ -620,10 +620,10 @@ class Rule: for new_rule in self._extract_subscope_rules_rec(self.statement): yield new_rule - def evaluate(self, features: FeatureSet): + def evaluate(self, features: FeatureSet, short_circuit=True): capa.perf.counters["evaluate.feature"] += 1 capa.perf.counters["evaluate.feature.rule"] += 1 - return self.statement.evaluate(features) + return self.statement.evaluate(features, short_circuit=short_circuit) @classmethod def from_dict(cls, d, definition): diff --git a/tests/test_engine.py b/tests/test_engine.py index ce421759..b07c89e6 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -533,3 +533,29 @@ def test_render_offset(): assert str(capa.features.insn.Offset(1)) == "offset(0x1)" assert str(capa.features.insn.Offset(1, bitness=capa.features.common.BITNESS_X32)) == "offset/x32(0x1)" assert str(capa.features.insn.Offset(1, bitness=capa.features.common.BITNESS_X64)) == "offset/x64(0x1)" + + +def test_short_circuit(): + assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}) == True + + # with short circuiting, only the children up until the first satisfied child are captured. + assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {1}}, short_circuit=True).children) == 1 + assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {1}}, short_circuit=False).children) == 2 + + +def test_eval_order(): + # base cases. + assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}) == True + assert Or([Number(1), Number(2)]).evaluate({Number(2): {1}}) == True + + # with short circuiting, only the children up until the first satisfied child are captured. + assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {1}}).children) == 1 + assert len(Or([Number(1), Number(2)]).evaluate({Number(2): {1}}).children) == 2 + assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {1}, Number(2): {1}}).children) == 1 + + # and its guaranteed that children are evaluated in order. + assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}).children[0].statement == Number(1) + assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}).children[0].statement != Number(2) + + assert Or([Number(1), Number(2)]).evaluate({Number(2): {1}}).children[1].statement == Number(2) + assert Or([Number(1), Number(2)]).evaluate({Number(2): {1}}).children[1].statement != Number(1)