mirror of
https://github.com/mandiant/capa.git
synced 2025-12-12 15:49:46 -08:00
codecut: torch loader
This commit is contained in:
@@ -5,7 +5,7 @@ import logging
|
||||
import sqlite3
|
||||
import argparse
|
||||
import subprocess
|
||||
from typing import Iterator, Optional
|
||||
from typing import Iterator, Optional, Literal
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import Pool
|
||||
@@ -521,20 +521,16 @@ def cluster_main(args: argparse.Namespace) -> int:
|
||||
|
||||
|
||||
|
||||
# uv pip install torch --index-url https://download.pytorch.org/whl/cpu
|
||||
# uv pip install torch --index-url https://download.pytorch.org/whl/cpu
|
||||
# uv pip install torch-geometric pandas numpy
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torch_geometric.utils.convert
|
||||
|
||||
import torch
|
||||
from torch_geometric.data import HeteroData
|
||||
# import torch # do this on-demand below, because its slow
|
||||
# from torch_geometric.data import HeteroData
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeType:
|
||||
type: str
|
||||
attributes: Dict[str, int | bool]
|
||||
attributes: dict[str, Literal[False] | Literal[""] | Literal[0] | float]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -542,7 +538,7 @@ class EdgeType:
|
||||
key: str
|
||||
source_type: NodeType
|
||||
destination_type: NodeType
|
||||
attributes: Dict[str, int | bool]
|
||||
attributes: dict[str, Literal[False] | Literal[""] | Literal[0] | float]
|
||||
|
||||
|
||||
NODE_TYPES = {
|
||||
@@ -551,22 +547,22 @@ NODE_TYPES = {
|
||||
NodeType(
|
||||
type="function",
|
||||
attributes={
|
||||
"is_import": bool,
|
||||
"is_import": False,
|
||||
# unused:
|
||||
# - repr: str
|
||||
# - address: int
|
||||
# - name: str
|
||||
# - file: str
|
||||
}
|
||||
},
|
||||
),
|
||||
NodeType(
|
||||
type="data",
|
||||
attributes={
|
||||
"is_string": bool,
|
||||
"is_string": False,
|
||||
# unused:
|
||||
# - repr: str
|
||||
# - address: int
|
||||
}
|
||||
},
|
||||
),
|
||||
]
|
||||
}
|
||||
@@ -613,12 +609,15 @@ EDGE_TYPES = {
|
||||
|
||||
@dataclass
|
||||
class LoadedGraph:
|
||||
data: HeteroGraph
|
||||
data: "HeteroData"
|
||||
# map from node id (str) to node index (int), and node index (int) to node id (str).
|
||||
mapping: dict[str | int, int | str]
|
||||
|
||||
|
||||
def load_graph(g: nx.MultiDiGraph) -> LoadedGraph:
|
||||
import torch
|
||||
from torch_geometric.data import HeteroData
|
||||
|
||||
# Our networkx graph identifies nodes by str ("sha256:address").
|
||||
# Torch identifies nodes by index, from 0 to #nodes.
|
||||
# Map one to another.
|
||||
@@ -627,8 +626,8 @@ def load_graph(g: nx.MultiDiGraph) -> LoadedGraph:
|
||||
# Because the types are different (str and int),
|
||||
# here's a single mapping where the type of the key implies
|
||||
# the sort of lookup you're doing (by index (int) or by node id (str)).
|
||||
node_mapping = dict[str | int, int | str] = {}
|
||||
for i, node in enumerate(sort(g.nodes)):
|
||||
node_mapping: dict[str | int, int | str] = {}
|
||||
for i, node in enumerate(sorted(g.nodes)):
|
||||
node_indexes_by_node[node] = i
|
||||
nodes_by_node_index[i] = node
|
||||
node_mapping[node] = i
|
||||
@@ -637,11 +636,10 @@ def load_graph(g: nx.MultiDiGraph) -> LoadedGraph:
|
||||
data = HeteroData()
|
||||
|
||||
for node_type in NODE_TYPES.values():
|
||||
logger.debug("loading nodes: %s", node_type.type)
|
||||
|
||||
node_indexes: list[int] = []
|
||||
attr_values: dict[str, list] = {
|
||||
attribute: []
|
||||
for attribute in node_type.attributes.keys()
|
||||
}
|
||||
attr_values: dict[str, list] = {attribute: [] for attribute in node_type.attributes.keys()}
|
||||
|
||||
for node, attrs in g.nodes(data=True):
|
||||
if attrs["type"] != node_type.type:
|
||||
@@ -655,16 +653,19 @@ def load_graph(g: nx.MultiDiGraph) -> LoadedGraph:
|
||||
attr_values[attribute].append(value)
|
||||
|
||||
data[node_type.type].node_id = torch.tensor(node_indexes)
|
||||
# attribute order is implicit in the NODE_TYPES data model above.
|
||||
data[node_type.type].x = torch.stack([torch.tensor(values) for values in attr_values.values()], dim=-1).float()
|
||||
if attr_values:
|
||||
# attribute order is implicit in the NODE_TYPES data model above.
|
||||
data[node_type.type].x = torch.stack([torch.tensor(values) for values in attr_values.values()], dim=-1).float()
|
||||
|
||||
for edge_type in EDGE_TYPES.values():
|
||||
logger.debug(
|
||||
"loading edges: %s > %s > %s",
|
||||
edge_type.source_type.type, edge_type.key, edge_type.destination_type.type
|
||||
)
|
||||
|
||||
source_indexes: list[int] = []
|
||||
destination_indexes: list[int] = []
|
||||
attr_values: dict[str, list] = {
|
||||
attribute: []
|
||||
for attribute in edge_type.attributes.keys()
|
||||
}
|
||||
attr_values: dict[str, list] = {attribute: [] for attribute in edge_type.attributes.keys()}
|
||||
|
||||
for source, destination, key, attrs in g.edges(data=True, keys=True):
|
||||
if key != edge_type.key:
|
||||
@@ -684,58 +685,36 @@ def load_graph(g: nx.MultiDiGraph) -> LoadedGraph:
|
||||
value = attrs.get(attribute, default_value)
|
||||
attr_values[attribute].append(value)
|
||||
|
||||
data[edge_type.source_type.type, edge_type.key, edge_type.destination_type.type].edge_index = torch.stack([
|
||||
torch.tensor(source_indexes),
|
||||
torch.tensor(destination_indexes),
|
||||
])
|
||||
# attribute order is implicit in the EDGE_TYPES data model above.
|
||||
data[edge_type.source_type.type, edge_type.key, edge_type.destination_type.type].edge_attr = torch.stack([
|
||||
torch.tensor(values) for values in attr_values.values()
|
||||
], dim=-1).float()
|
||||
data[edge_type.source_type.type, edge_type.key, edge_type.destination_type.type].edge_index = torch.stack(
|
||||
[
|
||||
torch.tensor(source_indexes),
|
||||
torch.tensor(destination_indexes),
|
||||
]
|
||||
)
|
||||
if attr_values:
|
||||
# attribute order is implicit in the EDGE_TYPES data model above.
|
||||
data[edge_type.source_type.type, edge_type.key, edge_type.destination_type.type].edge_attr = torch.stack(
|
||||
[torch.tensor(values) for values in attr_values.values()], dim=-1
|
||||
).float()
|
||||
|
||||
return LoadedData(
|
||||
return LoadedGraph(
|
||||
data,
|
||||
node_mapping,
|
||||
)
|
||||
|
||||
|
||||
|
||||
def train_main(args: argparse.Namespace) -> int:
|
||||
if not args.graph.is_file():
|
||||
raise ValueError("graph file doesn't exist")
|
||||
|
||||
logger.debug("loading torch")
|
||||
import torch
|
||||
|
||||
g = nx.read_gexf(args.graph)
|
||||
|
||||
# set node default attributes
|
||||
for _, attrs in g.nodes(data=True):
|
||||
if "type" not in attrs:
|
||||
raise ValueError("node missing `type`")
|
||||
lg = load_graph(g)
|
||||
|
||||
for key, value in {
|
||||
"name": "",
|
||||
"file": "",
|
||||
"repr": "(unknown)",
|
||||
"is_import": False,
|
||||
"is_string": False,
|
||||
}.items():
|
||||
if key not in attrs:
|
||||
attrs[key] = value
|
||||
|
||||
# set edge default attributes
|
||||
for a, b, key, attrs in g.edges(data=True, keys=True):
|
||||
if "key" not in attrs:
|
||||
raise ValueError("edge missing `key`", a, b, key, attrs)
|
||||
|
||||
for key, value in {
|
||||
# TODO: this should only be on neighbor edges, for multi-graphs
|
||||
"is_source_file_boundary": false,
|
||||
}.items():
|
||||
if key not in attrs:
|
||||
attrs[key] = value
|
||||
|
||||
# TODO: this is only for Graph or DiGraph, not MultiGraph
|
||||
graph = torch_geometric.utils.convert.from_networkx(g)
|
||||
|
||||
print(graph)
|
||||
print(lg.data)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user