engine, common: use FeatureSet type annotation for evaluate signature

It was used in some places already, but now used everywhere consistently.
This should make it easier to refactor the FeatureSet type, if necessary,
because its easier to see all the places its used.
This commit is contained in:
Willi Ballenthin
2024-05-06 12:01:59 +02:00
committed by Willi Ballenthin
parent 4fbd2ba2b8
commit 6869ef6520
3 changed files with 30 additions and 30 deletions

View File

@@ -102,14 +102,14 @@ class And(Statement):
super().__init__(description=description)
self.children = children
def evaluate(self, ctx, short_circuit=True):
def evaluate(self, features: FeatureSet, short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.and"] += 1
if short_circuit:
results = []
for child in self.children:
result = child.evaluate(ctx, short_circuit=short_circuit)
result = child.evaluate(features, short_circuit=short_circuit)
results.append(result)
if not result:
# short circuit
@@ -117,7 +117,7 @@ class And(Statement):
return Result(True, self, results)
else:
results = [child.evaluate(ctx, short_circuit=short_circuit) for child in self.children]
results = [child.evaluate(features, short_circuit=short_circuit) for child in self.children]
success = all(results)
return Result(success, self, results)
@@ -135,14 +135,14 @@ class Or(Statement):
super().__init__(description=description)
self.children = children
def evaluate(self, ctx, short_circuit=True):
def evaluate(self, features: FeatureSet, short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.or"] += 1
if short_circuit:
results = []
for child in self.children:
result = child.evaluate(ctx, short_circuit=short_circuit)
result = child.evaluate(features, short_circuit=short_circuit)
results.append(result)
if result:
# short circuit as soon as we hit one match
@@ -150,7 +150,7 @@ class Or(Statement):
return Result(False, self, results)
else:
results = [child.evaluate(ctx, short_circuit=short_circuit) for child in self.children]
results = [child.evaluate(features, short_circuit=short_circuit) for child in self.children]
success = any(results)
return Result(success, self, results)
@@ -162,11 +162,11 @@ class Not(Statement):
super().__init__(description=description)
self.child = child
def evaluate(self, ctx, short_circuit=True):
def evaluate(self, features: FeatureSet, short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.not"] += 1
results = [self.child.evaluate(ctx, short_circuit=short_circuit)]
results = [self.child.evaluate(features, short_circuit=short_circuit)]
success = not results[0]
return Result(success, self, results)
@@ -185,7 +185,7 @@ class Some(Statement):
self.count = count
self.children = children
def evaluate(self, ctx, short_circuit=True):
def evaluate(self, features: FeatureSet, short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.some"] += 1
@@ -193,7 +193,7 @@ class Some(Statement):
results = []
satisfied_children_count = 0
for child in self.children:
result = child.evaluate(ctx, short_circuit=short_circuit)
result = child.evaluate(features, short_circuit=short_circuit)
results.append(result)
if result:
satisfied_children_count += 1
@@ -204,7 +204,7 @@ class Some(Statement):
return Result(False, self, results)
else:
results = [child.evaluate(ctx, short_circuit=short_circuit) for child in self.children]
results = [child.evaluate(features, short_circuit=short_circuit) for child in self.children]
# note that here we cast the child result as a bool
# because we've overridden `__bool__` above.
#
@@ -214,7 +214,7 @@ class Some(Statement):
class Range(Statement):
"""match if the child is contained in the ctx set with a count in the given range."""
"""match if the child is contained in the feature set with a count in the given range."""
def __init__(self, child, min=None, max=None, description=None):
super().__init__(description=description)
@@ -222,15 +222,15 @@ class Range(Statement):
self.min = min if min is not None else 0
self.max = max if max is not None else (1 << 64 - 1)
def evaluate(self, ctx, **kwargs):
def evaluate(self, features: FeatureSet, short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.range"] += 1
count = len(ctx.get(self.child, []))
count = len(features.get(self.child, []))
if self.min == 0 and count == 0:
return Result(True, self, [])
return Result(self.min <= count <= self.max, self, [], locations=ctx.get(self.child))
return Result(self.min <= count <= self.max, self, [], locations=features.get(self.child))
def __str__(self):
if self.max == (1 << 64 - 1):
@@ -250,7 +250,7 @@ class Subscope(Statement):
self.scope = scope
self.child = child
def evaluate(self, ctx, **kwargs):
def evaluate(self, features: FeatureSet, short_circuit=True):
raise ValueError("cannot evaluate a subscope directly!")

View File

@@ -166,10 +166,10 @@ class Feature(abc.ABC): # noqa: B024
def __repr__(self):
return str(self)
def evaluate(self, ctx: Dict["Feature", Set[Address]], **kwargs) -> Result:
def evaluate(self, features: "capa.engine.FeatureSet", short_circuit=True) -> Result:
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature." + self.name] += 1
return Result(self in ctx, self, [], locations=ctx.get(self, set()))
return Result(self in features, self, [], locations=features.get(self, set()))
class MatchedRule(Feature):
@@ -207,7 +207,7 @@ class Substring(String):
super().__init__(value, description=description)
self.value = value
def evaluate(self, ctx, short_circuit=True):
def evaluate(self, features: "capa.engine.FeatureSet", short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.substring"] += 1
@@ -216,7 +216,7 @@ class Substring(String):
matches: typing.DefaultDict[str, Set[Address]] = collections.defaultdict(set)
assert isinstance(self.value, str)
for feature, locations in ctx.items():
for feature, locations in features.items():
if not isinstance(feature, (String,)):
continue
@@ -299,7 +299,7 @@ class Regex(String):
f"invalid regular expression: {value} it should use Python syntax, try it at https://pythex.org"
) from exc
def evaluate(self, ctx, short_circuit=True):
def evaluate(self, features: "capa.engine.FeatureSet", short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.regex"] += 1
@@ -307,7 +307,7 @@ class Regex(String):
# will unique the locations later on.
matches: typing.DefaultDict[str, Set[Address]] = collections.defaultdict(set)
for feature, locations in ctx.items():
for feature, locations in features.items():
if not isinstance(feature, (String,)):
continue
@@ -384,12 +384,12 @@ class Bytes(Feature):
super().__init__(value, description=description)
self.value = value
def evaluate(self, ctx, **kwargs):
def evaluate(self, features: "capa.engine.FeatureSet", short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.bytes"] += 1
assert isinstance(self.value, bytes)
for feature, locations in ctx.items():
for feature, locations in features.items():
if not isinstance(feature, (Bytes,)):
continue
@@ -434,11 +434,11 @@ class OS(Feature):
super().__init__(value, description=description)
self.name = "os"
def evaluate(self, ctx, **kwargs):
def evaluate(self, features: "capa.engine.FeatureSet", short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature." + self.name] += 1
for feature, locations in ctx.items():
for feature, locations in features.items():
if not isinstance(feature, (OS,)):
continue

View File

@@ -93,10 +93,10 @@ def test_complex():
def test_range():
# unbounded range, but no matching feature
# since the lower bound is zero, and there are zero matches, ok
assert bool(Range(Number(1)).evaluate({Number(2): {}})) is True
assert bool(Range(Number(1)).evaluate({Number(2): {}})) is True # type: ignore
# unbounded range with matching feature should always match
assert bool(Range(Number(1)).evaluate({Number(1): {}})) is True
assert bool(Range(Number(1)).evaluate({Number(1): {}})) is True # type: ignore
assert bool(Range(Number(1)).evaluate({Number(1): {ADDR1}})) is True
# unbounded max
@@ -112,12 +112,12 @@ def test_range():
assert bool(Range(Number(1), max=2).evaluate({Number(1): {ADDR1, ADDR2, ADDR3}})) is False
# we can do an exact match by setting min==max
assert bool(Range(Number(1), min=1, max=1).evaluate({Number(1): {}})) is False
assert bool(Range(Number(1), min=1, max=1).evaluate({Number(1): {}})) is False # type: ignore
assert bool(Range(Number(1), min=1, max=1).evaluate({Number(1): {ADDR1}})) is True
assert bool(Range(Number(1), min=1, max=1).evaluate({Number(1): {ADDR1, ADDR2}})) is False
# bounded range
assert bool(Range(Number(1), min=1, max=3).evaluate({Number(1): {}})) is False
assert bool(Range(Number(1), min=1, max=3).evaluate({Number(1): {}})) is False # type: ignore
assert bool(Range(Number(1), min=1, max=3).evaluate({Number(1): {ADDR1}})) is True
assert bool(Range(Number(1), min=1, max=3).evaluate({Number(1): {ADDR1, ADDR2}})) is True
assert bool(Range(Number(1), min=1, max=3).evaluate({Number(1): {ADDR1, ADDR2, ADDR3}})) is True