lots of mypy

This commit is contained in:
Willi Ballenthin
2022-12-14 10:37:39 +01:00
parent b1d6fcd6c8
commit b819033da0
29 changed files with 410 additions and 233 deletions

View File

@@ -8,7 +8,7 @@
import copy
import collections
from typing import TYPE_CHECKING, Set, Dict, List, Tuple, Mapping, Iterable
from typing import TYPE_CHECKING, Set, Dict, List, Tuple, Mapping, Iterable, Iterator, Union, cast
import capa.perf
import capa.features.common
@@ -60,17 +60,24 @@ class Statement:
"""
raise NotImplementedError()
def get_children(self):
def get_children(self) -> Iterator[Union["Statement", Feature]]:
if hasattr(self, "child"):
yield self.child
# this really confuses mypy because the property may not exist
# since its defined in the subclasses.
child = self.child # type: ignore
assert isinstance(child, (Statement, Feature))
yield child
if hasattr(self, "children"):
for child in getattr(self, "children"):
assert isinstance(child, (Statement, Feature))
yield child
def replace_child(self, existing, new):
if hasattr(self, "child"):
if self.child is existing:
# this really confuses mypy because the property may not exist
# since its defined in the subclasses.
if self.child is existing: # type: ignore
self.child = new
if hasattr(self, "children"):

View File

@@ -200,8 +200,9 @@ class Substring(String):
# mapping from string value to list of locations.
# will unique the locations later on.
matches = collections.defaultdict(list)
matches: collections.defaultdict[str, Set[Address]] = collections.defaultdict(set)
assert isinstance(self.value, str)
for feature, locations in ctx.items():
if not isinstance(feature, (String,)):
continue
@@ -211,32 +212,29 @@ class Substring(String):
raise ValueError("unexpected feature value type")
if self.value in feature.value:
matches[feature.value].extend(locations)
matches[feature.value].update(locations)
if short_circuit:
# we found one matching string, thats sufficient to match.
# don't collect other matching strings in this mode.
break
if matches:
# finalize: defaultdict -> dict
# which makes json serialization easier
matches = dict(matches)
# collect all locations
locations = set()
for s in matches.keys():
matches[s] = list(set(matches[s]))
locations.update(matches[s])
for locs in matches.values():
locations.update(locs)
# unlike other features, we cannot return put a reference to `self` directly in a `Result`.
# this is because `self` may match on many strings, so we can't stuff the matched value into it.
# instead, return a new instance that has a reference to both the substring and the matched values.
return Result(True, _MatchedSubstring(self, matches), [], locations=locations)
return Result(True, _MatchedSubstring(self, dict(matches)), [], locations=locations)
else:
return Result(False, _MatchedSubstring(self, {}), [])
def __str__(self):
return "substring(%s)" % self.value
v = self.value
assert isinstance(v, str)
return "substring(%s)" % v
class _MatchedSubstring(Substring):
@@ -261,6 +259,7 @@ class _MatchedSubstring(Substring):
self.matches = matches
def __str__(self):
assert isinstance(self.value, str)
return 'substring("%s", matches = %s)' % (
self.value,
", ".join(map(lambda s: '"' + s + '"', (self.matches or {}).keys())),
@@ -292,7 +291,7 @@ class Regex(String):
# mapping from string value to list of locations.
# will unique the locations later on.
matches = collections.defaultdict(list)
matches: collections.defaultdict[str, Set[Address]] = collections.defaultdict(set)
for feature, locations in ctx.items():
if not isinstance(feature, (String,)):
@@ -307,32 +306,28 @@ class Regex(String):
# using this mode cleans is more convenient for rule authors,
# so that they don't have to prefix/suffix their terms like: /.*foo.*/.
if self.re.search(feature.value):
matches[feature.value].extend(locations)
matches[feature.value].update(locations)
if short_circuit:
# we found one matching string, thats sufficient to match.
# don't collect other matching strings in this mode.
break
if matches:
# finalize: defaultdict -> dict
# which makes json serialization easier
matches = dict(matches)
# collect all locations
locations = set()
for s in matches.keys():
matches[s] = list(set(matches[s]))
locations.update(matches[s])
for locs in matches.values():
locations.update(locs)
# unlike other features, we cannot return put a reference to `self` directly in a `Result`.
# this is because `self` may match on many strings, so we can't stuff the matched value into it.
# instead, return a new instance that has a reference to both the regex and the matched values.
# see #262.
return Result(True, _MatchedRegex(self, matches), [], locations=locations)
return Result(True, _MatchedRegex(self, dict(matches)), [], locations=locations)
else:
return Result(False, _MatchedRegex(self, {}), [])
def __str__(self):
assert isinstance(self.value, str)
return "regex(string =~ %s)" % self.value
@@ -358,6 +353,7 @@ class _MatchedRegex(Regex):
self.matches = matches
def __str__(self):
assert isinstance(self.value, str)
return "regex(string =~ %s, matches = %s)" % (
self.value,
", ".join(map(lambda s: '"' + s + '"', (self.matches or {}).keys())),
@@ -380,16 +376,19 @@ class Bytes(Feature):
capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.bytes"] += 1
assert isinstance(self.value, bytes)
for feature, locations in ctx.items():
if not isinstance(feature, (Bytes,)):
continue
assert isinstance(feature.value, bytes)
if feature.value.startswith(self.value):
return Result(True, self, [], locations=locations)
return Result(False, self, [])
def get_value_str(self):
assert isinstance(self.value, bytes)
return hex_string(bytes_to_str(self.value))

View File

@@ -107,8 +107,18 @@ class DnUnmanagedMethod:
return f"{module}.{method}"
def validate_has_dotnet(pe: dnfile.dnPE):
assert pe.net is not None
assert pe.net.mdtables is not None
assert pe.net.Flags is not None
def resolve_dotnet_token(pe: dnfile.dnPE, token: Token) -> Any:
"""map generic token to string or table row"""
validate_has_dotnet(pe)
assert pe.net is not None
assert pe.net.mdtables is not None
if isinstance(token, StringToken):
user_string: Optional[str] = read_dotnet_user_string(pe, token)
if user_string is None:
@@ -143,6 +153,10 @@ def read_dotnet_method_body(pe: dnfile.dnPE, row: dnfile.mdtable.MethodDefRow) -
def read_dotnet_user_string(pe: dnfile.dnPE, token: StringToken) -> Optional[str]:
"""read user string from #US stream"""
validate_has_dotnet(pe)
assert pe.net is not None
assert pe.net.user_strings is not None
try:
user_string: Optional[dnfile.stream.UserString] = pe.net.user_strings.get_us(token.rid)
except UnicodeDecodeError as e:
@@ -169,6 +183,11 @@ def get_dotnet_managed_imports(pe: dnfile.dnPE) -> Iterator[DnType]:
TypeName (index into String heap)
TypeNamespace (index into String heap)
"""
validate_has_dotnet(pe)
assert pe.net is not None
assert pe.net.mdtables is not None
assert pe.net.mdtables.MemberRef is not None
for (rid, row) in enumerate(iter_dotnet_table(pe, "MemberRef")):
if not isinstance(row.Class.row, dnfile.mdtable.TypeRefRow):
continue
@@ -258,6 +277,11 @@ def get_dotnet_properties(pe: dnfile.dnPE) -> Iterator[DnType]:
def get_dotnet_managed_method_bodies(pe: dnfile.dnPE) -> Iterator[Tuple[int, CilMethodBody]]:
"""get managed methods from MethodDef table"""
validate_has_dotnet(pe)
assert pe.net is not None
assert pe.net.mdtables is not None
assert pe.net.mdtables.MethodDef is not None
if not hasattr(pe.net.mdtables, "MethodDef"):
return
@@ -307,15 +331,28 @@ def calculate_dotnet_token_value(table: int, rid: int) -> int:
def is_dotnet_table_valid(pe: dnfile.dnPE, table_name: str) -> bool:
validate_has_dotnet(pe)
assert pe.net is not None
assert pe.net.mdtables is not None
return bool(getattr(pe.net.mdtables, table_name, None))
def is_dotnet_mixed_mode(pe: dnfile.dnPE) -> bool:
validate_has_dotnet(pe)
assert pe.net is not None
assert pe.net.Flags is not None
return not bool(pe.net.Flags.CLR_ILONLY)
def iter_dotnet_table(pe: dnfile.dnPE, name: str) -> Iterator[Any]:
validate_has_dotnet(pe)
assert pe.net is not None
assert pe.net.mdtables is not None
if not is_dotnet_table_valid(pe, name):
return
for row in getattr(pe.net.mdtables, name):
yield row

View File

@@ -19,9 +19,19 @@ def extract_file_os(**kwargs) -> Iterator[Tuple[Feature, Address]]:
yield OS(OS_ANY), NO_ADDRESS
def extract_file_arch(pe, **kwargs) -> Iterator[Tuple[Feature, Address]]:
def validate_has_dotnet(pe: dnfile.dnPE):
assert pe.net is not None
assert pe.net.mdtables is not None
assert pe.net.Flags is not None
def extract_file_arch(pe: dnfile.dnPE, **kwargs) -> Iterator[Tuple[Feature, Address]]:
# to distinguish in more detail, see https://stackoverflow.com/a/23614024/10548020
# .NET 4.5 added option: any CPU, 32-bit preferred
validate_has_dotnet(pe)
assert pe.net is not None
assert pe.net.Flags is not None
if pe.net.Flags.CLR_32BITREQUIRED and pe.PE_TYPE == pefile.OPTIONAL_HEADER_MAGIC_PE:
yield Arch(ARCH_I386), NO_ADDRESS
elif not pe.net.Flags.CLR_32BITREQUIRED and pe.PE_TYPE == pefile.OPTIONAL_HEADER_MAGIC_PE_PLUS:
@@ -71,6 +81,10 @@ class DnfileFeatureExtractor(FeatureExtractor):
# self.pe.net.Flags.CLT_NATIVE_ENTRYPOINT
# True: native EP: Token
# False: managed EP: RVA
validate_has_dotnet(self.pe)
assert self.pe.net is not None
assert self.pe.net.struct is not None
return self.pe.net.struct.EntryPointTokenOrRva
def extract_global_features(self):
@@ -83,13 +97,32 @@ class DnfileFeatureExtractor(FeatureExtractor):
return bool(self.pe.net)
def is_mixed_mode(self) -> bool:
validate_has_dotnet(self.pe)
assert self.pe is not None
assert self.pe.net is not None
assert self.pe.net.Flags is not None
return not bool(self.pe.net.Flags.CLR_ILONLY)
def get_runtime_version(self) -> Tuple[int, int]:
validate_has_dotnet(self.pe)
assert self.pe is not None
assert self.pe.net is not None
assert self.pe.net.struct is not None
return self.pe.net.struct.MajorRuntimeVersion, self.pe.net.struct.MinorRuntimeVersion
def get_meta_version_string(self) -> str:
return self.pe.net.metadata.struct.Version.rstrip(b"\x00").decode("utf-8")
validate_has_dotnet(self.pe)
assert self.pe.net is not None
assert self.pe.net.metadata is not None
assert self.pe.net.metadata.struct is not None
assert self.pe.net.metadata.struct.Version is not None
vbuf = self.pe.net.metadata.struct.Version
assert isinstance(vbuf, bytes)
return vbuf.rstrip(b"\x00").decode("utf-8")
def get_functions(self):
raise NotImplementedError("DnfileFeatureExtractor can only be used to extract file features")

View File

@@ -40,6 +40,12 @@ def extract_file_format(**kwargs) -> Iterator[Tuple[Format, Address]]:
yield Format(FORMAT_DOTNET), NO_ADDRESS
def validate_has_dotnet(pe: dnfile.dnPE):
assert pe.net is not None
assert pe.net.mdtables is not None
assert pe.net.Flags is not None
def extract_file_import_names(pe: dnfile.dnPE, **kwargs) -> Iterator[Tuple[Import, Address]]:
for method in get_dotnet_managed_imports(pe):
# like System.IO.File::OpenRead
@@ -78,6 +84,12 @@ def extract_file_namespace_features(pe: dnfile.dnPE, **kwargs) -> Iterator[Tuple
def extract_file_class_features(pe: dnfile.dnPE, **kwargs) -> Iterator[Tuple[Class, Address]]:
"""emit class features from TypeRef and TypeDef tables"""
validate_has_dotnet(pe)
assert pe.net is not None
assert pe.net.mdtables is not None
assert pe.net.mdtables.TypeDef is not None
assert pe.net.mdtables.TypeRef is not None
for (rid, row) in enumerate(iter_dotnet_table(pe, "TypeDef")):
token = calculate_dotnet_token_value(pe.net.mdtables.TypeDef.number, rid + 1)
yield Class(DnType.format_name(row.TypeName, namespace=row.TypeNamespace)), DNTokenAddress(token)
@@ -94,6 +106,10 @@ def extract_file_os(**kwargs) -> Iterator[Tuple[OS, Address]]:
def extract_file_arch(pe: dnfile.dnPE, **kwargs) -> Iterator[Tuple[Arch, Address]]:
# to distinguish in more detail, see https://stackoverflow.com/a/23614024/10548020
# .NET 4.5 added option: any CPU, 32-bit preferred
validate_has_dotnet(pe)
assert pe.net is not None
assert pe.net.Flags is not None
if pe.net.Flags.CLR_32BITREQUIRED and pe.PE_TYPE == pefile.OPTIONAL_HEADER_MAGIC_PE:
yield Arch(ARCH_I386), NO_ADDRESS
elif not pe.net.Flags.CLR_32BITREQUIRED and pe.PE_TYPE == pefile.OPTIONAL_HEADER_MAGIC_PE_PLUS:
@@ -155,6 +171,10 @@ class DotnetFileFeatureExtractor(FeatureExtractor):
# self.pe.net.Flags.CLT_NATIVE_ENTRYPOINT
# True: native EP: Token
# False: managed EP: RVA
validate_has_dotnet(self.pe)
assert self.pe.net is not None
assert self.pe.net.struct is not None
return self.pe.net.struct.EntryPointTokenOrRva
def extract_global_features(self):
@@ -170,10 +190,25 @@ class DotnetFileFeatureExtractor(FeatureExtractor):
return is_dotnet_mixed_mode(self.pe)
def get_runtime_version(self) -> Tuple[int, int]:
validate_has_dotnet(self.pe)
assert self.pe.net is not None
assert self.pe.net.struct is not None
assert self.pe.net.struct.MajorRuntimeVersion is not None
assert self.pe.net.struct.MinorRuntimeVersion is not None
return self.pe.net.struct.MajorRuntimeVersion, self.pe.net.struct.MinorRuntimeVersion
def get_meta_version_string(self) -> str:
return self.pe.net.metadata.struct.Version.rstrip(b"\x00").decode("utf-8")
validate_has_dotnet(self.pe)
assert self.pe.net is not None
assert self.pe.net.metadata is not None
assert self.pe.net.metadata.struct is not None
assert self.pe.net.metadata.struct.Version is not None
vbuf = self.pe.net.metadata.struct.Version
assert isinstance(vbuf, bytes)
return vbuf.rstrip(b"\x00").decode("utf-8")
def get_functions(self):
raise NotImplementedError("DotnetFileFeatureExtractor can only be used to extract file features")

View File

@@ -52,26 +52,26 @@ class NullFeatureExtractor(FeatureExtractor):
yield FunctionHandle(address, None)
def extract_function_features(self, f):
for address, feature in self.functions.get(f.address, {}).features:
for address, feature in self.functions[f.address].features:
yield feature, address
def get_basic_blocks(self, f):
for address in sorted(self.functions.get(f.address, {}).basic_blocks.keys()):
for address in sorted(self.functions[f.address].basic_blocks.keys()):
yield BBHandle(address, None)
def extract_basic_block_features(self, f, bb):
for address, feature in self.functions.get(f.address, {}).basic_blocks.get(bb.address, {}).features:
for address, feature in self.functions[f.address].basic_blocks[bb.address].features:
yield feature, address
def get_instructions(self, f, bb):
for address in sorted(self.functions.get(f.address, {}).basic_blocks.get(bb.address, {}).instructions.keys()):
for address in sorted(self.functions[f.address].basic_blocks[bb.address].instructions.keys()):
yield InsnHandle(address, None)
def extract_insn_features(self, f, bb, insn):
for address, feature in (
self.functions.get(f.address, {})
.basic_blocks.get(bb.address, {})
.instructions.get(insn.address, {})
self.functions[f.address]
.basic_blocks[bb.address]
.instructions[insn.address]
.features
):
yield feature, address

View File

@@ -133,7 +133,8 @@ def extract_file_features(pe, buf):
"""
for file_handler in FILE_HANDLERS:
for feature, va in file_handler(pe=pe, buf=buf):
# file_handler: type: (pe, bytes) -> Iterable[Tuple[Feature, Address]]
for feature, va in file_handler(pe=pe, buf=buf): # type: ignore
yield feature, va
@@ -160,7 +161,8 @@ def extract_global_features(pe, buf):
Tuple[Feature, VA]: a feature and its location.
"""
for handler in GLOBAL_HANDLERS:
for feature, va in handler(pe=pe, buf=buf):
# file_handler: type: (pe, bytes) -> Iterable[Tuple[Feature, Address]]
for feature, va in handler(pe=pe, buf=buf): # type: ignore
yield feature, va

View File

@@ -88,7 +88,8 @@ def extract_features(smda_report, buf):
"""
for file_handler in FILE_HANDLERS:
for feature, addr in file_handler(smda_report=smda_report, buf=buf):
# file_handler: type: (smda_report, bytes) -> Iterable[Tuple[Feature, Address]]
for feature, addr in file_handler(smda_report=smda_report, buf=buf): # type: ignore
yield feature, addr

View File

@@ -11,7 +11,7 @@ import copy
import logging
import itertools
import collections
from typing import Set, Dict, Optional
from typing import Set, Dict, Optional, List, Any
import idaapi
import ida_kernwin
@@ -72,14 +72,14 @@ def trim_function_name(f, max_length=25):
def find_func_features(fh: FunctionHandle, extractor):
""" """
func_features: Dict[Feature, Set] = collections.defaultdict(set)
bb_features: Dict[Address, Dict] = collections.defaultdict(dict)
func_features: Dict[Feature, Set[Address]] = collections.defaultdict(set)
bb_features: Dict[Address, Dict[Feature, Set[Address]]] = collections.defaultdict(dict)
for (feature, addr) in extractor.extract_function_features(fh):
func_features[feature].add(addr)
for bbh in extractor.get_basic_blocks(fh):
_bb_features = collections.defaultdict(set)
_bb_features: Dict[Feature, Set[Address]] = collections.defaultdict(set)
for (feature, addr) in extractor.extract_basic_block_features(fh, bbh):
_bb_features[feature].add(addr)
@@ -239,53 +239,52 @@ class CapaSettingsInputDialog(QtWidgets.QDialog):
class CapaExplorerForm(idaapi.PluginForm):
"""form element for plugin interface"""
def __init__(self, name, option=Options.DEFAULT):
def __init__(self, name: str, option=Options.DEFAULT):
"""initialize form elements"""
super().__init__()
self.form_title = name
self.process_total = 0
self.process_count = 0
self.form_title: str = name
self.process_total: int = 0
self.process_count: int = 0
self.parent = None
self.ida_hooks = None
self.parent: Any # QtWidget
self.ida_hooks: CapaExplorerIdaHooks
self.doc: Optional[capa.render.result_document.ResultDocument] = None
self.rule_paths = None
self.rules_cache = None
self.ruleset_cache = None
self.rule_paths: Optional[List[str]]
self.rules_cache: Optional[List[capa.rules.Rule]]
self.ruleset_cache: Optional[capa.rules.RuleSet]
# models
self.model_data = None
self.range_model_proxy = None
self.search_model_proxy = None
self.model_data: CapaExplorerDataModel
self.range_model_proxy: CapaExplorerRangeProxyModel
self.search_model_proxy: CapaExplorerSearchProxyModel
# UI controls
self.view_limit_results_by_function = None
self.view_show_results_by_function = None
self.view_search_bar = None
self.view_tree = None
self.view_rulegen = None
self.view_tabs = None
self.view_limit_results_by_function: QtWidgets.QCheckBox
self.view_show_results_by_function: QtWidgets.QCheckBox
self.view_search_bar: QtWidgets.QLineEdit
self.view_tree: CapaExplorerQtreeView
self.view_tabs: QtWidgets.QTabWidget
self.view_tab_rulegen = None
self.view_status_label = None
self.view_buttons = None
self.view_analyze_button = None
self.view_reset_button = None
self.view_settings_button = None
self.view_save_button = None
self.view_status_label: QtWidgets.QLabel
self.view_buttons: QtWidgets.QHBoxLayout
self.view_analyze_button: QtWidgets.QPushButton
self.view_reset_button: QtWidgets.QPushButton
self.view_settings_button: QtWidgets.QPushButton
self.view_save_button: QtWidgets.QPushButton
self.view_rulegen_preview = None
self.view_rulegen_features = None
self.view_rulegen_editor = None
self.view_rulegen_header_label = None
self.view_rulegen_search = None
self.view_rulegen_limit_features_by_ea = None
self.rulegen_current_function = None
self.rulegen_bb_features_cache = {}
self.rulegen_func_features_cache = {}
self.rulegen_file_features_cache = {}
self.view_rulegen_status_label = None
self.view_rulegen_preview: CapaExplorerRulegenPreview
self.view_rulegen_features: CapaExplorerRulegenFeatures
self.view_rulegen_editor: CapaExplorerRulegenEditor
self.view_rulegen_header_label: QtWidgets.QLabel
self.view_rulegen_search: QtWidgets.QLineEdit
self.view_rulegen_limit_features_by_ea: QtWidgets.QCheckBox
self.rulegen_current_function: Optional[FunctionHandle]
self.rulegen_bb_features_cache: Dict[Address, Dict[Feature, Set[Address]]] = {}
self.rulegen_func_features_cache: Dict[Feature, Set[Address]] = {}
self.rulegen_file_features_cache: Dict[Feature, Set[Address]] = {}
self.view_rulegen_status_label: QtWidgets.QLabel
self.Show()
@@ -762,6 +761,9 @@ class CapaExplorerForm(idaapi.PluginForm):
if not self.load_capa_rules():
return False
assert self.rules_cache is not None
assert self.ruleset_cache is not None
if ida_kernwin.user_cancelled():
logger.info("User cancelled analysis.")
return False
@@ -822,6 +824,13 @@ class CapaExplorerForm(idaapi.PluginForm):
return False
try:
# either the results are cached and the doc already exists,
# or the doc was just created above
assert self.doc is not None
# same with rules cache, either it's cached or it was just loaded
assert self.rules_cache is not None
assert self.ruleset_cache is not None
self.model_data.render_capa_doc(self.doc, self.view_show_results_by_function.isChecked())
self.set_view_status_label(
"capa rules directory: %s (%d rules)" % (settings.user[CAPA_SETTINGS_RULE_PATH], len(self.rules_cache))
@@ -871,6 +880,9 @@ class CapaExplorerForm(idaapi.PluginForm):
else:
logger.info('Using cached ruleset, click "Reset" to reload rules from disk.')
assert self.rules_cache is not None
assert self.ruleset_cache is not None
if ida_kernwin.user_cancelled():
logger.info("User cancelled analysis.")
return False
@@ -891,7 +903,8 @@ class CapaExplorerForm(idaapi.PluginForm):
try:
f = idaapi.get_func(idaapi.get_screen_ea())
if f:
fh: FunctionHandle = extractor.get_function(f.start_ea)
fh: Optional[FunctionHandle] = extractor.get_function(f.start_ea)
assert fh is not None
self.rulegen_current_function = fh
func_features, bb_features = find_func_features(fh, extractor)
@@ -1053,6 +1066,8 @@ class CapaExplorerForm(idaapi.PluginForm):
def update_rule_status(self, rule_text):
""" """
assert self.rules_cache is not None
if not self.view_rulegen_editor.invisibleRootItem().childCount():
self.set_rulegen_preview_border_neutral()
self.view_rulegen_status_label.clear()
@@ -1077,7 +1092,7 @@ class CapaExplorerForm(idaapi.PluginForm):
rules.append(rule)
try:
file_features = copy.copy(self.rulegen_file_features_cache)
file_features = copy.copy(dict(self.rulegen_file_features_cache))
if self.rulegen_current_function:
func_matches, bb_matches = find_func_matches(
self.rulegen_current_function,
@@ -1093,7 +1108,7 @@ class CapaExplorerForm(idaapi.PluginForm):
_, file_matches = capa.engine.match(
capa.rules.RuleSet(list(capa.rules.get_rules_and_dependencies(rules, rule.name))).file_rules,
file_features,
0x0,
NO_ADDRESS
)
except Exception as e:
self.set_rulegen_status("Failed to match rule (%s)" % e)

View File

@@ -36,7 +36,7 @@ def ea_to_hex(ea):
class CapaExplorerDataItem:
"""store data for CapaExplorerDataModel"""
def __init__(self, parent: "CapaExplorerDataItem", data: List[str], can_check=True):
def __init__(self, parent: Optional["CapaExplorerDataItem"], data: List[str], can_check=True):
"""initialize item"""
self.pred = parent
self._data = data
@@ -110,7 +110,7 @@ class CapaExplorerDataItem:
except IndexError:
return None
def parent(self) -> "CapaExplorerDataItem":
def parent(self) -> Optional["CapaExplorerDataItem"]:
"""get parent"""
return self.pred

View File

@@ -92,7 +92,7 @@ class CapaExplorerRangeProxyModel(QtCore.QSortFilterProxyModel):
@param parent: QModelIndex of parent
"""
# filter not set
if self.min_ea is None and self.max_ea is None:
if self.min_ea is None or self.max_ea is None:
return True
index = self.sourceModel().index(row, 0, parent)

View File

@@ -18,7 +18,7 @@ import capa.ida.helpers
import capa.features.common
import capa.features.basicblock
from capa.ida.plugin.item import CapaExplorerFunctionItem
from capa.features.address import Address, _NoAddress
from capa.features.address import _NoAddress, AbsoluteVirtualAddress
from capa.ida.plugin.model import CapaExplorerDataModel
MAX_SECTION_SIZE = 750
@@ -1013,8 +1013,10 @@ class CapaExplorerRulegenFeatures(QtWidgets.QTreeWidget):
self.parent_items = {}
def format_address(e):
assert isinstance(e, Address)
return "%X" % e if not isinstance(e, _NoAddress) else ""
if isinstance(e, AbsoluteVirtualAddress):
return "%X" % int(e)
else:
return ""
def format_feature(feature):
""" """

View File

@@ -66,7 +66,7 @@ from capa.features.common import (
FORMAT_DOTNET,
FORMAT_FREEZE,
)
from capa.features.address import NO_ADDRESS
from capa.features.address import NO_ADDRESS, Address
from capa.features.extractors.base_extractor import BBHandle, InsnHandle, FunctionHandle, FeatureExtractor
RULES_PATH_DEFAULT_STRING = "(embedded rules)"
@@ -718,8 +718,8 @@ def compute_layout(rules, extractor, capabilities):
otherwise, we may pollute the json document with
a large amount of un-referenced data.
"""
functions_by_bb = {}
bbs_by_function = {}
functions_by_bb: Dict[Address, Address] = {}
bbs_by_function: Dict[Address, List[Address]] = {}
for f in extractor.get_functions():
bbs_by_function[f.address] = []
for bb in extractor.get_basic_blocks(f):
@@ -1016,8 +1016,7 @@ def main(argv=None):
return E_INVALID_FILE_TYPE
try:
rules = get_rules(args.rules, disable_progress=args.quiet)
rules = capa.rules.RuleSet(rules)
rules = capa.rules.RuleSet(get_rules(args.rules, disable_progress=args.quiet))
logger.debug(
"successfully loaded %s rules",
@@ -1167,8 +1166,7 @@ def ida_main():
rules_path = os.path.join(get_default_root(), "rules")
logger.debug("rule path: %s", rules_path)
rules = get_rules([rules_path])
rules = capa.rules.RuleSet(rules)
rules = capa.rules.RuleSet(get_rules([rules_path]))
meta = capa.ida.helpers.collect_metadata([rules_path])

View File

@@ -2,7 +2,7 @@ import collections
from typing import Dict
# this structure is unstable and may change before the next major release.
counters: Dict[str, int] = collections.Counter()
counters: collections.Counter[str] = collections.Counter()
def reset():

View File

@@ -634,7 +634,7 @@ class Rule:
Returns:
List[str]: names of rules upon which this rule depends.
"""
deps = set([])
deps: Set[str] = set([])
def rec(statement):
if isinstance(statement, capa.features.common.MatchedRule):
@@ -651,6 +651,7 @@ class Rule:
deps.update(map(lambda r: r.name, namespaces[statement.value]))
else:
# not a namespace, assume its a rule name.
assert isinstance(statement.value, str)
deps.add(statement.value)
elif isinstance(statement, ceng.Statement):
@@ -666,7 +667,11 @@ class Rule:
def _extract_subscope_rules_rec(self, statement):
if isinstance(statement, ceng.Statement):
# for each child that is a subscope,
for subscope in filter(lambda statement: isinstance(statement, ceng.Subscope), statement.get_children()):
for child in statement.get_children():
if not isinstance(child, ceng.Subscope):
continue
subscope = child
# create a new rule from it.
# the name is a randomly generated, hopefully unique value.
@@ -737,7 +742,7 @@ class Rule:
return self.statement.evaluate(features, short_circuit=short_circuit)
@classmethod
def from_dict(cls, d, definition):
def from_dict(cls, d, definition) -> "Rule":
meta = d["rule"]["meta"]
name = meta["name"]
# if scope is not specified, default to function scope.
@@ -771,14 +776,12 @@ class Rule:
# prefer to use CLoader to be fast, see #306
# on Linux, make sure you install libyaml-dev or similar
# on Windows, get WHLs from pyyaml.org/pypi
loader = yaml.CLoader
logger.debug("using libyaml CLoader.")
return yaml.CLoader
except:
loader = yaml.Loader
logger.debug("unable to import libyaml CLoader, falling back to Python yaml parser.")
logger.debug("this will be slower to load rules.")
return loader
return yaml.Loader
@staticmethod
def _get_ruamel_yaml_parser():
@@ -790,8 +793,9 @@ class Rule:
# use block mode, not inline json-like mode
y.default_flow_style = False
# leave quotes unchanged
y.preserve_quotes = True
# leave quotes unchanged.
# manually verified this property exists, even if mypy complains.
y.preserve_quotes = True # type: ignore
# indent lists by two spaces below their parent
#
@@ -802,12 +806,13 @@ class Rule:
y.indent(sequence=2, offset=2)
# avoid word wrapping
y.width = 4096
# manually verified this property exists, even if mypy complains.
y.width = 4096 # type: ignore
return y
@classmethod
def from_yaml(cls, s, use_ruamel=False):
def from_yaml(cls, s, use_ruamel=False) -> "Rule":
if use_ruamel:
# ruamel enables nice formatting and doc roundtripping with comments
doc = cls._get_ruamel_yaml_parser().load(s)
@@ -817,7 +822,7 @@ class Rule:
return cls.from_dict(doc, s)
@classmethod
def from_yaml_file(cls, path, use_ruamel=False):
def from_yaml_file(cls, path, use_ruamel=False) -> "Rule":
with open(path, "rb") as f:
try:
rule = cls.from_yaml(f.read().decode("utf-8"), use_ruamel=use_ruamel)
@@ -832,7 +837,7 @@ class Rule:
except pydantic.ValidationError as e:
raise InvalidRuleWithPath(path, str(e)) from e
def to_yaml(self):
def to_yaml(self) -> str:
# reformat the yaml document with a common style.
# this includes:
# - ordering the meta elements
@@ -1261,7 +1266,7 @@ class RuleSet:
return (easy_rules_by_feature, hard_rules)
@staticmethod
def _get_rules_for_scope(rules, scope):
def _get_rules_for_scope(rules, scope) -> List[Rule]:
"""
given a collection of rules, collect the rules that are needed at the given scope.
these rules are ordered topologically.
@@ -1269,7 +1274,7 @@ class RuleSet:
don't include auto-generated "subscope" rules.
we want to include general "lib" rules here - even if they are not dependencies of other rules, see #398
"""
scope_rules = set([])
scope_rules: Set[Rule] = set([])
# we need to process all rules, not just rules with the given scope.
# this is because rules with a higher scope, e.g. file scope, may have subscope rules
@@ -1283,7 +1288,7 @@ class RuleSet:
return get_rules_with_scope(topologically_order_rules(list(scope_rules)), scope)
@staticmethod
def _extract_subscope_rules(rules):
def _extract_subscope_rules(rules) -> List[Rule]:
"""
process the given sequence of rules.
for each one, extract any embedded subscope rules into their own rule.

2
rules

Submodule rules updated: 2bc58afb51...5ba70c97d2

View File

@@ -152,8 +152,7 @@ def main(argv=None):
capa.main.handle_common_args(args)
try:
rules = capa.main.get_rules(args.rules)
rules = capa.rules.RuleSet(rules)
rules = capa.rules.RuleSet(capa.main.get_rules(args.rules))
logger.info("successfully loaded %s rules", len(rules))
except (IOError, capa.rules.InvalidRule, capa.rules.InvalidRuleSet) as e:
logger.error("%s", str(e))

View File

@@ -64,7 +64,6 @@ unsupported = ["characteristic", "mnemonic", "offset", "subscope", "Range"]
# collect all converted rules to be able to check if we have needed sub rules for match:
converted_rules = []
count_incomplete = 0
default_tags = "CAPA "
@@ -537,7 +536,8 @@ def output_unsupported_capa_rules(yaml, capa_rulename, url, reason):
unsupported_capa_rules_names.write(url.encode("utf-8") + b"\n")
def convert_rules(rules, namespaces, cround):
def convert_rules(rules, namespaces, cround, make_priv):
count_incomplete = 0
for rule in rules.rules.values():
rule_name = convert_rule_name(rule.name)
@@ -652,7 +652,6 @@ def convert_rules(rules, namespaces, cround):
if meta_name and meta_value:
yara_meta += "\t" + meta_name + ' = "' + meta_value + '"\n'
rule_name_bonus = ""
if rule_comment:
yara_meta += '\tcomment = "' + rule_comment + '"\n'
yara_meta += '\tdate = "' + today + '"\n'
@@ -679,12 +678,13 @@ def convert_rules(rules, namespaces, cround):
# TODO: now the rule is finished and could be automatically checked with the capa-testfile(s) named in meta (doing it for all of them using yara-ci upload at the moment)
output_yar(yara)
converted_rules.append(rule_name)
global count_incomplete
count_incomplete += incomplete
else:
output_unsupported_capa_rules(rule.to_yaml(), rule.name, url, yara_condition)
pass
return count_incomplete
def main(argv=None):
if argv is None:
@@ -696,7 +696,6 @@ def main(argv=None):
capa.main.install_common_args(parser, wanted={"tag"})
args = parser.parse_args(args=argv)
global make_priv
make_priv = args.private
if args.verbose:
@@ -710,9 +709,9 @@ def main(argv=None):
logging.getLogger("capa2yara").setLevel(level)
try:
rules = capa.main.get_rules([args.rules], disable_progress=True)
namespaces = capa.rules.index_rules_by_namespace(list(rules))
rules = capa.rules.RuleSet(rules)
rules_ = capa.main.get_rules([args.rules], disable_progress=True)
namespaces = capa.rules.index_rules_by_namespace(rules_)
rules = capa.rules.RuleSet(rules_)
logger.info("successfully loaded %s rules (including subscope rules which will be ignored)", len(rules))
if args.tag:
rules = rules.filter_rules_by_meta(args.tag)
@@ -745,14 +744,15 @@ def main(argv=None):
# do several rounds of converting rules because some rules for match: might not be converted in the 1st run
num_rules = 9999999
cround = 0
count_incomplete = 0
while num_rules != len(converted_rules) or cround < min_rounds:
cround += 1
logger.info("doing convert_rules(), round: " + str(cround))
num_rules = len(converted_rules)
convert_rules(rules, namespaces, cround)
count_incomplete += convert_rules(rules, namespaces, cround, make_priv)
# one last round to collect all unconverted rules
convert_rules(rules, namespaces, 9000)
count_incomplete += convert_rules(rules, namespaces, 9000, make_priv)
stats = "\n// converted rules : " + str(len(converted_rules))
stats += "\n// among those are incomplete : " + str(count_incomplete)

View File

@@ -172,7 +172,7 @@ def capa_details(rules_path, file_path, output_format="dictionary"):
meta["analysis"].update(counts)
meta["analysis"]["layout"] = capa.main.compute_layout(rules, extractor, capabilities)
capa_output = False
capa_output: Any = False
if output_format == "dictionary":
# ...as python dictionary, simplified as textable but in dictionary
doc = rd.ResultDocument.from_capa(meta, rules, capabilities)

View File

@@ -28,7 +28,7 @@ def main(argv=None):
if capa.helpers.is_runtime_ida():
from capa.ida.helpers import IDAIO
f: BinaryIO = IDAIO()
f: BinaryIO = IDAIO() # type: ignore
else:
if argv is None:

View File

@@ -902,11 +902,15 @@ def redirecting_print_to_tqdm():
old_print(*args, **kwargs)
try:
# Globaly replace print with new_print
inspect.builtins.print = new_print
# Globaly replace print with new_print.
# Verified this works manually on Python 3.11:
# >>> import inspect
# >>> inspect.builtins
# <module 'builtins' (built-in)>
inspect.builtins.print = new_print # type: ignore
yield
finally:
inspect.builtins.print = old_print
inspect.builtins.print = old_print # type: ignore
def lint(ctx: Context):
@@ -998,10 +1002,8 @@ def main(argv=None):
time0 = time.time()
try:
rules = capa.main.get_rules(args.rules, disable_progress=True)
rule_count = len(rules)
rules = capa.rules.RuleSet(rules)
logger.info("successfully loaded %s rules", rule_count)
rules = capa.rules.RuleSet(capa.main.get_rules(args.rules, disable_progress=True))
logger.info("successfully loaded %s rules", len(rules))
if args.tag:
rules = rules.filter_rules_by_meta(args.tag)
logger.debug("selected %s rules", len(rules))

View File

@@ -141,8 +141,7 @@ def main(argv=None):
return -1
try:
rules = capa.main.get_rules(args.rules)
rules = capa.rules.RuleSet(rules)
rules = capa.rules.RuleSet(capa.main.get_rules(args.rules))
logger.info("successfully loaded %s rules", len(rules))
if args.tag:
rules = rules.filter_rules_by_meta(args.tag)

View File

@@ -136,7 +136,7 @@ def main(argv=None):
for feature, addr in extractor.extract_file_features():
print("file: %s: %s" % (format_address(addr), feature))
function_handles = extractor.get_functions()
function_handles = tuple(extractor.get_functions())
if args.function:
if args.format == "freeze":
@@ -173,7 +173,7 @@ def ida_main():
print("file: %s: %s" % (format_address(addr), feature))
return
function_handles = extractor.get_functions()
function_handles = tuple(extractor.get_functions())
if function:
function_handles = tuple(filter(lambda fh: fh.inner.start_ea == function, function_handles))

View File

@@ -8,58 +8,63 @@
from capa.engine import *
from capa.features import *
from capa.features.insn import *
import capa.features.address
ADDR1 = capa.features.address.AbsoluteVirtualAddress(0x401001)
ADDR2 = capa.features.address.AbsoluteVirtualAddress(0x401002)
ADDR3 = capa.features.address.AbsoluteVirtualAddress(0x401003)
ADDR4 = capa.features.address.AbsoluteVirtualAddress(0x401004)
def test_number():
assert Number(1).evaluate({Number(0): {1}}) == False
assert Number(1).evaluate({Number(1): {1}}) == True
assert Number(1).evaluate({Number(2): {1, 2}}) == False
assert Number(1).evaluate({Number(0): {ADDR1}}) == False
assert Number(1).evaluate({Number(1): {ADDR1}}) == True
assert Number(1).evaluate({Number(2): {ADDR1, ADDR2}}) == False
def test_and():
assert And([Number(1)]).evaluate({Number(0): {1}}) == False
assert And([Number(1)]).evaluate({Number(1): {1}}) == True
assert And([Number(1), Number(2)]).evaluate({Number(0): {1}}) == False
assert And([Number(1), Number(2)]).evaluate({Number(1): {1}}) == False
assert And([Number(1), Number(2)]).evaluate({Number(2): {1}}) == False
assert And([Number(1), Number(2)]).evaluate({Number(1): {1}, Number(2): {2}}) == True
assert And([Number(1)]).evaluate({Number(0): {ADDR1}}) == False
assert And([Number(1)]).evaluate({Number(1): {ADDR1}}) == True
assert And([Number(1), Number(2)]).evaluate({Number(0): {ADDR1}}) == False
assert And([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}) == False
assert And([Number(1), Number(2)]).evaluate({Number(2): {ADDR1}}) == False
assert And([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}, Number(2): {ADDR2}}) == True
def test_or():
assert Or([Number(1)]).evaluate({Number(0): {1}}) == False
assert Or([Number(1)]).evaluate({Number(1): {1}}) == True
assert Or([Number(1), Number(2)]).evaluate({Number(0): {1}}) == False
assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}) == True
assert Or([Number(1), Number(2)]).evaluate({Number(2): {1}}) == True
assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}, Number(2): {2}}) == True
assert Or([Number(1)]).evaluate({Number(0): {ADDR1}}) == False
assert Or([Number(1)]).evaluate({Number(1): {ADDR1}}) == True
assert Or([Number(1), Number(2)]).evaluate({Number(0): {ADDR1}}) == False
assert Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}) == True
assert Or([Number(1), Number(2)]).evaluate({Number(2): {ADDR1}}) == True
assert Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}, Number(2): {ADDR2}}) == True
def test_not():
assert Not(Number(1)).evaluate({Number(0): {1}}) == True
assert Not(Number(1)).evaluate({Number(1): {1}}) == False
assert Not(Number(1)).evaluate({Number(0): {ADDR1}}) == True
assert Not(Number(1)).evaluate({Number(1): {ADDR1}}) == False
def test_some():
assert Some(0, [Number(1)]).evaluate({Number(0): {1}}) == True
assert Some(1, [Number(1)]).evaluate({Number(0): {1}}) == False
assert Some(0, [Number(1)]).evaluate({Number(0): {ADDR1}}) == True
assert Some(1, [Number(1)]).evaluate({Number(0): {ADDR1}}) == False
assert Some(2, [Number(1), Number(2), Number(3)]).evaluate({Number(0): {1}}) == False
assert Some(2, [Number(1), Number(2), Number(3)]).evaluate({Number(0): {1}, Number(1): {1}}) == False
assert Some(2, [Number(1), Number(2), Number(3)]).evaluate({Number(0): {1}, Number(1): {1}, Number(2): {1}}) == True
assert Some(2, [Number(1), Number(2), Number(3)]).evaluate({Number(0): {ADDR1}}) == False
assert Some(2, [Number(1), Number(2), Number(3)]).evaluate({Number(0): {ADDR1}, Number(1): {ADDR1}}) == False
assert Some(2, [Number(1), Number(2), Number(3)]).evaluate({Number(0): {ADDR1}, Number(1): {ADDR1}, Number(2): {ADDR1}}) == True
assert (
Some(2, [Number(1), Number(2), Number(3)]).evaluate(
{Number(0): {1}, Number(1): {1}, Number(2): {1}, Number(3): {1}}
{Number(0): {ADDR1}, Number(1): {ADDR1}, Number(2): {ADDR1}, Number(3): {ADDR1}}
)
== True
)
assert (
Some(2, [Number(1), Number(2), Number(3)]).evaluate(
{
Number(0): {1},
Number(1): {1},
Number(2): {1},
Number(3): {1},
Number(4): {1},
Number(0): {ADDR1},
Number(1): {ADDR1},
Number(2): {ADDR1},
Number(3): {ADDR1},
Number(4): {ADDR1},
}
)
== True
@@ -69,10 +74,10 @@ def test_some():
def test_complex():
assert True == Or(
[And([Number(1), Number(2)]), Or([Number(3), Some(2, [Number(4), Number(5), Number(6)])])]
).evaluate({Number(5): {1}, Number(6): {1}, Number(7): {1}, Number(8): {1}})
).evaluate({Number(5): {ADDR1}, Number(6): {ADDR1}, Number(7): {ADDR1}, Number(8): {ADDR1}})
assert False == Or([And([Number(1), Number(2)]), Or([Number(3), Some(2, [Number(4), Number(5)])])]).evaluate(
{Number(5): {1}, Number(6): {1}, Number(7): {1}, Number(8): {1}}
{Number(5): {ADDR1}, Number(6): {ADDR1}, Number(7): {ADDR1}, Number(8): {ADDR1}}
)
@@ -83,54 +88,54 @@ def test_range():
# unbounded range with matching feature should always match
assert Range(Number(1)).evaluate({Number(1): {}}) == True
assert Range(Number(1)).evaluate({Number(1): {0}}) == True
assert Range(Number(1)).evaluate({Number(1): {ADDR1}}) == True
# unbounded max
assert Range(Number(1), min=1).evaluate({Number(1): {0}}) == True
assert Range(Number(1), min=2).evaluate({Number(1): {0}}) == False
assert Range(Number(1), min=2).evaluate({Number(1): {0, 1}}) == True
assert Range(Number(1), min=1).evaluate({Number(1): {ADDR1}}) == True
assert Range(Number(1), min=2).evaluate({Number(1): {ADDR1}}) == False
assert Range(Number(1), min=2).evaluate({Number(1): {ADDR1, ADDR2}}) == True
# unbounded min
assert Range(Number(1), max=0).evaluate({Number(1): {0}}) == False
assert Range(Number(1), max=1).evaluate({Number(1): {0}}) == True
assert Range(Number(1), max=2).evaluate({Number(1): {0}}) == True
assert Range(Number(1), max=2).evaluate({Number(1): {0, 1}}) == True
assert Range(Number(1), max=2).evaluate({Number(1): {0, 1, 3}}) == False
assert Range(Number(1), max=0).evaluate({Number(1): {ADDR1}}) == False
assert Range(Number(1), max=1).evaluate({Number(1): {ADDR1}}) == True
assert Range(Number(1), max=2).evaluate({Number(1): {ADDR1}}) == True
assert Range(Number(1), max=2).evaluate({Number(1): {ADDR1, ADDR2}}) == True
assert Range(Number(1), max=2).evaluate({Number(1): {ADDR1, ADDR2, ADDR3}}) == False
# we can do an exact match by setting min==max
assert Range(Number(1), min=1, max=1).evaluate({Number(1): {}}) == False
assert Range(Number(1), min=1, max=1).evaluate({Number(1): {1}}) == True
assert Range(Number(1), min=1, max=1).evaluate({Number(1): {1, 2}}) == False
assert Range(Number(1), min=1, max=1).evaluate({Number(1): {ADDR1}}) == True
assert Range(Number(1), min=1, max=1).evaluate({Number(1): {ADDR1, ADDR2}}) == False
# bounded range
assert Range(Number(1), min=1, max=3).evaluate({Number(1): {}}) == False
assert Range(Number(1), min=1, max=3).evaluate({Number(1): {1}}) == True
assert Range(Number(1), min=1, max=3).evaluate({Number(1): {1, 2}}) == True
assert Range(Number(1), min=1, max=3).evaluate({Number(1): {1, 2, 3}}) == True
assert Range(Number(1), min=1, max=3).evaluate({Number(1): {1, 2, 3, 4}}) == False
assert Range(Number(1), min=1, max=3).evaluate({Number(1): {ADDR1}}) == True
assert Range(Number(1), min=1, max=3).evaluate({Number(1): {ADDR1, ADDR2}}) == True
assert Range(Number(1), min=1, max=3).evaluate({Number(1): {ADDR1, ADDR2, ADDR3}}) == True
assert Range(Number(1), min=1, max=3).evaluate({Number(1): {ADDR1, ADDR2, ADDR3, ADDR4}}) == False
def test_short_circuit():
assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}) == True
assert Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}) == True
# with short circuiting, only the children up until the first satisfied child are captured.
assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {1}}, short_circuit=True).children) == 1
assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {1}}, short_circuit=False).children) == 2
assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}, short_circuit=True).children) == 1
assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}, short_circuit=False).children) == 2
def test_eval_order():
# base cases.
assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}) == True
assert Or([Number(1), Number(2)]).evaluate({Number(2): {1}}) == True
assert Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}) == True
assert Or([Number(1), Number(2)]).evaluate({Number(2): {ADDR1}}) == True
# with short circuiting, only the children up until the first satisfied child are captured.
assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {1}}).children) == 1
assert len(Or([Number(1), Number(2)]).evaluate({Number(2): {1}}).children) == 2
assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {1}, Number(2): {1}}).children) == 1
assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}).children) == 1
assert len(Or([Number(1), Number(2)]).evaluate({Number(2): {ADDR1}}).children) == 2
assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}, Number(2): {ADDR1}}).children) == 1
# and its guaranteed that children are evaluated in order.
assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}).children[0].statement == Number(1)
assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}).children[0].statement != Number(2)
assert Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}).children[0].statement == Number(1)
assert Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}).children[0].statement != Number(2)
assert Or([Number(1), Number(2)]).evaluate({Number(2): {1}}).children[1].statement == Number(2)
assert Or([Number(1), Number(2)]).evaluate({Number(2): {1}}).children[1].statement != Number(1)
assert Or([Number(1), Number(2)]).evaluate({Number(2): {ADDR1}}).children[1].statement == Number(2)
assert Or([Number(1), Number(2)]).evaluate({Number(2): {ADDR1}}).children[1].statement != Number(1)

