refactor: add hook interface for extended functionality (#8585)

This commit is contained in:
Teppei Fukuda
2025-04-08 15:49:16 +04:00
committed by GitHub
parent 9dcd06fda7
commit a0dc3b688e
14 changed files with 795 additions and 198 deletions

View File

@@ -3,12 +3,12 @@
package integration
import (
"github.com/aquasecurity/trivy/pkg/types"
"path/filepath"
"testing"
"github.com/aquasecurity/trivy/pkg/extension"
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
"github.com/aquasecurity/trivy/pkg/scan/post"
"github.com/aquasecurity/trivy/pkg/types"
)
func TestModule(t *testing.T) {
@@ -52,7 +52,7 @@ func TestModule(t *testing.T) {
t.Cleanup(func() {
analyzer.DeregisterAnalyzer("spring4shell")
post.DeregisterPostScanner("spring4shell")
extension.DeregisterHook("spring4shell")
})
// Run Trivy

96
internal/hooktest/hook.go Normal file
View File

@@ -0,0 +1,96 @@
package hooktest
import (
"context"
"errors"
"testing"
"github.com/aquasecurity/trivy/pkg/extension"
"github.com/aquasecurity/trivy/pkg/flag"
"github.com/aquasecurity/trivy/pkg/types"
)
type testHook struct{}
func (*testHook) Name() string {
return "test"
}
func (*testHook) Version() int {
return 1
}
// RunHook implementation
func (*testHook) PreRun(ctx context.Context, opts flag.Options) error {
if opts.GlobalOptions.ConfigFile == "bad-config" {
return errors.New("bad pre-run")
}
return nil
}
func (*testHook) PostRun(ctx context.Context, opts flag.Options) error {
if opts.GlobalOptions.ConfigFile == "bad-config" {
return errors.New("bad post-run")
}
return nil
}
// ScanHook implementation
func (*testHook) PreScan(ctx context.Context, target *types.ScanTarget, options types.ScanOptions) error {
if target.Name == "bad-pre" {
return errors.New("bad pre-scan")
}
target.Name += " (pre-scan)"
return nil
}
func (*testHook) PostScan(ctx context.Context, results types.Results) (types.Results, error) {
for i, r := range results {
if r.Target == "bad" {
return nil, errors.New("bad")
}
for j := range r.Vulnerabilities {
results[i].Vulnerabilities[j].References = []string{
"https://example.com/post-scan",
}
}
}
return results, nil
}
// ReportHook implementation
func (*testHook) PreReport(ctx context.Context, report *types.Report, opts flag.Options) error {
if report.ArtifactName == "bad-report" {
return errors.New("bad pre-report")
}
// Modify the report
for i := range report.Results {
for j := range report.Results[i].Vulnerabilities {
report.Results[i].Vulnerabilities[j].Title = "Modified by pre-report hook"
}
}
return nil
}
func (*testHook) PostReport(ctx context.Context, report *types.Report, opts flag.Options) error {
if report.ArtifactName == "bad-report" {
return errors.New("bad post-report")
}
// Modify the report
for i := range report.Results {
for j := range report.Results[i].Vulnerabilities {
report.Results[i].Vulnerabilities[j].Description = "Modified by post-report hook"
}
}
return nil
}
func Init(t *testing.T) {
h := &testHook{}
extension.RegisterHook(h)
t.Cleanup(func() {
extension.DeregisterHook(h.Name())
})
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/aquasecurity/trivy/pkg/cache"
"github.com/aquasecurity/trivy/pkg/commands/operation"
"github.com/aquasecurity/trivy/pkg/db"
"github.com/aquasecurity/trivy/pkg/extension"
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
"github.com/aquasecurity/trivy/pkg/fanal/artifact"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
@@ -277,7 +278,6 @@ func (r *runner) Report(ctx context.Context, opts flag.Options, report types.Rep
if err := pkgReport.Write(ctx, report, opts); err != nil {
return xerrors.Errorf("unable to write results: %w", err)
}
return nil
}
@@ -375,12 +375,32 @@ func Run(ctx context.Context, opts flag.Options, targetKind TargetKind) (err err
return v.SafeWriteConfigAs("trivy-default.yaml")
}
// Call pre-run hooks
if err := extension.PreRun(ctx, opts); err != nil {
return xerrors.Errorf("pre run error: %w", err)
}
// Run the application
report, err := run(ctx, opts, targetKind)
if err != nil {
return xerrors.Errorf("run error: %w", err)
}
// Call post-run hooks
if err := extension.PostRun(ctx, opts); err != nil {
return xerrors.Errorf("post run error: %w", err)
}
return operation.Exit(opts, report.Results.Failed(), report.Metadata)
}
func run(ctx context.Context, opts flag.Options, targetKind TargetKind) (types.Report, error) {
r, err := NewRunner(ctx, opts)
if err != nil {
if errors.Is(err, SkipScan) {
return nil
return types.Report{}, nil
}
return xerrors.Errorf("init error: %w", err)
return types.Report{}, xerrors.Errorf("init error: %w", err)
}
defer r.Close(ctx)
@@ -395,24 +415,27 @@ func Run(ctx context.Context, opts flag.Options, targetKind TargetKind) (err err
scanFunction, exists := scans[targetKind]
if !exists {
return xerrors.Errorf("unknown target kind: %s", targetKind)
return types.Report{}, xerrors.Errorf("unknown target kind: %s", targetKind)
}
// 1. Scan the artifact
report, err := scanFunction(ctx, opts)
if err != nil {
return xerrors.Errorf("%s scan error: %w", targetKind, err)
return types.Report{}, xerrors.Errorf("%s scan error: %w", targetKind, err)
}
// 2. Filter the results
report, err = r.Filter(ctx, opts, report)
if err != nil {
return xerrors.Errorf("filter error: %w", err)
return types.Report{}, xerrors.Errorf("filter error: %w", err)
}
// 3. Report the results
if err = r.Report(ctx, opts, report); err != nil {
return xerrors.Errorf("report error: %w", err)
return types.Report{}, xerrors.Errorf("report error: %w", err)
}
return operation.Exit(opts, report.Results.Failed(), report.Metadata)
return report, nil
}
func disabledAnalyzers(opts flag.Options) []analyzer.Type {

162
pkg/extension/hook.go Normal file
View File

@@ -0,0 +1,162 @@
package extension
import (
"context"
"sort"
"github.com/samber/lo"
"golang.org/x/xerrors"
"github.com/aquasecurity/trivy/pkg/flag"
"github.com/aquasecurity/trivy/pkg/types"
)
var hooks = make(map[string]Hook)
func RegisterHook(s Hook) {
// Avoid duplication
hooks[s.Name()] = s
}
func DeregisterHook(name string) {
delete(hooks, name)
}
// Hook is an interface that defines the methods for a hook.
type Hook interface {
// Name returns the name of the extension.
Name() string
}
// RunHook is a extension that is called before and after all the processes.
type RunHook interface {
Hook
// PreRun is called before all the processes.
PreRun(ctx context.Context, opts flag.Options) error
// PostRun is called after all the processes.
PostRun(ctx context.Context, opts flag.Options) error
}
// ScanHook is a extension that is called before and after the scan.
type ScanHook interface {
Hook
// PreScan is called before the scan. It can modify the scan target.
// It may be called on the server side in client/server mode.
PreScan(ctx context.Context, target *types.ScanTarget, opts types.ScanOptions) error
// PostScan is called after the scan. It can modify the results.
// It may be called on the server side in client/server mode.
// NOTE: Wasm modules cannot directly modify the passed results,
// so it returns a copy of the results.
PostScan(ctx context.Context, results types.Results) (types.Results, error)
}
// ReportHook is a extension that is called before and after the report is written.
type ReportHook interface {
Hook
// PreReport is called before the report is written.
// It can modify the report. It is called on the client side.
PreReport(ctx context.Context, report *types.Report, opts flag.Options) error
// PostReport is called after the report is written.
// It can modify the report. It is called on the client side.
PostReport(ctx context.Context, report *types.Report, opts flag.Options) error
}
func PreRun(ctx context.Context, opts flag.Options) error {
for _, e := range Hooks() {
h, ok := e.(RunHook)
if !ok {
continue
}
if err := h.PreRun(ctx, opts); err != nil {
return xerrors.Errorf("%s pre run error: %w", e.Name(), err)
}
}
return nil
}
// PostRun is a hook that is called after all the processes.
func PostRun(ctx context.Context, opts flag.Options) error {
for _, e := range Hooks() {
h, ok := e.(RunHook)
if !ok {
continue
}
if err := h.PostRun(ctx, opts); err != nil {
return xerrors.Errorf("%s post run error: %w", e.Name(), err)
}
}
return nil
}
// PreScan is a hook that is called before the scan.
func PreScan(ctx context.Context, target *types.ScanTarget, options types.ScanOptions) error {
for _, e := range Hooks() {
h, ok := e.(ScanHook)
if !ok {
continue
}
if err := h.PreScan(ctx, target, options); err != nil {
return xerrors.Errorf("%s pre scan error: %w", e.Name(), err)
}
}
return nil
}
// PostScan is a hook that is called after the scan.
func PostScan(ctx context.Context, results types.Results) (types.Results, error) {
var err error
for _, e := range Hooks() {
h, ok := e.(ScanHook)
if !ok {
continue
}
results, err = h.PostScan(ctx, results)
if err != nil {
return nil, xerrors.Errorf("%s post scan error: %w", e.Name(), err)
}
}
return results, nil
}
// PreReport is a hook that is called before the report is written.
func PreReport(ctx context.Context, report *types.Report, opts flag.Options) error {
for _, e := range Hooks() {
h, ok := e.(ReportHook)
if !ok {
continue
}
if err := h.PreReport(ctx, report, opts); err != nil {
return xerrors.Errorf("%s pre report error: %w", e.Name(), err)
}
}
return nil
}
// PostReport is a hook that is called after the report is written.
func PostReport(ctx context.Context, report *types.Report, opts flag.Options) error {
for _, e := range Hooks() {
h, ok := e.(ReportHook)
if !ok {
continue
}
if err := h.PostReport(ctx, report, opts); err != nil {
return xerrors.Errorf("%s post report error: %w", e.Name(), err)
}
}
return nil
}
// Hooks returns the list of hooks.
func Hooks() []Hook {
hooks := lo.Values(hooks)
sort.Slice(hooks, func(i, j int) bool {
return hooks[i].Name() < hooks[j].Name()
})
return hooks
}

278
pkg/extension/hook_test.go Normal file
View File

@@ -0,0 +1,278 @@
package extension_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
"github.com/aquasecurity/trivy/internal/hooktest"
"github.com/aquasecurity/trivy/pkg/extension"
"github.com/aquasecurity/trivy/pkg/flag"
"github.com/aquasecurity/trivy/pkg/types"
)
func TestPostScan(t *testing.T) {
tests := []struct {
name string
results types.Results
want types.Results
wantErr bool
}{
{
name: "happy path",
results: types.Results{
{
Target: "test",
Vulnerabilities: []types.DetectedVulnerability{
{
VulnerabilityID: "CVE-2022-0001",
PkgName: "musl",
InstalledVersion: "1.2.3",
FixedVersion: "1.2.4",
Vulnerability: dbTypes.Vulnerability{
Severity: "CRITICAL",
},
},
},
},
},
want: types.Results{
{
Target: "test",
Vulnerabilities: []types.DetectedVulnerability{
{
VulnerabilityID: "CVE-2022-0001",
PkgName: "musl",
InstalledVersion: "1.2.3",
FixedVersion: "1.2.4",
Vulnerability: dbTypes.Vulnerability{
Severity: "CRITICAL",
References: []string{
"https://example.com/post-scan",
},
},
},
},
},
},
},
{
name: "sad path",
results: types.Results{
{
Target: "bad",
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Initialize the test hook
hooktest.Init(t)
results, err := extension.PostScan(t.Context(), tt.results)
require.Equal(t, tt.wantErr, err != nil)
assert.Equal(t, tt.want, results)
})
}
}
func TestPreScan(t *testing.T) {
tests := []struct {
name string
target *types.ScanTarget
options types.ScanOptions
wantErr bool
}{
{
name: "happy path",
target: &types.ScanTarget{
Name: "test",
},
wantErr: false,
},
{
name: "sad path",
target: &types.ScanTarget{
Name: "bad-pre",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Initialize the test hook
hooktest.Init(t)
err := extension.PreScan(t.Context(), tt.target, tt.options)
require.Equal(t, tt.wantErr, err != nil)
})
}
}
func TestPreRun(t *testing.T) {
tests := []struct {
name string
opts flag.Options
wantErr bool
}{
{
name: "happy path",
opts: flag.Options{},
wantErr: false,
},
{
name: "sad path",
opts: flag.Options{
GlobalOptions: flag.GlobalOptions{
ConfigFile: "bad-config",
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Initialize the test hook
hooktest.Init(t)
err := extension.PreRun(t.Context(), tt.opts)
require.Equal(t, tt.wantErr, err != nil)
})
}
}
func TestPostRun(t *testing.T) {
tests := []struct {
name string
opts flag.Options
wantErr bool
}{
{
name: "happy path",
opts: flag.Options{},
wantErr: false,
},
{
name: "sad path",
opts: flag.Options{
GlobalOptions: flag.GlobalOptions{
ConfigFile: "bad-config",
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Initialize the test extension
hooktest.Init(t)
err := extension.PostRun(t.Context(), tt.opts)
require.Equal(t, tt.wantErr, err != nil)
})
}
}
func TestPreReport(t *testing.T) {
tests := []struct {
name string
report *types.Report
opts flag.Options
wantTitle string
wantErr bool
}{
{
name: "happy path",
report: &types.Report{
Results: types.Results{
{
Vulnerabilities: []types.DetectedVulnerability{
{
VulnerabilityID: "CVE-2022-0001",
},
},
},
},
},
wantTitle: "Modified by pre-report hook",
wantErr: false,
},
{
name: "sad path",
report: &types.Report{
ArtifactName: "bad-report",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Initialize the test hook
hooktest.Init(t)
err := extension.PreReport(t.Context(), tt.report, tt.opts)
if tt.wantErr {
require.Error(t, err)
return
}
require.Len(t, tt.report.Results, 1)
require.Len(t, tt.report.Results[0].Vulnerabilities, 1)
assert.Equal(t, tt.wantTitle, tt.report.Results[0].Vulnerabilities[0].Title)
})
}
}
func TestPostReport(t *testing.T) {
tests := []struct {
name string
report *types.Report
opts flag.Options
wantDescription string
wantErr bool
}{
{
name: "happy path",
report: &types.Report{
Results: types.Results{
{
Vulnerabilities: []types.DetectedVulnerability{
{
VulnerabilityID: "CVE-2022-0001",
},
},
},
},
},
wantDescription: "Modified by post-report hook",
wantErr: false,
},
{
name: "sad path",
report: &types.Report{
ArtifactName: "bad-report",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Initialize the test hook
hooktest.Init(t)
err := extension.PostReport(t.Context(), tt.report, tt.opts)
if tt.wantErr {
require.Error(t, err)
return
}
require.Len(t, tt.report.Results, 1)
require.Len(t, tt.report.Results[0].Vulnerabilities, 1)
assert.Equal(t, tt.wantDescription, tt.report.Results[0].Vulnerabilities[0].Description)
})
}
}

View File

@@ -1,7 +1,9 @@
package flag
import (
"github.com/aquasecurity/trivy/pkg/module"
"path/filepath"
"github.com/aquasecurity/trivy/pkg/utils/fsutils"
)
// e.g. config yaml
@@ -14,7 +16,7 @@ var (
ModuleDirFlag = Flag[string]{
Name: "module-dir",
ConfigName: "module.dir",
Default: module.DefaultDir,
Default: filepath.Join(fsutils.HomeDir(), ".trivy", "modules"),
Usage: "specify directory to the wasm modules that will be loaded",
Persistent: true,
}

View File

@@ -17,13 +17,12 @@ import (
wasi "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
"golang.org/x/xerrors"
"github.com/aquasecurity/trivy/pkg/extension"
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
"github.com/aquasecurity/trivy/pkg/log"
tapi "github.com/aquasecurity/trivy/pkg/module/api"
"github.com/aquasecurity/trivy/pkg/module/serialize"
"github.com/aquasecurity/trivy/pkg/scan/post"
"github.com/aquasecurity/trivy/pkg/types"
"github.com/aquasecurity/trivy/pkg/utils/fsutils"
)
var (
@@ -33,10 +32,6 @@ var (
"warn": logWarn,
"error": logError,
}
RelativeDir = filepath.Join(".trivy", "modules")
DefaultDir = dir()
)
// logDebug is defined as an api.GoModuleFunc for lower overhead vs reflection.
@@ -172,7 +167,7 @@ func (m *Manager) Register() {
func (m *Manager) Deregister() {
for _, mod := range m.modules {
analyzer.DeregisterAnalyzer(analyzer.Type(mod.Name()))
post.DeregisterPostScanner(mod.Name())
extension.DeregisterHook(mod.Name())
}
}
@@ -262,6 +257,8 @@ func marshal(ctx context.Context, m api.Module, malloc api.Function, v any) (uin
return ptr, size, nil
}
var _ extension.ScanHook = (*wasmModule)(nil)
type wasmModule struct {
mod api.Module
memFS *memFS
@@ -416,7 +413,7 @@ func (m *wasmModule) Register() {
}
if m.isPostScanner {
logger.Debug("Registering custom post scanner")
post.RegisterPostScanner(m)
extension.RegisterHook(m)
}
}
@@ -486,8 +483,11 @@ func (m *wasmModule) Analyze(ctx context.Context, input analyzer.AnalysisInput)
return &result, nil
}
// PostScan performs post scanning
// e.g. Remove a vulnerability, change severity, etc.
func (m *wasmModule) PreScan(ctx context.Context, target *types.ScanTarget, options types.ScanOptions) error {
// TODO: Implement
return nil
}
func (m *wasmModule) PostScan(ctx context.Context, results types.Results) (types.Results, error) {
// Find custom resources
var custom types.Result
@@ -746,10 +746,6 @@ func isType(ctx context.Context, mod api.Module, name string) (bool, error) {
return isRes[0] > 0, nil
}
func dir() string {
return filepath.Join(fsutils.HomeDir(), RelativeDir)
}
func modulePostScanSpec(ctx context.Context, mod api.Module, freeFn api.Function) (serialize.PostScanSpec, error) {
postScanSpecFunc := mod.ExportedFunction("post_scan_spec")
if postScanSpecFunc == nil {

View File

@@ -6,12 +6,13 @@ import (
"runtime"
"testing"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/aquasecurity/trivy/pkg/extension"
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
"github.com/aquasecurity/trivy/pkg/module"
"github.com/aquasecurity/trivy/pkg/scan/post"
)
func TestManager_Register(t *testing.T) {
@@ -20,12 +21,12 @@ func TestManager_Register(t *testing.T) {
t.Skip("Test satisfied adequately by Linux tests")
}
tests := []struct {
name string
moduleDir string
enabledModules []string
wantAnalyzerVersions analyzer.Versions
wantPostScannerVersions map[string]int
wantErr bool
name string
moduleDir string
enabledModules []string
wantAnalyzerVersions analyzer.Versions
wantExtentions []string
wantErr bool
}{
{
name: "happy path",
@@ -36,8 +37,8 @@ func TestManager_Register(t *testing.T) {
},
PostAnalyzers: make(map[string]int),
},
wantPostScannerVersions: map[string]int{
"happy": 1,
wantExtentions: []string{
"happy",
},
},
{
@@ -49,7 +50,7 @@ func TestManager_Register(t *testing.T) {
},
PostAnalyzers: make(map[string]int),
},
wantPostScannerVersions: make(map[string]int),
wantExtentions: []string{},
},
{
name: "only post scanner",
@@ -58,8 +59,8 @@ func TestManager_Register(t *testing.T) {
Analyzers: make(map[string]int),
PostAnalyzers: make(map[string]int),
},
wantPostScannerVersions: map[string]int{
"scanner": 2,
wantExtentions: []string{
"scanner",
},
},
{
@@ -69,7 +70,7 @@ func TestManager_Register(t *testing.T) {
Analyzers: make(map[string]int),
PostAnalyzers: make(map[string]int),
},
wantPostScannerVersions: make(map[string]int),
wantExtentions: []string{},
},
{
name: "pass enabled modules",
@@ -85,8 +86,8 @@ func TestManager_Register(t *testing.T) {
},
PostAnalyzers: make(map[string]int),
},
wantPostScannerVersions: map[string]int{
"happy": 1,
wantExtentions: []string{
"happy",
},
},
}
@@ -124,9 +125,10 @@ func TestManager_Register(t *testing.T) {
got := a.AnalyzerVersions()
assert.Equal(t, tt.wantAnalyzerVersions, got)
// Confirm the post scanner is registered
gotScannerVersions := post.ScannerVersions()
assert.Equal(t, tt.wantPostScannerVersions, gotScannerVersions)
hookNames := lo.Map(extension.Hooks(), func(hook extension.Hook, _ int) string {
return hook.Name()
})
assert.Equal(t, tt.wantExtentions, hookNames)
})
}
}

View File

@@ -9,6 +9,7 @@ import (
"golang.org/x/xerrors"
cr "github.com/aquasecurity/trivy/pkg/compliance/report"
"github.com/aquasecurity/trivy/pkg/extension"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/flag"
"github.com/aquasecurity/trivy/pkg/log"
@@ -26,6 +27,11 @@ const (
// Write writes the result to output, format as passed in argument
func Write(ctx context.Context, report types.Report, option flag.Options) (err error) {
// Call pre-report hooks
if err := extension.PreReport(ctx, &report, option); err != nil {
return xerrors.Errorf("pre report error: %w", err)
}
output, cleanup, err := option.OutputWriter(ctx)
if err != nil {
return xerrors.Errorf("failed to create a file: %w", err)
@@ -106,6 +112,11 @@ func Write(ctx context.Context, report types.Report, option flag.Options) (err e
return xerrors.Errorf("failed to write results: %w", err)
}
// Call post-report hooks
if err := extension.PostReport(ctx, &report, option); err != nil {
return xerrors.Errorf("post report error: %w", err)
}
return nil
}

View File

@@ -1,10 +1,16 @@
package report_test
import (
"bytes"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
"github.com/aquasecurity/trivy/internal/hooktest"
"github.com/aquasecurity/trivy/pkg/flag"
"github.com/aquasecurity/trivy/pkg/report"
"github.com/aquasecurity/trivy/pkg/types"
)
@@ -82,3 +88,93 @@ func TestResults_Failed(t *testing.T) {
})
}
}
func TestWrite(t *testing.T) {
testReport := types.Report{
SchemaVersion: report.SchemaVersion,
ArtifactName: "test-artifact",
Results: types.Results{
{
Target: "test-target",
Vulnerabilities: []types.DetectedVulnerability{
{
VulnerabilityID: "CVE-2021-0001",
PkgName: "test-pkg",
Vulnerability: dbTypes.Vulnerability{
Title: "Test Vulnerability Title",
Description: "This is a test description of a vulnerability",
},
},
},
},
},
}
testTemplate := "{{ range . }}{{ range .Vulnerabilities }}- {{ .VulnerabilityID }}: {{ .Title }}\n {{ .Description }}\n{{ end }}{{ end }}"
tests := []struct {
name string
setUpHook bool
report types.Report
options flag.Options
wantOutput string
wantTitle string // Expected title after function call
wantDesc string // Expected description after function call
}{
{
name: "template with title and description",
report: testReport,
options: flag.Options{
ReportOptions: flag.ReportOptions{
Format: types.FormatTemplate,
Template: testTemplate,
},
},
wantOutput: "- CVE-2021-0001: Test Vulnerability Title\n This is a test description of a vulnerability\n",
wantTitle: "Test Vulnerability Title", // Should remain unchanged
wantDesc: "This is a test description of a vulnerability", // Should remain unchanged
},
{
name: "report modified by hooks",
setUpHook: true,
report: testReport,
options: flag.Options{
ReportOptions: flag.ReportOptions{
Format: types.FormatTemplate,
Template: testTemplate,
},
},
// The template output only reflects the pre-report hook changes because
// the post-report hook runs AFTER the output is written.
// However, the report object itself is modified by both pre and post hooks.
wantOutput: "- CVE-2021-0001: Modified by pre-report hook\n This is a test description of a vulnerability\n",
wantTitle: "Modified by pre-report hook",
wantDesc: "Modified by post-report hook",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.setUpHook {
hooktest.Init(t)
}
// Create a buffer to capture the output
output := new(bytes.Buffer)
tt.options.SetOutputWriter(output)
// Execute the Write function
err := report.Write(t.Context(), tt.report, tt.options)
require.NoError(t, err)
// Verify the output matches the expected template rendering
got := output.String()
assert.Equal(t, tt.wantOutput, got, "Template output does not match wanted value")
// Verify that the title and description in the report match the expected values
require.Len(t, tt.report.Results, 1)
require.Len(t, tt.report.Results[0].Vulnerabilities, 1)
assert.Equal(t, tt.wantTitle, tt.report.Results[0].Vulnerabilities[0].Title)
assert.Equal(t, tt.wantDesc, tt.report.Results[0].Vulnerabilities[0].Description)
})
}
}

View File

@@ -15,6 +15,7 @@ import (
dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
ospkgDetector "github.com/aquasecurity/trivy/pkg/detector/ospkg"
"github.com/aquasecurity/trivy/pkg/extension"
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
"github.com/aquasecurity/trivy/pkg/fanal/applier"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
@@ -23,7 +24,6 @@ import (
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/scan/langpkg"
"github.com/aquasecurity/trivy/pkg/scan/ospkg"
"github.com/aquasecurity/trivy/pkg/scan/post"
"github.com/aquasecurity/trivy/pkg/set"
"github.com/aquasecurity/trivy/pkg/types"
"github.com/aquasecurity/trivy/pkg/vulnerability"
@@ -49,7 +49,7 @@ type Service struct {
vulnClient vulnerability.Client
}
// NewService is the factory method for Scanner
// NewService is the factory method for scan service
func NewService(a applier.Applier, osPkgScanner ospkg.Scanner, langPkgScanner langpkg.Scanner,
vulnClient vulnerability.Client) Service {
return Service{
@@ -113,6 +113,11 @@ func (s Service) Scan(ctx context.Context, targetName, artifactKey string, blobK
}
func (s Service) ScanTarget(ctx context.Context, target types.ScanTarget, options types.ScanOptions) (types.Results, ftypes.OS, error) {
// Call pre-scan hooks
if err := extension.PreScan(ctx, &target, options); err != nil {
return nil, ftypes.OS{}, xerrors.Errorf("pre scan error: %w", err)
}
var results types.Results
// Filter packages according to the options
@@ -148,9 +153,8 @@ func (s Service) ScanTarget(ctx context.Context, target types.ScanTarget, option
s.vulnClient.FillInfo(results[i].Vulnerabilities, options.VulnSeveritySources)
}
// Post scanning
results, err = post.Scan(ctx, results)
if err != nil {
// Call post-scan hooks
if results, err = extension.PostScan(ctx, results); err != nil {
return nil, ftypes.OS{}, xerrors.Errorf("post scan error: %w", err)
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/aquasecurity/trivy-db/pkg/db"
dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
"github.com/aquasecurity/trivy/internal/dbtest"
"github.com/aquasecurity/trivy/internal/hooktest"
"github.com/aquasecurity/trivy/pkg/cache"
"github.com/aquasecurity/trivy/pkg/fanal/applier"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
@@ -151,6 +152,7 @@ func TestScanner_Scan(t *testing.T) {
name string
args args
fixtures []string
setUpHook bool
setupCache func(t *testing.T) cache.Cache
wantResults types.Results
wantOS ftypes.OS
@@ -909,6 +911,75 @@ func TestScanner_Scan(t *testing.T) {
Name: "3.11",
},
},
{
name: "happy path with hooks",
args: args{
target: "alpine:latest",
layerIDs: []string{"sha256:5216338b40a7b96416b8b9858974bbe4acc3096ee60acbc4dfb1ee02aecceb10"},
options: types.ScanOptions{
PkgTypes: []string{types.PkgTypeOS},
PkgRelationships: ftypes.Relationships,
Scanners: types.Scanners{types.VulnerabilityScanner},
VulnSeveritySources: []dbTypes.SourceID{"auto"},
},
},
fixtures: []string{"testdata/fixtures/happy.yaml"},
setUpHook: true,
setupCache: func(t *testing.T) cache.Cache {
c := cache.NewMemoryCache()
require.NoError(t, c.PutBlob("sha256:5216338b40a7b96416b8b9858974bbe4acc3096ee60acbc4dfb1ee02aecceb10", ftypes.BlobInfo{
SchemaVersion: ftypes.BlobJSONSchemaVersion,
OS: ftypes.OS{
Family: ftypes.Alpine,
Name: "3.11",
},
PackageInfos: []ftypes.PackageInfo{
{
FilePath: "lib/apk/db/installed",
Packages: []ftypes.Package{muslPkg},
},
},
}))
return c
},
wantResults: types.Results{
{
Target: "alpine:latest (pre-scan) (alpine 3.11)",
Class: types.ClassOSPkg,
Type: ftypes.Alpine,
Packages: ftypes.Packages{
muslPkg,
},
Vulnerabilities: []types.DetectedVulnerability{
{
VulnerabilityID: "CVE-2020-9999",
PkgName: muslPkg.Name,
PkgIdentifier: muslPkg.Identifier,
InstalledVersion: muslPkg.Version,
FixedVersion: "1.2.4",
Status: dbTypes.StatusFixed,
Layer: ftypes.Layer{
DiffID: "sha256:ebf12965380b39889c99a9c02e82ba465f887b45975b6e389d42e9e6a3857888",
},
PrimaryURL: "https://avd.aquasec.com/nvd/cve-2020-9999",
Vulnerability: dbTypes.Vulnerability{
Title: "dos",
Description: "dos vulnerability",
Severity: "HIGH",
References: []string{
"https://example.com/post-scan", // modified by post-scan hook
},
},
},
},
},
},
wantOS: ftypes.OS{
Family: "alpine",
Name: "3.11",
Eosl: true,
},
},
{
name: "happy path with misconfigurations",
args: args{
@@ -1242,6 +1313,10 @@ func TestScanner_Scan(t *testing.T) {
_ = dbtest.InitDB(t, tt.fixtures)
defer db.Close()
if tt.setUpHook {
hooktest.Init(t)
}
c := tt.setupCache(t)
a := applier.NewApplier(c)
s := NewService(a, ospkg.NewScanner(), langpkg.NewScanner(), vulnerability.NewClient(db.Config{}))

View File

@@ -1,45 +0,0 @@
package post
import (
"context"
"golang.org/x/xerrors"
"github.com/aquasecurity/trivy/pkg/types"
)
type Scanner interface {
Name() string
Version() int
PostScan(ctx context.Context, results types.Results) (types.Results, error)
}
func RegisterPostScanner(s Scanner) {
// Avoid duplication
postScanners[s.Name()] = s
}
func DeregisterPostScanner(name string) {
delete(postScanners, name)
}
func ScannerVersions() map[string]int {
versions := make(map[string]int)
for _, s := range postScanners {
versions[s.Name()] = s.Version()
}
return versions
}
var postScanners = make(map[string]Scanner)
func Scan(ctx context.Context, results types.Results) (types.Results, error) {
var err error
for _, s := range postScanners {
results, err = s.PostScan(ctx, results)
if err != nil {
return nil, xerrors.Errorf("%s post scan error: %w", s.Name(), err)
}
}
return results, nil
}

View File

@@ -1,103 +0,0 @@
package post_test
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
"github.com/aquasecurity/trivy/pkg/scan/post"
"github.com/aquasecurity/trivy/pkg/types"
)
type testPostScanner struct{}
func (testPostScanner) Name() string {
return "test"
}
func (testPostScanner) Version() int {
return 1
}
func (testPostScanner) PostScan(ctx context.Context, results types.Results) (types.Results, error) {
for i, r := range results {
if r.Target == "bad" {
return nil, errors.New("bad")
}
for j := range r.Vulnerabilities {
results[i].Vulnerabilities[j].Severity = "LOW"
}
}
return results, nil
}
func TestScan(t *testing.T) {
tests := []struct {
name string
results types.Results
want types.Results
wantErr bool
}{
{
name: "happy path",
results: types.Results{
{
Target: "test",
Vulnerabilities: []types.DetectedVulnerability{
{
VulnerabilityID: "CVE-2022-0001",
PkgName: "musl",
InstalledVersion: "1.2.3",
FixedVersion: "1.2.4",
Vulnerability: dbTypes.Vulnerability{
Severity: "CRITICAL",
},
},
},
},
},
want: types.Results{
{
Target: "test",
Vulnerabilities: []types.DetectedVulnerability{
{
VulnerabilityID: "CVE-2022-0001",
PkgName: "musl",
InstalledVersion: "1.2.3",
FixedVersion: "1.2.4",
Vulnerability: dbTypes.Vulnerability{
Severity: "LOW",
},
},
},
},
},
},
{
name: "sad path",
results: types.Results{
{
Target: "bad",
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := testPostScanner{}
post.RegisterPostScanner(s)
defer func() {
post.DeregisterPostScanner(s.Name())
}()
results, err := post.Scan(t.Context(), tt.results)
require.Equal(t, tt.wantErr, err != nil)
assert.Equal(t, tt.want, results)
})
}
}