refactor: add hook interface for extended functionality (#8585)

This commit is contained in:
Teppei Fukuda
2025-04-08 15:49:16 +04:00
committed by GitHub
parent 9dcd06fda7
commit a0dc3b688e
14 changed files with 795 additions and 198 deletions

162
pkg/extension/hook.go Normal file
View File

@@ -0,0 +1,162 @@
package extension
import (
"context"
"sort"
"github.com/samber/lo"
"golang.org/x/xerrors"
"github.com/aquasecurity/trivy/pkg/flag"
"github.com/aquasecurity/trivy/pkg/types"
)
var hooks = make(map[string]Hook)
func RegisterHook(s Hook) {
// Avoid duplication
hooks[s.Name()] = s
}
func DeregisterHook(name string) {
delete(hooks, name)
}
// Hook is an interface that defines the methods for a hook.
type Hook interface {
// Name returns the name of the extension.
Name() string
}
// RunHook is a extension that is called before and after all the processes.
type RunHook interface {
Hook
// PreRun is called before all the processes.
PreRun(ctx context.Context, opts flag.Options) error
// PostRun is called after all the processes.
PostRun(ctx context.Context, opts flag.Options) error
}
// ScanHook is a extension that is called before and after the scan.
type ScanHook interface {
Hook
// PreScan is called before the scan. It can modify the scan target.
// It may be called on the server side in client/server mode.
PreScan(ctx context.Context, target *types.ScanTarget, opts types.ScanOptions) error
// PostScan is called after the scan. It can modify the results.
// It may be called on the server side in client/server mode.
// NOTE: Wasm modules cannot directly modify the passed results,
// so it returns a copy of the results.
PostScan(ctx context.Context, results types.Results) (types.Results, error)
}
// ReportHook is a extension that is called before and after the report is written.
type ReportHook interface {
Hook
// PreReport is called before the report is written.
// It can modify the report. It is called on the client side.
PreReport(ctx context.Context, report *types.Report, opts flag.Options) error
// PostReport is called after the report is written.
// It can modify the report. It is called on the client side.
PostReport(ctx context.Context, report *types.Report, opts flag.Options) error
}
func PreRun(ctx context.Context, opts flag.Options) error {
for _, e := range Hooks() {
h, ok := e.(RunHook)
if !ok {
continue
}
if err := h.PreRun(ctx, opts); err != nil {
return xerrors.Errorf("%s pre run error: %w", e.Name(), err)
}
}
return nil
}
// PostRun is a hook that is called after all the processes.
func PostRun(ctx context.Context, opts flag.Options) error {
for _, e := range Hooks() {
h, ok := e.(RunHook)
if !ok {
continue
}
if err := h.PostRun(ctx, opts); err != nil {
return xerrors.Errorf("%s post run error: %w", e.Name(), err)
}
}
return nil
}
// PreScan is a hook that is called before the scan.
func PreScan(ctx context.Context, target *types.ScanTarget, options types.ScanOptions) error {
for _, e := range Hooks() {
h, ok := e.(ScanHook)
if !ok {
continue
}
if err := h.PreScan(ctx, target, options); err != nil {
return xerrors.Errorf("%s pre scan error: %w", e.Name(), err)
}
}
return nil
}
// PostScan is a hook that is called after the scan.
func PostScan(ctx context.Context, results types.Results) (types.Results, error) {
var err error
for _, e := range Hooks() {
h, ok := e.(ScanHook)
if !ok {
continue
}
results, err = h.PostScan(ctx, results)
if err != nil {
return nil, xerrors.Errorf("%s post scan error: %w", e.Name(), err)
}
}
return results, nil
}
// PreReport is a hook that is called before the report is written.
func PreReport(ctx context.Context, report *types.Report, opts flag.Options) error {
for _, e := range Hooks() {
h, ok := e.(ReportHook)
if !ok {
continue
}
if err := h.PreReport(ctx, report, opts); err != nil {
return xerrors.Errorf("%s pre report error: %w", e.Name(), err)
}
}
return nil
}
// PostReport is a hook that is called after the report is written.
func PostReport(ctx context.Context, report *types.Report, opts flag.Options) error {
for _, e := range Hooks() {
h, ok := e.(ReportHook)
if !ok {
continue
}
if err := h.PostReport(ctx, report, opts); err != nil {
return xerrors.Errorf("%s post report error: %w", e.Name(), err)
}
}
return nil
}
// Hooks returns the list of hooks.
func Hooks() []Hook {
hooks := lo.Values(hooks)
sort.Slice(hooks, func(i, j int) bool {
return hooks[i].Name() < hooks[j].Name()
})
return hooks
}

278
pkg/extension/hook_test.go Normal file
View File

@@ -0,0 +1,278 @@
package extension_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
"github.com/aquasecurity/trivy/internal/hooktest"
"github.com/aquasecurity/trivy/pkg/extension"
"github.com/aquasecurity/trivy/pkg/flag"
"github.com/aquasecurity/trivy/pkg/types"
)
func TestPostScan(t *testing.T) {
tests := []struct {
name string
results types.Results
want types.Results
wantErr bool
}{
{
name: "happy path",
results: types.Results{
{
Target: "test",
Vulnerabilities: []types.DetectedVulnerability{
{
VulnerabilityID: "CVE-2022-0001",
PkgName: "musl",
InstalledVersion: "1.2.3",
FixedVersion: "1.2.4",
Vulnerability: dbTypes.Vulnerability{
Severity: "CRITICAL",
},
},
},
},
},
want: types.Results{
{
Target: "test",
Vulnerabilities: []types.DetectedVulnerability{
{
VulnerabilityID: "CVE-2022-0001",
PkgName: "musl",
InstalledVersion: "1.2.3",
FixedVersion: "1.2.4",
Vulnerability: dbTypes.Vulnerability{
Severity: "CRITICAL",
References: []string{
"https://example.com/post-scan",
},
},
},
},
},
},
},
{
name: "sad path",
results: types.Results{
{
Target: "bad",
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Initialize the test hook
hooktest.Init(t)
results, err := extension.PostScan(t.Context(), tt.results)
require.Equal(t, tt.wantErr, err != nil)
assert.Equal(t, tt.want, results)
})
}
}
func TestPreScan(t *testing.T) {
tests := []struct {
name string
target *types.ScanTarget
options types.ScanOptions
wantErr bool
}{
{
name: "happy path",
target: &types.ScanTarget{
Name: "test",
},
wantErr: false,
},
{
name: "sad path",
target: &types.ScanTarget{
Name: "bad-pre",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Initialize the test hook
hooktest.Init(t)
err := extension.PreScan(t.Context(), tt.target, tt.options)
require.Equal(t, tt.wantErr, err != nil)
})
}
}
func TestPreRun(t *testing.T) {
tests := []struct {
name string
opts flag.Options
wantErr bool
}{
{
name: "happy path",
opts: flag.Options{},
wantErr: false,
},
{
name: "sad path",
opts: flag.Options{
GlobalOptions: flag.GlobalOptions{
ConfigFile: "bad-config",
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Initialize the test hook
hooktest.Init(t)
err := extension.PreRun(t.Context(), tt.opts)
require.Equal(t, tt.wantErr, err != nil)
})
}
}
func TestPostRun(t *testing.T) {
tests := []struct {
name string
opts flag.Options
wantErr bool
}{
{
name: "happy path",
opts: flag.Options{},
wantErr: false,
},
{
name: "sad path",
opts: flag.Options{
GlobalOptions: flag.GlobalOptions{
ConfigFile: "bad-config",
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Initialize the test extension
hooktest.Init(t)
err := extension.PostRun(t.Context(), tt.opts)
require.Equal(t, tt.wantErr, err != nil)
})
}
}
func TestPreReport(t *testing.T) {
tests := []struct {
name string
report *types.Report
opts flag.Options
wantTitle string
wantErr bool
}{
{
name: "happy path",
report: &types.Report{
Results: types.Results{
{
Vulnerabilities: []types.DetectedVulnerability{
{
VulnerabilityID: "CVE-2022-0001",
},
},
},
},
},
wantTitle: "Modified by pre-report hook",
wantErr: false,
},
{
name: "sad path",
report: &types.Report{
ArtifactName: "bad-report",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Initialize the test hook
hooktest.Init(t)
err := extension.PreReport(t.Context(), tt.report, tt.opts)
if tt.wantErr {
require.Error(t, err)
return
}
require.Len(t, tt.report.Results, 1)
require.Len(t, tt.report.Results[0].Vulnerabilities, 1)
assert.Equal(t, tt.wantTitle, tt.report.Results[0].Vulnerabilities[0].Title)
})
}
}
func TestPostReport(t *testing.T) {
tests := []struct {
name string
report *types.Report
opts flag.Options
wantDescription string
wantErr bool
}{
{
name: "happy path",
report: &types.Report{
Results: types.Results{
{
Vulnerabilities: []types.DetectedVulnerability{
{
VulnerabilityID: "CVE-2022-0001",
},
},
},
},
},
wantDescription: "Modified by post-report hook",
wantErr: false,
},
{
name: "sad path",
report: &types.Report{
ArtifactName: "bad-report",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Initialize the test hook
hooktest.Init(t)
err := extension.PostReport(t.Context(), tt.report, tt.opts)
if tt.wantErr {
require.Error(t, err)
return
}
require.Len(t, tt.report.Results, 1)
require.Len(t, tt.report.Results[0].Vulnerabilities, 1)
assert.Equal(t, tt.wantDescription, tt.report.Results[0].Vulnerabilities[0].Description)
})
}
}