FeatureExtractor alias: fix mypy typing issues by adding ininstance-based assert statements

This commit is contained in:
Yacine Elhamer
2023-06-26 22:46:27 +01:00
parent 63e4d3d5eb
commit b172f9a354
4 changed files with 15 additions and 13 deletions

View File

@@ -23,9 +23,9 @@ import capa.features.insn
import capa.features.common
import capa.features.address
import capa.features.basicblock
import capa.features.extractors.base_extractor
from capa.helpers import assert_never
from capa.features.freeze.features import Feature, feature_from_capa
from capa.features.extractors.base_extractor import FeatureExtractor, StaticFeatureExtractor
logger = logging.getLogger(__name__)
@@ -226,7 +226,7 @@ class Freeze(BaseModel):
allow_population_by_field_name = True
def dumps(extractor: capa.features.extractors.base_extractor.StaticFeatureExtractor) -> str:
def dumps(extractor: StaticFeatureExtractor) -> str:
"""
serialize the given extractor to a string
"""
@@ -327,7 +327,7 @@ def dumps(extractor: capa.features.extractors.base_extractor.StaticFeatureExtrac
return freeze.json()
def loads(s: str) -> capa.features.extractors.base_extractor.StaticFeatureExtractor:
def loads(s: str) -> StaticFeatureExtractor:
"""deserialize a set of features (as a NullFeatureExtractor) from a string."""
import capa.features.extractors.null as null
@@ -363,8 +363,9 @@ def loads(s: str) -> capa.features.extractors.base_extractor.StaticFeatureExtrac
MAGIC = "capa0000".encode("ascii")
def dump(extractor: capa.features.extractors.base_extractor.StaticFeatureExtractor) -> bytes:
def dump(extractor: FeatureExtractor) -> bytes:
"""serialize the given extractor to a byte array."""
assert isinstance(extractor, StaticFeatureExtractor)
return MAGIC + zlib.compress(dumps(extractor).encode("utf-8"))
@@ -372,7 +373,7 @@ def is_freeze(buf: bytes) -> bool:
return buf[: len(MAGIC)] == MAGIC
def load(buf: bytes) -> capa.features.extractors.base_extractor.StaticFeatureExtractor:
def load(buf: bytes) -> StaticFeatureExtractor:
"""deserialize a set of features (as a NullFeatureExtractor) from a byte array."""
if not is_freeze(buf):
raise ValueError("missing magic header")

View File

@@ -46,7 +46,7 @@ import capa.helpers
import capa.features
import capa.features.common
import capa.features.freeze
from capa.features.extractors.base_extractor import StaticFeatureExtractor
from capa.features.extractors.base_extractor import FeatureExtractor, StaticFeatureExtractor
logger = logging.getLogger("capa.profile")
@@ -104,13 +104,14 @@ def main(argv=None):
args.format == capa.features.common.FORMAT_AUTO and capa.features.freeze.is_freeze(taste)
):
with open(args.sample, "rb") as f:
extractor = capa.features.freeze.load(f.read())
extractor: FeatureExtractor = capa.features.freeze.load(f.read())
assert isinstance(extractor, StaticFeatureExtractor)
else:
extractor = capa.main.get_extractor(
args.sample, args.format, args.os, capa.main.BACKEND_VIV, sig_paths, should_save_workspace=False
)
assert isinstance(extractor, StaticFeatureExtractor)
with tqdm.tqdm(total=args.number * args.repeat) as pbar:
def do_iteration():

View File

@@ -70,7 +70,7 @@ import capa.render.result_document as rd
from capa.helpers import get_file_taste
from capa.features.common import FORMAT_AUTO
from capa.features.freeze import Address
from capa.features.extractors.base_extractor import StaticFeatureExtractor
from capa.features.extractors.base_extractor import FeatureExtractor, StaticFeatureExtractor
logger = logging.getLogger("capa.show-capabilities-by-function")
@@ -161,7 +161,7 @@ def main(argv=None):
if (args.format == "freeze") or (args.format == FORMAT_AUTO and capa.features.freeze.is_freeze(taste)):
format_ = "freeze"
with open(args.sample, "rb") as f:
extractor = capa.features.freeze.load(f.read())
extractor: FeatureExtractor = capa.features.freeze.load(f.read())
else:
format_ = args.format
should_save_workspace = os.environ.get("CAPA_SAVE_WORKSPACE") not in ("0", "no", "NO", "n", None)

View File

@@ -80,8 +80,8 @@ import capa.render.verbose as v
import capa.features.common
import capa.features.freeze
import capa.features.address
import capa.features.extractors.base_extractor
from capa.helpers import log_unsupported_runtime_error
from capa.features.extractors.base_extractor import FeatureExtractor, StaticFeatureExtractor
logger = logging.getLogger("capa.show-features")
@@ -117,14 +117,13 @@ def main(argv=None):
args.format == capa.features.common.FORMAT_AUTO and capa.features.freeze.is_freeze(taste)
):
with open(args.sample, "rb") as f:
extractor = capa.features.freeze.load(f.read())
extractor: FeatureExtractor = capa.features.freeze.load(f.read())
else:
should_save_workspace = os.environ.get("CAPA_SAVE_WORKSPACE") not in ("0", "no", "NO", "n", None)
try:
extractor = capa.main.get_extractor(
args.sample, args.format, args.os, args.backend, sig_paths, should_save_workspace
)
assert isinstance(extractor, capa.features.extractors.base_extractor.StaticFeatureExtractor)
except capa.exceptions.UnsupportedFormatError:
capa.helpers.log_unsupported_format_error()
return -1
@@ -132,6 +131,7 @@ def main(argv=None):
log_unsupported_runtime_error()
return -1
assert isinstance(extractor, StaticFeatureExtractor)
for feature, addr in extractor.extract_global_features():
print(f"global: {format_address(addr)}: {feature}")
@@ -190,7 +190,7 @@ def ida_main():
return 0
def print_features(functions, extractor: capa.features.extractors.base_extractor.FeatureExtractor):
def print_features(functions, extractor: StaticFeatureExtractor):
for f in functions:
if extractor.is_library_function(f.address):
function_name = extractor.get_function_name(f.address)