codecut: import to torch

This commit is contained in:
Willi Ballenthin
2024-11-14 10:05:49 +00:00
parent 6fc4567f0c
commit e94147b4c2

View File

@@ -338,6 +338,7 @@ def generate_main(args: argparse.Namespace) -> int:
else:
logger.info("%s @ 0x%X: (none)", name, flow_graph_address)
# set ground truth node attributes from source data
for node, attrs in g.nodes(data=True):
if attrs["type"] != "function":
continue
@@ -410,6 +411,7 @@ def generate_main(args: argparse.Namespace) -> int:
b, b_attrs = section_functions[i]
is_boundary = a_attrs["file"] == b_attrs["file"]
# edge attribute: is_source_file_boundary
g.add_edge(a, b, key="neighbor", is_source_file_boundary=is_boundary)
g.add_edge(b, a, key="neighbor", is_source_file_boundary=is_boundary)
@@ -518,6 +520,226 @@ def cluster_main(args: argparse.Namespace) -> int:
return 0
# uv pip install torch --index-url https://download.pytorch.org/whl/cpu
# uv pip install torch-geometric pandas numpy
import pandas as pd
import numpy as np
import torch_geometric.utils.convert
import torch
from torch_geometric.data import HeteroData
@dataclass
class NodeType:
type: str
attributes: Dict[str, int | bool]
@dataclass
class EdgeType:
key: str
source_type: NodeType
destination_type: NodeType
attributes: Dict[str, int | bool]
NODE_TYPES = {
node.type: node
for node in [
NodeType(
type="function",
attributes={
"is_import": bool,
# unused:
# - repr: str
# - address: int
# - name: str
# - file: str
}
),
NodeType(
type="data",
attributes={
"is_string": bool,
# unused:
# - repr: str
# - address: int
}
),
]
}
FUNCTION_NODE = NODE_TYPES["function"]
DATA_NODE = NODE_TYPES["data"]
EDGE_TYPES = {
(edge.source_type.type, edge.key, edge.destination_type.type): edge
for edge in [
EdgeType(
key="call",
source_type=FUNCTION_NODE,
destination_type=FUNCTION_NODE,
attributes={},
),
EdgeType(
key="reference",
source_type=FUNCTION_NODE,
destination_type=DATA_NODE,
attributes={},
),
EdgeType(
key="neighbor",
source_type=FUNCTION_NODE,
destination_type=FUNCTION_NODE,
attributes={
# this is the attribute to predict
"is_source_file_boundary": False,
},
),
EdgeType(
key="neighbor",
source_type=DATA_NODE,
destination_type=DATA_NODE,
attributes={
# this is the attribute to predict
"is_source_file_boundary": False,
},
),
]
}
@dataclass
class LoadedGraph:
data: HeteroGraph
# map from node id (str) to node index (int), and node index (int) to node id (str).
mapping: dict[str | int, int | str]
def load_graph(g: nx.MultiDiGraph) -> LoadedGraph:
# Our networkx graph identifies nodes by str ("sha256:address").
# Torch identifies nodes by index, from 0 to #nodes.
# Map one to another.
node_indexes_by_node: dict[str, int] = {}
nodes_by_node_index: dict[int, str] = {}
# Because the types are different (str and int),
# here's a single mapping where the type of the key implies
# the sort of lookup you're doing (by index (int) or by node id (str)).
node_mapping = dict[str | int, int | str] = {}
for i, node in enumerate(sort(g.nodes)):
node_indexes_by_node[node] = i
nodes_by_node_index[i] = node
node_mapping[node] = i
node_mapping[i] = node
data = HeteroData()
for node_type in NODE_TYPES.values():
node_indexes: list[int] = []
attr_values: dict[str, list] = {
attribute: []
for attribute in node_type.attributes.keys()
}
for node, attrs in g.nodes(data=True):
if attrs["type"] != node_type.type:
continue
node_index = node_indexes_by_node[node]
node_indexes.append(node_index)
for attribute, default_value in node_type.attributes.items():
value = attrs.get(attribute, default_value)
attr_values[attribute].append(value)
data[node_type.type].node_id = torch.tensor(node_indexes)
# attribute order is implicit in the NODE_TYPES data model above.
data[node_type.type].x = torch.stack([torch.tensor(values) for values in attr_values.values()], dim=-1).float()
for edge_type in EDGE_TYPES.values():
source_indexes: list[int] = []
destination_indexes: list[int] = []
attr_values: dict[str, list] = {
attribute: []
for attribute in edge_type.attributes.keys()
}
for source, destination, key, attrs in g.edges(data=True, keys=True):
if key != edge_type.key:
continue
if g.nodes[source]["type"] != edge_type.source_type.type:
continue
if g.nodes[destination]["type"] != edge_type.destination_type.type:
continue
source_index = node_indexes_by_node[source]
destination_index = node_indexes_by_node[destination]
source_indexes.append(source_index)
destination_indexes.append(destination_index)
for attribute, default_value in edge_type.attributes.items():
value = attrs.get(attribute, default_value)
attr_values[attribute].append(value)
data[edge_type.source_type.type, edge_type.key, edge_type.destination_type.type].edge_index = torch.stack([
torch.tensor(source_indexes),
torch.tensor(destination_indexes),
])
# attribute order is implicit in the EDGE_TYPES data model above.
data[edge_type.source_type.type, edge_type.key, edge_type.destination_type.type].edge_attr = torch.stack([
torch.tensor(values) for values in attr_values.values()
], dim=-1).float()
return LoadedData(
data,
node_mapping,
)
def train_main(args: argparse.Namespace) -> int:
if not args.graph.is_file():
raise ValueError("graph file doesn't exist")
g = nx.read_gexf(args.graph)
# set node default attributes
for _, attrs in g.nodes(data=True):
if "type" not in attrs:
raise ValueError("node missing `type`")
for key, value in {
"name": "",
"file": "",
"repr": "(unknown)",
"is_import": False,
"is_string": False,
}.items():
if key not in attrs:
attrs[key] = value
# set edge default attributes
for a, b, key, attrs in g.edges(data=True, keys=True):
if "key" not in attrs:
raise ValueError("edge missing `key`", a, b, key, attrs)
for key, value in {
# TODO: this should only be on neighbor edges, for multi-graphs
"is_source_file_boundary": false,
}.items():
if key not in attrs:
attrs[key] = value
# TODO: this is only for Graph or DiGraph, not MultiGraph
graph = torch_geometric.utils.convert.from_networkx(g)
print(graph)
return 0
def main(argv=None) -> int:
if argv is None:
argv = sys.argv[1:]
@@ -547,6 +769,10 @@ def main(argv=None) -> int:
cluster_parser.add_argument("graph", type=Path, help="path to a graph file")
cluster_parser.set_defaults(func=cluster_main)
train_parser = subparsers.add_parser("train", help="train using an existing graph")
train_parser.add_argument("graph", type=Path, help="path to a graph file")
train_parser.set_defaults(func=train_main)
args = parser.parse_args(args=argv)
try: