codecut: torch loader

This commit is contained in:
Willi Ballenthin
2024-11-14 10:32:07 +00:00
parent e94147b4c2
commit 891fa8aaa3

View File

@@ -5,7 +5,7 @@ import logging
import sqlite3
import argparse
import subprocess
from typing import Iterator, Optional
from typing import Iterator, Optional, Literal
from pathlib import Path
from dataclasses import dataclass
from multiprocessing import Pool
@@ -521,20 +521,16 @@ def cluster_main(args: argparse.Namespace) -> int:
# uv pip install torch --index-url https://download.pytorch.org/whl/cpu
# uv pip install torch --index-url https://download.pytorch.org/whl/cpu
# uv pip install torch-geometric pandas numpy
import pandas as pd
import numpy as np
import torch_geometric.utils.convert
import torch
from torch_geometric.data import HeteroData
# import torch # do this on-demand below, because its slow
# from torch_geometric.data import HeteroData
@dataclass
class NodeType:
type: str
attributes: Dict[str, int | bool]
attributes: dict[str, Literal[False] | Literal[""] | Literal[0] | float]
@dataclass
@@ -542,7 +538,7 @@ class EdgeType:
key: str
source_type: NodeType
destination_type: NodeType
attributes: Dict[str, int | bool]
attributes: dict[str, Literal[False] | Literal[""] | Literal[0] | float]
NODE_TYPES = {
@@ -551,22 +547,22 @@ NODE_TYPES = {
NodeType(
type="function",
attributes={
"is_import": bool,
"is_import": False,
# unused:
# - repr: str
# - address: int
# - name: str
# - file: str
}
},
),
NodeType(
type="data",
attributes={
"is_string": bool,
"is_string": False,
# unused:
# - repr: str
# - address: int
}
},
),
]
}
@@ -613,12 +609,15 @@ EDGE_TYPES = {
@dataclass
class LoadedGraph:
data: HeteroGraph
data: "HeteroData"
# map from node id (str) to node index (int), and node index (int) to node id (str).
mapping: dict[str | int, int | str]
def load_graph(g: nx.MultiDiGraph) -> LoadedGraph:
import torch
from torch_geometric.data import HeteroData
# Our networkx graph identifies nodes by str ("sha256:address").
# Torch identifies nodes by index, from 0 to #nodes.
# Map one to another.
@@ -627,8 +626,8 @@ def load_graph(g: nx.MultiDiGraph) -> LoadedGraph:
# Because the types are different (str and int),
# here's a single mapping where the type of the key implies
# the sort of lookup you're doing (by index (int) or by node id (str)).
node_mapping = dict[str | int, int | str] = {}
for i, node in enumerate(sort(g.nodes)):
node_mapping: dict[str | int, int | str] = {}
for i, node in enumerate(sorted(g.nodes)):
node_indexes_by_node[node] = i
nodes_by_node_index[i] = node
node_mapping[node] = i
@@ -637,11 +636,10 @@ def load_graph(g: nx.MultiDiGraph) -> LoadedGraph:
data = HeteroData()
for node_type in NODE_TYPES.values():
logger.debug("loading nodes: %s", node_type.type)
node_indexes: list[int] = []
attr_values: dict[str, list] = {
attribute: []
for attribute in node_type.attributes.keys()
}
attr_values: dict[str, list] = {attribute: [] for attribute in node_type.attributes.keys()}
for node, attrs in g.nodes(data=True):
if attrs["type"] != node_type.type:
@@ -655,16 +653,19 @@ def load_graph(g: nx.MultiDiGraph) -> LoadedGraph:
attr_values[attribute].append(value)
data[node_type.type].node_id = torch.tensor(node_indexes)
# attribute order is implicit in the NODE_TYPES data model above.
data[node_type.type].x = torch.stack([torch.tensor(values) for values in attr_values.values()], dim=-1).float()
if attr_values:
# attribute order is implicit in the NODE_TYPES data model above.
data[node_type.type].x = torch.stack([torch.tensor(values) for values in attr_values.values()], dim=-1).float()
for edge_type in EDGE_TYPES.values():
logger.debug(
"loading edges: %s > %s > %s",
edge_type.source_type.type, edge_type.key, edge_type.destination_type.type
)
source_indexes: list[int] = []
destination_indexes: list[int] = []
attr_values: dict[str, list] = {
attribute: []
for attribute in edge_type.attributes.keys()
}
attr_values: dict[str, list] = {attribute: [] for attribute in edge_type.attributes.keys()}
for source, destination, key, attrs in g.edges(data=True, keys=True):
if key != edge_type.key:
@@ -684,58 +685,36 @@ def load_graph(g: nx.MultiDiGraph) -> LoadedGraph:
value = attrs.get(attribute, default_value)
attr_values[attribute].append(value)
data[edge_type.source_type.type, edge_type.key, edge_type.destination_type.type].edge_index = torch.stack([
torch.tensor(source_indexes),
torch.tensor(destination_indexes),
])
# attribute order is implicit in the EDGE_TYPES data model above.
data[edge_type.source_type.type, edge_type.key, edge_type.destination_type.type].edge_attr = torch.stack([
torch.tensor(values) for values in attr_values.values()
], dim=-1).float()
data[edge_type.source_type.type, edge_type.key, edge_type.destination_type.type].edge_index = torch.stack(
[
torch.tensor(source_indexes),
torch.tensor(destination_indexes),
]
)
if attr_values:
# attribute order is implicit in the EDGE_TYPES data model above.
data[edge_type.source_type.type, edge_type.key, edge_type.destination_type.type].edge_attr = torch.stack(
[torch.tensor(values) for values in attr_values.values()], dim=-1
).float()
return LoadedData(
return LoadedGraph(
data,
node_mapping,
)
def train_main(args: argparse.Namespace) -> int:
if not args.graph.is_file():
raise ValueError("graph file doesn't exist")
logger.debug("loading torch")
import torch
g = nx.read_gexf(args.graph)
# set node default attributes
for _, attrs in g.nodes(data=True):
if "type" not in attrs:
raise ValueError("node missing `type`")
lg = load_graph(g)
for key, value in {
"name": "",
"file": "",
"repr": "(unknown)",
"is_import": False,
"is_string": False,
}.items():
if key not in attrs:
attrs[key] = value
# set edge default attributes
for a, b, key, attrs in g.edges(data=True, keys=True):
if "key" not in attrs:
raise ValueError("edge missing `key`", a, b, key, attrs)
for key, value in {
# TODO: this should only be on neighbor edges, for multi-graphs
"is_source_file_boundary": false,
}.items():
if key not in attrs:
attrs[key] = value
# TODO: this is only for Graph or DiGraph, not MultiGraph
graph = torch_geometric.utils.convert.from_networkx(g)
print(graph)
print(lg.data)
return 0