View File

@@ -98,7 +98,7 @@ def test_rule_reformat_order():
def test_rule_reformat_meta_update():
# test updating the rule content after parsing
rule = textwrap.dedent(
src = textwrap.dedent(
"""
rule:
meta:
@@ -116,7 +116,7 @@ def test_rule_reformat_meta_update():
"""
)
rule = capa.rules.Rule.from_yaml(rule)
rule = capa.rules.Rule.from_yaml(src)
rule.name = "test rule"
assert rule.to_yaml() == EXPECTED

View File

@@ -218,7 +218,7 @@ def test_match_matched_rules():
# the ordering of the rules must not matter,
# the engine should match rules in an appropriate order.
features, _ = match(
capa.rules.topologically_order_rules(reversed(rules)),
capa.rules.topologically_order_rules(list(reversed(rules))),
{capa.features.insn.Number(100): {1}},
0x0,
)

View File

@@ -19,6 +19,7 @@ def test_optional_node_from_capa():
[],
)
)
assert isinstance(node, rdoc.StatementNode)
assert isinstance(node.statement, rdoc.CompoundStatement)
assert node.statement.type == rdoc.CompoundStatementType.OPTIONAL
@@ -32,6 +33,7 @@ def test_some_node_from_capa():
],
)
)
assert isinstance(node, rdoc.StatementNode)
assert isinstance(node.statement, rdoc.SomeStatement)
@@ -41,6 +43,7 @@ def test_range_node_from_capa():
capa.features.insn.Number(0),
)
)
assert isinstance(node, rdoc.StatementNode)
assert isinstance(node.statement, rdoc.RangeStatement)
@@ -51,6 +54,7 @@ def test_subscope_node_from_capa():
capa.features.insn.Number(0),
)
)
assert isinstance(node, rdoc.StatementNode)
assert isinstance(node.statement, rdoc.SubscopeStatement)
@@ -62,6 +66,7 @@ def test_and_node_from_capa():
],
)
)
assert isinstance(node, rdoc.StatementNode)
assert isinstance(node.statement, rdoc.CompoundStatement)
assert node.statement.type == rdoc.CompoundStatementType.AND
@@ -74,6 +79,7 @@ def test_or_node_from_capa():
],
)
)
assert isinstance(node, rdoc.StatementNode)
assert isinstance(node.statement, rdoc.CompoundStatement)
assert node.statement.type == rdoc.CompoundStatementType.OR
@@ -86,115 +92,138 @@ def test_not_node_from_capa():
],
)
)
assert isinstance(node, rdoc.StatementNode)
assert isinstance(node.statement, rdoc.CompoundStatement)
assert node.statement.type == rdoc.CompoundStatementType.NOT
def test_os_node_from_capa():
node = rdoc.node_from_capa(capa.features.common.OS(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.OSFeature)
def test_arch_node_from_capa():
node = rdoc.node_from_capa(capa.features.common.Arch(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.ArchFeature)
def test_format_node_from_capa():
node = rdoc.node_from_capa(capa.features.common.Format(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.FormatFeature)
def test_match_node_from_capa():
node = rdoc.node_from_capa(capa.features.common.MatchedRule(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.MatchFeature)
def test_characteristic_node_from_capa():
node = rdoc.node_from_capa(capa.features.common.Characteristic(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.CharacteristicFeature)
def test_substring_node_from_capa():
node = rdoc.node_from_capa(capa.features.common.Substring(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.SubstringFeature)
def test_regex_node_from_capa():
node = rdoc.node_from_capa(capa.features.common.Regex(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.RegexFeature)
def test_class_node_from_capa():
node = rdoc.node_from_capa(capa.features.common.Class(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.ClassFeature)
def test_namespace_node_from_capa():
node = rdoc.node_from_capa(capa.features.common.Namespace(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.NamespaceFeature)
def test_bytes_node_from_capa():
node = rdoc.node_from_capa(capa.features.common.Bytes(b""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.BytesFeature)
def test_export_node_from_capa():
node = rdoc.node_from_capa(capa.features.file.Export(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.ExportFeature)
def test_import_node_from_capa():
node = rdoc.node_from_capa(capa.features.file.Import(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.ImportFeature)
def test_section_node_from_capa():
node = rdoc.node_from_capa(capa.features.file.Section(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.SectionFeature)
def test_function_name_node_from_capa():
node = rdoc.node_from_capa(capa.features.file.FunctionName(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.FunctionNameFeature)
def test_api_node_from_capa():
node = rdoc.node_from_capa(capa.features.insn.API(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.APIFeature)
def test_property_node_from_capa():
node = rdoc.node_from_capa(capa.features.insn.Property(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.PropertyFeature)
def test_number_node_from_capa():
node = rdoc.node_from_capa(capa.features.insn.Number(0))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.NumberFeature)
def test_offset_node_from_capa():
node = rdoc.node_from_capa(capa.features.insn.Offset(0))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.OffsetFeature)
def test_mnemonic_node_from_capa():
node = rdoc.node_from_capa(capa.features.insn.Mnemonic(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.MnemonicFeature)
def test_operand_number_node_from_capa():
node = rdoc.node_from_capa(capa.features.insn.OperandNumber(0, 0))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.OperandNumberFeature)
def test_operand_offset_node_from_capa():
node = rdoc.node_from_capa(capa.features.insn.OperandOffset(0, 0))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.OperandOffsetFeature)
def test_basic_block_node_from_capa():
node = rdoc.node_from_capa(capa.features.basicblock.BasicBlock(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.BasicBlockFeature)

View File

@@ -13,8 +13,10 @@ import pytest
import capa.rules
import capa.engine
import capa.features.common
from capa.features.address import AbsoluteVirtualAddress
from capa.features.file import FunctionName
from capa.features.insn import Number, Offset, Property
from capa.engine import Or
from capa.features.common import (
OS,
OS_LINUX,
@@ -29,12 +31,19 @@ from capa.features.common import (
Substring,
FeatureAccess,
)
import capa.features.address
ADDR1 = capa.features.address.AbsoluteVirtualAddress(0x401001)
ADDR2 = capa.features.address.AbsoluteVirtualAddress(0x401002)
ADDR3 = capa.features.address.AbsoluteVirtualAddress(0x401003)
ADDR4 = capa.features.address.AbsoluteVirtualAddress(0x401004)
def test_rule_ctor():
r = capa.rules.Rule("test rule", capa.rules.FUNCTION_SCOPE, Number(1), {})
assert r.evaluate({Number(0): {1}}) == False
assert r.evaluate({Number(1): {1}}) == True
r = capa.rules.Rule("test rule", capa.rules.FUNCTION_SCOPE, Or(Number(1)), {})
assert r.evaluate({Number(0): {ADDR1}}) == False
assert r.evaluate({Number(1): {ADDR2}}) == True
def test_rule_yaml():
@@ -56,10 +65,10 @@ def test_rule_yaml():
"""
)
r = capa.rules.Rule.from_yaml(rule)
assert r.evaluate({Number(0): {1}}) == False
assert r.evaluate({Number(0): {1}, Number(1): {1}}) == False
assert r.evaluate({Number(0): {1}, Number(1): {1}, Number(2): {1}}) == True
assert r.evaluate({Number(0): {1}, Number(1): {1}, Number(2): {1}, Number(3): {1}}) == True
assert r.evaluate({Number(0): {ADDR1}}) == False
assert r.evaluate({Number(0): {ADDR1}, Number(1): {ADDR1}}) == False
assert r.evaluate({Number(0): {ADDR1}, Number(1): {ADDR1}, Number(2): {ADDR1}}) == True
assert r.evaluate({Number(0): {ADDR1}, Number(1): {ADDR1}, Number(2): {ADDR1}, Number(3): {ADDR1}}) == True
def test_rule_yaml_complex():
@@ -82,8 +91,8 @@ def test_rule_yaml_complex():
"""
)
r = capa.rules.Rule.from_yaml(rule)
assert r.evaluate({Number(5): {1}, Number(6): {1}, Number(7): {1}, Number(8): {1}}) == True
assert r.evaluate({Number(6): {1}, Number(7): {1}, Number(8): {1}}) == False
assert r.evaluate({Number(5): {ADDR1}, Number(6): {ADDR1}, Number(7): {ADDR1}, Number(8): {ADDR1}}) == True
assert r.evaluate({Number(6): {ADDR1}, Number(7): {ADDR1}, Number(8): {ADDR1}}) == False
def test_rule_descriptions():
@@ -160,8 +169,8 @@ def test_rule_yaml_not():
"""
)
r = capa.rules.Rule.from_yaml(rule)
assert r.evaluate({Number(1): {1}}) == True
assert r.evaluate({Number(1): {1}, Number(2): {1}}) == False
assert r.evaluate({Number(1): {ADDR1}}) == True
assert r.evaluate({Number(1): {ADDR1}, Number(2): {ADDR1}}) == False
def test_rule_yaml_count():
@@ -175,9 +184,9 @@ def test_rule_yaml_count():
"""
)
r = capa.rules.Rule.from_yaml(rule)
assert r.evaluate({Number(100): {}}) == False
assert r.evaluate({Number(100): {1}}) == True
assert r.evaluate({Number(100): {1, 2}}) == False
assert r.evaluate({Number(100): set()}) == False
assert r.evaluate({Number(100): {ADDR1}}) == True
assert r.evaluate({Number(100): {ADDR1, ADDR2}}) == False
def test_rule_yaml_count_range():
@@ -191,10 +200,10 @@ def test_rule_yaml_count_range():
"""
)
r = capa.rules.Rule.from_yaml(rule)
assert r.evaluate({Number(100): {}}) == False
assert r.evaluate({Number(100): {1}}) == True
assert r.evaluate({Number(100): {1, 2}}) == True
assert r.evaluate({Number(100): {1, 2, 3}}) == False
assert r.evaluate({Number(100): set()}) == False
assert r.evaluate({Number(100): {ADDR1}}) == True
assert r.evaluate({Number(100): {ADDR1, ADDR2}}) == True
assert r.evaluate({Number(100): {ADDR1, ADDR2, ADDR3}}) == False
def test_rule_yaml_count_string():
@@ -208,10 +217,10 @@ def test_rule_yaml_count_string():
"""
)
r = capa.rules.Rule.from_yaml(rule)
assert r.evaluate({String("foo"): {}}) == False
assert r.evaluate({String("foo"): {1}}) == False
assert r.evaluate({String("foo"): {1, 2}}) == True
assert r.evaluate({String("foo"): {1, 2, 3}}) == False
assert r.evaluate({String("foo"): set()}) == False
assert r.evaluate({String("foo"): {ADDR1}}) == False
assert r.evaluate({String("foo"): {ADDR1, ADDR2}}) == True
assert r.evaluate({String("foo"): {ADDR1, ADDR2, ADDR3}}) == False
def test_invalid_rule_feature():
@@ -481,11 +490,11 @@ def test_count_number_symbol():
"""
)
r = capa.rules.Rule.from_yaml(rule)
assert r.evaluate({Number(2): {}}) == False
assert r.evaluate({Number(2): {1}}) == True
assert r.evaluate({Number(2): {1, 2}}) == False
assert r.evaluate({Number(0x100, description="symbol name"): {1}}) == False
assert r.evaluate({Number(0x100, description="symbol name"): {1, 2, 3}}) == True
assert r.evaluate({Number(2): set()}) == False
assert r.evaluate({Number(2): {ADDR1}}) == True
assert r.evaluate({Number(2): {ADDR1, ADDR2}}) == False
assert r.evaluate({Number(0x100, description="symbol name"): {ADDR1}}) == False
assert r.evaluate({Number(0x100, description="symbol name"): {ADDR1, ADDR2, ADDR3}}) == True
def test_invalid_number():
@@ -567,11 +576,11 @@ def test_count_offset_symbol():
"""
)
r = capa.rules.Rule.from_yaml(rule)
assert r.evaluate({Offset(2): {}}) == False
assert r.evaluate({Offset(2): {1}}) == True
assert r.evaluate({Offset(2): {1, 2}}) == False
assert r.evaluate({Offset(0x100, description="symbol name"): {1}}) == False
assert r.evaluate({Offset(0x100, description="symbol name"): {1, 2, 3}}) == True
assert r.evaluate({Offset(2): set()}) == False
assert r.evaluate({Offset(2): {ADDR1}}) == True
assert r.evaluate({Offset(2): {ADDR1, ADDR2}}) == False
assert r.evaluate({Offset(0x100, description="symbol name"): {ADDR1}}) == False
assert r.evaluate({Offset(0x100, description="symbol name"): {ADDR1, ADDR2, ADDR3}}) == True
def test_invalid_offset():
@@ -966,10 +975,10 @@ def test_property_access():
"""
)
)
assert r.evaluate({Property("System.IO.FileInfo::Length", access=FeatureAccess.READ): {1}}) == True
assert r.evaluate({Property("System.IO.FileInfo::Length", access=FeatureAccess.READ): {ADDR1}}) == True
assert r.evaluate({Property("System.IO.FileInfo::Length"): {1}}) == False
assert r.evaluate({Property("System.IO.FileInfo::Length", access=FeatureAccess.WRITE): {1}}) == False
assert r.evaluate({Property("System.IO.FileInfo::Length"): {ADDR1}}) == False
assert r.evaluate({Property("System.IO.FileInfo::Length", access=FeatureAccess.WRITE): {ADDR1}}) == False
def test_property_access_symbol():
@@ -986,7 +995,7 @@ def test_property_access_symbol():
)
assert (
r.evaluate(
{Property("System.IO.FileInfo::Length", access=FeatureAccess.READ, description="some property"): {1}}
{Property("System.IO.FileInfo::Length", access=FeatureAccess.READ, description="some property"): {ADDR1}}
)
== True
)