Files
trivy/pkg/module/module.go

710 lines
20 KiB
Go

package module
import (
"context"
"encoding/json"
"io"
"io/fs"
"os"
"path/filepath"
"regexp"
"github.com/liamg/memoryfs"
"github.com/mailru/easyjson"
"github.com/samber/lo"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/experimental"
wasi "github.com/tetratelabs/wazero/wasi_snapshot_preview1"
"golang.org/x/exp/slices"
"golang.org/x/xerrors"
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
"github.com/aquasecurity/trivy/pkg/log"
tapi "github.com/aquasecurity/trivy/pkg/module/api"
"github.com/aquasecurity/trivy/pkg/module/serialize"
"github.com/aquasecurity/trivy/pkg/scanner/post"
"github.com/aquasecurity/trivy/pkg/types"
"github.com/aquasecurity/trivy/pkg/utils"
)
var (
exportFunctions = map[string]interface{}{
"debug": logDebug,
"info": logInfo,
"warn": logWarn,
"error": logError,
}
RelativeDir = filepath.Join(".trivy", "modules")
)
func logDebug(ctx context.Context, m api.Module, offset, size uint32) {
buf := readMemory(ctx, m, offset, size)
if buf != nil {
log.Logger.Debug(string(buf))
}
}
func logInfo(ctx context.Context, m api.Module, offset, size uint32) {
buf := readMemory(ctx, m, offset, size)
if buf != nil {
log.Logger.Info(string(buf))
}
}
func logWarn(ctx context.Context, m api.Module, offset, size uint32) {
buf := readMemory(ctx, m, offset, size)
if buf != nil {
log.Logger.Warn(string(buf))
}
}
func logError(ctx context.Context, m api.Module, offset, size uint32) {
buf := readMemory(ctx, m, offset, size)
if buf != nil {
log.Logger.Error(string(buf))
}
}
func readMemory(ctx context.Context, m api.Module, offset, size uint32) []byte {
buf, ok := m.Memory().Read(ctx, offset, size)
if !ok {
log.Logger.Errorf("Memory.Read(%d, %d) out of range", offset, size)
return nil
}
return buf
}
type Manager struct {
runtime wazero.Runtime
modules []*wasmModule
}
func NewManager(ctx context.Context) (*Manager, error) {
m := &Manager{}
// The runtime must enable the following features because Tinygo uses these features to build.
// cf. https://github.com/tinygo-org/tinygo/blob/b65447c7d567eea495805656f45472cc3c483e03/targets/wasi.json#L4
c := wazero.NewRuntimeConfig().
WithFeatureBulkMemoryOperations(true).
WithFeatureNonTrappingFloatToIntConversion(true).
WithFeatureSignExtensionOps(true)
// Create a new WebAssembly Runtime.
m.runtime = wazero.NewRuntimeWithConfig(c)
// Load WASM modules in local
if err := m.loadModules(ctx); err != nil {
return nil, xerrors.Errorf("module load error: %w", err)
}
return m, nil
}
func (m *Manager) loadModules(ctx context.Context) error {
moduleDir := dir()
_, err := os.Stat(moduleDir)
if os.IsNotExist(err) {
return nil
}
log.Logger.Debugf("Module dir: %s", moduleDir)
err = filepath.Walk(moduleDir, func(path string, info fs.FileInfo, err error) error {
if err != nil {
return err
} else if info.IsDir() || filepath.Ext(info.Name()) != ".wasm" {
return nil
}
rel, err := filepath.Rel(moduleDir, path)
if err != nil {
return xerrors.Errorf("failed to get a relative path: %w", err)
}
log.Logger.Infof("Loading %s...", rel)
wasmCode, err := os.ReadFile(path)
if err != nil {
return xerrors.Errorf("file read error: %w", err)
}
p, err := newWASMPlugin(ctx, m.runtime, wasmCode)
if err != nil {
return xerrors.Errorf("WASM module init error %s: %w", rel, err)
}
m.modules = append(m.modules, p)
return nil
})
if err != nil {
return xerrors.Errorf("module walk error: %w", err)
}
return nil
}
func (m *Manager) Register() {
for _, mod := range m.modules {
mod.Register()
}
}
func (m *Manager) Close(ctx context.Context) error {
return m.runtime.Close(ctx)
}
func splitPtrSize(u uint64) (uint32, uint32) {
ptr := uint32(u >> 32)
size := uint32(u)
return ptr, size
}
func ptrSizeToString(ctx context.Context, m api.Module, ptrSize uint64) (string, error) {
ptr, size := splitPtrSize(ptrSize)
buf := readMemory(ctx, m, ptr, size)
if buf == nil {
return "", xerrors.New("unable to read memory")
}
return string(buf), nil
}
// stringToPtr returns a pointer and size pair for the given string in a way compatible with WebAssembly numeric types.
func stringToPtrSize(ctx context.Context, s string, mod api.Module, malloc api.Function) (uint64, uint64, error) {
size := uint64(len(s))
results, err := malloc.Call(ctx, size)
if err != nil {
return 0, 0, xerrors.Errorf("malloc error: %w", err)
}
// The pointer is a linear memory offset, which is where we write the string.
ptr := results[0]
if !mod.Memory().Write(ctx, uint32(ptr), []byte(s)) {
return 0, 0, xerrors.Errorf("Memory.Write(%d, %d) out of range of memory size %d",
ptr, size, mod.Memory().Size(ctx))
}
return ptr, size, nil
}
func unmarshal(ctx context.Context, m api.Module, ptrSize uint64, v any) error {
ptr, size := splitPtrSize(ptrSize)
buf := readMemory(ctx, m, ptr, size)
if buf == nil {
return xerrors.New("unable to read memory")
}
if err := json.Unmarshal(buf, v); err != nil {
return xerrors.Errorf("unmarshal error: %w", err)
}
return nil
}
func marshal(ctx context.Context, m api.Module, malloc api.Function, v easyjson.Marshaler) (uint64, uint64, error) {
b, err := easyjson.Marshal(v)
if err != nil {
return 0, 0, xerrors.Errorf("marshal error: %w", err)
}
size := uint64(len(b))
results, err := malloc.Call(ctx, size)
if err != nil {
return 0, 0, xerrors.Errorf("malloc error: %w", err)
}
// The pointer is a linear memory offset, which is where we write the marshaled value.
ptr := results[0]
if !m.Memory().Write(ctx, uint32(ptr), b) {
return 0, 0, xerrors.Errorf("Memory.Write(%d, %d) out of range of memory size %d",
ptr, size, m.Memory().Size(ctx))
}
return ptr, size, nil
}
type wasmModule struct {
mod api.Module
name string
version int
requiredFiles []*regexp.Regexp
isAnalyzer bool
isPostScanner bool
postScanSpec serialize.PostScanSpec
// Exported functions
analyze api.Function
postScan api.Function
malloc api.Function // TinyGo specific
free api.Function // TinyGo specific
}
func newWASMPlugin(ctx context.Context, r wazero.Runtime, code []byte) (*wasmModule, error) {
// Combine the above into our baseline config, overriding defaults (which discard stdout and have no file system).
config := wazero.NewModuleConfig().WithStdout(os.Stdout).WithFS(memoryfs.New())
// Create an empty namespace so that multiple modules will not conflict
ns := r.NewNamespace(ctx)
// Instantiate a Go-defined module named "env" that exports functions.
_, err := r.NewModuleBuilder("env").
ExportMemory("mem", 100).
ExportFunctions(exportFunctions).
Instantiate(ctx, ns)
if err != nil {
return nil, xerrors.Errorf("wasm module build error: %w", err)
}
if _, err = wasi.NewBuilder(r).Instantiate(ctx, ns); err != nil {
return nil, xerrors.Errorf("WASI init error: %w", err)
}
// Compile the WebAssembly module using the default configuration.
compiled, err := r.CompileModule(ctx, code, wazero.NewCompileConfig())
if err != nil {
return nil, xerrors.Errorf("module compile error: %w", err)
}
// InstantiateModule runs the "_start" function which is what TinyGo compiles "main" to.
mod, err := ns.InstantiateModule(ctx, compiled, config)
if err != nil {
return nil, xerrors.Errorf("module init error: %w", err)
}
// These are undocumented, but exported. See tinygo-org/tinygo#2788
// TODO: improve TinyGo specific code
malloc := mod.ExportedFunction("malloc")
free := mod.ExportedFunction("free")
// Get a module name
name, err := moduleName(ctx, mod)
if err != nil {
return nil, xerrors.Errorf("failed to get a module name: %w", err)
}
// Get a module version
version, err := moduleVersion(ctx, mod)
if err != nil {
return nil, xerrors.Errorf("failed to get a module version: %w", err)
}
// Get a module API version
apiVersion, err := moduleAPIVersion(ctx, mod)
if err != nil {
return nil, xerrors.Errorf("failed to get a module version: %w", err)
}
if apiVersion != tapi.Version {
log.Logger.Infof("Ignore %s@v%d module due to API version mismatch, got: %d, want: %d",
name, version, apiVersion, tapi.Version)
return nil, nil
}
isAnalyzer, err := moduleIsAnalyzer(ctx, mod)
if err != nil {
return nil, xerrors.Errorf("failed to check if the module is an analyzer: %w", err)
}
isPostScanner, err := moduleIsPostScanner(ctx, mod)
if err != nil {
return nil, xerrors.Errorf("failed to check if the module is a post scanner: %w", err)
}
// Get exported functions by WASM module
analyzeFunc := mod.ExportedFunction("analyze")
if analyzeFunc == nil {
return nil, xerrors.New("analyze() must be exported")
}
postScanFunc := mod.ExportedFunction("post_scan")
if postScanFunc == nil {
return nil, xerrors.New("post_scan() must be exported")
}
var requiredFiles []*regexp.Regexp
if isAnalyzer {
// Get required files
requiredFiles, err = moduleRequiredFiles(ctx, mod)
if err != nil {
return nil, xerrors.Errorf("failed to get required files: %w", err)
}
}
var postScanSpec serialize.PostScanSpec
if isPostScanner {
// This spec defines how the module works in post scanning like INSERT, UPDATE and DELETE.
postScanSpec, err = modulePostScanSpec(ctx, mod)
if err != nil {
return nil, xerrors.Errorf("failed to get a post scan spec: %w", err)
}
}
return &wasmModule{
mod: mod,
name: name,
version: version,
requiredFiles: requiredFiles,
isAnalyzer: isAnalyzer,
isPostScanner: isPostScanner,
postScanSpec: postScanSpec,
analyze: analyzeFunc,
postScan: postScanFunc,
malloc: malloc,
free: free,
}, nil
}
func (m *wasmModule) Register() {
log.Logger.Infof("Registering WASM module: %s@v%d", m.name, m.version)
if m.isAnalyzer {
log.Logger.Debugf("Registering custom analyzer in %s@v%d", m.name, m.version)
analyzer.RegisterAnalyzer(m)
}
if m.isPostScanner {
log.Logger.Debugf("Registering custom post scanner in %s@v%d", m.name, m.version)
post.RegisterPostScanner(m)
}
}
func (m *wasmModule) Close(ctx context.Context) error {
return m.mod.Close(ctx)
}
func (m *wasmModule) Type() analyzer.Type {
return analyzer.Type(m.name)
}
func (m *wasmModule) Name() string {
return m.name
}
func (m *wasmModule) Version() int {
return m.version
}
func (m *wasmModule) Required(filePath string, _ os.FileInfo) bool {
for _, r := range m.requiredFiles {
if r.MatchString(filePath) {
return true
}
}
return false
}
func (m *wasmModule) Analyze(ctx context.Context, input analyzer.AnalysisInput) (*analyzer.AnalysisResult, error) {
filePath := "/" + filepath.ToSlash(input.FilePath)
log.Logger.Debugf("Module %s: analyzing %s...", m.name, filePath)
memfs := memoryfs.New()
if err := memfs.MkdirAll(filepath.Dir(filePath), fs.ModePerm); err != nil {
return nil, xerrors.Errorf("memory fs mkdir error: %w", err)
}
err := memfs.WriteLazyFile(filePath, func() (io.Reader, error) {
return input.Content, nil
}, fs.ModePerm)
if err != nil {
return nil, xerrors.Errorf("memory fs write error: %w", err)
}
// Pass memory fs to the analyze() function
ctx, closer := experimental.WithFS(ctx, memfs)
defer closer.Close(ctx)
inputPtr, inputSize, err := stringToPtrSize(ctx, filePath, m.mod, m.malloc)
if err != nil {
return nil, xerrors.Errorf("failed to write string to memory: %w", err)
}
defer m.free.Call(ctx, inputPtr) // nolint: errcheck
analyzeRes, err := m.analyze.Call(ctx, inputPtr, inputSize)
if err != nil {
return nil, xerrors.Errorf("analyze error: %w", err)
} else if len(analyzeRes) != 1 {
return nil, xerrors.New("invalid signature: analyze")
}
var result analyzer.AnalysisResult
if err = unmarshal(ctx, m.mod, analyzeRes[0], &result); err != nil {
return nil, xerrors.Errorf("invalid return value: %w", err)
}
return &result, nil
}
// PostScan performs post scanning
// e.g. Remove a vulnerability, change severity, etc.
func (m *wasmModule) PostScan(ctx context.Context, results types.Results) (types.Results, error) {
// Find custom resources
var custom serialize.Result
for _, result := range results {
if result.Class == types.ClassCustom {
custom = serialize.Result(result)
break
}
}
arg := serialize.Results{custom}
switch m.postScanSpec.Action {
case tapi.ActionUpdate, tapi.ActionDelete:
// Pass the relevant results to the module
arg = append(arg, findIDs(m.postScanSpec.IDs, results)...)
}
// Marshal the argument into WASM memory so that the WASM module can read it.
inputPtr, inputSize, err := marshal(ctx, m.mod, m.malloc, arg)
if err != nil {
return nil, xerrors.Errorf("post scan marshal error: %w", err)
}
defer m.free.Call(ctx, inputPtr) //nolint: errcheck
analyzeRes, err := m.postScan.Call(ctx, inputPtr, inputSize)
if err != nil {
return nil, xerrors.Errorf("post scan invocation error: %w", err)
} else if len(analyzeRes) != 1 {
return nil, xerrors.New("invalid signature: post_scan")
}
var got types.Results
if err = unmarshal(ctx, m.mod, analyzeRes[0], &got); err != nil {
return nil, xerrors.Errorf("post scan unmarshal error: %w", err)
}
switch m.postScanSpec.Action {
case tapi.ActionInsert:
results = append(results, lo.Filter(got, func(r types.Result, _ int) bool {
return r.Class != types.ClassCustom && r.Class != ""
})...)
case tapi.ActionUpdate:
updateResults(got, results)
case tapi.ActionDelete:
deleteResults(got, results)
}
return results, nil
}
func findIDs(ids []string, results types.Results) serialize.Results {
var filtered serialize.Results
for _, result := range results {
if result.Class == types.ClassCustom {
continue
}
vulns := lo.Filter(result.Vulnerabilities, func(v types.DetectedVulnerability, _ int) bool {
return slices.Contains(ids, v.VulnerabilityID)
})
misconfs := lo.Filter(result.Misconfigurations, func(m types.DetectedMisconfiguration, _ int) bool {
return slices.Contains(ids, m.ID)
})
if len(vulns) > 0 || len(misconfs) > 0 {
filtered = append(filtered, serialize.Result{
Target: result.Target,
Class: result.Class,
Type: result.Type,
Vulnerabilities: vulns,
Misconfigurations: misconfs,
})
}
}
return filtered
}
func updateResults(gotResults, results types.Results) {
for _, g := range gotResults {
for i, result := range results {
if g.Target == result.Target && g.Class == result.Class && g.Type == result.Type {
results[i].Vulnerabilities = lo.Map(result.Vulnerabilities, func(v types.DetectedVulnerability, _ int) types.DetectedVulnerability {
// Update vulnerabilities in the existing result
for _, got := range g.Vulnerabilities {
if got.VulnerabilityID == v.VulnerabilityID && got.PkgName == v.PkgName &&
got.PkgPath == v.PkgPath && got.InstalledVersion == v.InstalledVersion {
// Override vulnerability details
v.SeveritySource = got.SeveritySource
v.Vulnerability = got.Vulnerability
}
}
return v
})
results[i].Misconfigurations = lo.Map(result.Misconfigurations, func(m types.DetectedMisconfiguration, _ int) types.DetectedMisconfiguration {
// Update misconfigurations in the existing result
for _, got := range g.Misconfigurations {
if got.ID == m.ID &&
got.CauseMetadata.StartLine == m.CauseMetadata.StartLine &&
got.CauseMetadata.EndLine == m.CauseMetadata.EndLine {
// Override misconfiguration details
m.CauseMetadata = got.CauseMetadata
m.Severity = got.Severity
m.Status = got.Status
}
}
return m
})
}
}
}
}
func deleteResults(gotResults, results types.Results) {
for _, gotResult := range gotResults {
for i, result := range results {
// Remove vulnerabilities and misconfigurations from the existing result
if gotResult.Target == result.Target && gotResult.Class == result.Class && gotResult.Type == result.Type {
results[i].Vulnerabilities = lo.Reject(result.Vulnerabilities, func(v types.DetectedVulnerability, _ int) bool {
for _, got := range gotResult.Vulnerabilities {
if got.VulnerabilityID == v.VulnerabilityID && got.PkgName == v.PkgName &&
got.PkgPath == v.PkgPath && got.InstalledVersion == v.InstalledVersion {
return true
}
}
return false
})
results[i].Misconfigurations = lo.Reject(result.Misconfigurations, func(v types.DetectedMisconfiguration, _ int) bool {
for _, got := range gotResult.Misconfigurations {
if got.ID == v.ID && got.Status == v.Status &&
got.CauseMetadata.StartLine == v.CauseMetadata.StartLine &&
got.CauseMetadata.EndLine == v.CauseMetadata.EndLine {
return true
}
}
return false
})
}
}
}
}
func moduleName(ctx context.Context, mod api.Module) (string, error) {
nameFunc := mod.ExportedFunction("name")
if nameFunc == nil {
return "", xerrors.New("name() must be exported")
}
nameRes, err := nameFunc.Call(ctx)
if err != nil {
return "", xerrors.Errorf("wasm function name() invocation error: %w", err)
} else if len(nameRes) != 1 {
return "", xerrors.New("invalid signature: name()")
}
name, err := ptrSizeToString(ctx, mod, nameRes[0])
if err != nil {
return "", xerrors.Errorf("invalid return value: %w", err)
}
return name, nil
}
func moduleVersion(ctx context.Context, mod api.Module) (int, error) {
versionFunc := mod.ExportedFunction("version")
if versionFunc == nil {
return 0, xerrors.New("version() must be exported")
}
versionRes, err := versionFunc.Call(ctx)
if err != nil {
return 0, xerrors.Errorf("wasm function version() invocation error: %w", err)
} else if len(versionRes) != 1 {
return 0, xerrors.New("invalid signature: version")
}
return int(versionRes[0]), nil
}
func moduleAPIVersion(ctx context.Context, mod api.Module) (int, error) {
versionFunc := mod.ExportedFunction("api_version")
if versionFunc == nil {
return 0, xerrors.New("api_version() must be exported")
}
versionRes, err := versionFunc.Call(ctx)
if err != nil {
return 0, xerrors.Errorf("wasm function api_version() invocation error: %w", err)
} else if len(versionRes) != 1 {
return 0, xerrors.New("invalid signature: api_version")
}
return int(versionRes[0]), nil
}
func moduleRequiredFiles(ctx context.Context, mod api.Module) ([]*regexp.Regexp, error) {
requiredFilesFunc := mod.ExportedFunction("required")
if requiredFilesFunc == nil {
return nil, xerrors.New("required() must be exported")
}
requiredFilesRes, err := requiredFilesFunc.Call(ctx)
if err != nil {
return nil, xerrors.Errorf("wasm function required() invocation error: %w", err)
} else if len(requiredFilesRes) != 1 {
return nil, xerrors.New("invalid signature: required_files")
}
var fileRegexps serialize.StringSlice
if err = unmarshal(ctx, mod, requiredFilesRes[0], &fileRegexps); err != nil {
return nil, xerrors.Errorf("invalid return value: %w", err)
}
var requiredFiles []*regexp.Regexp
for _, file := range fileRegexps {
re, err := regexp.Compile(file)
if err != nil {
return nil, xerrors.Errorf("regexp compile error: %w", err)
}
requiredFiles = append(requiredFiles, re)
}
return requiredFiles, nil
}
func moduleIsAnalyzer(ctx context.Context, mod api.Module) (bool, error) {
return isType(ctx, mod, "is_analyzer")
}
func moduleIsPostScanner(ctx context.Context, mod api.Module) (bool, error) {
return isType(ctx, mod, "is_post_scanner")
}
func isType(ctx context.Context, mod api.Module, name string) (bool, error) {
isFunc := mod.ExportedFunction(name)
if isFunc == nil {
return false, xerrors.Errorf("%s() must be exported", name)
}
isRes, err := isFunc.Call(ctx)
if err != nil {
return false, xerrors.Errorf("wasm function %s() invocation error: %w", name, err)
} else if len(isRes) != 1 {
return false, xerrors.Errorf("invalid signature: %s", name)
}
return isRes[0] > 0, nil
}
func dir() string {
return filepath.Join(utils.HomeDir(), RelativeDir)
}
func modulePostScanSpec(ctx context.Context, mod api.Module) (serialize.PostScanSpec, error) {
postScanSpecFunc := mod.ExportedFunction("post_scan_spec")
if postScanSpecFunc == nil {
return serialize.PostScanSpec{}, xerrors.New("post_scan_spec() must be exported")
}
postScanSpecRes, err := postScanSpecFunc.Call(ctx)
if err != nil {
return serialize.PostScanSpec{}, xerrors.Errorf("wasm function post_scan_spec() invocation error: %w", err)
} else if len(postScanSpecRes) != 1 {
return serialize.PostScanSpec{}, xerrors.New("invalid signature: post_scan_spec")
}
var spec serialize.PostScanSpec
if err = unmarshal(ctx, mod, postScanSpecRes[0], &spec); err != nil {
return serialize.PostScanSpec{}, xerrors.Errorf("invalid return value: %w", err)
}
return spec, nil
}