Compare commits

...

26 Commits

Author SHA1 Message Date
Willi Ballenthin
d89083ab31 wip 2025-01-15 12:09:17 +00:00
Willi Ballenthin
891fa8aaa3 codecut: torch loader 2025-01-15 12:09:17 +00:00
Willi Ballenthin
e94147b4c2 codecut: import to torch 2025-01-15 12:09:17 +00:00
Willi Ballenthin
6fc4567f0c codecut: better graph structure 2025-01-15 12:09:17 +00:00
Willi Ballenthin
3b1a8f5b5a codecut: bulk generate graphs 2025-01-15 12:09:17 +00:00
Willi Ballenthin
f296e7d423 lints 2025-01-15 12:09:17 +00:00
Willi Ballenthin
3e02b67480 split codecut into subcommands 2025-01-15 12:09:17 +00:00
Willi Ballenthin
536526f61d use networkx for the graph 2025-01-15 12:09:16 +00:00
Willi Ballenthin
bcd2c3fb35 show-object-layout: make it generate a graph of code/data instead 2025-01-15 12:09:16 +00:00
Willi Ballenthin
f340b93a02 init show-object-layout using assemblage 2025-01-15 12:09:16 +00:00
Willi Ballenthin
8bd6f8b99a detect user code via entry points (main function name) 2025-01-15 12:09:16 +00:00
Willi Ballenthin
8b4d5d3d22 mute unknown lines 2025-01-15 12:09:16 +00:00
Willi Ballenthin
bc6e18ed85 consider thunks library functions 2025-01-15 12:09:16 +00:00
mr-tz
2426aba223 render from doc 2025-01-15 12:09:16 +00:00
mr-tz
405e189267 simplify and include thunks 2025-01-15 12:09:16 +00:00
mr-tz
cfb632edc8 remove Python 3.12 syntax 2025-01-15 12:09:16 +00:00
Willi Ballenthin
70c96a29b4 library-detection: add json output format 2025-01-15 12:09:16 +00:00
Willi Ballenthin
c005de0a0d library-detection: cleanup script 2025-01-15 12:09:16 +00:00
Moritz
8d42b14b20 Add LookupError exception 2025-01-15 12:09:16 +00:00
mr-tz
bad32b91fb init add result structure and render 2025-01-15 12:09:16 +00:00
mr-tz
9716da4765 add option to save idb 2025-01-15 12:09:16 +00:00
mr-tz
e0784f2e85 ease/document extra dependency 2025-01-15 12:09:16 +00:00
Willi Ballenthin
4a775bab2e library detection: rough integration of algorithms 2025-01-15 12:09:16 +00:00
Willi Ballenthin
2de7830f5e library detection: merge flirt and string branches 2025-01-15 12:09:16 +00:00
mr-tz
9d67e133c9 initial commit of out-of-the box flirt-based library id 2025-01-15 12:09:16 +00:00
Willi Ballenthin
fa18b4e201 initial attempt at library identification via known strings 2025-01-15 12:09:16 +00:00
31 changed files with 1756 additions and 0 deletions

View File

@@ -108,6 +108,7 @@ repos:
- "--check-untyped-defs"
- "--ignore-missing-imports"
- "--config-file=.github/mypy/mypy.ini"
- "--enable-incomplete-feature=NewGenericSyntax"
- "capa/"
- "scripts/"
- "tests/"

View File

38
capa/analysis/flirt.py Normal file
View File

@@ -0,0 +1,38 @@
# Copyright (C) 2024 Mandiant, Inc. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at: [package root]/LICENSE.txt
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License.
from pydantic import BaseModel
import capa.features.extractors.ida.idalib as idalib
if not idalib.has_idalib():
raise RuntimeError("cannot find IDA idalib module.")
if not idalib.load_idalib():
raise RuntimeError("failed to load IDA idalib module.")
import idaapi
import idautils
class FunctionId(BaseModel):
va: int
is_library: bool
name: str
def get_flirt_matches(lib_only=True):
for fva in idautils.Functions():
f = idaapi.get_func(fva)
is_lib = bool(f.flags & idaapi.FUNC_LIB)
fname = idaapi.get_func_name(fva)
if lib_only and not is_lib:
continue
yield FunctionId(va=fva, is_library=is_lib, name=fname)

242
capa/analysis/libraries.py Normal file
View File

