This commit is contained in:
Willi Ballenthin
2024-11-06 12:18:41 +00:00
parent 3e02b67480
commit f296e7d423
2 changed files with 36 additions and 33 deletions

View File

@@ -177,7 +177,9 @@ def main(argv=None):
for va in idautils.Functions():
name = idaapi.get_func_name(va)
if name not in {"WinMain", }:
if name not in {
"WinMain",
}:
continue
function_classifications.append(

View File

@@ -3,17 +3,15 @@ import json
import logging
import sqlite3
import argparse
from typing import Optional
from typing import Iterator, Optional
from pathlib import Path
from dataclasses import dataclass
import rich
import rich.table
import pefile
import lancelot
import lancelot.be2utils
import networkx as nx
from lancelot.be2utils import BinExport2Index,ReadMemoryError, AddressSpace
import lancelot.be2utils
from lancelot.be2utils import AddressSpace, BinExport2Index, ReadMemoryError
from lancelot.be2utils.binexport2_pb2 import BinExport2
import capa.main
@@ -21,7 +19,6 @@ import capa.main
logger = logging.getLogger(__name__)
def is_vertex_type(vertex: BinExport2.CallGraph.Vertex, type_: BinExport2.CallGraph.Vertex.Type.ValueType) -> bool:
return vertex.HasField("type") and vertex.type == type_
@@ -144,7 +141,8 @@ class Assemblage:
self.conn = sqlite3.connect(self.db)
with self.conn:
self.conn.executescript("""
self.conn.executescript(
"""
PRAGMA journal_mode = WAL;
PRAGMA synchronous = NORMAL;
PRAGMA busy_timeout = 5000;
@@ -156,8 +154,8 @@ class Assemblage:
CREATE INDEX IF NOT EXISTS idx__functions__binary_id ON functions (binary_id);
CREATE INDEX IF NOT EXISTS idx__rvas__function_id ON rvas (function_id);
CREATE VIEW IF NOT EXISTS assemblage AS
SELECT
CREATE VIEW IF NOT EXISTS assemblage AS
SELECT
binaries.id AS binary_id,
binaries.file_name AS file_name,
binaries.platform AS platform,
@@ -183,19 +181,20 @@ class Assemblage:
rvas.id AS rva_id,
rvas.start AS start_rva,
rvas.end AS end_rva
FROM binaries
FROM binaries
JOIN functions ON binaries.id = functions.binary_id
JOIN rvas ON functions.id = rvas.function_id;
""")
"""
)
def get_row_by_binary_id(self, binary_id: int) -> AssemblageRow:
with self.conn:
cur = self.conn.execute("SELECT * FROM assemblage WHERE binary_id = ? LIMIT 1;", (binary_id, ))
cur = self.conn.execute("SELECT * FROM assemblage WHERE binary_id = ? LIMIT 1;", (binary_id,))
return AssemblageRow(*cur.fetchone())
def get_rows_by_binary_id(self, binary_id: int) -> AssemblageRow:
def get_rows_by_binary_id(self, binary_id: int) -> Iterator[AssemblageRow]:
with self.conn:
cur = self.conn.execute("SELECT * FROM assemblage WHERE binary_id = ?;", (binary_id, ))
cur = self.conn.execute("SELECT * FROM assemblage WHERE binary_id = ?;", (binary_id,))
row = cur.fetchone()
while row:
yield AssemblageRow(*row)
@@ -203,14 +202,13 @@ class Assemblage:
def get_path_by_binary_id(self, binary_id: int) -> Path:
with self.conn:
cur = self.conn.execute("""SELECT path FROM assemblage WHERE binary_id = ? LIMIT 1""", (binary_id, ))
cur = self.conn.execute("""SELECT path FROM assemblage WHERE binary_id = ? LIMIT 1""", (binary_id,))
return self.samples / cur.fetchone()[0]
def get_pe_by_binary_id(self, binary_id: int) -> pefile.PE:
path = self.get_path_by_binary_id(binary_id)
return pefile.PE(data=path.read_bytes(), fast_load=True)
def generate_main(args: argparse.Namespace) -> int:
if not args.assemblage_database.is_file():
@@ -240,11 +238,7 @@ def generate_main(args: argparse.Namespace) -> int:
pe_path = db.get_path_by_binary_id(args.binary_id)
be2: BinExport2 = lancelot.get_binexport2_from_bytes(
pe_path.read_bytes(),
function_hints=[
base_address + function.start_rva
for function in functions
]
pe_path.read_bytes(), function_hints=[base_address + function.start_rva for function in functions]
)
idx = lancelot.be2utils.BinExport2Index(be2)
@@ -253,7 +247,7 @@ def generate_main(args: argparse.Namespace) -> int:
g = nx.MultiDiGraph()
for flow_graph_index, flow_graph in enumerate(be2.flow_graph):
for flow_graph in be2.flow_graph:
datas: set[int] = set()
callees: set[str] = set()
@@ -263,7 +257,7 @@ def generate_main(args: argparse.Namespace) -> int:
for basic_block_index in flow_graph.basic_block_index:
basic_block: BinExport2.BasicBlock = be2.basic_block[basic_block_index]
for instruction_index, instruction, instruction_address in idx.basic_block_instructions(basic_block):
for instruction_index, instruction, _ in idx.basic_block_instructions(basic_block):
for addr in instruction.call_target:
addr = thunks.get(addr, addr)
@@ -277,7 +271,9 @@ def generate_main(args: argparse.Namespace) -> int:
callees.add(vertex.address)
for data_reference_index in idx.data_reference_index_by_source_instruction_index.get(instruction_index, []):
for data_reference_index in idx.data_reference_index_by_source_instruction_index.get(
instruction_index, []
):
data_reference: BinExport2.DataReference = be2.data_reference[data_reference_index]
data_reference_address: int = data_reference.address
@@ -336,12 +332,15 @@ def generate_main(args: argparse.Namespace) -> int:
# within each section, emit a neighbor edge for each pair of neighbors.
section_nodes = [
node for node, attrs in g.nodes(data=True)
if (section.VirtualAddress + base_address) <= attrs["address"] < (base_address + section.VirtualAddress + section.Misc_VirtualSize)
node
for node, attrs in g.nodes(data=True)
if (section.VirtualAddress + base_address)
<= attrs["address"]
< (base_address + section.VirtualAddress + section.Misc_VirtualSize)
]
for i in range(1, len(section_nodes)):
a = section_nodes[i-1]
a = section_nodes[i - 1]
b = section_nodes[i]
g.add_edge(
@@ -353,8 +352,8 @@ def generate_main(args: argparse.Namespace) -> int:
)
for function in functions:
g.nodes[base_address+function.start_rva]["name"] = function.name
g.nodes[base_address+function.start_rva]["file"] = function.file
g.nodes[base_address + function.start_rva]["name"] = function.name
g.nodes[base_address + function.start_rva]["file"] = function.file
# rename unknown functions like: sub_401000
for n, attrs in g.nodes(data=True):
@@ -373,7 +372,7 @@ def generate_main(args: argparse.Namespace) -> int:
attrs["repr"] = attrs["name"]
attrs["is_import"] = "!" in attrs["name"]
case "data":
if (string := read_string(address_space, n)):
if string := read_string(address_space, n):
attrs["repr"] = json.dumps(string)
attrs["is_string"] = True
else:
@@ -384,6 +383,7 @@ def generate_main(args: argparse.Namespace) -> int:
print(line)
# db.conn.close()
return 0
def cluster_main(args: argparse.Namespace) -> int:
@@ -391,7 +391,7 @@ def cluster_main(args: argparse.Namespace) -> int:
raise ValueError("graph file doesn't exist")
g = nx.read_gexf(args.graph)
communities = nx.algorithms.community.louvain_communities(g)
for i, community in enumerate(communities):
print(f"[{i}]:")
@@ -401,6 +401,8 @@ def cluster_main(args: argparse.Namespace) -> int:
else:
print(f" - {hex(int(node, 0))}")
return 0
def main(argv=None) -> int:
if argv is None:
@@ -416,7 +418,6 @@ def main(argv=None) -> int:
generate_parser.add_argument("binary_id", type=int, help="primary key of binary to inspect")
generate_parser.set_defaults(func=generate_main)
cluster_parser = subparsers.add_parser("cluster", help="cluster an existing graph")
cluster_parser.add_argument("graph", type=Path, help="path to a graph file")
cluster_parser.set_defaults(func=cluster_main)