Compare commits

...

2 Commits

Author SHA1 Message Date
Willi Ballenthin
036e157474 ida: use ida-domain api 2026-01-14 12:04:13 +01:00
Moritz
3919475728 Merge branch 'master' into idalib-tests 2026-01-14 12:04:13 +01:00
14 changed files with 372 additions and 298 deletions

View File

@@ -18,6 +18,7 @@ import struct
from typing import Iterator
import idaapi
from ida_domain import Database
import capa.features.extractors.ida.helpers
from capa.features.common import Feature, Characteristic
@@ -59,7 +60,7 @@ def get_printable_len(op: idaapi.op_t) -> int:
return 0
def is_mov_imm_to_stack(insn: idaapi.insn_t) -> bool:
def is_mov_imm_to_stack(db: Database, insn: idaapi.insn_t) -> bool:
"""verify instruction moves immediate onto stack"""
if insn.Op2.type != idaapi.o_imm:
return False
@@ -67,42 +68,43 @@ def is_mov_imm_to_stack(insn: idaapi.insn_t) -> bool:
if not helpers.is_op_stack_var(insn.ea, 0):
return False
if not insn.get_canon_mnem().startswith("mov"):
mnem = db.instructions.get_mnemonic(insn)
if not mnem.startswith("mov"):
return False
return True
def bb_contains_stackstring(f: idaapi.func_t, bb: idaapi.BasicBlock) -> bool:
def bb_contains_stackstring(db: Database, f: idaapi.func_t, bb: idaapi.BasicBlock) -> bool:
"""check basic block for stackstring indicators
true if basic block contains enough moves of constant bytes to the stack
"""
count = 0
for insn in capa.features.extractors.ida.helpers.get_instructions_in_range(bb.start_ea, bb.end_ea):
if is_mov_imm_to_stack(insn):
for insn in capa.features.extractors.ida.helpers.get_instructions_in_range(db, bb.start_ea, bb.end_ea):
if is_mov_imm_to_stack(db, insn):
count += get_printable_len(insn.Op2)
if count > MIN_STACKSTRING_LEN:
return True
return False
def extract_bb_stackstring(fh: FunctionHandle, bbh: BBHandle) -> Iterator[tuple[Feature, Address]]:
def extract_bb_stackstring(db: Database, fh: FunctionHandle, bbh: BBHandle) -> Iterator[tuple[Feature, Address]]:
"""extract stackstring indicators from basic block"""
if bb_contains_stackstring(fh.inner, bbh.inner):
if bb_contains_stackstring(db, fh.inner, bbh.inner):
yield Characteristic("stack string"), bbh.address
def extract_bb_tight_loop(fh: FunctionHandle, bbh: BBHandle) -> Iterator[tuple[Feature, Address]]:
def extract_bb_tight_loop(db: Database, fh: FunctionHandle, bbh: BBHandle) -> Iterator[tuple[Feature, Address]]:
"""extract tight loop indicators from a basic block"""
if capa.features.extractors.ida.helpers.is_basic_block_tight_loop(bbh.inner):
if capa.features.extractors.ida.helpers.is_basic_block_tight_loop(db, bbh.inner):
yield Characteristic("tight loop"), bbh.address
def extract_features(fh: FunctionHandle, bbh: BBHandle) -> Iterator[tuple[Feature, Address]]:
def extract_features(db: Database, fh: FunctionHandle, bbh: BBHandle) -> Iterator[tuple[Feature, Address]]:
"""extract basic block features"""
for bb_handler in BASIC_BLOCK_HANDLERS:
for feature, addr in bb_handler(fh, bbh):
for feature, addr in bb_handler(db, fh, bbh):
yield feature, addr
yield BasicBlock(), bbh.address

View File

