mirror of
https://github.com/aquasecurity/trivy.git
synced 2025-12-12 15:50:15 -08:00
refactor: add hook interface for extended functionality (#8585)
This commit is contained in:
162
pkg/extension/hook.go
Normal file
162
pkg/extension/hook.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package extension
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
|
||||
"github.com/samber/lo"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/aquasecurity/trivy/pkg/flag"
|
||||
"github.com/aquasecurity/trivy/pkg/types"
|
||||
)
|
||||
|
||||
var hooks = make(map[string]Hook)
|
||||
|
||||
func RegisterHook(s Hook) {
|
||||
// Avoid duplication
|
||||
hooks[s.Name()] = s
|
||||
}
|
||||
|
||||
func DeregisterHook(name string) {
|
||||
delete(hooks, name)
|
||||
}
|
||||
|
||||
// Hook is an interface that defines the methods for a hook.
|
||||
type Hook interface {
|
||||
// Name returns the name of the extension.
|
||||
Name() string
|
||||
}
|
||||
|
||||
// RunHook is a extension that is called before and after all the processes.
|
||||
type RunHook interface {
|
||||
Hook
|
||||
|
||||
// PreRun is called before all the processes.
|
||||
PreRun(ctx context.Context, opts flag.Options) error
|
||||
|
||||
// PostRun is called after all the processes.
|
||||
PostRun(ctx context.Context, opts flag.Options) error
|
||||
}
|
||||
|
||||
// ScanHook is a extension that is called before and after the scan.
|
||||
type ScanHook interface {
|
||||
Hook
|
||||
|
||||
// PreScan is called before the scan. It can modify the scan target.
|
||||
// It may be called on the server side in client/server mode.
|
||||
PreScan(ctx context.Context, target *types.ScanTarget, opts types.ScanOptions) error
|
||||
|
||||
// PostScan is called after the scan. It can modify the results.
|
||||
// It may be called on the server side in client/server mode.
|
||||
// NOTE: Wasm modules cannot directly modify the passed results,
|
||||
// so it returns a copy of the results.
|
||||
PostScan(ctx context.Context, results types.Results) (types.Results, error)
|
||||
}
|
||||
|
||||
// ReportHook is a extension that is called before and after the report is written.
|
||||
type ReportHook interface {
|
||||
Hook
|
||||
|
||||
// PreReport is called before the report is written.
|
||||
// It can modify the report. It is called on the client side.
|
||||
PreReport(ctx context.Context, report *types.Report, opts flag.Options) error
|
||||
|
||||
// PostReport is called after the report is written.
|
||||
// It can modify the report. It is called on the client side.
|
||||
PostReport(ctx context.Context, report *types.Report, opts flag.Options) error
|
||||
}
|
||||
|
||||
func PreRun(ctx context.Context, opts flag.Options) error {
|
||||
for _, e := range Hooks() {
|
||||
h, ok := e.(RunHook)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := h.PreRun(ctx, opts); err != nil {
|
||||
return xerrors.Errorf("%s pre run error: %w", e.Name(), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PostRun is a hook that is called after all the processes.
|
||||
func PostRun(ctx context.Context, opts flag.Options) error {
|
||||
for _, e := range Hooks() {
|
||||
h, ok := e.(RunHook)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := h.PostRun(ctx, opts); err != nil {
|
||||
return xerrors.Errorf("%s post run error: %w", e.Name(), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PreScan is a hook that is called before the scan.
|
||||
func PreScan(ctx context.Context, target *types.ScanTarget, options types.ScanOptions) error {
|
||||
for _, e := range Hooks() {
|
||||
h, ok := e.(ScanHook)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := h.PreScan(ctx, target, options); err != nil {
|
||||
return xerrors.Errorf("%s pre scan error: %w", e.Name(), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PostScan is a hook that is called after the scan.
|
||||
func PostScan(ctx context.Context, results types.Results) (types.Results, error) {
|
||||
var err error
|
||||
for _, e := range Hooks() {
|
||||
h, ok := e.(ScanHook)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
results, err = h.PostScan(ctx, results)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("%s post scan error: %w", e.Name(), err)
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// PreReport is a hook that is called before the report is written.
|
||||
func PreReport(ctx context.Context, report *types.Report, opts flag.Options) error {
|
||||
for _, e := range Hooks() {
|
||||
h, ok := e.(ReportHook)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := h.PreReport(ctx, report, opts); err != nil {
|
||||
return xerrors.Errorf("%s pre report error: %w", e.Name(), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PostReport is a hook that is called after the report is written.
|
||||
func PostReport(ctx context.Context, report *types.Report, opts flag.Options) error {
|
||||
for _, e := range Hooks() {
|
||||
h, ok := e.(ReportHook)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := h.PostReport(ctx, report, opts); err != nil {
|
||||
return xerrors.Errorf("%s post report error: %w", e.Name(), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Hooks returns the list of hooks.
|
||||
func Hooks() []Hook {
|
||||
hooks := lo.Values(hooks)
|
||||
sort.Slice(hooks, func(i, j int) bool {
|
||||
return hooks[i].Name() < hooks[j].Name()
|
||||
})
|
||||
return hooks
|
||||
}
|
||||
278
pkg/extension/hook_test.go
Normal file
278
pkg/extension/hook_test.go
Normal file
@@ -0,0 +1,278 @@
|
||||
package extension_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
|
||||
"github.com/aquasecurity/trivy/internal/hooktest"
|
||||
"github.com/aquasecurity/trivy/pkg/extension"
|
||||
"github.com/aquasecurity/trivy/pkg/flag"
|
||||
"github.com/aquasecurity/trivy/pkg/types"
|
||||
)
|
||||
|
||||
func TestPostScan(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
results types.Results
|
||||
want types.Results
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "happy path",
|
||||
results: types.Results{
|
||||
{
|
||||
Target: "test",
|
||||
Vulnerabilities: []types.DetectedVulnerability{
|
||||
{
|
||||
VulnerabilityID: "CVE-2022-0001",
|
||||
PkgName: "musl",
|
||||
InstalledVersion: "1.2.3",
|
||||
FixedVersion: "1.2.4",
|
||||
Vulnerability: dbTypes.Vulnerability{
|
||||
Severity: "CRITICAL",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: types.Results{
|
||||
{
|
||||
Target: "test",
|
||||
Vulnerabilities: []types.DetectedVulnerability{
|
||||
{
|
||||
VulnerabilityID: "CVE-2022-0001",
|
||||
PkgName: "musl",
|
||||
InstalledVersion: "1.2.3",
|
||||
FixedVersion: "1.2.4",
|
||||
Vulnerability: dbTypes.Vulnerability{
|
||||
Severity: "CRITICAL",
|
||||
References: []string{
|
||||
"https://example.com/post-scan",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sad path",
|
||||
results: types.Results{
|
||||
{
|
||||
Target: "bad",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Initialize the test hook
|
||||
hooktest.Init(t)
|
||||
|
||||
results, err := extension.PostScan(t.Context(), tt.results)
|
||||
require.Equal(t, tt.wantErr, err != nil)
|
||||
assert.Equal(t, tt.want, results)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreScan(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
target *types.ScanTarget
|
||||
options types.ScanOptions
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "happy path",
|
||||
target: &types.ScanTarget{
|
||||
Name: "test",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "sad path",
|
||||
target: &types.ScanTarget{
|
||||
Name: "bad-pre",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Initialize the test hook
|
||||
hooktest.Init(t)
|
||||
|
||||
err := extension.PreScan(t.Context(), tt.target, tt.options)
|
||||
require.Equal(t, tt.wantErr, err != nil)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreRun(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
opts flag.Options
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "happy path",
|
||||
opts: flag.Options{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "sad path",
|
||||
opts: flag.Options{
|
||||
GlobalOptions: flag.GlobalOptions{
|
||||
ConfigFile: "bad-config",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Initialize the test hook
|
||||
hooktest.Init(t)
|
||||
|
||||
err := extension.PreRun(t.Context(), tt.opts)
|
||||
require.Equal(t, tt.wantErr, err != nil)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostRun(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
opts flag.Options
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "happy path",
|
||||
opts: flag.Options{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "sad path",
|
||||
opts: flag.Options{
|
||||
GlobalOptions: flag.GlobalOptions{
|
||||
ConfigFile: "bad-config",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Initialize the test extension
|
||||
hooktest.Init(t)
|
||||
|
||||
err := extension.PostRun(t.Context(), tt.opts)
|
||||
require.Equal(t, tt.wantErr, err != nil)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreReport(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
report *types.Report
|
||||
opts flag.Options
|
||||
wantTitle string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "happy path",
|
||||
report: &types.Report{
|
||||
Results: types.Results{
|
||||
{
|
||||
Vulnerabilities: []types.DetectedVulnerability{
|
||||
{
|
||||
VulnerabilityID: "CVE-2022-0001",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantTitle: "Modified by pre-report hook",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "sad path",
|
||||
report: &types.Report{
|
||||
ArtifactName: "bad-report",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Initialize the test hook
|
||||
hooktest.Init(t)
|
||||
|
||||
err := extension.PreReport(t.Context(), tt.report, tt.opts)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.Len(t, tt.report.Results, 1)
|
||||
require.Len(t, tt.report.Results[0].Vulnerabilities, 1)
|
||||
assert.Equal(t, tt.wantTitle, tt.report.Results[0].Vulnerabilities[0].Title)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostReport(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
report *types.Report
|
||||
opts flag.Options
|
||||
wantDescription string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "happy path",
|
||||
report: &types.Report{
|
||||
Results: types.Results{
|
||||
{
|
||||
Vulnerabilities: []types.DetectedVulnerability{
|
||||
{
|
||||
VulnerabilityID: "CVE-2022-0001",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantDescription: "Modified by post-report hook",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "sad path",
|
||||
report: &types.Report{
|
||||
ArtifactName: "bad-report",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Initialize the test hook
|
||||
hooktest.Init(t)
|
||||
|
||||
err := extension.PostReport(t.Context(), tt.report, tt.opts)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.Len(t, tt.report.Results, 1)
|
||||
require.Len(t, tt.report.Results[0].Vulnerabilities, 1)
|
||||
assert.Equal(t, tt.wantDescription, tt.report.Results[0].Vulnerabilities[0].Description)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user