mirror of
https://github.com/mandiant/capa.git
synced 2025-12-12 15:49:46 -08:00
codecut: import to torch
This commit is contained in:
@@ -338,6 +338,7 @@ def generate_main(args: argparse.Namespace) -> int:
|
||||
else:
|
||||
logger.info("%s @ 0x%X: (none)", name, flow_graph_address)
|
||||
|
||||
# set ground truth node attributes from source data
|
||||
for node, attrs in g.nodes(data=True):
|
||||
if attrs["type"] != "function":
|
||||
continue
|
||||
@@ -410,6 +411,7 @@ def generate_main(args: argparse.Namespace) -> int:
|
||||
b, b_attrs = section_functions[i]
|
||||
is_boundary = a_attrs["file"] == b_attrs["file"]
|
||||
|
||||
# edge attribute: is_source_file_boundary
|
||||
g.add_edge(a, b, key="neighbor", is_source_file_boundary=is_boundary)
|
||||
g.add_edge(b, a, key="neighbor", is_source_file_boundary=is_boundary)
|
||||
|
||||
@@ -518,6 +520,226 @@ def cluster_main(args: argparse.Namespace) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
# uv pip install torch --index-url https://download.pytorch.org/whl/cpu
|
||||
# uv pip install torch-geometric pandas numpy
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torch_geometric.utils.convert
|
||||
|
||||
import torch
|
||||
from torch_geometric.data import HeteroData
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeType:
|
||||
type: str
|
||||
attributes: Dict[str, int | bool]
|
||||
|
||||
|
||||
@dataclass
|
||||
class EdgeType:
|
||||
key: str
|
||||
source_type: NodeType
|
||||
destination_type: NodeType
|
||||
attributes: Dict[str, int | bool]
|
||||
|
||||
|
||||
NODE_TYPES = {
|
||||
node.type: node
|
||||
for node in [
|
||||
NodeType(
|
||||
type="function",
|
||||
attributes={
|
||||
"is_import": bool,
|
||||
# unused:
|
||||
# - repr: str
|
||||
# - address: int
|
||||
# - name: str
|
||||
# - file: str
|
||||
}
|
||||
),
|
||||
NodeType(
|
||||
type="data",
|
||||
attributes={
|
||||
"is_string": bool,
|
||||
# unused:
|
||||
# - repr: str
|
||||
# - address: int
|
||||
}
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
FUNCTION_NODE = NODE_TYPES["function"]
|
||||
DATA_NODE = NODE_TYPES["data"]
|
||||
|
||||
EDGE_TYPES = {
|
||||
(edge.source_type.type, edge.key, edge.destination_type.type): edge
|
||||
for edge in [
|
||||
EdgeType(
|
||||
key="call",
|
||||
source_type=FUNCTION_NODE,
|
||||
destination_type=FUNCTION_NODE,
|
||||
attributes={},
|
||||
),
|
||||
EdgeType(
|
||||
key="reference",
|
||||
source_type=FUNCTION_NODE,
|
||||
destination_type=DATA_NODE,
|
||||
attributes={},
|
||||
),
|
||||
EdgeType(
|
||||
key="neighbor",
|
||||
source_type=FUNCTION_NODE,
|
||||
destination_type=FUNCTION_NODE,
|
||||
attributes={
|
||||
# this is the attribute to predict
|
||||
"is_source_file_boundary": False,
|
||||
},
|
||||
),
|
||||
EdgeType(
|
||||
key="neighbor",
|
||||
source_type=DATA_NODE,
|
||||
destination_type=DATA_NODE,
|
||||
attributes={
|
||||
# this is the attribute to predict
|
||||
"is_source_file_boundary": False,
|
||||
},
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedGraph:
|
||||
data: HeteroGraph
|
||||
# map from node id (str) to node index (int), and node index (int) to node id (str).
|
||||
mapping: dict[str | int, int | str]
|
||||
|
||||
|
||||
def load_graph(g: nx.MultiDiGraph) -> LoadedGraph:
|
||||
# Our networkx graph identifies nodes by str ("sha256:address").
|
||||
# Torch identifies nodes by index, from 0 to #nodes.
|
||||
# Map one to another.
|
||||
node_indexes_by_node: dict[str, int] = {}
|
||||
nodes_by_node_index: dict[int, str] = {}
|
||||
# Because the types are different (str and int),
|
||||
# here's a single mapping where the type of the key implies
|
||||
# the sort of lookup you're doing (by index (int) or by node id (str)).
|
||||
node_mapping = dict[str | int, int | str] = {}
|
||||
for i, node in enumerate(sort(g.nodes)):
|
||||
node_indexes_by_node[node] = i
|
||||
nodes_by_node_index[i] = node
|
||||
node_mapping[node] = i
|
||||
node_mapping[i] = node
|
||||
|
||||
data = HeteroData()
|
||||
|
||||
for node_type in NODE_TYPES.values():
|
||||
node_indexes: list[int] = []
|
||||
attr_values: dict[str, list] = {
|
||||
attribute: []
|
||||
for attribute in node_type.attributes.keys()
|
||||
}
|
||||
|
||||
for node, attrs in g.nodes(data=True):
|
||||
if attrs["type"] != node_type.type:
|
||||
continue
|
||||
|
||||
node_index = node_indexes_by_node[node]
|
||||
node_indexes.append(node_index)
|
||||
|
||||
for attribute, default_value in node_type.attributes.items():
|
||||
value = attrs.get(attribute, default_value)
|
||||
attr_values[attribute].append(value)
|
||||
|
||||
data[node_type.type].node_id = torch.tensor(node_indexes)
|
||||
# attribute order is implicit in the NODE_TYPES data model above.
|
||||
data[node_type.type].x = torch.stack([torch.tensor(values) for values in attr_values.values()], dim=-1).float()
|
||||
|
||||
for edge_type in EDGE_TYPES.values():
|
||||
source_indexes: list[int] = []
|
||||
destination_indexes: list[int] = []
|
||||
attr_values: dict[str, list] = {
|
||||
attribute: []
|
||||
for attribute in edge_type.attributes.keys()
|
||||
}
|
||||
|
||||
for source, destination, key, attrs in g.edges(data=True, keys=True):
|
||||
if key != edge_type.key:
|
||||
continue
|
||||
if g.nodes[source]["type"] != edge_type.source_type.type:
|
||||
continue
|
||||
if g.nodes[destination]["type"] != edge_type.destination_type.type:
|
||||
continue
|
||||
|
||||
source_index = node_indexes_by_node[source]
|
||||
destination_index = node_indexes_by_node[destination]
|
||||
|
||||
source_indexes.append(source_index)
|
||||
destination_indexes.append(destination_index)
|
||||
|
||||
for attribute, default_value in edge_type.attributes.items():
|
||||
value = attrs.get(attribute, default_value)
|
||||
attr_values[attribute].append(value)
|
||||
|
||||
data[edge_type.source_type.type, edge_type.key, edge_type.destination_type.type].edge_index = torch.stack([
|
||||
torch.tensor(source_indexes),
|
||||
torch.tensor(destination_indexes),
|
||||
])
|
||||
# attribute order is implicit in the EDGE_TYPES data model above.
|
||||
data[edge_type.source_type.type, edge_type.key, edge_type.destination_type.type].edge_attr = torch.stack([
|
||||
torch.tensor(values) for values in attr_values.values()
|
||||
], dim=-1).float()
|
||||
|
||||
return LoadedData(
|
||||
data,
|
||||
node_mapping,
|
||||
)
|
||||
|
||||
|
||||
def train_main(args: argparse.Namespace) -> int:
|
||||
if not args.graph.is_file():
|
||||
raise ValueError("graph file doesn't exist")
|
||||
|
||||
g = nx.read_gexf(args.graph)
|
||||
|
||||
# set node default attributes
|
||||
for _, attrs in g.nodes(data=True):
|
||||
if "type" not in attrs:
|
||||
raise ValueError("node missing `type`")
|
||||
|
||||
for key, value in {
|
||||
"name": "",
|
||||
"file": "",
|
||||
"repr": "(unknown)",
|
||||
"is_import": False,
|
||||
"is_string": False,
|
||||
}.items():
|
||||
if key not in attrs:
|
||||
attrs[key] = value
|
||||
|
||||
# set edge default attributes
|
||||
for a, b, key, attrs in g.edges(data=True, keys=True):
|
||||
if "key" not in attrs:
|
||||
raise ValueError("edge missing `key`", a, b, key, attrs)
|
||||
|
||||
for key, value in {
|
||||
# TODO: this should only be on neighbor edges, for multi-graphs
|
||||
"is_source_file_boundary": false,
|
||||
}.items():
|
||||
if key not in attrs:
|
||||
attrs[key] = value
|
||||
|
||||
# TODO: this is only for Graph or DiGraph, not MultiGraph
|
||||
graph = torch_geometric.utils.convert.from_networkx(g)
|
||||
|
||||
print(graph)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def main(argv=None) -> int:
|
||||
if argv is None:
|
||||
argv = sys.argv[1:]
|
||||
@@ -547,6 +769,10 @@ def main(argv=None) -> int:
|
||||
cluster_parser.add_argument("graph", type=Path, help="path to a graph file")
|
||||
cluster_parser.set_defaults(func=cluster_main)
|
||||
|
||||
train_parser = subparsers.add_parser("train", help="train using an existing graph")
|
||||
train_parser.add_argument("graph", type=Path, help="path to a graph file")
|
||||
train_parser.set_defaults(func=train_main)
|
||||
|
||||
args = parser.parse_args(args=argv)
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user