@@ -0,0 +1,242 @@
# Copyright (C) 2024 Mandiant, Inc. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at: [package root]/LICENSE.txt
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License.
import io
import sys
import logging
import argparse
import tempfile
import contextlib
from enum import Enum
from typing import List, Optional
from pathlib import Path
import rich
from pydantic import BaseModel
from rich.text import Text
from rich.console import Console
import capa.main
import capa.helpers
import capa.analysis.flirt
import capa.analysis.strings
import capa.features.extractors.ida.idalib as idalib
if not idalib.has_idalib():
raise RuntimeError("cannot find IDA idalib module.")
if not idalib.load_idalib():
raise RuntimeError("failed to load IDA idalib module.")
import idaapi
import idapro
import ida_auto
import idautils
logger = logging.getLogger(__name__)
class Classification(str, Enum):
USER = "user"
LIBRARY = "library"
UNKNOWN = "unknown"
class Method(str, Enum):
FLIRT = "flirt"
STRINGS = "strings"
THUNK = "thunk"
ENTRYPOINT = "entrypoint"
class FunctionClassification(BaseModel):
va: int
classification: Classification
# name per the disassembler/analysis tool
# may be combined with the recovered/suspected name TODO below
name: str
# if is library, this must be provided
method: Optional[Method]
# TODO if is library, recovered/suspected name?
# if is library, these can optionally be provided.
library_name: Optional[str] = None
library_version: Optional[str] = None
class FunctionIdResults(BaseModel):
function_classifications: List[FunctionClassification]
@contextlib.contextmanager
def ida_session(input_path: Path, use_temp_dir=True):
if use_temp_dir:
t = Path(tempfile.mkdtemp(prefix="ida-")) / input_path.name
else:
t = input_path
logger.debug("using %s", str(t))
# stderr=True is used here to redirect the spinner banner to stderr,
# so that users can redirect capa's output.
console = Console(stderr=True, quiet=False)
try:
if use_temp_dir:
t.write_bytes(input_path.read_bytes())
# idalib writes to stdout (ugh), so we have to capture that
# so as not to screw up structured output.
with capa.helpers.stdout_redirector(io.BytesIO()):
idapro.enable_console_messages(False)
with capa.main.timing("analyze program"):
with console.status("analyzing program...", spinner="dots"):
if idapro.open_database(str(t.absolute()), run_auto_analysis=True):
raise RuntimeError("failed to analyze input file")
logger.debug("idalib: waiting for analysis...")
ida_auto.auto_wait()
logger.debug("idalib: opened database.")
yield
finally:
idapro.close_database()
if use_temp_dir:
t.unlink()
def is_thunk_function(fva):
f = idaapi.get_func(fva)
return bool(f.flags & idaapi.FUNC_THUNK)
def main(argv=None):
if argv is None:
argv = sys.argv[1:]
parser = argparse.ArgumentParser(description="Identify library functions using various strategies.")
capa.main.install_common_args(parser, wanted={"input_file"})
parser.add_argument("--store-idb", action="store_true", default=False, help="store IDA database file")
parser.add_argument("--min-string-length", type=int, default=8, help="minimum string length")
parser.add_argument("-j", "--json", action="store_true", help="emit JSON instead of text")
args = parser.parse_args(args=argv)
try:
capa.main.handle_common_args(args)
except capa.main.ShouldExitError as e:
return e.status_code
dbs = capa.analysis.strings.get_default_databases()
capa.analysis.strings.prune_databases(dbs, n=args.min_string_length)
function_classifications: List[FunctionClassification] = []
with ida_session(args.input_file, use_temp_dir=not args.store_idb):
with capa.main.timing("FLIRT-based library identification"):
# TODO: add more signature (files)
# TOOD: apply more signatures
for flirt_match in capa.analysis.flirt.get_flirt_matches():
function_classifications.append(
FunctionClassification(
va=flirt_match.va,
name=flirt_match.name,
classification=Classification.LIBRARY,
method=Method.FLIRT,
# note: we cannot currently include which signature matched per function via the IDA API
)
)
# thunks
for fva in idautils.Functions():
if is_thunk_function(fva):
function_classifications.append(
FunctionClassification(
va=fva,
name=idaapi.get_func_name(fva),
classification=Classification.LIBRARY,
method=Method.THUNK,
)
)
with capa.main.timing("string-based library identification"):
for string_match in capa.analysis.strings.get_string_matches(dbs):
function_classifications.append(
FunctionClassification(
va=string_match.va,
name=idaapi.get_func_name(string_match.va),
classification=Classification.LIBRARY,
method=Method.STRINGS,
library_name=string_match.metadata.library_name,
library_version=string_match.metadata.library_version,
)
)
for va in idautils.Functions():
name = idaapi.get_func_name(va)
if name not in {
"WinMain",
}:
continue
function_classifications.append(
FunctionClassification(
va=va,
name=name,
classification=Classification.USER,
method=Method.ENTRYPOINT,
)
)
doc = FunctionIdResults(function_classifications=[])
classifications_by_va = capa.analysis.strings.create_index(function_classifications, "va")
for va in idautils.Functions():
if classifications := classifications_by_va.get(va):
doc.function_classifications.extend(classifications)
else:
doc.function_classifications.append(
FunctionClassification(
va=va,
name=idaapi.get_func_name(va),
classification=Classification.UNKNOWN,
method=None,
)
)
if args.json:
print(doc.model_dump_json()) # noqa: T201 print found
else:
table = rich.table.Table()
table.add_column("FVA")
table.add_column("CLASSIFICATION")
table.add_column("METHOD")
table.add_column("FNAME")
table.add_column("EXTRA INFO")
classifications_by_va = capa.analysis.strings.create_index(doc.function_classifications, "va", sorted_=True)
for va, classifications in classifications_by_va.items():
name = ", ".join({c.name for c in classifications})
if "sub_" in name:
name = Text(name, style="grey53")
classification = {c.classification for c in classifications}
method = {c.method for c in classifications if c.method}
extra = {f"{c.library_name}@{c.library_version}" for c in classifications if c.library_name}
table.add_row(
hex(va),
", ".join(classification) if classification != {"unknown"} else Text("unknown", style="grey53"),
", ".join(method),
name,
", ".join(extra),
)
rich.print(table)
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,2 @@
# temporary extra file to track dependencies of the analysis directory
nltk==3.9.1

View File

@@ -0,0 +1,269 @@
# Copyright (C) 2024 Mandiant, Inc. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at: [package root]/LICENSE.txt
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License.
"""
further requirements:
- nltk
"""
import gzip
import logging
import collections
from typing import Any, Dict, Mapping
from pathlib import Path
from dataclasses import dataclass
import msgspec
import capa.features.extractors.strings
logger = logging.getLogger(__name__)
class LibraryString(msgspec.Struct):
string: str
library_name: str
library_version: str
file_path: str | None = None
function_name: str | None = None
line_number: int | None = None
@dataclass
class LibraryStringDatabase:
metadata_by_string: Dict[str, LibraryString]
def __len__(self) -> int:
return len(self.metadata_by_string)
@classmethod
def from_file(cls, path: Path) -> "LibraryStringDatabase":
metadata_by_string: Dict[str, LibraryString] = {}
decoder = msgspec.json.Decoder(type=LibraryString)
for line in gzip.decompress(path.read_bytes()).split(b"\n"):
if not line:
continue
s = decoder.decode(line)
metadata_by_string[s.string] = s
return cls(metadata_by_string=metadata_by_string)
DEFAULT_FILENAMES = (
"brotli.jsonl.gz",
"bzip2.jsonl.gz",
"cryptopp.jsonl.gz",
"curl.jsonl.gz",
"detours.jsonl.gz",
"jemalloc.jsonl.gz",
"jsoncpp.jsonl.gz",
"kcp.jsonl.gz",
"liblzma.jsonl.gz",
"libsodium.jsonl.gz",
"libpcap.jsonl.gz",
"mbedtls.jsonl.gz",
"openssl.jsonl.gz",
"sqlite3.jsonl.gz",
"tomcrypt.jsonl.gz",
"wolfssl.jsonl.gz",
"zlib.jsonl.gz",
)
DEFAULT_PATHS = tuple(Path(__file__).parent / "data" / "oss" / filename for filename in DEFAULT_FILENAMES) + (
Path(__file__).parent / "data" / "crt" / "msvc_v143.jsonl.gz",
)
def get_default_databases() -> list[LibraryStringDatabase]:
return [LibraryStringDatabase.from_file(path) for path in DEFAULT_PATHS]
@dataclass
class WindowsApiStringDatabase:
dll_names: set[str]
api_names: set[str]
def __len__(self) -> int:
return len(self.dll_names) + len(self.api_names)
@classmethod
def from_dir(cls, path: Path) -> "WindowsApiStringDatabase":
dll_names: set[str] = set()
api_names: set[str] = set()
for line in gzip.decompress((path / "dlls.txt.gz").read_bytes()).decode("utf-8").splitlines():
if not line:
continue
dll_names.add(line)
for line in gzip.decompress((path / "apis.txt.gz").read_bytes()).decode("utf-8").splitlines():
if not line:
continue
api_names.add(line)
return cls(dll_names=dll_names, api_names=api_names)
@classmethod
def from_defaults(cls) -> "WindowsApiStringDatabase":
return cls.from_dir(Path(__file__).parent / "data" / "winapi")
def extract_strings(buf, n=4):
yield from capa.features.extractors.strings.extract_ascii_strings(buf, n=n)
yield from capa.features.extractors.strings.extract_unicode_strings(buf, n=n)
def prune_databases(dbs: list[LibraryStringDatabase], n=8):
"""remove less trustyworthy database entries.
such as:
- those found in multiple databases
- those that are English words
- those that are too short
- Windows API and DLL names
"""
# TODO: consider applying these filters directly to the persisted databases, not at load time.
winapi = WindowsApiStringDatabase.from_defaults()
try:
from nltk.corpus import words as nltk_words
nltk_words.words()
except (ImportError, LookupError):
# one-time download of dataset.
# this probably doesn't work well for embedded use.
import nltk
nltk.download("words")
from nltk.corpus import words as nltk_words
words = set(nltk_words.words())
counter: collections.Counter[str] = collections.Counter()
to_remove = set()
for db in dbs:
for string in db.metadata_by_string.keys():
counter[string] += 1
if string in words:
to_remove.add(string)
continue
if len(string) < n:
to_remove.add(string)
continue
if string in winapi.api_names:
to_remove.add(string)
continue
if string in winapi.dll_names:
to_remove.add(string)
continue
for string, count in counter.most_common():
if count <= 1:
break
# remove strings that are seen in more than one database
to_remove.add(string)
for db in dbs:
for string in to_remove:
if string in db.metadata_by_string:
del db.metadata_by_string[string]
def get_function_strings():
import idaapi
import idautils
import capa.features.extractors.ida.helpers as ida_helpers
strings_by_function = collections.defaultdict(set)
for ea in idautils.Functions():
f = idaapi.get_func(ea)
# ignore library functions and thunk functions as identified by IDA
if f.flags & idaapi.FUNC_THUNK:
continue
if f.flags & idaapi.FUNC_LIB:
continue
for bb in ida_helpers.get_function_blocks(f):
for insn in ida_helpers.get_instructions_in_range(bb.start_ea, bb.end_ea):
ref = capa.features.extractors.ida.helpers.find_data_reference_from_insn(insn)
if ref == insn.ea:
continue
string = capa.features.extractors.ida.helpers.find_string_at(ref)
if not string:
continue
strings_by_function[ea].add(string)
return strings_by_function
@dataclass
class LibraryStringClassification:
va: int
string: str
library_name: str
metadata: LibraryString
def create_index(s: list, k: str, sorted_: bool = False) -> Mapping[Any, list]:
"""create an index of the elements in `s` using the key `k`, optionally sorted by `k`"""
if sorted_:
s = sorted(s, key=lambda x: getattr(x, k))
s_by_k = collections.defaultdict(list)
for v in s:
p = getattr(v, k)
s_by_k[p].append(v)
return s_by_k
def get_string_matches(dbs: list[LibraryStringDatabase]) -> list[LibraryStringClassification]:
matches: list[LibraryStringClassification] = []
for function, strings in sorted(get_function_strings().items()):
for string in strings:
for db in dbs:
if metadata := db.metadata_by_string.get(string):
matches.append(
LibraryStringClassification(
va=function,
string=string,
library_name=metadata.library_name,
metadata=metadata,
)
)
# if there are less than N strings per library, ignore that library
matches_by_library = create_index(matches, "library_name")
for library_name, library_matches in matches_by_library.items():
if len(library_matches) > 5:
continue
logger.info("pruning library %s: only %d matched string", library_name, len(library_matches))
matches = [m for m in matches if m.library_name != library_name]
# if there are conflicts within a single function, don't label it
matches_by_function = create_index(matches, "va")
for va, function_matches in matches_by_function.items():
library_names = {m.library_name for m in function_matches}
if len(library_names) == 1:
continue
logger.info("conflicting matches: 0x%x: %s", va, sorted(library_names))
# this is potentially slow (O(n**2)) but hopefully fast enough in practice.
matches = [m for m in matches if m.va != va]
return matches

View File

@@ -0,0 +1,130 @@
# Copyright (C) 2024 Mandiant, Inc. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at: [package root]/LICENSE.txt
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License.
import sys
import logging
import collections
from pathlib import Path
import rich
from rich.text import Text
import capa.analysis.strings
import capa.features.extractors.strings
import capa.features.extractors.ida.helpers as ida_helpers
logger = logging.getLogger(__name__)
def open_ida(input_path: Path):
import tempfile
import idapro
t = Path(tempfile.mkdtemp(prefix="ida-")) / input_path.name
t.write_bytes(input_path.read_bytes())
# resource leak: we should delete this upon exit
idapro.enable_console_messages(False)
idapro.open_database(str(t.absolute()), run_auto_analysis=True)
import ida_auto
ida_auto.auto_wait()
def main():
logging.basicConfig(level=logging.DEBUG)
# use n=8 to ignore common words
N = 8
input_path = Path(sys.argv[1])
dbs = capa.analysis.strings.get_default_databases()
capa.analysis.strings.prune_databases(dbs, n=N)
strings_by_library = collections.defaultdict(set)
for string in capa.analysis.strings.extract_strings(input_path.read_bytes(), n=N):
for db in dbs:
if metadata := db.metadata_by_string.get(string.s):
strings_by_library[metadata.library_name].add(string.s)
console = rich.get_console()
console.print("found libraries:", style="bold")
for library, strings in sorted(strings_by_library.items(), key=lambda p: len(p[1]), reverse=True):
console.print(f" - [b]{library}[/] ({len(strings)} strings)")
for string in sorted(strings)[:10]:
console.print(f" - {string}", markup=False, style="grey37")
if len(strings) > 10:
console.print(" ...", style="grey37")
if not strings_by_library:
console.print(" (none)", style="grey37")
# since we're not going to find any strings
# return early and don't do IDA analysis
return
open_ida(input_path)
import idaapi
import idautils
import ida_funcs
strings_by_function = collections.defaultdict(set)
for ea in idautils.Functions():
f = idaapi.get_func(ea)
# ignore library functions and thunk functions as identified by IDA
if f.flags & idaapi.FUNC_THUNK:
continue
if f.flags & idaapi.FUNC_LIB:
continue
for bb in ida_helpers.get_function_blocks(f):
for insn in ida_helpers.get_instructions_in_range(bb.start_ea, bb.end_ea):
ref = capa.features.extractors.ida.helpers.find_data_reference_from_insn(insn)
if ref == insn.ea:
continue
string = capa.features.extractors.ida.helpers.find_string_at(ref)
if not string:
continue
for db in dbs:
if metadata := db.metadata_by_string.get(string):
strings_by_function[ea].add(string)
# ensure there are at least XXX functions renamed, or ignore those entries
console.print("functions:", style="bold")
for function, strings in sorted(strings_by_function.items()):
if strings:
name = ida_funcs.get_func_name(function)
console.print(f" [b]{name}[/]@{function:08x}:")
for string in strings:
for db in dbs:
if metadata := db.metadata_by_string.get(string):
location = Text(
f"{metadata.library_name}@{metadata.library_version}::{metadata.function_name}",
style="grey37",
)
console.print(" - ", location, ": ", string.rstrip())
console.print()
console.print(
f"found {len(strings_by_function)} library functions across {len(list(idautils.Functions()))} functions"
)
if __name__ == "__main__":
main()

Binary file not shown.

View File

@@ -0,0 +1,3 @@
*.csv
*.jsonl
*.jsonl.gz

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,99 @@
# Strings from Open Source libraries
This directory contains databases of strings extracted from open soure software.
capa uses these databases to ignore functions that are likely library code.
There is one file for each database. Each database is a gzip-compressed, JSONL (one JSON document per line) file.
The JSON document looks like this:
string: "1.0.8, 13-Jul-2019"
library_name: "bzip2"
library_version: "1.0.8#3"
file_path: "CMakeFiles/bz2.dir/bzlib.c.obj"
function_name: "BZ2_bzlibVersion"
line_number: null
The following databases were extracted via the vkpkg & jh technique:
- brotli 1.0.9#5
- bzip2 1.0.8#3
- cryptopp 8.7.0
- curl 7.86.0#1
- detours 4.0.1#7
- jemalloc 5.3.0#1
- jsoncpp 1.9.5
- kcp 1.7
- liblzma 5.2.5#6
- libsodium 1.0.18#8
- libpcap 1.10.1#3
- mbedtls 2.28.1
- openssl 3.0.7#1
- sqlite3 3.40.0#1
- tomcrypt 1.18.2#2
- wolfssl 5.5.0
- zlib 1.2.13
This code was originally developed in FLOSS and imported into capa.
## The vkpkg & jh technique
Major steps:
1. build static libraries via vcpkg
2. extract features via jh
3. convert to JSONL format with `jh_to_qs.py`
4. compress with gzip
### Build static libraries via vcpkg
[vcpkg](https://vcpkg.io/en/) is a free C/C++ package manager for acquiring and managing libraries.
We use it to easily build common open source libraries, like zlib.
Use the triplet `x64-windows-static` to build static archives (.lib files that are AR archives containing COFF object files):
```console
PS > C:\vcpkg\vcpkg.exe install --triplet x64-windows-static zlib
```
### Extract features via jh
[jh](https://github.com/williballenthin/lancelot/blob/master/bin/src/bin/jh.rs)
is a lancelot-based utility that parses AR archives containing COFF object files,
reconstructs their control flow, finds functions, and extracts features.
jh extracts numbers, API calls, and strings; we are only interested in the string features.
For each feature, jh emits a CSV line with the fields
- target triplet
- compiler
- library
- version
- build profile
- path
- function
- feature type
- feature value
For example:
```csv
x64-windows-static,msvc143,bzip2,1.0.8#3,release,CMakeFiles/bz2.dir/bzlib.c.obj,BZ2_bzBuffToBuffCompress,number,0x00000100
```
For example, to invoke jh:
```console
$ ~/lancelot/target/release/jh x64-windows-static msvc143 zlib 1.2.13 release /mnt/c/vcpkg/installed/x64-windows-static/lib/zlib.lib > ~/flare-floss/floss/qs/db/data/oss/zlib.csv
```
### Convert to OSS database format
We use the script `jh_to_qs.py` to convert these CSV lines into JSONL file prepared for FLOSS:
```console
$ python3 jh_to_qs.py zlib.csv > zlib.jsonl
```
These files are then gzip'd:
```console
$ gzip -c zlib.jsonl > zlib.jsonl.gz
```

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -77,6 +77,8 @@ dependencies = [
"protobuf>=5",
"msgspec>=0.18.6",
"xmltodict>=0.13.0",
# for library detection (in development)
"nltk>=3",
# ---------------------------------------
# Dependencies that we develop

970
scripts/codecut.py Normal file
View File

@@ -0,0 +1,970 @@
import os
import sys
import json
import logging
import sqlite3
import argparse
import subprocess
from typing import Iterator, Optional, Literal
from pathlib import Path
from dataclasses import dataclass
from multiprocessing import Pool
import pefile
import lancelot
import networkx as nx
import lancelot.be2utils
from lancelot.be2utils import AddressSpace, BinExport2Index, ReadMemoryError
from lancelot.be2utils.binexport2_pb2 import BinExport2
import capa.main
logger = logging.getLogger(__name__)
def is_vertex_type(vertex: BinExport2.CallGraph.Vertex, type_: BinExport2.CallGraph.Vertex.Type.ValueType) -> bool:
return vertex.HasField("type") and vertex.type == type_
def is_vertex_thunk(vertex: BinExport2.CallGraph.Vertex) -> bool:
return is_vertex_type(vertex, BinExport2.CallGraph.Vertex.Type.THUNK)
THUNK_CHAIN_DEPTH_DELTA = 5
def compute_thunks(be2: BinExport2, idx: BinExport2Index) -> dict[int, int]:
# from thunk address to target function address
thunks: dict[int, int] = {}
for addr, vertex_idx in idx.vertex_index_by_address.items():
vertex: BinExport2.CallGraph.Vertex = be2.call_graph.vertex[vertex_idx]
if not is_vertex_thunk(vertex):
continue
curr_vertex_idx: int = vertex_idx
for _ in range(THUNK_CHAIN_DEPTH_DELTA):
thunk_callees: list[int] = idx.callees_by_vertex_index[curr_vertex_idx]
# if this doesn't hold, then it doesn't seem like this is a thunk,
# because either, len is:
# 0 and the thunk doesn't point to anything, such as `jmp eax`, or
# >1 and the thunk may end up at many functions.
if not thunk_callees:
# maybe we have an indirect jump, like `jmp eax`
# that we can't actually resolve here.
break
if len(thunk_callees) != 1:
for thunk_callee in thunk_callees:
logger.warning("%s", hex(be2.call_graph.vertex[thunk_callee].address))
assert len(thunk_callees) == 1, f"thunk @ {hex(addr)} failed"
thunked_vertex_idx: int = thunk_callees[0]
thunked_vertex: BinExport2.CallGraph.Vertex = be2.call_graph.vertex[thunked_vertex_idx]
if not is_vertex_thunk(thunked_vertex):
assert thunked_vertex.HasField("address")
thunks[addr] = thunked_vertex.address
break
curr_vertex_idx = thunked_vertex_idx
return thunks
def read_string(address_space: AddressSpace, address: int) -> Optional[str]:
try:
# if at end of segment then there might be an overrun here.
buf: bytes = address_space.read_memory(address, 0x100)
except ReadMemoryError:
logger.debug("failed to read memory: 0x%x", address)
return None
# note: we *always* break after the first iteration
for s in capa.features.extractors.strings.extract_ascii_strings(buf):
if s.offset != 0:
break
return s.s
# note: we *always* break after the first iteration
for s in capa.features.extractors.strings.extract_unicode_strings(buf):
if s.offset != 0:
break
return s.s
return None
@dataclass
class AssemblageRow:
# from table: binaries
binary_id: int
file_name: str
platform: str
build_mode: str
toolset_version: str
github_url: str
optimization: str
repo_last_update: int
size: int
path: str
license: str
binary_hash: str
repo_commit_hash: str
# from table: functions
function_id: int
function_name: str
function_hash: str
top_comments: str
source_codes: str
prototype: str
_source_file: str
# from table: rvas
rva_id: int
start_rva: int
end_rva: int
@property
def source_file(self):
# cleanup some extra metadata provided by assemblage
return self._source_file.partition(" (MD5: ")[0].partition(" (0x3: ")[0]
class Assemblage:
conn: sqlite3.Connection
samples: Path
def __init__(self, db: Path, samples: Path):
super().__init__()
self.db = db
self.samples = samples
self.conn = sqlite3.connect(self.db)
with self.conn:
self.conn.executescript(
"""
PRAGMA journal_mode = WAL;
PRAGMA synchronous = NORMAL;
PRAGMA busy_timeout = 5000;
PRAGMA cache_size = -20000; -- 20MB
PRAGMA foreign_keys = true;
PRAGMA temp_store = memory;
BEGIN IMMEDIATE TRANSACTION;
CREATE INDEX IF NOT EXISTS idx__functions__binary_id ON functions (binary_id);
CREATE INDEX IF NOT EXISTS idx__rvas__function_id ON rvas (function_id);
CREATE VIEW IF NOT EXISTS assemblage AS
SELECT
binaries.id AS binary_id,
binaries.file_name AS file_name,
binaries.platform AS platform,
binaries.build_mode AS build_mode,
binaries.toolset_version AS toolset_version,
binaries.github_url AS github_url,
binaries.optimization AS optimization,
binaries.repo_last_update AS repo_last_update,
binaries.size AS size,
binaries.path AS path,
binaries.license AS license,
binaries.hash AS hash,
binaries.repo_commit_hash AS repo_commit_hash,
functions.id AS function_id,
functions.name AS function_name,
functions.hash AS function_hash,
functions.top_comments AS top_comments,
functions.source_codes AS source_codes,
functions.prototype AS prototype,
functions.source_file AS source_file,
rvas.id AS rva_id,
rvas.start AS start_rva,
rvas.end AS end_rva
FROM binaries
JOIN functions ON binaries.id = functions.binary_id
JOIN rvas ON functions.id = rvas.function_id;
"""
)
def get_row_by_binary_id(self, binary_id: int) -> AssemblageRow:
with self.conn:
cur = self.conn.execute("SELECT * FROM assemblage WHERE binary_id = ? LIMIT 1;", (binary_id,))
return AssemblageRow(*cur.fetchone())
def get_rows_by_binary_id(self, binary_id: int) -> Iterator[AssemblageRow]:
with self.conn:
cur = self.conn.execute("SELECT * FROM assemblage WHERE binary_id = ?;", (binary_id,))
row = cur.fetchone()
while row:
yield AssemblageRow(*row)
row = cur.fetchone()
def get_path_by_binary_id(self, binary_id: int) -> Path:
with self.conn:
cur = self.conn.execute("""SELECT path FROM assemblage WHERE binary_id = ? LIMIT 1""", (binary_id,))
return self.samples / cur.fetchone()[0]
def get_pe_by_binary_id(self, binary_id: int) -> pefile.PE:
path = self.get_path_by_binary_id(binary_id)
return pefile.PE(data=path.read_bytes(), fast_load=True)
def get_binary_ids(self) -> Iterator[int]:
with self.conn:
cur = self.conn.execute("SELECT DISTINCT binary_id FROM assemblage ORDER BY binary_id ASC;")
row = cur.fetchone()
while row:
yield row[0]
row = cur.fetchone()
def generate_main(args: argparse.Namespace) -> int:
if not args.assemblage_database.is_file():
raise ValueError("database doesn't exist")
db = Assemblage(args.assemblage_database, args.assemblage_directory)
pe = db.get_pe_by_binary_id(args.binary_id)
base_address: int = pe.OPTIONAL_HEADER.ImageBase
functions_by_address = {
base_address + function.start_rva: function for function in db.get_rows_by_binary_id(args.binary_id)
}
hash = db.get_row_by_binary_id(args.binary_id).binary_hash
def make_node_id(address: int) -> str:
return f"{hash}:{address:x}"
pe_path = db.get_path_by_binary_id(args.binary_id)
be2: BinExport2 = lancelot.get_binexport2_from_bytes(
pe_path.read_bytes(), function_hints=list(functions_by_address.keys())
)
idx = lancelot.be2utils.BinExport2Index(be2)
address_space = lancelot.be2utils.AddressSpace.from_pe(pe, base_address)
thunks = compute_thunks(be2, idx)
g = nx.MultiDiGraph()
# ensure all functions from ground truth have an entry
for address, function in functions_by_address.items():
g.add_node(
make_node_id(address),
address=address,
type="function",
)
for flow_graph in be2.flow_graph:
datas: set[int] = set()
callees: set[int] = set()
entry_basic_block_index: int = flow_graph.entry_basic_block_index
flow_graph_address: int = idx.get_basic_block_address(entry_basic_block_index)
for basic_block_index in flow_graph.basic_block_index:
basic_block: BinExport2.BasicBlock = be2.basic_block[basic_block_index]
for instruction_index, instruction, _ in idx.basic_block_instructions(basic_block):
for addr in instruction.call_target:
addr = thunks.get(addr, addr)
if addr not in idx.vertex_index_by_address:
# disassembler did not define function at address
logger.debug("0x%x is not a vertex", addr)
continue
vertex_idx: int = idx.vertex_index_by_address[addr]
vertex: BinExport2.CallGraph.Vertex = be2.call_graph.vertex[vertex_idx]
callees.add(vertex.address)
for data_reference_index in idx.data_reference_index_by_source_instruction_index.get(
instruction_index, []
):
data_reference: BinExport2.DataReference = be2.data_reference[data_reference_index]
data_reference_address: int = data_reference.address
if data_reference_address in idx.insn_address_by_index:
# appears to be code
continue
datas.add(data_reference_address)
vertex_index = idx.vertex_index_by_address[flow_graph_address]
name = idx.get_function_name_by_vertex(vertex_index)
g.add_node(
make_node_id(flow_graph_address),
address=flow_graph_address,
type="function",
)
if datas or callees:
logger.info("%s @ 0x%X:", name, flow_graph_address)
for data_address in sorted(datas):
logger.info(" - 0x%X", data_address)
# TODO: check if this is already a function
g.add_node(
make_node_id(data_address),
address=data_address,
type="data",
)
g.add_edge(
make_node_id(flow_graph_address),
make_node_id(data_address),
key="reference",
)
for callee in sorted(callees):
logger.info(" - %s", idx.get_function_name_by_address(callee))
g.add_node(
make_node_id(callee),
address=callee,
type="function",
)
g.add_edge(
make_node_id(flow_graph_address),
make_node_id(callee),
key="call",
)
else:
logger.info("%s @ 0x%X: (none)", name, flow_graph_address)
# set ground truth node attributes from source data
for node, attrs in g.nodes(data=True):
if attrs["type"] != "function":
continue
if f := functions_by_address.get(attrs["address"]):
attrs["name"] = f.function_name
attrs["file"] = f.file_name
for section in pe.sections:
# Within each section, emit a neighbor edge for each pair of neighbors.
# Neighbors only link nodes of the same type, because assemblage doesn't
# have ground truth for data items, so we don't quite know where to split.
# Consider this situation:
#
# moduleA::func1
# --- cut ---
# moduleB::func1
#
# that one is ok, but this is hard:
#
# moduleA::func1
# --- cut??? ---
# dataZ
# --- or cut here??? ---
# moduleB::func1
#
# Does the cut go before or after dataZ?
# So, we only have neighbor graphs within functions, and within datas.
# For datas, we don't allow interspersed functions.
section_nodes = sorted(
[
(node, attrs)
for node, attrs in g.nodes(data=True)
if (section.VirtualAddress + base_address)
<= attrs["address"]
< (base_address + section.VirtualAddress + section.Misc_VirtualSize)
],
key=lambda p: p[1]["address"],
)
# add neighbor edges between data items.
# the data items must not be separated by any functions.
for i in range(1, len(section_nodes)):
a, a_attrs = section_nodes[i - 1]
b, b_attrs = section_nodes[i]
if a_attrs["type"] != "data":
continue
if b_attrs["type"] != "data":
continue
g.add_edge(a, b, key="neighbor")
g.add_edge(b, a, key="neighbor")
section_functions = [
(node, attrs)
for node, attrs in section_nodes
if attrs["type"] == "function"
# we only have ground truth for the known functions
# so only consider those in the function neighbor graph.
and attrs["address"] in functions_by_address
]
# add neighbor edges between functions.
# we drop the potentially interspersed data items before computing these edges.
for i in range(1, len(section_functions)):
a, a_attrs = section_functions[i - 1]
b, b_attrs = section_functions[i]
is_boundary = a_attrs["file"] == b_attrs["file"]
# edge attribute: is_source_file_boundary
g.add_edge(a, b, key="neighbor", is_source_file_boundary=is_boundary)
g.add_edge(b, a, key="neighbor", is_source_file_boundary=is_boundary)
# rename unknown functions like: sub_401000
for n, attrs in g.nodes(data=True):
if attrs["type"] != "function":
continue
if "name" in attrs:
continue
attrs["name"] = f"sub_{attrs['address']:x}"
# assign human-readable repr to add nodes
# assign is_import=bool to functions
# assign is_string=bool to datas
for n, attrs in g.nodes(data=True):
match attrs["type"]:
case "function":
attrs["repr"] = attrs["name"]
attrs["is_import"] = "!" in attrs["name"]
case "data":
if string := read_string(address_space, attrs["address"]):
attrs["repr"] = json.dumps(string)
attrs["is_string"] = True
else:
attrs["repr"] = f"data_{attrs['address']:x}"
attrs["is_string"] = False
for line in nx.generate_gexf(g):
print(line)
# db.conn.close()
return 0
def _worker(args):
assemblage_database: Path
assemblage_directory: Path
graph_file: Path
binary_id: int
(assemblage_database, assemblage_directory, graph_file, binary_id) = args
if graph_file.is_file():
return
logger.info("processing: %d", binary_id)
process = subprocess.run(
["python", __file__, "--debug", "generate", assemblage_database, assemblage_directory, str(binary_id)],
capture_output=True,
encoding="utf-8",
)
if process.returncode != 0:
logger.warning("failed: %d", binary_id)
logger.debug("%s", process.stderr)
return
graph_file.parent.mkdir(exist_ok=True)
graph = process.stdout
graph_file.write_text(graph)
def generate_all_main(args: argparse.Namespace) -> int:
if not args.assemblage_database.is_file():
raise ValueError("database doesn't exist")
db = Assemblage(args.assemblage_database, args.assemblage_directory)
binary_ids = list(db.get_binary_ids())
with Pool(args.num_workers) as p:
_ = list(
p.imap_unordered(
_worker,
(
(
args.assemblage_database,
args.assemblage_directory,
args.output_directory / str(binary_id) / "graph.gexf",
binary_id,
)
for binary_id in binary_ids
),
)
)
return 0
def cluster_main(args: argparse.Namespace) -> int:
if not args.graph.is_file():
raise ValueError("graph file doesn't exist")
g = nx.read_gexf(args.graph)
communities = nx.algorithms.community.louvain_communities(g)
for i, community in enumerate(communities):
print(f"[{i}]:")
for node in community:
if "name" in g.nodes[node]:
print(f" - {hex(int(node, 0))}: {g.nodes[node]['file']}")
else:
print(f" - {hex(int(node, 0))}")
return 0
# uv pip install torch --index-url https://download.pytorch.org/whl/cpu
# uv pip install torch-geometric pandas numpy scikit-learn
# import torch # do this on-demand below, because its slow
# from torch_geometric.data import HeteroData
@dataclass
class NodeType:
type: str
attributes: dict[str, Literal[False] | Literal[""] | Literal[0] | float]
@dataclass
class EdgeType:
key: str
source_type: NodeType
destination_type: NodeType
attributes: dict[str, Literal[False] | Literal[""] | Literal[0] | float]
NODE_TYPES = {
node.type: node
for node in [
NodeType(
type="function",
attributes={
"is_import": False,
"does_reference_string": False,
# "ground_truth": False,
# unused:
# - repr: str
# - address: int
# - name: str
# - file: str
},
),
NodeType(
type="data",
attributes={
"is_string": False,
# unused:
# - repr: str
# - address: int
},
),
]
}
FUNCTION_NODE = NODE_TYPES["function"]
DATA_NODE = NODE_TYPES["data"]
EDGE_TYPES = {
(edge.source_type.type, edge.key, edge.destination_type.type): edge
for edge in [
EdgeType(
key="call",
source_type=FUNCTION_NODE,
destination_type=FUNCTION_NODE,
attributes={},
),
EdgeType(
key="reference",
source_type=FUNCTION_NODE,
destination_type=DATA_NODE,
attributes={},
),
EdgeType(
# When functions reference other functions as data,
# such as passing a function pointer as a callback.
#
# Example:
# __scrt_set_unhandled_exception_filter > reference > __scrt_unhandled_exception_filter
key="reference",
source_type=FUNCTION_NODE,
destination_type=FUNCTION_NODE,
attributes={},
),
EdgeType(
key="neighbor",
source_type=FUNCTION_NODE,
destination_type=FUNCTION_NODE,
attributes={
# this is the attribute to predict (ultimately)
# "is_source_file_boundary": False,
"distance": 1,
},
),
EdgeType(
key="neighbor",
source_type=DATA_NODE,
destination_type=DATA_NODE,
# attributes={
# },
attributes={
# this is the attribute to predict (ultimately)
# "is_source_file_boundary": False,
"distance": 1,
},
),
]
}
@dataclass
class LoadedGraph:
data: "HeteroData"
# map from node type to:
# map from node id (str) to node index (int), and node index (int) to node id (str).
mapping: dict[str, dict[str | int, int | str]]
def load_graph(g: nx.MultiDiGraph) -> LoadedGraph:
import torch
from torch_geometric.data import HeteroData
# Our networkx graph identifies nodes by str ("sha256:address").
# Torch identifies nodes by index, from 0 to #nodes, for each type of node.
# Map one to another.
node_indexes_by_node: dict[str, dict[str, int]] = {n: {} for n in NODE_TYPES.keys()}
# Because the types are different (str and int),
# here's a single mapping where the type of the key implies
# the sort of lookup you're doing (by index (int) or by node id (str)).
node_mapping: dict[str, dict[str | int, int | str]] = {n: {} for n in NODE_TYPES.keys()}
for node_type in NODE_TYPES.keys():
def is_this_node_type(node_attrs):
node, attrs = node_attrs
return attrs["type"] == node_type
ns = g.nodes(data=True)
ns = sorted(ns)
ns = filter(is_this_node_type, ns)
ns = map(lambda p: p[0], ns)
for i, node in enumerate(ns):
node_indexes_by_node[node_type][node] = i
node_mapping[node_type][node] = i
node_mapping[node_type][i] = node
data = HeteroData()
for node_type in NODE_TYPES.values():
logger.debug("loading nodes: %s", node_type.type)
node_indexes: list[int] = []
attr_values: dict[str, list] = {attribute: [] for attribute in node_type.attributes.keys()}
for node, attrs in g.nodes(data=True):
if attrs["type"] != node_type.type:
continue
node_index = node_indexes_by_node[node_type.type][node]
node_indexes.append(node_index)
for attribute, default_value in node_type.attributes.items():
value = attrs.get(attribute, default_value)
attr_values[attribute].append(value)
data[node_type.type].node_id = torch.tensor(node_indexes)
if attr_values:
# attribute order is implicit in the NODE_TYPES data model above.
data[node_type.type].x = torch.stack([torch.tensor(values) for values in attr_values.values()], dim=-1).float()
for edge_type in EDGE_TYPES.values():
logger.debug(
"loading edges: %s > %s > %s",
edge_type.source_type.type, edge_type.key, edge_type.destination_type.type
)
source_indexes: list[int] = []
destination_indexes: list[int] = []
attr_values: dict[str, list] = {attribute: [] for attribute in edge_type.attributes.keys()}
for source, destination, key, attrs in g.edges(data=True, keys=True):
if key != edge_type.key:
continue
if g.nodes[source]["type"] != edge_type.source_type.type:
continue
if g.nodes[destination]["type"] != edge_type.destination_type.type:
continue
# These are global node indexes
# but we need to provide the node type-local index.
# That is, functions have their own node indexes, 0 to N. data have their own node indexes, 0 to N.
source_index = node_indexes_by_node[g.nodes[source]["type"]][source]
destination_index = node_indexes_by_node[g.nodes[destination]["type"]][destination]
source_indexes.append(source_index)
destination_indexes.append(destination_index)
for attribute, default_value in edge_type.attributes.items():
value = attrs.get(attribute, default_value)
attr_values[attribute].append(value)
data[edge_type.source_type.type, edge_type.key, edge_type.destination_type.type].edge_index = torch.stack(
[
torch.tensor(source_indexes),
torch.tensor(destination_indexes),
]
)
if attr_values:
# attribute order is implicit in the EDGE_TYPES data model above.
data[edge_type.source_type.type, edge_type.key, edge_type.destination_type.type].edge_attr = torch.stack(
[torch.tensor(values) for values in attr_values.values()], dim=-1
).float()
return LoadedGraph(
data,
node_mapping,
)
def train_main(args: argparse.Namespace) -> int:
if not args.graph.is_file():
raise ValueError("graph file doesn't exist")
logger.debug("loading torch")
import torch
import random
import numpy as np
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
logger.debug("reading graph from disk")
g = nx.read_gexf(args.graph)
# Initial model: learn to find functions that reference a string.
#
# Once this works, then we can try a more complex model (edge features),
# and ultimately an edge classifier.
#
# Ground truth from existing patterns like:
#
# function > references > data (:is_string=True)
for a, b, key, attrs in g.edges(data=True, keys=True):
match (g.nodes[a]["type"], key, g.nodes[b]["type"]):
case ("function", "reference", "data"):
if g.nodes[b].get("is_string"):
g.nodes[a]["does_reference_string"] = True
logger.debug("%s > reference > %s (string)", g.nodes[a]["repr"], g.nodes[b]["repr"])
case ("function", "reference", "function"):
# The data model supports this.
# Like passing a function pointer as a callback
continue
case ("data", "reference", "data"):
# We don't support this.
continue
case ("data", "reference", "function"):
# We don't support this.
continue
case (_, "call", _):
continue
case (_, "neighbor", _):
continue
case _:
print(a, b, key, attrs, g.nodes[a], g.nodes[b])
raise ValueError("unexpected structure")
# map existing attributes to the ground_truth attribute
# for ease of updating the model/training.
for node, attrs in g.nodes(data=True):
if attrs["type"] != "function":
continue
attrs["ground_truth"] = attrs.get("does_reference_string", False)
logger.debug("loading graph into torch")
lg = load_graph(g)
data = lg.data
data['data'].y = torch.zeros(data['data'].num_nodes, dtype=torch.long)
data['function'].y = torch.zeros(data['function'].num_nodes, dtype=torch.long)
true_indices = []
for node, attrs in g.nodes(data=True):
if attrs.get("ground_truth"):
print("true: ", g.nodes[node]["repr"])
node_index = lg.mapping[attrs["type"]][node]
print("index", attrs["type"], node_index)
print(" ", node)
print(" ", lg.mapping[attrs["type"]][node_index])
true_indices.append(node_index)
# true_indices.append(data['function'].node_id[node_index].item())
# print("true index: ", node_index, data['function'].node_id[node_index].item())
data['function'].y[true_indices] = 1
print(data['function'].y)
# TODO
import torch_geometric.transforms as T
data = T.ToUndirected()(data)
# data = T.AddSelfLoops()(data)
data = T.NormalizeFeatures()(data)
print(data)
from torch_geometric.nn import RGCNConv, to_hetero, SAGEConv, Linear
import torch.nn.functional as F
class GNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv((-1, -1), hidden_channels)
self.conv2 = SAGEConv((-1, -1), hidden_channels)
self.lin = Linear(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
x = self.lin(x)
return x
model = GNN(hidden_channels=4, out_channels=2)
# metadata: tuple[list of node types, list of edge types (source, key, dest)]
model = to_hetero(model, data.metadata(), aggr='sum')
# model.print_readable()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
from sklearn.model_selection import train_test_split
train_nodes, test_nodes = train_test_split(
torch.arange(data['function'].num_nodes), test_size=0.2, random_state=42
)
train_mask = torch.zeros(data['function'].num_nodes, dtype=torch.bool)
# train_mask[train_nodes] = True
train_mask[:] = True
test_mask = torch.zeros(data['function'].num_nodes, dtype=torch.bool)
# test_mask[test_nodes] = True
test_mask[:] = True
data['function'].train_mask = train_mask
data['function'].test_mask = test_mask
logger.debug("training")
for epoch in range(999):
model.train()
optimizer.zero_grad()
# don't use edge attrs right now.
out = model(data.x_dict, data.edge_index_dict) # data.edge_attr_dict)
out_function = out['function']
y_function = data['function'].y
mask = data['function'].train_mask
# When classifying "function has string reference"
# there is a major class imbalance, because 95% of function's don't reference a string,
# so the model just learns to predict "no".
# Therefore, weight the classes so that a "yes" prediction is much more valuable.
class_counts = torch.bincount(data['function'].y[mask])
class_weights = 1.0 / class_counts.float()
class_weights = class_weights / class_weights.sum() * len(class_counts)
# CrossEntropyLoss(): the most common choice for node classification with mutually exclusive classes.
# BCEWithLogitsLoss(): multi-label node classification
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
loss = criterion(out_function[mask], y_function[mask])
loss.backward()
optimizer.step()
logger.info(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
if loss <= 0.0001:
logger.info("no more loss")
break
logger.debug("evaluating")
model.eval()
with torch.no_grad():
out = model(data.x_dict, data.edge_index_dict) # TODO: edge attrs
mask = data['function'].test_mask
pred = torch.argmax(out['function'][mask], dim=1)
truth = data['function'].y[mask].int()
print("pred", pred[:32])
print("truth", truth[:32])
# print("index", data['function'].node_id[mask])
# print("83: ", g.nodes[lg.mapping['function'][83]]['repr'])
accuracy = (pred == truth).float().mean()
# pred = (out[data['function'].test_mask] > 0).int().squeeze()
# accuracy = (pred == data['function'].y[data['function'].test_mask]).float().mean()
print(f'Accuracy: {accuracy:.4f}')
return 0
def main(argv=None) -> int:
if argv is None:
argv = sys.argv[1:]
parser = argparse.ArgumentParser(description="Identify object boundaries in compiled programs")
capa.main.install_common_args(parser, wanted={})
subparsers = parser.add_subparsers(title="subcommands", required=True)
generate_parser = subparsers.add_parser("generate", help="generate graph for a sample")
generate_parser.add_argument("assemblage_database", type=Path, help="path to Assemblage database")
generate_parser.add_argument("assemblage_directory", type=Path, help="path to Assemblage samples directory")
generate_parser.add_argument("binary_id", type=int, help="primary key of binary to inspect")
generate_parser.set_defaults(func=generate_main)
num_cores = os.cpu_count() or 1
default_workers = max(1, num_cores - 2)
generate_all_parser = subparsers.add_parser("generate_all", help="generate graphs for all samples")
generate_all_parser.add_argument("assemblage_database", type=Path, help="path to Assemblage database")
generate_all_parser.add_argument("assemblage_directory", type=Path, help="path to Assemblage samples directory")
generate_all_parser.add_argument("output_directory", type=Path, help="path to output directory")
generate_all_parser.add_argument(
"--num_workers", type=int, default=default_workers, help="number of workers to use"
)
generate_all_parser.set_defaults(func=generate_all_main)
cluster_parser = subparsers.add_parser("cluster", help="cluster an existing graph")
cluster_parser.add_argument("graph", type=Path, help="path to a graph file")
cluster_parser.set_defaults(func=cluster_main)
train_parser = subparsers.add_parser("train", help="train using an existing graph")
train_parser.add_argument("graph", type=Path, help="path to a graph file")
train_parser.set_defaults(func=train_main)
args = parser.parse_args(args=argv)
try:
capa.main.handle_common_args(args)
except capa.main.ShouldExitError as e:
return e.status_code
logging.getLogger("goblin.pe").setLevel(logging.WARNING)
return args.func(args)
if __name__ == "__main__":
sys.exit(main())