codecut: better graph structure

This commit is contained in:
Willi Ballenthin
2024-11-12 14:43:32 +00:00
parent 3b1a8f5b5a
commit 6fc4567f0c

View File

@@ -230,29 +230,21 @@ def generate_main(args: argparse.Namespace) -> int:
db = Assemblage(args.assemblage_database, args.assemblage_directory)
@dataclass
class Function:
file: str
name: str
start_rva: int
end_rva: int
functions = [
Function(
file=m.source_file,
name=m.function_name,
start_rva=m.start_rva,
end_rva=m.end_rva,
)
for m in db.get_rows_by_binary_id(args.binary_id)
]
pe = db.get_pe_by_binary_id(args.binary_id)
base_address: int = pe.OPTIONAL_HEADER.ImageBase
functions_by_address = {
base_address + function.start_rva: function for function in db.get_rows_by_binary_id(args.binary_id)
}
hash = db.get_row_by_binary_id(args.binary_id).binary_hash
def make_node_id(address: int) -> str:
return f"{hash}:{address:x}"
pe_path = db.get_path_by_binary_id(args.binary_id)
be2: BinExport2 = lancelot.get_binexport2_from_bytes(
pe_path.read_bytes(), function_hints=[base_address + function.start_rva for function in functions]
pe_path.read_bytes(), function_hints=list(functions_by_address.keys())
)
idx = lancelot.be2utils.BinExport2Index(be2)
@@ -262,16 +254,16 @@ def generate_main(args: argparse.Namespace) -> int:
g = nx.MultiDiGraph()
# ensure all functions from ground truth have an entry
for function in functions:
for address, function in functions_by_address.items():
g.add_node(
base_address + function.start_rva,
address=base_address + function.start_rva,
make_node_id(address),
address=address,
type="function",
)
for flow_graph in be2.flow_graph:
datas: set[int] = set()
callees: set[str] = set()
callees: set[int] = set()
entry_basic_block_index: int = flow_graph.entry_basic_block_index
flow_graph_address: int = idx.get_basic_block_address(entry_basic_block_index)
@@ -309,81 +301,127 @@ def generate_main(args: argparse.Namespace) -> int:
name = idx.get_function_name_by_vertex(vertex_index)
g.add_node(
flow_graph_address,
make_node_id(flow_graph_address),
address=flow_graph_address,
type="function",
)
if datas or callees:
logger.info("%s @ 0x%X:", name, flow_graph_address)
for data in sorted(datas):
logger.info(" - 0x%X", data)
for data_address in sorted(datas):
logger.info(" - 0x%X", data_address)
g.add_node(
data,
address=data,
make_node_id(data_address),
address=data_address,
type="data",
)
g.add_edge(
flow_graph_address,
data,
make_node_id(flow_graph_address),
make_node_id(data_address),
key="reference",
source_address=flow_graph_address,
destination_address=data,
)
for callee in sorted(callees):
logger.info(" - %s", idx.get_function_name_by_address(callee))
g.add_node(
callee,
make_node_id(callee),
address=callee,
type="function",
)
g.add_edge(
flow_graph_address,
callee,
make_node_id(flow_graph_address),
make_node_id(callee),
key="call",
source_address=flow_graph_address,
destination_address=callee,
)
else:
logger.info("%s @ 0x%X: (none)", name, flow_graph_address)
for section in pe.sections:
# within each section, emit a neighbor edge for each pair of neighbors.
for node, attrs in g.nodes(data=True):
if attrs["type"] != "function":
continue
section_nodes = [
node
for node, attrs in g.nodes(data=True)
if (section.VirtualAddress + base_address)
<= attrs["address"]
< (base_address + section.VirtualAddress + section.Misc_VirtualSize)
if f := functions_by_address.get(attrs["address"]):
attrs["name"] = f.function_name
attrs["file"] = f.file_name
for section in pe.sections:
# Within each section, emit a neighbor edge for each pair of neighbors.
# Neighbors only link nodes of the same type, because assemblage doesn't
# have ground truth for data items, so we don't quite know where to split.
# Consider this situation:
#
# moduleA::func1
# --- cut ---
# moduleB::func1
#
# that one is ok, but this is hard:
#
# moduleA::func1
# --- cut??? ---
# dataZ
# --- or cut here??? ---
# moduleB::func1
#
# Does the cut go before or after dataZ?
# So, we only have neighbor graphs within functions, and within datas.
# For datas, we don't allow interspersed functions.
section_nodes = sorted(
[
(node, attrs)
for node, attrs in g.nodes(data=True)
if (section.VirtualAddress + base_address)
<= attrs["address"]
< (base_address + section.VirtualAddress + section.Misc_VirtualSize)
],
key=lambda p: p[1]["address"],
)
# add neighbor edges between data items.
# the data items must not be separated by any functions.
for i in range(1, len(section_nodes)):
a, a_attrs = section_nodes[i - 1]
b, b_attrs = section_nodes[i]
if a_attrs["type"] != "data":
continue
if b_attrs["type"] != "data":
continue
g.add_edge(a, b, key="neighbor")
g.add_edge(b, a, key="neighbor")
section_functions = [
(node, attrs)
for node, attrs in section_nodes
if attrs["type"] == "function"
# we only have ground truth for the known functions
# so only consider those in the function neighbor graph.
and attrs["address"] in functions_by_address
]
for i in range(1, len(section_nodes)):
a = section_nodes[i - 1]
b = section_nodes[i]
# add neighbor edges between functions.
# we drop the potentially interspersed data items before computing these edges.
for i in range(1, len(section_functions)):
a, a_attrs = section_functions[i - 1]
b, b_attrs = section_functions[i]
is_boundary = a_attrs["file"] == b_attrs["file"]
g.add_edge(
a,
b,
key="neighbor",
source_address=a,
destination_address=b,
)
for function in functions:
g.nodes[base_address + function.start_rva]["name"] = function.name
g.nodes[base_address + function.start_rva]["file"] = function.file
g.add_edge(a, b, key="neighbor", is_source_file_boundary=is_boundary)
g.add_edge(b, a, key="neighbor", is_source_file_boundary=is_boundary)
# rename unknown functions like: sub_401000
for n, attrs in g.nodes(data=True):
if attrs["type"] != "function":
continue
if "name" in attrs:
continue
attrs["name"] = f"sub_{n:x}"
attrs["name"] = f"sub_{attrs['address']:x}"
# assign human-readable repr to add nodes
# assign is_import=bool to functions
@@ -394,11 +432,11 @@ def generate_main(args: argparse.Namespace) -> int:
attrs["repr"] = attrs["name"]
attrs["is_import"] = "!" in attrs["name"]
case "data":
if string := read_string(address_space, n):
if string := read_string(address_space, attrs["address"]):
attrs["repr"] = json.dumps(string)
attrs["is_string"] = True
else:
attrs["repr"] = f"data_{n:x}"
attrs["repr"] = f"data_{attrs['address']:x}"
attrs["is_string"] = False
for line in nx.generate_gexf(g):
@@ -441,21 +479,26 @@ def generate_all_main(args: argparse.Namespace) -> int:
db = Assemblage(args.assemblage_database, args.assemblage_directory)
output_directory: Path = args.output_directory
binary_ids = list(db.get_binary_ids())
with Pool(args.num_workers) as p:
_ = list(p.imap_unordered(
_worker,
((
args.assemblage_database,
args.assemblage_directory,
args.output_directory / str(binary_id) / "graph.gexf",
binary_id
)
for binary_id in binary_ids))
_ = list(
p.imap_unordered(
_worker,
(
(
args.assemblage_database,
args.assemblage_directory,
args.output_directory / str(binary_id) / "graph.gexf",
binary_id,
)
for binary_id in binary_ids
),
)
)
return 0
def cluster_main(args: argparse.Namespace) -> int:
if not args.graph.is_file():
@@ -489,13 +532,15 @@ def main(argv=None) -> int:
generate_parser.add_argument("binary_id", type=int, help="primary key of binary to inspect")
generate_parser.set_defaults(func=generate_main)
num_cores = os.cpu_count()
num_cores = os.cpu_count() or 1
default_workers = max(1, num_cores - 2)
generate_all_parser = subparsers.add_parser("generate_all", help="generate graphs for all samples")
generate_all_parser.add_argument("assemblage_database", type=Path, help="path to Assemblage database")
generate_all_parser.add_argument("assemblage_directory", type=Path, help="path to Assemblage samples directory")
generate_all_parser.add_argument("output_directory", type=Path, help="path to output directory")
generate_all_parser.add_argument("--num_workers", type=int, default=default_workers, help="number of workers to use")
generate_all_parser.add_argument(
"--num_workers", type=int, default=default_workers, help="number of workers to use"
)
generate_all_parser.set_defaults(func=generate_all_main)
cluster_parser = subparsers.add_parser("cluster", help="cluster an existing graph")