update type annotations

tmp
This commit is contained in:
mr-tz
2024-10-22 07:40:26 +00:00
parent cebf8e7274
commit 2987eeb0ac
116 changed files with 874 additions and 905 deletions

View File

@@ -9,7 +9,7 @@
import json
import collections
from typing import Any, Set, Dict
from typing import Any
from pathlib import Path
import capa.main
@@ -34,7 +34,7 @@ def render_meta(doc: rd.ResultDocument, result):
result["path"] = doc.meta.sample.path
def find_subrule_matches(doc: rd.ResultDocument) -> Set[str]:
def find_subrule_matches(doc: rd.ResultDocument) -> set[str]:
"""
collect the rule names that have been matched as a subrule match.
this way we can avoid displaying entries for things that are too specific.
@@ -158,8 +158,8 @@ def render_mbc(doc, result):
result["MBC"].setdefault(objective.upper(), inner_rows)
def render_dictionary(doc: rd.ResultDocument) -> Dict[str, Any]:
result: Dict[str, Any] = {}
def render_dictionary(doc: rd.ResultDocument) -> dict[str, Any]:
result: dict[str, Any] = {}
render_meta(doc, result)
render_attack(doc, result)
render_mbc(doc, result)

View File

@@ -25,7 +25,7 @@ import sys
import json
import logging
import argparse
from typing import List, Optional
from typing import Optional
from pathlib import Path
from capa.version import __version__
@@ -241,7 +241,7 @@ def _populate_invocations(sarif_log: dict, meta_data: dict) -> None:
sarif_log["runs"][0]["invocations"].append(invoke)
def _enumerate_evidence(node: dict, related_count: int) -> List[dict]:
def _enumerate_evidence(node: dict, related_count: int) -> list[dict]:
related_locations = []
if node.get("success") and node.get("node", {}).get("type") != "statement":
label = ""

View File

@@ -15,7 +15,7 @@ import contextlib
import statistics
import subprocess
import multiprocessing
from typing import Set, Dict, List, Optional
from typing import Optional
from pathlib import Path
from collections import Counter
from dataclasses import dataclass
@@ -183,8 +183,8 @@ def report(args):
for backend in BACKENDS:
samples.update(doc[backend].keys())
failures_by_backend: Dict[str, Set[str]] = {backend: set() for backend in BACKENDS}
durations_by_backend: Dict[str, List[float]] = {backend: [] for backend in BACKENDS}
failures_by_backend: dict[str, set[str]] = {backend: set() for backend in BACKENDS}
durations_by_backend: dict[str, list[float]] = {backend: [] for backend in BACKENDS}
console = rich.get_console()
for key in sorted(samples):
@@ -193,7 +193,7 @@ def report(args):
seen_rules: Counter[str] = Counter()
rules_by_backend: Dict[str, Set[str]] = {backend: set() for backend in BACKENDS}
rules_by_backend: dict[str, set[str]] = {backend: set() for backend in BACKENDS}
for backend in BACKENDS:
if key not in doc[backend]:

View File

@@ -8,7 +8,6 @@
import sys
import logging
import argparse
from typing import Set
from pathlib import Path
import capa.main
@@ -18,7 +17,7 @@ from capa.features.common import Feature
logger = logging.getLogger("detect_duplicate_features")
def get_features(rule_path: str) -> Set[Feature]:
def get_features(rule_path: str) -> set[Feature]:
"""
Extracts all features from a given rule file.

View File