@@ -13,8 +13,9 @@
# limitations under the License.
from typing import Iterator
from pathlib import Path
import idaapi
from ida_domain import Database
import capa.ida.helpers
import capa.features.extractors.elf
@@ -35,56 +36,68 @@ from capa.features.extractors.base_extractor import (
class IdaFeatureExtractor(StaticFeatureExtractor):
def __init__(self):
def __init__(self, db: Database):
self.db = db
super().__init__(
hashes=SampleHashes(
md5=capa.ida.helpers.retrieve_input_file_md5(),
md5=db.md5,
sha1="(unknown)",
sha256=capa.ida.helpers.retrieve_input_file_sha256(),
sha256=db.sha256,
)
)
self.global_features: list[tuple[Feature, Address]] = []
self.global_features.extend(capa.features.extractors.ida.file.extract_file_format())
self.global_features.extend(capa.features.extractors.ida.global_.extract_os())
self.global_features.extend(capa.features.extractors.ida.global_.extract_arch())
self.global_features.extend(capa.features.extractors.ida.file.extract_file_format(self.db))
self.global_features.extend(capa.features.extractors.ida.global_.extract_os(self.db))
self.global_features.extend(capa.features.extractors.ida.global_.extract_arch(self.db))
@classmethod
def from_current_database(cls) -> "IdaFeatureExtractor":
"""Create extractor for interactive IDA GUI use."""
db = Database.open()
return cls(db)
@classmethod
def from_file(cls, path: Path) -> "IdaFeatureExtractor":
"""Create extractor for idalib/headless use."""
db = Database.open(str(path))
return cls(db)
def get_base_address(self):
return AbsoluteVirtualAddress(idaapi.get_imagebase())
return AbsoluteVirtualAddress(self.db.base_address)
def extract_global_features(self):
yield from self.global_features
def extract_file_features(self):
yield from capa.features.extractors.ida.file.extract_features()
yield from capa.features.extractors.ida.file.extract_features(self.db)
def get_functions(self) -> Iterator[FunctionHandle]:
import capa.features.extractors.ida.helpers as ida_helpers
# ignore library functions and thunk functions as identified by IDA
yield from ida_helpers.get_functions(skip_thunks=True, skip_libs=True)
yield from ida_helpers.get_functions(self.db, skip_thunks=True, skip_libs=True)
@staticmethod
def get_function(ea: int) -> FunctionHandle:
f = idaapi.get_func(ea)
def get_function(self, ea: int) -> FunctionHandle:
f = self.db.functions.get_at(ea)
return FunctionHandle(address=AbsoluteVirtualAddress(f.start_ea), inner=f)
def extract_function_features(self, fh: FunctionHandle) -> Iterator[tuple[Feature, Address]]:
yield from capa.features.extractors.ida.function.extract_features(fh)
yield from capa.features.extractors.ida.function.extract_features(self.db, fh)
def get_basic_blocks(self, fh: FunctionHandle) -> Iterator[BBHandle]:
import capa.features.extractors.ida.helpers as ida_helpers
for bb in ida_helpers.get_function_blocks(fh.inner):
for bb in ida_helpers.get_function_blocks(self.db, fh.inner):
yield BBHandle(address=AbsoluteVirtualAddress(bb.start_ea), inner=bb)
def extract_basic_block_features(self, fh: FunctionHandle, bbh: BBHandle) -> Iterator[tuple[Feature, Address]]:
yield from capa.features.extractors.ida.basicblock.extract_features(fh, bbh)
yield from capa.features.extractors.ida.basicblock.extract_features(self.db, fh, bbh)
def get_instructions(self, fh: FunctionHandle, bbh: BBHandle) -> Iterator[InsnHandle]:
import capa.features.extractors.ida.helpers as ida_helpers
for insn in ida_helpers.get_instructions_in_range(bbh.inner.start_ea, bbh.inner.end_ea):
for insn in ida_helpers.get_instructions_in_range(self.db, bbh.inner.start_ea, bbh.inner.end_ea):
yield InsnHandle(address=AbsoluteVirtualAddress(insn.ea), inner=insn)
def extract_insn_features(self, fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle):
yield from capa.features.extractors.ida.insn.extract_features(fh, bbh, ih)
yield from capa.features.extractors.ida.insn.extract_features(self.db, fh, bbh, ih)

View File

@@ -16,10 +16,9 @@
import struct
from typing import Iterator
import idc
import idaapi
import idautils
import ida_entry
from ida_domain import Database
from ida_domain.functions import FunctionFlags
import capa.ida.helpers
import capa.features.extractors.common
@@ -33,7 +32,7 @@ from capa.features.address import NO_ADDRESS, Address, FileOffsetAddress, Absolu
MAX_OFFSET_PE_AFTER_MZ = 0x200
def check_segment_for_pe(seg: idaapi.segment_t) -> Iterator[tuple[int, int]]:
def check_segment_for_pe(db: Database, seg) -> Iterator[tuple[int, int]]:
"""check segment for embedded PE
adapted for IDA from:
@@ -51,8 +50,7 @@ def check_segment_for_pe(seg: idaapi.segment_t) -> Iterator[tuple[int, int]]:
todo = []
for mzx, pex, i in mz_xor:
# find all segment offsets containing XOR'd "MZ" bytes
for off in capa.features.extractors.ida.helpers.find_byte_sequence(seg.start_ea, seg.end_ea, mzx):
for off in capa.features.extractors.ida.helpers.find_byte_sequence(db, seg.start_ea, seg.end_ea, mzx):
todo.append((off, mzx, pex, i))
while len(todo):
@@ -64,9 +62,11 @@ def check_segment_for_pe(seg: idaapi.segment_t) -> Iterator[tuple[int, int]]:
if seg_max < (e_lfanew + 4):
continue
newoff = struct.unpack("<I", capa.features.extractors.helpers.xor_static(idc.get_bytes(e_lfanew, 4), i))[0]
raw_bytes = db.bytes.get_bytes_at(e_lfanew, 4)
if not raw_bytes:
continue
newoff = struct.unpack("<I", capa.features.extractors.helpers.xor_static(raw_bytes, i))[0]
# assume XOR'd "PE" bytes exist within threshold
if newoff > MAX_OFFSET_PE_AFTER_MZ:
continue
@@ -74,35 +74,35 @@ def check_segment_for_pe(seg: idaapi.segment_t) -> Iterator[tuple[int, int]]:
if seg_max < (peoff + 2):
continue
if idc.get_bytes(peoff, 2) == pex:
pe_bytes = db.bytes.get_bytes_at(peoff, 2)
if pe_bytes == pex:
yield off, i
def extract_file_embedded_pe() -> Iterator[tuple[Feature, Address]]:
def extract_file_embedded_pe(db: Database) -> Iterator[tuple[Feature, Address]]:
"""extract embedded PE features
IDA must load resource sections for this to be complete
- '-R' from console
- Check 'Load resource sections' when opening binary in IDA manually
"""
for seg in capa.features.extractors.ida.helpers.get_segments(skip_header_segments=True):
for ea, _ in check_segment_for_pe(seg):
for seg in capa.features.extractors.ida.helpers.get_segments(db, skip_header_segments=True):
for ea, _ in check_segment_for_pe(db, seg):
yield Characteristic("embedded pe"), FileOffsetAddress(ea)
def extract_file_export_names() -> Iterator[tuple[Feature, Address]]:
def extract_file_export_names(db: Database) -> Iterator[tuple[Feature, Address]]:
"""extract function exports"""
for _, ordinal, ea, name in idautils.Entries():
forwarded_name = ida_entry.get_entry_forwarder(ordinal)
if forwarded_name is None:
yield Export(name), AbsoluteVirtualAddress(ea)
for entry in db.entries.get_all():
if entry.has_forwarder():
forwarded_name = capa.features.extractors.helpers.reformat_forwarded_export_name(entry.forwarder_name)
yield Export(forwarded_name), AbsoluteVirtualAddress(entry.address)
yield Characteristic("forwarded export"), AbsoluteVirtualAddress(entry.address)
else:
forwarded_name = capa.features.extractors.helpers.reformat_forwarded_export_name(forwarded_name)
yield Export(forwarded_name), AbsoluteVirtualAddress(ea)
yield Characteristic("forwarded export"), AbsoluteVirtualAddress(ea)
yield Export(entry.name), AbsoluteVirtualAddress(entry.address)
def extract_file_import_names() -> Iterator[tuple[Feature, Address]]:
def extract_file_import_names(db: Database) -> Iterator[tuple[Feature, Address]]:
"""extract function imports
1. imports by ordinal:
@@ -113,7 +113,7 @@ def extract_file_import_names() -> Iterator[tuple[Feature, Address]]:
- modulename.importname
- importname
"""
for ea, info in capa.features.extractors.ida.helpers.get_file_imports().items():
for ea, info in capa.features.extractors.ida.helpers.get_file_imports(db).items():
addr = AbsoluteVirtualAddress(ea)
if info[1] and info[2]:
# e.g. in mimikatz: ('cabinet', 'FCIAddFile', 11L)
@@ -134,30 +134,31 @@ def extract_file_import_names() -> Iterator[tuple[Feature, Address]]:
for name in capa.features.extractors.helpers.generate_symbols(dll, symbol, include_dll=True):
yield Import(name), addr
for ea, info in capa.features.extractors.ida.helpers.get_file_externs().items():
for ea, info in capa.features.extractors.ida.helpers.get_file_externs(db).items():
yield Import(info[1]), AbsoluteVirtualAddress(ea)
def extract_file_section_names() -> Iterator[tuple[Feature, Address]]:
def extract_file_section_names(db: Database) -> Iterator[tuple[Feature, Address]]:
"""extract section names
IDA must load resource sections for this to be complete
- '-R' from console
- Check 'Load resource sections' when opening binary in IDA manually
"""
for seg in capa.features.extractors.ida.helpers.get_segments(skip_header_segments=True):
yield Section(idaapi.get_segm_name(seg)), AbsoluteVirtualAddress(seg.start_ea)
for seg in capa.features.extractors.ida.helpers.get_segments(db, skip_header_segments=True):
name = db.segments.get_name(seg)
yield Section(name), AbsoluteVirtualAddress(seg.start_ea)
def extract_file_strings() -> Iterator[tuple[Feature, Address]]:
def extract_file_strings(db: Database) -> Iterator[tuple[Feature, Address]]:
"""extract ASCII and UTF-16 LE strings
IDA must load resource sections for this to be complete
- '-R' from console
- Check 'Load resource sections' when opening binary in IDA manually
"""
for seg in capa.features.extractors.ida.helpers.get_segments():
seg_buff = capa.features.extractors.ida.helpers.get_segment_buffer(seg)
for seg in capa.features.extractors.ida.helpers.get_segments(db):
seg_buff = capa.features.extractors.ida.helpers.get_segment_buffer(db, seg)
# differing to common string extractor factor in segment offset here
for s in capa.features.extractors.strings.extract_ascii_strings(seg_buff):
@@ -167,41 +168,40 @@ def extract_file_strings() -> Iterator[tuple[Feature, Address]]:
yield String(s.s), FileOffsetAddress(seg.start_ea + s.offset)
def extract_file_function_names() -> Iterator[tuple[Feature, Address]]:
"""
extract the names of statically-linked library functions.
"""
for ea in idautils.Functions():
addr = AbsoluteVirtualAddress(ea)
if idaapi.get_func(ea).flags & idaapi.FUNC_LIB:
name = idaapi.get_name(ea)
yield FunctionName(name), addr
if name.startswith("_"):
# some linkers may prefix linked routines with a `_` to avoid name collisions.
# extract features for both the mangled and un-mangled representations.
# e.g. `_fwrite` -> `fwrite`
# see: https://stackoverflow.com/a/2628384/87207
yield FunctionName(name[1:]), addr
def extract_file_function_names(db: Database) -> Iterator[tuple[Feature, Address]]:
"""extract the names of statically-linked library functions."""
for f in db.functions.get_all():
flags = db.functions.get_flags(f)
if flags & FunctionFlags.LIB:
addr = AbsoluteVirtualAddress(f.start_ea)
name = db.names.get_at(f.start_ea)
if name:
yield FunctionName(name), addr
if name.startswith("_"):
# some linkers may prefix linked routines with a `_` to avoid name collisions.
# extract features for both the mangled and un-mangled representations.
# e.g. `_fwrite` -> `fwrite`
# see: https://stackoverflow.com/a/2628384/87207
yield FunctionName(name[1:]), addr
def extract_file_format() -> Iterator[tuple[Feature, Address]]:
filetype = capa.ida.helpers.get_filetype()
def extract_file_format(db: Database) -> Iterator[tuple[Feature, Address]]:
format_name = db.format
if filetype in (idaapi.f_PE, idaapi.f_COFF):
if "PE" in format_name or "COFF" in format_name:
yield Format(FORMAT_PE), NO_ADDRESS
elif filetype == idaapi.f_ELF:
elif "ELF" in format_name:
yield Format(FORMAT_ELF), NO_ADDRESS
elif filetype == idaapi.f_BIN:
# no file type to return when processing a binary file, but we want to continue processing
elif "Binary" in format_name:
return
else:
raise NotImplementedError(f"unexpected file format: {filetype}")
raise NotImplementedError(f"unexpected file format: {format_name}")
def extract_features() -> Iterator[tuple[Feature, Address]]:
def extract_features(db: Database) -> Iterator[tuple[Feature, Address]]:
"""extract file features"""
for file_handler in FILE_HANDLERS:
for feature, addr in file_handler():
for feature, addr in file_handler(db):
yield feature, addr

View File

@@ -15,7 +15,7 @@
from typing import Iterator
import idaapi
import idautils
from ida_domain import Database
import capa.features.extractors.ida.helpers
from capa.features.file import FunctionName
@@ -25,19 +25,20 @@ from capa.features.extractors import loops
from capa.features.extractors.base_extractor import FunctionHandle
def extract_function_calls_to(fh: FunctionHandle):
def extract_function_calls_to(db: Database, fh: FunctionHandle):
"""extract callers to a function"""
for ea in idautils.CodeRefsTo(fh.inner.start_ea, True):
for ea in db.xrefs.code_refs_to_ea(fh.inner.start_ea):
yield Characteristic("calls to"), AbsoluteVirtualAddress(ea)
def extract_function_loop(fh: FunctionHandle):
def extract_function_loop(db: Database, fh: FunctionHandle):
"""extract loop indicators from a function"""
f: idaapi.func_t = fh.inner
edges = []
# construct control flow graph
for bb in idaapi.FlowChart(f):
flowchart = db.functions.get_flowchart(f)
for bb in flowchart:
for succ in bb.succs():
edges.append((bb.start_ea, succ.start_ea))
@@ -45,16 +46,16 @@ def extract_function_loop(fh: FunctionHandle):
yield Characteristic("loop"), fh.address
def extract_recursive_call(fh: FunctionHandle):
def extract_recursive_call(db: Database, fh: FunctionHandle):
"""extract recursive function call"""
if capa.features.extractors.ida.helpers.is_function_recursive(fh.inner):
if capa.features.extractors.ida.helpers.is_function_recursive(db, fh.inner):
yield Characteristic("recursive call"), fh.address
def extract_function_name(fh: FunctionHandle) -> Iterator[tuple[Feature, Address]]:
def extract_function_name(db: Database, fh: FunctionHandle) -> Iterator[tuple[Feature, Address]]:
ea = fh.inner.start_ea
name = idaapi.get_name(ea)
if name.startswith("sub_"):
name = db.names.get_at(ea)
if not name or name.startswith("sub_"):
# skip default names, like "sub_401000"
return
@@ -67,16 +68,15 @@ def extract_function_name(fh: FunctionHandle) -> Iterator[tuple[Feature, Address
yield FunctionName(name[1:]), fh.address
def extract_function_alternative_names(fh: FunctionHandle):
def extract_function_alternative_names(db: Database, fh: FunctionHandle):
"""Get all alternative names for an address."""
for aname in capa.features.extractors.ida.helpers.get_function_alternative_names(fh.inner.start_ea):
for aname in capa.features.extractors.ida.helpers.get_function_alternative_names(db, fh.inner.start_ea):
yield FunctionName(aname), fh.address
def extract_features(fh: FunctionHandle) -> Iterator[tuple[Feature, Address]]:
def extract_features(db: Database, fh: FunctionHandle) -> Iterator[tuple[Feature, Address]]:
for func_handler in FUNCTION_HANDLERS:
for feature, addr in func_handler(fh):
for feature, addr in func_handler(db, fh):
yield feature, addr

View File

@@ -16,7 +16,7 @@ import logging
import contextlib
from typing import Iterator
import ida_loader
from ida_domain import Database
import capa.ida.helpers
import capa.features.extractors.elf
@@ -26,8 +26,8 @@ from capa.features.address import NO_ADDRESS, Address
logger = logging.getLogger(__name__)
def extract_os() -> Iterator[tuple[Feature, Address]]:
format_name: str = ida_loader.get_file_type_name()
def extract_os(db: Database) -> Iterator[tuple[Feature, Address]]:
format_name: str = db.format
if "PE" in format_name:
yield OS(OS_WINDOWS), NO_ADDRESS
@@ -53,13 +53,14 @@ def extract_os() -> Iterator[tuple[Feature, Address]]:
return
def extract_arch() -> Iterator[tuple[Feature, Address]]:
procname = capa.ida.helpers.get_processor_name()
if procname == "metapc" and capa.ida.helpers.is_64bit():
def extract_arch(db: Database) -> Iterator[tuple[Feature, Address]]:
bitness = db.bitness
arch = db.architecture
if arch == "metapc" and bitness == 64:
yield Arch(ARCH_AMD64), NO_ADDRESS
elif procname == "metapc" and capa.ida.helpers.is_32bit():
elif arch == "metapc" and bitness == 32:
yield Arch(ARCH_I386), NO_ADDRESS
elif procname == "metapc":
elif arch == "metapc":
logger.debug("unsupported architecture: non-32-bit nor non-64-bit intel")
return
else:
@@ -67,5 +68,5 @@ def extract_arch() -> Iterator[tuple[Feature, Address]]:
# 1. handling a new architecture (e.g. aarch64)
#
# for (1), this logic will need to be updated as the format is implemented.
logger.debug("unsupported architecture: %s", procname)
logger.debug("unsupported architecture: %s", arch)
return

View File

@@ -22,109 +22,86 @@ import idautils
import ida_bytes
import ida_funcs
import ida_segment
from ida_domain import Database
from ida_domain.functions import FunctionFlags
from capa.features.address import AbsoluteVirtualAddress
from capa.features.extractors.base_extractor import FunctionHandle
IDA_NALT_ENCODING = ida_nalt.get_default_encoding_idx(ida_nalt.BPU_1B) # use one byte-per-character encoding
def find_byte_sequence(db: Database, start: int, end: int, seq: bytes) -> Iterator[int]:
"""yield all ea of a given byte sequence
if hasattr(ida_bytes, "parse_binpat_str"):
# TODO (mr): use find_bytes
# https://github.com/mandiant/capa/issues/2339
def find_byte_sequence(start: int, end: int, seq: bytes) -> Iterator[int]:
"""yield all ea of a given byte sequence
args:
start: min virtual address
end: max virtual address
seq: bytes to search e.g. b"\x01\x03"
"""
patterns = ida_bytes.compiled_binpat_vec_t()
seqstr = " ".join([f"{b:02x}" for b in seq])
err = ida_bytes.parse_binpat_str(patterns, 0, seqstr, 16, IDA_NALT_ENCODING)
if err:
return
while True:
ea = ida_bytes.bin_search(start, end, patterns, ida_bytes.BIN_SEARCH_FORWARD)
if isinstance(ea, int):
# "ea_t" in IDA 8.4, 8.3
pass
elif isinstance(ea, tuple):
# "drc_t" in IDA 9
ea = ea[0]
else:
raise NotImplementedError(f"bin_search returned unhandled type: {type(ea)}")
if ea == idaapi.BADADDR:
break
start = ea + 1
yield ea
else:
# for IDA 7.5 and older; using deprecated find_binary instead of bin_search
def find_byte_sequence(start: int, end: int, seq: bytes) -> Iterator[int]:
"""yield all ea of a given byte sequence
args:
start: min virtual address
end: max virtual address
seq: bytes to search e.g. b"\x01\x03"
"""
seqstr = " ".join([f"{b:02x}" for b in seq])
while True:
ea = idaapi.find_binary(start, end, seqstr, 0, idaapi.SEARCH_DOWN)
if ea == idaapi.BADADDR:
break
start = ea + 1
yield ea
args:
db: IDA Domain Database handle
start: min virtual address
end: max virtual address
seq: bytes to search e.g. b"\x01\x03"
"""
for match in db.bytes.find_binary_sequence(seq, start, end):
yield match
def get_functions(
start: Optional[int] = None, end: Optional[int] = None, skip_thunks: bool = False, skip_libs: bool = False
db: Database,
start: Optional[int] = None,
end: Optional[int] = None,
skip_thunks: bool = False,
skip_libs: bool = False,
) -> Iterator[FunctionHandle]:
"""get functions, range optional
args:
db: IDA Domain Database handle
start: min virtual address
end: max virtual address
skip_thunks: skip thunk functions
skip_libs: skip library functions
"""
for ea in idautils.Functions(start=start, end=end):
f = idaapi.get_func(ea)
if not (skip_thunks and (f.flags & idaapi.FUNC_THUNK) or skip_libs and (f.flags & idaapi.FUNC_LIB)):
yield FunctionHandle(address=AbsoluteVirtualAddress(ea), inner=f)
if start is not None and end is not None:
funcs = db.functions.get_between(start, end)
else:
funcs = db.functions.get_all()
for f in funcs:
flags = db.functions.get_flags(f)
if skip_thunks and (flags & FunctionFlags.THUNK):
continue
if skip_libs and (flags & FunctionFlags.LIB):
continue
yield FunctionHandle(address=AbsoluteVirtualAddress(f.start_ea), inner=f)
def get_segments(skip_header_segments=False) -> Iterator[idaapi.segment_t]:
def get_segments(db: Database, skip_header_segments: bool = False):
"""get list of segments (sections) in the binary image
args:
db: IDA Domain Database handle
skip_header_segments: IDA may load header segments - skip if set
"""
for n in range(idaapi.get_segm_qty()):
seg = idaapi.getnseg(n)
if seg and not (skip_header_segments and seg.is_header_segm()):
yield seg
for seg in db.segments.get_all():
if skip_header_segments and seg.is_header_segm():
continue
yield seg
def get_segment_buffer(seg: idaapi.segment_t) -> bytes:
def get_segment_buffer(db: Database, seg) -> bytes:
"""return bytes stored in a given segment
decrease buffer size until IDA is able to read bytes from the segment
args:
db: IDA Domain Database handle
seg: segment object
"""
buff = b""
sz = seg.end_ea - seg.start_ea
# decrease buffer size until IDA is able to read bytes from the segment
while sz > 0:
buff = idaapi.get_bytes(seg.start_ea, sz)
buff = db.bytes.get_bytes_at(seg.start_ea, sz)
if buff:
break
return buff
sz -= 0x1000
# IDA returns None if get_bytes fails, so convert for consistent return type
return buff if buff else b""
return b""
def inspect_import(imports, library, ea, function, ordinal):
@@ -140,8 +117,14 @@ def inspect_import(imports, library, ea, function, ordinal):
return True
def get_file_imports() -> dict[int, tuple[str, str, int]]:
"""get file imports"""
def get_file_imports(db: Database) -> dict[int, tuple[str, str, int]]:
"""get file imports
Note: import enumeration has no Domain API equivalent, using SDK fallback.
args:
db: IDA Domain Database handle (unused, kept for API consistency)
"""
imports: dict[int, tuple[str, str, int]] = {}
for idx in range(idaapi.get_import_module_qty()):
@@ -163,28 +146,35 @@ def get_file_imports() -> dict[int, tuple[str, str, int]]:
return imports
def get_file_externs() -> dict[int, tuple[str, str, int]]:
def get_file_externs(db: Database) -> dict[int, tuple[str, str, int]]:
"""get extern functions
args:
db: IDA Domain Database handle
"""
externs = {}
for seg in get_segments(skip_header_segments=True):
for seg in get_segments(db, skip_header_segments=True):
if seg.type != ida_segment.SEG_XTRN:
continue
for ea in idautils.Functions(seg.start_ea, seg.end_ea):
externs[ea] = ("", idaapi.get_func_name(ea), -1)
for f in db.functions.get_between(seg.start_ea, seg.end_ea):
name = db.functions.get_name(f)
externs[f.start_ea] = ("", name, -1)
return externs
def get_instructions_in_range(start: int, end: int) -> Iterator[idaapi.insn_t]:
def get_instructions_in_range(db: Database, start: int, end: int) -> Iterator[idaapi.insn_t]:
"""yield instructions in range
args:
db: IDA Domain Database handle
start: virtual address (inclusive)
end: virtual address (exclusive)
"""
for head in idautils.Heads(start, end):
insn = idautils.DecodeInstruction(head)
for head in db.heads.get_between(start, end):
insn = db.instructions.get_at(head)
if insn:
yield insn
@@ -234,21 +224,38 @@ def basic_block_size(bb: idaapi.BasicBlock) -> int:
return bb.end_ea - bb.start_ea
def read_bytes_at(ea: int, count: int) -> bytes:
""" """
# check if byte has a value, see get_wide_byte doc
if not idc.is_loaded(ea):
def read_bytes_at(db: Database, ea: int, count: int) -> bytes:
"""read bytes at address
args:
db: IDA Domain Database handle
ea: effective address
count: number of bytes to read
"""
if not db.bytes.is_value_initialized_at(ea):
return b""
segm_end = idc.get_segm_end(ea)
if ea + count > segm_end:
return idc.get_bytes(ea, segm_end - ea)
seg = db.segments.get_at(ea)
if seg is None:
return b""
if ea + count > seg.end_ea:
return db.bytes.get_bytes_at(ea, seg.end_ea - ea) or b""
else:
return idc.get_bytes(ea, count)
return db.bytes.get_bytes_at(ea, count) or b""
def find_string_at(ea: int, min_: int = 4) -> str:
"""check if ASCII string exists at a given virtual address"""
def find_string_at(db: Database, ea: int, min_: int = 4) -> str:
"""check if string exists at a given virtual address
Note: Uses SDK fallback as Domain API get_string_at only works for
addresses where IDA has already identified a string.
args:
db: IDA Domain Database handle (unused, kept for API consistency)
ea: effective address
min_: minimum string length
"""
found = idaapi.get_strlit_contents(ea, -1, idaapi.STRTYPE_C)
if found and len(found) >= min_:
try:
@@ -375,31 +382,51 @@ def mask_op_val(op: idaapi.op_t) -> int:
return masks.get(op.dtype, op.value) & op.value
def is_function_recursive(f: idaapi.func_t) -> bool:
"""check if function is recursive"""
return any(f.contains(ref) for ref in idautils.CodeRefsTo(f.start_ea, True))
def is_function_recursive(db: Database, f: idaapi.func_t) -> bool:
"""check if function is recursive
args:
db: IDA Domain Database handle
f: function object
"""
for ref in db.xrefs.code_refs_to_ea(f.start_ea):
if f.contains(ref):
return True
return False
def is_basic_block_tight_loop(bb: idaapi.BasicBlock) -> bool:
def is_basic_block_tight_loop(db: Database, bb: idaapi.BasicBlock) -> bool:
"""check basic block loops to self
args:
db: IDA Domain Database handle
bb: basic block object
true if last instruction in basic block branches to basic block start
"""
bb_end = idc.prev_head(bb.end_ea)
bb_end = db.heads.get_previous(bb.end_ea)
if bb_end is None:
return False
if bb.start_ea < bb_end:
for ref in idautils.CodeRefsFrom(bb_end, True):
for ref in db.xrefs.code_refs_from_ea(bb_end):
if ref == bb.start_ea:
return True
return False
def find_data_reference_from_insn(insn: idaapi.insn_t, max_depth: int = 10) -> int:
"""search for data reference from instruction, return address of instruction if no reference exists"""
def find_data_reference_from_insn(db: Database, insn: idaapi.insn_t, max_depth: int = 10) -> int:
"""search for data reference from instruction, return address of instruction if no reference exists
args:
db: IDA Domain Database handle
insn: instruction object
max_depth: maximum depth to follow references
"""
depth = 0
ea = insn.ea
while True:
data_refs = list(idautils.DataRefsFrom(ea))
data_refs = list(db.xrefs.data_refs_from_ea(ea))
if len(data_refs) != 1:
# break if no refs or more than one ref (assume nested pointers only have one data reference)
@@ -409,7 +436,7 @@ def find_data_reference_from_insn(insn: idaapi.insn_t, max_depth: int = 10) -> i
# break if circular reference
break
if not idaapi.is_mapped(data_refs[0]):
if not db.is_valid_ea(data_refs[0]):
# break if address is not mapped
break
@@ -423,10 +450,16 @@ def find_data_reference_from_insn(insn: idaapi.insn_t, max_depth: int = 10) -> i
return ea
def get_function_blocks(f: idaapi.func_t) -> Iterator[idaapi.BasicBlock]:
"""yield basic blocks contained in specified function"""
def get_function_blocks(db: Database, f: idaapi.func_t) -> Iterator[idaapi.BasicBlock]:
"""yield basic blocks contained in specified function
args:
db: IDA Domain Database handle
f: function object
"""
# leverage idaapi.FC_NOEXT flag to ignore useless external blocks referenced by the function
yield from idaapi.FlowChart(f, flags=(idaapi.FC_PREDS | idaapi.FC_NOEXT))
flowchart = db.functions.get_flowchart(f, flags=(idaapi.FC_PREDS | idaapi.FC_NOEXT))
yield from flowchart
def is_basic_block_return(bb: idaapi.BasicBlock) -> bool:
@@ -446,7 +479,17 @@ def find_alternative_names(cmt: str):
yield name
def get_function_alternative_names(fva: int):
"""Get all alternative names for an address."""
yield from find_alternative_names(ida_bytes.get_cmt(fva, False) or "")
yield from find_alternative_names(ida_funcs.get_func_cmt(idaapi.get_func(fva), False) or "")
def get_function_alternative_names(db: Database, fva: int):
"""Get all alternative names for an address.
args:
db: IDA Domain Database handle
fva: function virtual address
"""
cmt_info = db.comments.get_at(fva)
cmt = cmt_info.comment if cmt_info else ""
yield from find_alternative_names(cmt)
f = db.functions.get_at(fva)
if f:
func_cmt = db.functions.get_comment(f, False)
yield from find_alternative_names(func_cmt or "")

View File

@@ -18,7 +18,8 @@ from typing import Any, Iterator, Optional
import idc
import ida_ua
import idaapi
import idautils
from ida_domain import Database
from ida_domain.functions import FunctionFlags
import capa.features.extractors.helpers
import capa.features.extractors.ida.helpers
@@ -33,19 +34,19 @@ from capa.features.extractors.base_extractor import BBHandle, InsnHandle, Functi
SECURITY_COOKIE_BYTES_DELTA = 0x40
def get_imports(ctx: dict[str, Any]) -> dict[int, Any]:
def get_imports(db: Database, ctx: dict[str, Any]) -> dict[int, Any]:
if "imports_cache" not in ctx:
ctx["imports_cache"] = capa.features.extractors.ida.helpers.get_file_imports()
ctx["imports_cache"] = capa.features.extractors.ida.helpers.get_file_imports(db)
return ctx["imports_cache"]
def get_externs(ctx: dict[str, Any]) -> dict[int, Any]:
def get_externs(db: Database, ctx: dict[str, Any]) -> dict[int, Any]:
if "externs_cache" not in ctx:
ctx["externs_cache"] = capa.features.extractors.ida.helpers.get_file_externs()
ctx["externs_cache"] = capa.features.extractors.ida.helpers.get_file_externs(db)
return ctx["externs_cache"]
def check_for_api_call(insn: idaapi.insn_t, funcs: dict[int, Any]) -> Optional[tuple[str, str]]:
def check_for_api_call(db: Database, insn: idaapi.insn_t, funcs: dict[int, Any]) -> Optional[tuple[str, str]]:
"""check instruction for API call"""
info = None
ref = insn.ea
@@ -53,27 +54,32 @@ def check_for_api_call(insn: idaapi.insn_t, funcs: dict[int, Any]) -> Optional[t
# attempt to resolve API calls by following chained thunks to a reasonable depth
for _ in range(THUNK_CHAIN_DEPTH_DELTA):
# assume only one code/data ref when resolving "call" or "jmp"
try:
ref = tuple(idautils.CodeRefsFrom(ref, False))[0]
except IndexError:
try:
# thunks may be marked as data refs
ref = tuple(idautils.DataRefsFrom(ref))[0]
except IndexError:
code_refs = list(db.xrefs.code_refs_from_ea(ref, flow=False))
if code_refs:
ref = code_refs[0]
else:
# thunks may be marked as data refs
data_refs = list(db.xrefs.data_refs_from_ea(ref))
if data_refs:
ref = data_refs[0]
else:
break
info = funcs.get(ref)
if info:
break
f = idaapi.get_func(ref)
if not f or not (f.flags & idaapi.FUNC_THUNK):
f = db.functions.get_at(ref)
if f is None:
break
flags = db.functions.get_flags(f)
if not (flags & FunctionFlags.THUNK):
break
return info
def extract_insn_api_features(fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle) -> Iterator[tuple[Feature, Address]]:
def extract_insn_api_features(db: Database, fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle) -> Iterator[tuple[Feature, Address]]:
"""
parse instruction API features
@@ -82,35 +88,30 @@ def extract_insn_api_features(fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle)
"""
insn: idaapi.insn_t = ih.inner
if insn.get_canon_mnem() not in ("call", "jmp"):
mnem = db.instructions.get_mnemonic(insn)
if mnem not in ("call", "jmp"):
return
# check call to imported functions
api = check_for_api_call(insn, get_imports(fh.ctx))
api = check_for_api_call(db, insn, get_imports(db, fh.ctx))
if api:
# tuple (<module>, <function>, <ordinal>)
for name in capa.features.extractors.helpers.generate_symbols(api[0], api[1]):
yield API(name), ih.address
# a call instruction should only call one function, stop if a call to an import is extracted
return
# check call to extern functions
api = check_for_api_call(insn, get_externs(fh.ctx))
api = check_for_api_call(db, insn, get_externs(db, fh.ctx))
if api:
# tuple (<module>, <function>, <ordinal>)
yield API(api[1]), ih.address
# a call instruction should only call one function, stop if a call to an extern is extracted
return
# extract dynamically resolved APIs stored in renamed globals (renamed for example using `renimp.idc`)
# examples: `CreateProcessA`, `HttpSendRequestA`
if insn.Op1.type == ida_ua.o_mem:
op_addr = insn.Op1.addr
op_name = idaapi.get_name(op_addr)
op_name = db.names.get_at(op_addr)
# when renaming a global using an API name, IDA assigns it the function type
# ensure we do not extract something wrong by checking that the address has a name and a type
# we could check that the type is a function definition, but that complicates the code
if (not op_name.startswith("off_")) and idc.get_type(op_addr):
if op_name and (not op_name.startswith("off_")) and idc.get_type(op_addr):
# Remove suffix used in repeated names, for example _0 in VirtualFree_0
match = re.match(r"(.+)_\d+", op_name)
if match:
@@ -119,19 +120,21 @@ def extract_insn_api_features(fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle)
for name in capa.features.extractors.helpers.generate_symbols("", op_name):
yield API(name), ih.address
# extract IDA/FLIRT recognized API functions
targets = tuple(idautils.CodeRefsFrom(insn.ea, False))
targets = list(db.xrefs.code_refs_from_ea(insn.ea, flow=False))
if not targets:
return
target = targets[0]
target_func = idaapi.get_func(target)
target_func = db.functions.get_at(target)
if not target_func or target_func.start_ea != target:
# not a function (start)
return
name = idaapi.get_name(target_func.start_ea)
if target_func.flags & idaapi.FUNC_LIB or not name.startswith("sub_"):
name = db.names.get_at(target_func.start_ea)
if not name:
return
flags = db.functions.get_flags(target_func)
if flags & FunctionFlags.LIB or not name.startswith("sub_"):
yield API(name), ih.address
if name.startswith("_"):
# some linkers may prefix linked routines with a `_` to avoid name collisions.
@@ -140,13 +143,13 @@ def extract_insn_api_features(fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle)
# see: https://stackoverflow.com/a/2628384/87207
yield API(name[1:]), ih.address
for altname in capa.features.extractors.ida.helpers.get_function_alternative_names(target_func.start_ea):
for altname in capa.features.extractors.ida.helpers.get_function_alternative_names(db, target_func.start_ea):
yield FunctionName(altname), ih.address
yield API(altname), ih.address
def extract_insn_number_features(
fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle
db: Database, fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle
) -> Iterator[tuple[Feature, Address]]:
"""
parse instruction number features
@@ -155,7 +158,7 @@ def extract_insn_number_features(
"""
insn: idaapi.insn_t = ih.inner
if idaapi.is_ret_insn(insn):
if db.instructions.breaks_sequential_flow(insn):
# skip things like:
# .text:0042250E retn 8
return
@@ -183,7 +186,8 @@ def extract_insn_number_features(
yield Number(const), ih.address
yield OperandNumber(i, const), ih.address
if insn.itype == idaapi.NN_add and 0 < const < MAX_STRUCTURE_SIZE and op.type == idaapi.o_imm:
mnem = db.instructions.get_mnemonic(insn)
if mnem == "add" and 0 < const < MAX_STRUCTURE_SIZE and op.type == idaapi.o_imm:
# for pattern like:
#
# add eax, 0x10
@@ -193,7 +197,7 @@ def extract_insn_number_features(
yield OperandOffset(i, const), ih.address
def extract_insn_bytes_features(fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle) -> Iterator[tuple[Feature, Address]]:
def extract_insn_bytes_features(db: Database, fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle) -> Iterator[tuple[Feature, Address]]:
"""
parse referenced byte sequences
example:
@@ -201,20 +205,20 @@ def extract_insn_bytes_features(fh: FunctionHandle, bbh: BBHandle, ih: InsnHandl
"""
insn: idaapi.insn_t = ih.inner
if idaapi.is_call_insn(insn):
if db.instructions.is_call_instruction(insn):
return
ref = capa.features.extractors.ida.helpers.find_data_reference_from_insn(insn)
ref = capa.features.extractors.ida.helpers.find_data_reference_from_insn(db, insn)
if ref != insn.ea:
extracted_bytes = capa.features.extractors.ida.helpers.read_bytes_at(ref, MAX_BYTES_FEATURE_SIZE)
extracted_bytes = capa.features.extractors.ida.helpers.read_bytes_at(db, ref, MAX_BYTES_FEATURE_SIZE)
if extracted_bytes and not capa.features.extractors.helpers.all_zeros(extracted_bytes):
if not capa.features.extractors.ida.helpers.find_string_at(ref):
if not capa.features.extractors.ida.helpers.find_string_at(db, ref):
# don't extract byte features for obvious strings
yield Bytes(extracted_bytes), ih.address
def extract_insn_string_features(
fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle
db: Database, fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle
) -> Iterator[tuple[Feature, Address]]:
"""
parse instruction string features
@@ -224,15 +228,15 @@ def extract_insn_string_features(
"""
insn: idaapi.insn_t = ih.inner
ref = capa.features.extractors.ida.helpers.find_data_reference_from_insn(insn)
ref = capa.features.extractors.ida.helpers.find_data_reference_from_insn(db, insn)
if ref != insn.ea:
found = capa.features.extractors.ida.helpers.find_string_at(ref)
found = capa.features.extractors.ida.helpers.find_string_at(db, ref)
if found:
yield String(found), ih.address
def extract_insn_offset_features(
fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle
db: Database, fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle
) -> Iterator[tuple[Feature, Address]]:
"""
parse instruction structure offset features
@@ -256,7 +260,7 @@ def extract_insn_offset_features(
if op_off is None:
continue
if idaapi.is_mapped(op_off):
if db.is_valid_ea(op_off):
# Ignore:
# mov esi, dword_1005B148[esi]
continue
@@ -269,8 +273,9 @@ def extract_insn_offset_features(
yield Offset(op_off), ih.address
yield OperandOffset(i, op_off), ih.address
mnem = db.instructions.get_mnemonic(insn)
if (
insn.itype == idaapi.NN_lea
mnem == "lea"
and i == 1
# o_displ is used for both:
# [eax+1]
@@ -305,7 +310,7 @@ def contains_stack_cookie_keywords(s: str) -> bool:
return any(keyword in s for keyword in ("stack", "security"))
def bb_stack_cookie_registers(bb: idaapi.BasicBlock) -> Iterator[int]:
def bb_stack_cookie_registers(db: Database, bb: idaapi.BasicBlock) -> Iterator[int]:
"""scan basic block for stack cookie operations
yield registers ids that may have been used for stack cookie operations
@@ -331,21 +336,22 @@ def bb_stack_cookie_registers(bb: idaapi.BasicBlock) -> Iterator[int]:
TODO: this is expensive, but necessary?...
"""
for insn in capa.features.extractors.ida.helpers.get_instructions_in_range(bb.start_ea, bb.end_ea):
if contains_stack_cookie_keywords(idc.GetDisasm(insn.ea)):
for insn in capa.features.extractors.ida.helpers.get_instructions_in_range(db, bb.start_ea, bb.end_ea):
disasm = db.instructions.get_disassembly(insn)
if contains_stack_cookie_keywords(disasm):
for op in capa.features.extractors.ida.helpers.get_insn_ops(insn, target_ops=(idaapi.o_reg,)):
if capa.features.extractors.ida.helpers.is_op_write(insn, op):
# only include modified registers
yield op.reg
def is_nzxor_stack_cookie_delta(f: idaapi.func_t, bb: idaapi.BasicBlock, insn: idaapi.insn_t) -> bool:
def is_nzxor_stack_cookie_delta(db: Database, f: idaapi.func_t, bb: idaapi.BasicBlock, insn: idaapi.insn_t) -> bool:
"""check if nzxor exists within stack cookie delta"""
# security cookie check should use SP or BP
if not capa.features.extractors.ida.helpers.is_frame_register(insn.Op2.reg):
return False
f_bbs = tuple(capa.features.extractors.ida.helpers.get_function_blocks(f))
f_bbs = tuple(capa.features.extractors.ida.helpers.get_function_blocks(db, f))
# expect security cookie init in first basic block within first bytes (instructions)
if capa.features.extractors.ida.helpers.is_basic_block_equal(bb, f_bbs[0]) and insn.ea < (
@@ -362,15 +368,17 @@ def is_nzxor_stack_cookie_delta(f: idaapi.func_t, bb: idaapi.BasicBlock, insn: i
return False
def is_nzxor_stack_cookie(f: idaapi.func_t, bb: idaapi.BasicBlock, insn: idaapi.insn_t) -> bool:
def is_nzxor_stack_cookie(db: Database, f: idaapi.func_t, bb: idaapi.BasicBlock, insn: idaapi.insn_t) -> bool:
"""check if nzxor is related to stack cookie"""
if contains_stack_cookie_keywords(idaapi.get_cmt(insn.ea, False)):
cmt_info = db.comments.get_at(insn.ea)
cmt = cmt_info.comment if cmt_info else ""
if contains_stack_cookie_keywords(cmt):
# Example:
# xor ecx, ebp ; StackCookie
return True
if is_nzxor_stack_cookie_delta(f, bb, insn):
if is_nzxor_stack_cookie_delta(db, f, bb, insn):
return True
stack_cookie_regs = tuple(bb_stack_cookie_registers(bb))
stack_cookie_regs = tuple(bb_stack_cookie_registers(db, bb))
if any(op_reg in stack_cookie_regs for op_reg in (insn.Op1.reg, insn.Op2.reg)):
# Example:
# mov eax, ___security_cookie
@@ -380,7 +388,7 @@ def is_nzxor_stack_cookie(f: idaapi.func_t, bb: idaapi.BasicBlock, insn: idaapi.
def extract_insn_nzxor_characteristic_features(
fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle
db: Database, fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle
) -> Iterator[tuple[Feature, Address]]:
"""
parse instruction non-zeroing XOR instruction
@@ -388,31 +396,33 @@ def extract_insn_nzxor_characteristic_features(
"""
insn: idaapi.insn_t = ih.inner
if insn.itype not in (idaapi.NN_xor, idaapi.NN_xorpd, idaapi.NN_xorps, idaapi.NN_pxor):
mnem = db.instructions.get_mnemonic(insn)
if mnem not in ("xor", "xorpd", "xorps", "pxor"):
return
if capa.features.extractors.ida.helpers.is_operand_equal(insn.Op1, insn.Op2):
return
if is_nzxor_stack_cookie(fh.inner, bbh.inner, insn):
if is_nzxor_stack_cookie(db, fh.inner, bbh.inner, insn):
return
yield Characteristic("nzxor"), ih.address
def extract_insn_mnemonic_features(
fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle
db: Database, fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle
) -> Iterator[tuple[Feature, Address]]:
"""parse instruction mnemonic features"""
yield Mnemonic(idc.print_insn_mnem(ih.inner.ea)), ih.address
mnem = db.instructions.get_mnemonic(ih.inner)
yield Mnemonic(mnem), ih.address
def extract_insn_obfs_call_plus_5_characteristic_features(
fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle
db: Database, fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle
) -> Iterator[tuple[Feature, Address]]:
"""
parse call $+5 instruction from the given instruction.
"""
insn: idaapi.insn_t = ih.inner
if not idaapi.is_call_insn(insn):
if not db.instructions.is_call_instruction(insn):
return
if insn.ea + 5 == idc.get_operand_value(insn.ea, 0):
@@ -420,7 +430,7 @@ def extract_insn_obfs_call_plus_5_characteristic_features(
def extract_insn_peb_access_characteristic_features(
fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle
db: Database, fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle
) -> Iterator[tuple[Feature, Address]]:
"""parse instruction peb access
@@ -431,14 +441,15 @@ def extract_insn_peb_access_characteristic_features(
"""
insn: idaapi.insn_t = ih.inner
if insn.itype not in (idaapi.NN_push, idaapi.NN_mov):
mnem = db.instructions.get_mnemonic(insn)
if mnem not in ("push", "mov"):
return
if all(op.type != idaapi.o_mem for op in insn.ops):
# try to optimize for only memory references
return
disasm = idc.GetDisasm(insn.ea)
disasm = db.instructions.get_disassembly(insn)
if " fs:30h" in disasm or " gs:60h" in disasm:
# TODO(mike-hunhoff): use proper IDA API for fetching segment access
@@ -448,7 +459,7 @@ def extract_insn_peb_access_characteristic_features(
def extract_insn_segment_access_features(
fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle
db: Database, fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle
) -> Iterator[tuple[Feature, Address]]:
"""parse instruction fs or gs access
@@ -461,7 +472,7 @@ def extract_insn_segment_access_features(
# try to optimize for only memory references
return
disasm = idc.GetDisasm(insn.ea)
disasm = db.instructions.get_disassembly(insn)
if " fs:" in disasm:
# TODO(mike-hunhoff): use proper IDA API for fetching segment access
@@ -477,37 +488,39 @@ def extract_insn_segment_access_features(
def extract_insn_cross_section_cflow(
fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle
db: Database, fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle
) -> Iterator[tuple[Feature, Address]]:
"""inspect the instruction for a CALL or JMP that crosses section boundaries"""
insn: idaapi.insn_t = ih.inner
for ref in idautils.CodeRefsFrom(insn.ea, False):
if ref in get_imports(fh.ctx):
for ref in db.xrefs.code_refs_from_ea(insn.ea, flow=False):
if ref in get_imports(db, fh.ctx):
# ignore API calls
continue
if not idaapi.getseg(ref):
ref_seg = db.segments.get_at(ref)
if ref_seg is None:
# handle IDA API bug
continue
if idaapi.getseg(ref) == idaapi.getseg(insn.ea):
insn_seg = db.segments.get_at(insn.ea)
if ref_seg == insn_seg:
continue
yield Characteristic("cross section flow"), ih.address
def extract_function_calls_from(fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle) -> Iterator[tuple[Feature, Address]]:
def extract_function_calls_from(db: Database, fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle) -> Iterator[tuple[Feature, Address]]:
"""extract functions calls from features
most relevant at the function scope, however, its most efficient to extract at the instruction scope
"""
insn: idaapi.insn_t = ih.inner
if idaapi.is_call_insn(insn):
for ref in idautils.CodeRefsFrom(insn.ea, False):
if db.instructions.is_call_instruction(insn):
for ref in db.xrefs.code_refs_from_ea(insn.ea, flow=False):
yield Characteristic("calls from"), AbsoluteVirtualAddress(ref)
def extract_function_indirect_call_characteristic_features(
fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle
db: Database, fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle
) -> Iterator[tuple[Feature, Address]]:
"""extract indirect function calls (e.g., call eax or call dword ptr [edx+4])
does not include calls like => call ds:dword_ABD4974
@@ -517,14 +530,14 @@ def extract_function_indirect_call_characteristic_features(
"""
insn: idaapi.insn_t = ih.inner
if idaapi.is_call_insn(insn) and idc.get_operand_type(insn.ea, 0) in (idc.o_reg, idc.o_phrase, idc.o_displ):
if db.instructions.is_call_instruction(insn) and idc.get_operand_type(insn.ea, 0) in (idc.o_reg, idc.o_phrase, idc.o_displ):
yield Characteristic("indirect call"), ih.address
def extract_features(f: FunctionHandle, bbh: BBHandle, insn: InsnHandle) -> Iterator[tuple[Feature, Address]]:
def extract_features(db: Database, f: FunctionHandle, bbh: BBHandle, insn: InsnHandle) -> Iterator[tuple[Feature, Address]]:
"""extract instruction features"""
for inst_handler in INSTRUCTION_HANDLERS:
for feature, ea in inst_handler(f, bbh, insn):
for feature, ea in inst_handler(db, f, bbh, insn):
yield feature, ea

View File

@@ -14,6 +14,7 @@
import ida_kernwin
from ida_domain import Database
from capa.ida.plugin.error import UserCancelledError
from capa.ida.plugin.qt_compat import QtCore, Signal
@@ -43,7 +44,8 @@ class CapaExplorerFeatureExtractor(IdaFeatureExtractor):
"""
def __init__(self):
super().__init__()
db = Database.open()
super().__init__(db)
self.indicator = CapaExplorerProgressIndicator()
def extract_function_features(self, fh: FunctionHandle):

View File

@@ -357,7 +357,7 @@ def get_extractor(
ida_auto.auto_wait()
logger.debug("idalib: opened database.")
return capa.features.extractors.ida.extractor.IdaFeatureExtractor()
return capa.features.extractors.ida.extractor.IdaFeatureExtractor.from_current_database()
elif backend == BACKEND_GHIDRA:
import pyghidra

View File

@@ -1094,7 +1094,7 @@ def ida_main():
meta = capa.ida.helpers.collect_metadata([rules_path])
capabilities = find_capabilities(rules, capa.features.extractors.ida.extractor.IdaFeatureExtractor())
capabilities = find_capabilities(rules, capa.features.extractors.ida.extractor.IdaFeatureExtractor.from_current_database())
meta.analysis.feature_counts = capabilities.feature_counts
meta.analysis.library_functions = capabilities.library_functions

View File

@@ -275,7 +275,7 @@ def ida_main():
function = idc.get_func_attr(idc.here(), idc.FUNCATTR_START)
print(f"getting features for current function {hex(function)}")
extractor = capa.features.extractors.ida.extractor.IdaFeatureExtractor()
extractor = capa.features.extractors.ida.extractor.IdaFeatureExtractor.from_current_database()
if not function:
for feature, addr in extractor.extract_file_features():

View File

@@ -175,7 +175,7 @@ def ida_main():
function = idc.get_func_attr(idc.here(), idc.FUNCATTR_START)
print(f"getting features for current function {hex(function)}")
extractor = capa.features.extractors.ida.extractor.IdaFeatureExtractor()
extractor = capa.features.extractors.ida.extractor.IdaFeatureExtractor.from_current_database()
feature_map: Counter[Feature] = Counter()
feature_map.update([feature for feature, _ in extractor.extract_file_features()])

View File

@@ -241,7 +241,7 @@ def get_idalib_extractor(path: Path):
ida_auto.auto_wait()
logger.debug("idalib: opened database.")
extractor = capa.features.extractors.ida.extractor.IdaFeatureExtractor()
extractor = capa.features.extractors.ida.extractor.IdaFeatureExtractor.from_current_database()
fixup_idalib(path, extractor)
return extractor

View File

@@ -95,7 +95,7 @@ def get_ida_extractor(_path):
# have to import this inline so pytest doesn't bail outside of IDA
import capa.features.extractors.ida.extractor
return capa.features.extractors.ida.extractor.IdaFeatureExtractor()
return capa.features.extractors.ida.extractor.IdaFeatureExtractor.from_current_database()
def nocollect(f):