mirror of
https://github.com/mandiant/capa.git
synced 2025-12-12 15:49:46 -08:00
lints
This commit is contained in:
@@ -177,7 +177,9 @@ def main(argv=None):
|
||||
|
||||
for va in idautils.Functions():
|
||||
name = idaapi.get_func_name(va)
|
||||
if name not in {"WinMain", }:
|
||||
if name not in {
|
||||
"WinMain",
|
||||
}:
|
||||
continue
|
||||
|
||||
function_classifications.append(
|
||||
|
||||
@@ -3,17 +3,15 @@ import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import argparse
|
||||
from typing import Optional
|
||||
from typing import Iterator, Optional
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
|
||||
import rich
|
||||
import rich.table
|
||||
import pefile
|
||||
import lancelot
|
||||
import lancelot.be2utils
|
||||
import networkx as nx
|
||||
from lancelot.be2utils import BinExport2Index,ReadMemoryError, AddressSpace
|
||||
import lancelot.be2utils
|
||||
from lancelot.be2utils import AddressSpace, BinExport2Index, ReadMemoryError
|
||||
from lancelot.be2utils.binexport2_pb2 import BinExport2
|
||||
|
||||
import capa.main
|
||||
@@ -21,7 +19,6 @@ import capa.main
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
def is_vertex_type(vertex: BinExport2.CallGraph.Vertex, type_: BinExport2.CallGraph.Vertex.Type.ValueType) -> bool:
|
||||
return vertex.HasField("type") and vertex.type == type_
|
||||
|
||||
@@ -144,7 +141,8 @@ class Assemblage:
|
||||
|
||||
self.conn = sqlite3.connect(self.db)
|
||||
with self.conn:
|
||||
self.conn.executescript("""
|
||||
self.conn.executescript(
|
||||
"""
|
||||
PRAGMA journal_mode = WAL;
|
||||
PRAGMA synchronous = NORMAL;
|
||||
PRAGMA busy_timeout = 5000;
|
||||
@@ -156,8 +154,8 @@ class Assemblage:
|
||||
CREATE INDEX IF NOT EXISTS idx__functions__binary_id ON functions (binary_id);
|
||||
CREATE INDEX IF NOT EXISTS idx__rvas__function_id ON rvas (function_id);
|
||||
|
||||
CREATE VIEW IF NOT EXISTS assemblage AS
|
||||
SELECT
|
||||
CREATE VIEW IF NOT EXISTS assemblage AS
|
||||
SELECT
|
||||
binaries.id AS binary_id,
|
||||
binaries.file_name AS file_name,
|
||||
binaries.platform AS platform,
|
||||
@@ -183,19 +181,20 @@ class Assemblage:
|
||||
rvas.id AS rva_id,
|
||||
rvas.start AS start_rva,
|
||||
rvas.end AS end_rva
|
||||
FROM binaries
|
||||
FROM binaries
|
||||
JOIN functions ON binaries.id = functions.binary_id
|
||||
JOIN rvas ON functions.id = rvas.function_id;
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
def get_row_by_binary_id(self, binary_id: int) -> AssemblageRow:
|
||||
with self.conn:
|
||||
cur = self.conn.execute("SELECT * FROM assemblage WHERE binary_id = ? LIMIT 1;", (binary_id, ))
|
||||
cur = self.conn.execute("SELECT * FROM assemblage WHERE binary_id = ? LIMIT 1;", (binary_id,))
|
||||
return AssemblageRow(*cur.fetchone())
|
||||
|
||||
def get_rows_by_binary_id(self, binary_id: int) -> AssemblageRow:
|
||||
def get_rows_by_binary_id(self, binary_id: int) -> Iterator[AssemblageRow]:
|
||||
with self.conn:
|
||||
cur = self.conn.execute("SELECT * FROM assemblage WHERE binary_id = ?;", (binary_id, ))
|
||||
cur = self.conn.execute("SELECT * FROM assemblage WHERE binary_id = ?;", (binary_id,))
|
||||
row = cur.fetchone()
|
||||
while row:
|
||||
yield AssemblageRow(*row)
|
||||
@@ -203,14 +202,13 @@ class Assemblage:
|
||||
|
||||
def get_path_by_binary_id(self, binary_id: int) -> Path:
|
||||
with self.conn:
|
||||
cur = self.conn.execute("""SELECT path FROM assemblage WHERE binary_id = ? LIMIT 1""", (binary_id, ))
|
||||
cur = self.conn.execute("""SELECT path FROM assemblage WHERE binary_id = ? LIMIT 1""", (binary_id,))
|
||||
return self.samples / cur.fetchone()[0]
|
||||
|
||||
def get_pe_by_binary_id(self, binary_id: int) -> pefile.PE:
|
||||
path = self.get_path_by_binary_id(binary_id)
|
||||
return pefile.PE(data=path.read_bytes(), fast_load=True)
|
||||
|
||||
|
||||
|
||||
def generate_main(args: argparse.Namespace) -> int:
|
||||
if not args.assemblage_database.is_file():
|
||||
@@ -240,11 +238,7 @@ def generate_main(args: argparse.Namespace) -> int:
|
||||
|
||||
pe_path = db.get_path_by_binary_id(args.binary_id)
|
||||
be2: BinExport2 = lancelot.get_binexport2_from_bytes(
|
||||
pe_path.read_bytes(),
|
||||
function_hints=[
|
||||
base_address + function.start_rva
|
||||
for function in functions
|
||||
]
|
||||
pe_path.read_bytes(), function_hints=[base_address + function.start_rva for function in functions]
|
||||
)
|
||||
|
||||
idx = lancelot.be2utils.BinExport2Index(be2)
|
||||
@@ -253,7 +247,7 @@ def generate_main(args: argparse.Namespace) -> int:
|
||||
|
||||
g = nx.MultiDiGraph()
|
||||
|
||||
for flow_graph_index, flow_graph in enumerate(be2.flow_graph):
|
||||
for flow_graph in be2.flow_graph:
|
||||
datas: set[int] = set()
|
||||
callees: set[str] = set()
|
||||
|
||||
@@ -263,7 +257,7 @@ def generate_main(args: argparse.Namespace) -> int:
|
||||
for basic_block_index in flow_graph.basic_block_index:
|
||||
basic_block: BinExport2.BasicBlock = be2.basic_block[basic_block_index]
|
||||
|
||||
for instruction_index, instruction, instruction_address in idx.basic_block_instructions(basic_block):
|
||||
for instruction_index, instruction, _ in idx.basic_block_instructions(basic_block):
|
||||
for addr in instruction.call_target:
|
||||
addr = thunks.get(addr, addr)
|
||||
|
||||
@@ -277,7 +271,9 @@ def generate_main(args: argparse.Namespace) -> int:
|
||||
|
||||
callees.add(vertex.address)
|
||||
|
||||
for data_reference_index in idx.data_reference_index_by_source_instruction_index.get(instruction_index, []):
|
||||
for data_reference_index in idx.data_reference_index_by_source_instruction_index.get(
|
||||
instruction_index, []
|
||||
):
|
||||
data_reference: BinExport2.DataReference = be2.data_reference[data_reference_index]
|
||||
data_reference_address: int = data_reference.address
|
||||
|
||||
@@ -336,12 +332,15 @@ def generate_main(args: argparse.Namespace) -> int:
|
||||
# within each section, emit a neighbor edge for each pair of neighbors.
|
||||
|
||||
section_nodes = [
|
||||
node for node, attrs in g.nodes(data=True)
|
||||
if (section.VirtualAddress + base_address) <= attrs["address"] < (base_address + section.VirtualAddress + section.Misc_VirtualSize)
|
||||
node
|
||||
for node, attrs in g.nodes(data=True)
|
||||
if (section.VirtualAddress + base_address)
|
||||
<= attrs["address"]
|
||||
< (base_address + section.VirtualAddress + section.Misc_VirtualSize)
|
||||
]
|
||||
|
||||
for i in range(1, len(section_nodes)):
|
||||
a = section_nodes[i-1]
|
||||
a = section_nodes[i - 1]
|
||||
b = section_nodes[i]
|
||||
|
||||
g.add_edge(
|
||||
@@ -353,8 +352,8 @@ def generate_main(args: argparse.Namespace) -> int:
|
||||
)
|
||||
|
||||
for function in functions:
|
||||
g.nodes[base_address+function.start_rva]["name"] = function.name
|
||||
g.nodes[base_address+function.start_rva]["file"] = function.file
|
||||
g.nodes[base_address + function.start_rva]["name"] = function.name
|
||||
g.nodes[base_address + function.start_rva]["file"] = function.file
|
||||
|
||||
# rename unknown functions like: sub_401000
|
||||
for n, attrs in g.nodes(data=True):
|
||||
@@ -373,7 +372,7 @@ def generate_main(args: argparse.Namespace) -> int:
|
||||
attrs["repr"] = attrs["name"]
|
||||
attrs["is_import"] = "!" in attrs["name"]
|
||||
case "data":
|
||||
if (string := read_string(address_space, n)):
|
||||
if string := read_string(address_space, n):
|
||||
attrs["repr"] = json.dumps(string)
|
||||
attrs["is_string"] = True
|
||||
else:
|
||||
@@ -384,6 +383,7 @@ def generate_main(args: argparse.Namespace) -> int:
|
||||
print(line)
|
||||
|
||||
# db.conn.close()
|
||||
return 0
|
||||
|
||||
|
||||
def cluster_main(args: argparse.Namespace) -> int:
|
||||
@@ -391,7 +391,7 @@ def cluster_main(args: argparse.Namespace) -> int:
|
||||
raise ValueError("graph file doesn't exist")
|
||||
|
||||
g = nx.read_gexf(args.graph)
|
||||
|
||||
|
||||
communities = nx.algorithms.community.louvain_communities(g)
|
||||
for i, community in enumerate(communities):
|
||||
print(f"[{i}]:")
|
||||
@@ -401,6 +401,8 @@ def cluster_main(args: argparse.Namespace) -> int:
|
||||
else:
|
||||
print(f" - {hex(int(node, 0))}")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def main(argv=None) -> int:
|
||||
if argv is None:
|
||||
@@ -416,7 +418,6 @@ def main(argv=None) -> int:
|
||||
generate_parser.add_argument("binary_id", type=int, help="primary key of binary to inspect")
|
||||
generate_parser.set_defaults(func=generate_main)
|
||||
|
||||
|
||||
cluster_parser = subparsers.add_parser("cluster", help="cluster an existing graph")
|
||||
cluster_parser.add_argument("graph", type=Path, help="path to a graph file")
|
||||
cluster_parser.set_defaults(func=cluster_main)
|
||||
|
||||
Reference in New Issue
Block a user