@@ -14,7 +14,7 @@ import time
import logging
import argparse
import contextlib
from typing import Dict, List, Optional
from typing import Optional
import capa.main
import capa.features.extractors.binexport2
@@ -71,14 +71,14 @@ class Renderer:
def _render_expression_tree(
be2: BinExport2,
operand: BinExport2.Operand,
expression_tree: List[List[int]],
expression_tree: list[list[int]],
tree_index: int,
o: io.StringIO,
):
expression_index = operand.expression_index[tree_index]
expression = be2.expression[expression_index]
children_tree_indexes: List[int] = expression_tree[tree_index]
children_tree_indexes: list[int] = expression_tree[tree_index]
if expression.type == BinExport2.Expression.REGISTER:
o.write(expression.symbol)
@@ -177,7 +177,7 @@ def _render_expression_tree(
raise NotImplementedError(expression.type)
_OPERAND_CACHE: Dict[int, str] = {}
_OPERAND_CACHE: dict[int, str] = {}
def render_operand(be2: BinExport2, operand: BinExport2.Operand, index: Optional[int] = None) -> str:
@@ -223,7 +223,7 @@ def inspect_operand(be2: BinExport2, operand: BinExport2.Operand):
def rec(tree_index, indent=0):
expression_index = operand.expression_index[tree_index]
expression = be2.expression[expression_index]
children_tree_indexes: List[int] = expression_tree[tree_index]
children_tree_indexes: list[int] = expression_tree[tree_index]
NEWLINE = "\n"
print(f" {' ' * indent}expression: {str(expression).replace(NEWLINE, ', ')}")
@@ -435,7 +435,7 @@ def main(argv=None):
# appears to be code
continue
data_xrefs: List[int] = []
data_xrefs: list[int] = []
for data_reference_index in idx.data_reference_index_by_target_address[data_address]:
data_reference = be2.data_reference[data_reference_index]
instruction_address = idx.get_insn_address(data_reference.instruction_index)

View File

@@ -27,7 +27,6 @@ import logging
import argparse
import itertools
import posixpath
from typing import Set, Dict, List
from pathlib import Path
from dataclasses import field, dataclass
@@ -59,10 +58,10 @@ class Context:
capabilities_by_sample: cache of results, indexed by file path.
"""
samples: Dict[str, Path]
samples: dict[str, Path]
rules: RuleSet
is_thorough: bool
capabilities_by_sample: Dict[Path, Set[str]] = field(default_factory=dict)
capabilities_by_sample: dict[Path, set[str]] = field(default_factory=dict)
class Lint:
@@ -330,7 +329,7 @@ class InvalidAttckOrMbcTechnique(Lint):
DEFAULT_SIGNATURES = capa.main.get_default_signatures()
def get_sample_capabilities(ctx: Context, path: Path) -> Set[str]:
def get_sample_capabilities(ctx: Context, path: Path) -> set[str]:
nice_path = path.resolve().absolute()
if path in ctx.capabilities_by_sample:
logger.debug("found cached results: %s: %d capabilities", nice_path, len(ctx.capabilities_by_sample[path]))
@@ -541,7 +540,7 @@ class FeatureStringTooShort(Lint):
name = "feature string too short"
recommendation = 'capa only extracts strings with length >= 4; will not match on "{:s}"'
def check_features(self, ctx: Context, features: List[Feature]):
def check_features(self, ctx: Context, features: list[Feature]):
for feature in features:
if isinstance(feature, (String, Substring)):
assert isinstance(feature.value, str)
@@ -559,7 +558,7 @@ class FeatureNegativeNumber(Lint):
+ 'representation; will not match on "{:d}"'
)
def check_features(self, ctx: Context, features: List[Feature]):
def check_features(self, ctx: Context, features: list[Feature]):
for feature in features:
if isinstance(feature, (capa.features.insn.Number,)):
assert isinstance(feature.value, int)
@@ -577,7 +576,7 @@ class FeatureNtdllNtoskrnlApi(Lint):
+ "module requirement to improve detection"
)
def check_features(self, ctx: Context, features: List[Feature]):
def check_features(self, ctx: Context, features: list[Feature]):
for feature in features:
if isinstance(feature, capa.features.insn.API):
assert isinstance(feature.value, str)
@@ -712,7 +711,7 @@ def run_lints(lints, ctx: Context, rule: Rule):
yield lint
def run_feature_lints(lints, ctx: Context, features: List[Feature]):
def run_feature_lints(lints, ctx: Context, features: list[Feature]):
for lint in lints:
if lint.check_features(ctx, features):
yield lint
@@ -900,7 +899,7 @@ def width(s, count):
def lint(ctx: Context):
"""
Returns: Dict[string, Tuple(int, int)]
Returns: dict[string, tuple(int, int)]
- # lints failed
- # lints warned
"""
@@ -920,7 +919,7 @@ def lint(ctx: Context):
return ret
def collect_samples(samples_path: Path) -> Dict[str, Path]:
def collect_samples(samples_path: Path) -> dict[str, Path]:
"""
recurse through the given path, collecting all file paths, indexed by their content sha256, md5, and filename.
"""

View File

@@ -43,7 +43,6 @@ import json
import logging
import argparse
from sys import argv
from typing import Dict, List
from pathlib import Path
import requests
@@ -77,7 +76,7 @@ class MitreExtractor:
self._memory_store = MemoryStore(stix_data=stix_json["objects"])
@staticmethod
def _remove_deprecated_objects(stix_objects) -> List[AttackPattern]:
def _remove_deprecated_objects(stix_objects) -> list[AttackPattern]:
"""Remove any revoked or deprecated objects from queries made to the data source."""
return list(
filter(
@@ -86,7 +85,7 @@ class MitreExtractor:
)
)
def _get_tactics(self) -> List[Dict]:
def _get_tactics(self) -> list[dict]:
"""Get tactics IDs from Mitre matrix."""
# Only one matrix for enterprise att&ck framework
matrix = self._remove_deprecated_objects(
@@ -98,7 +97,7 @@ class MitreExtractor:
)[0]
return list(map(self._memory_store.get, matrix["tactic_refs"]))
def _get_techniques_from_tactic(self, tactic: str) -> List[AttackPattern]:
def _get_techniques_from_tactic(self, tactic: str) -> list[AttackPattern]:
"""Get techniques and sub techniques from a Mitre tactic (kill_chain_phases->phase_name)"""
techniques = self._remove_deprecated_objects(
self._memory_store.query(
@@ -124,12 +123,12 @@ class MitreExtractor:
)[0]
return parent_technique
def run(self) -> Dict[str, Dict[str, str]]:
def run(self) -> dict[str, dict[str, str]]:
"""Iterate over every technique over every tactic. If the technique is a sub technique, then
we also search for the parent technique name.
"""
logging.info("Starting extraction...")
data: Dict[str, Dict[str, str]] = {}
data: dict[str, dict[str, str]] = {}
for tactic in self._get_tactics():
data[tactic["name"]] = {}
for technique in sorted(
@@ -159,7 +158,7 @@ class MbcExtractor(MitreExtractor):
url = "https://raw.githubusercontent.com/MBCProject/mbc-stix2/master/mbc/mbc.json"
kill_chain_name = "mitre-mbc"
def _get_tactics(self) -> List[Dict]:
def _get_tactics(self) -> list[dict]:
"""Override _get_tactics to edit the tactic name for Micro-objective"""
tactics = super()._get_tactics()
# We don't want the Micro-objective string inside objective names

View File

@@ -59,7 +59,6 @@ import sys
import logging
import argparse
import collections
from typing import Dict
import colorama
@@ -99,7 +98,7 @@ def render_matches_by_function(doc: rd.ResultDocument):
- connect to HTTP server
"""
assert isinstance(doc.meta.analysis, rd.StaticAnalysis)
functions_by_bb: Dict[Address, Address] = {}
functions_by_bb: dict[Address, Address] = {}
for finfo in doc.meta.analysis.layout.functions:
faddress = finfo.address

View File

@@ -67,7 +67,6 @@ Example::
import sys
import logging
import argparse
from typing import Tuple
import capa.main
import capa.rules
@@ -136,7 +135,7 @@ def print_static_analysis(extractor: StaticFeatureExtractor, args):
for feature, addr in extractor.extract_file_features():
print(f"file: {format_address(addr)}: {feature}")
function_handles: Tuple[FunctionHandle, ...]
function_handles: tuple[FunctionHandle, ...]
if isinstance(extractor, capa.features.extractors.pefile.PefileFeatureExtractor):
# pefile extractor doesn't extract function features
function_handles = ()

View File

@@ -9,10 +9,8 @@ Unless required by applicable law or agreed to in writing, software distributed
See the License for the specific language governing permissions and limitations under the License.
"""
import sys
import typing
import logging
import argparse
from typing import Set, List, Tuple
from collections import Counter
from rich import print
@@ -40,8 +38,8 @@ def format_address(addr: capa.features.address.Address) -> str:
return v.format_address(capa.features.freeze.Address.from_capa((addr)))
def get_rules_feature_set(rules: capa.rules.RuleSet) -> Set[Feature]:
rules_feature_set: Set[Feature] = set()
def get_rules_feature_set(rules: capa.rules.RuleSet) -> set[Feature]:
rules_feature_set: set[Feature] = set()
for _, rule in rules.rules.items():
rules_feature_set.update(rule.extract_all_features())
@@ -49,9 +47,9 @@ def get_rules_feature_set(rules: capa.rules.RuleSet) -> Set[Feature]:
def get_file_features(
functions: Tuple[FunctionHandle, ...], extractor: capa.features.extractors.base_extractor.StaticFeatureExtractor
) -> typing.Counter[Feature]:
feature_map: typing.Counter[Feature] = Counter()
functions: tuple[FunctionHandle, ...], extractor: capa.features.extractors.base_extractor.StaticFeatureExtractor
) -> Counter[Feature]:
feature_map: Counter[Feature] = Counter()
for f in functions:
if extractor.is_library_function(f.address):
@@ -86,8 +84,8 @@ def get_colored(s: str) -> Text:
return Text(s, style="cyan")
def print_unused_features(feature_map: typing.Counter[Feature], rules_feature_set: Set[Feature]):
unused_features: List[Tuple[str, Text]] = []
def print_unused_features(feature_map: Counter[Feature], rules_feature_set: set[Feature]):
unused_features: list[tuple[str, Text]] = []
for feature, count in reversed(feature_map.most_common()):
if feature in rules_feature_set:
continue
@@ -130,11 +128,11 @@ def main(argv=None):
assert isinstance(extractor, StaticFeatureExtractor), "only static analysis supported today"
feature_map: typing.Counter[Feature] = Counter()
feature_map: Counter[Feature] = Counter()
feature_map.update([feature for feature, _ in extractor.extract_global_features()])
function_handles: Tuple[FunctionHandle, ...]
function_handles: tuple[FunctionHandle, ...]
if isinstance(extractor, capa.features.extractors.pefile.PefileFeatureExtractor):
# pefile extractor doesn't extract function features
function_handles = ()
@@ -173,7 +171,7 @@ def ida_main():
print(f"getting features for current function {hex(function)}")
extractor = capa.features.extractors.ida.extractor.IdaFeatureExtractor()
feature_map: typing.Counter[Feature] = Counter()
feature_map: Counter[Feature] = Counter()
feature_map.update([feature for feature, _ in extractor.extract_file_features()])