fix(cli): inconsistent behavior across CLI flags, environment variables, and config files (#5843)

Signed-off-by: knqyf263 <knqyf263@gmail.com>
This commit is contained in:
Teppei Fukuda
2024-02-01 07:25:30 +04:00
committed by GitHub
parent 5924c021da
commit 59e54334d1
43 changed files with 1736 additions and 1447 deletions

View File

@@ -8,13 +8,14 @@ import (
"testing"
"time"
awscommands "github.com/aquasecurity/trivy/pkg/cloud/aws/commands"
"github.com/aquasecurity/trivy/pkg/flag"
dockercontainer "github.com/docker/docker/api/types/container"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
testcontainers "github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/localstack"
awscommands "github.com/aquasecurity/trivy/pkg/cloud/aws/commands"
"github.com/aquasecurity/trivy/pkg/flag"
)
func TestAwsCommandRun(t *testing.T) {

View File

@@ -5,6 +5,7 @@ package integration
import (
"context"
"fmt"
"github.com/aquasecurity/trivy/pkg/types"
"os"
"path/filepath"
"strings"
@@ -15,16 +16,15 @@ import (
"github.com/docker/go-connections/nat"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
testcontainers "github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go"
"github.com/aquasecurity/trivy/pkg/report"
"github.com/aquasecurity/trivy/pkg/uuid"
)
type csArgs struct {
Command string
RemoteAddrOption string
Format string
Format types.Format
TemplatePath string
IgnoreUnfixed bool
Severity []string
@@ -265,19 +265,15 @@ func TestClientServer(t *testing.T) {
addr, cacheDir := setup(t, setupOptions{})
for _, c := range tests {
t.Run(c.name, func(t *testing.T) {
osArgs, outputFile := setupClient(t, c.args, addr, cacheDir, c.golden)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
osArgs := setupClient(t, tt.args, addr, cacheDir, tt.golden)
if c.args.secretConfig != "" {
osArgs = append(osArgs, "--secret-config", c.args.secretConfig)
if tt.args.secretConfig != "" {
osArgs = append(osArgs, "--secret-config", tt.args.secretConfig)
}
//
err := execute(osArgs)
require.NoError(t, err)
compareReports(t, c.golden, outputFile, nil)
runTest(t, osArgs, tt.golden, "", types.FormatJSON, runOptions{})
})
}
}
@@ -389,19 +385,9 @@ func TestClientServerWithFormat(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("AWS_REGION", "test-region")
t.Setenv("AWS_ACCOUNT_ID", "123456789012")
osArgs, outputFile := setupClient(t, tt.args, addr, cacheDir, tt.golden)
osArgs := setupClient(t, tt.args, addr, cacheDir, tt.golden)
// Run Trivy client
err := execute(osArgs)
require.NoError(t, err)
want, err := os.ReadFile(tt.golden)
require.NoError(t, err)
got, err := os.ReadFile(outputFile)
require.NoError(t, err)
assert.EqualValues(t, string(want), string(got))
runTest(t, osArgs, tt.golden, "", tt.args.Format, runOptions{})
})
}
}
@@ -425,21 +411,16 @@ func TestClientServerWithCycloneDX(t *testing.T) {
addr, cacheDir := setup(t, setupOptions{})
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
uuid.SetFakeUUID(t, "3ff14136-e09f-4df9-80ea-%012d")
osArgs, outputFile := setupClient(t, tt.args, addr, cacheDir, tt.golden)
// Run Trivy client
err := execute(osArgs)
require.NoError(t, err)
compareCycloneDX(t, tt.golden, outputFile)
osArgs := setupClient(t, tt.args, addr, cacheDir, tt.golden)
runTest(t, osArgs, tt.golden, "", types.FormatCycloneDX, runOptions{
fakeUUID: "3ff14136-e09f-4df9-80ea-%012d",
})
})
}
}
func TestClientServerWithToken(t *testing.T) {
cases := []struct {
tests := []struct {
name string
args csArgs
golden string
@@ -481,20 +462,10 @@ func TestClientServerWithToken(t *testing.T) {
tokenHeader: serverTokenHeader,
})
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
osArgs, outputFile := setupClient(t, c.args, addr, cacheDir, c.golden)
// Run Trivy client
err := execute(osArgs)
if c.wantErr != "" {
require.Error(t, err, c.name)
assert.Contains(t, err.Error(), c.wantErr, c.name)
return
}
require.NoError(t, err, c.name)
compareReports(t, c.golden, outputFile, nil)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
osArgs := setupClient(t, tt.args, addr, cacheDir, tt.golden)
runTest(t, osArgs, tt.golden, "", types.FormatJSON, runOptions{wantErr: tt.wantErr})
})
}
}
@@ -517,25 +488,22 @@ func TestClientServerWithRedis(t *testing.T) {
golden := "testdata/alpine-39.json.golden"
t.Run("alpine 3.9", func(t *testing.T) {
osArgs, outputFile := setupClient(t, testArgs, addr, cacheDir, golden)
osArgs := setupClient(t, testArgs, addr, cacheDir, golden)
// Run Trivy client
err := execute(osArgs)
require.NoError(t, err)
compareReports(t, golden, outputFile, nil)
runTest(t, osArgs, golden, "", types.FormatJSON, runOptions{})
})
// Terminate the Redis container
require.NoError(t, redisC.Terminate(ctx))
t.Run("sad path", func(t *testing.T) {
osArgs, _ := setupClient(t, testArgs, addr, cacheDir, golden)
osArgs := setupClient(t, testArgs, addr, cacheDir, golden)
// Run Trivy client
err := execute(osArgs)
require.Error(t, err)
assert.Contains(t, err.Error(), "unable to store cache")
runTest(t, osArgs, "", "", types.FormatJSON, runOptions{
wantErr: "unable to store cache",
})
})
}
@@ -595,7 +563,7 @@ func setupServer(addr, token, tokenHeader, cacheDir, cacheBackend string) []stri
return osArgs
}
func setupClient(t *testing.T, c csArgs, addr string, cacheDir string, golden string) ([]string, string) {
func setupClient(t *testing.T, c csArgs, addr string, cacheDir string, golden string) []string {
if c.Command == "" {
c.Command = "image"
}
@@ -612,7 +580,7 @@ func setupClient(t *testing.T, c csArgs, addr string, cacheDir string, golden st
}
if c.Format != "" {
osArgs = append(osArgs, "--format", c.Format)
osArgs = append(osArgs, "--format", string(c.Format))
if c.TemplatePath != "" {
osArgs = append(osArgs, "--template", c.TemplatePath)
}
@@ -642,19 +610,11 @@ func setupClient(t *testing.T, c csArgs, addr string, cacheDir string, golden st
osArgs = append(osArgs, "--input", c.Input)
}
// Set up the output file
outputFile := filepath.Join(t.TempDir(), "output.json")
if *update {
outputFile = golden
}
osArgs = append(osArgs, "--output", outputFile)
if c.Target != "" {
osArgs = append(osArgs, c.Target)
}
return osArgs, outputFile
return osArgs
}
func setupRedis(t *testing.T, ctx context.Context) (testcontainers.Container, string) {

230
integration/config_test.go Normal file
View File

@@ -0,0 +1,230 @@
//go:build integration
package integration
import (
"fmt"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"github.com/aquasecurity/trivy/pkg/types"
)
// TestConfiguration tests the configuration of the CLI flags, environmental variables, and config file
func TestConfiguration(t *testing.T) {
type args struct {
input string
flags map[string]string
envs map[string]string
configFile string
}
type test struct {
name string
args args
golden string
wantErr string
}
tests := []test{
{
name: "skip files",
args: args{
input: "testdata/fixtures/repo/gomod",
flags: map[string]string{
"scanners": "vuln",
"skip-files": "path/to/dummy,testdata/fixtures/repo/gomod/submod2/go.mod",
},
envs: map[string]string{
"TRIVY_SCANNERS": "vuln",
"TRIVY_SKIP_FILES": "path/to/dummy,testdata/fixtures/repo/gomod/submod2/go.mod",
},
configFile: `---
scan:
scanners:
- vuln
skip-files:
- path/to/dummy
- testdata/fixtures/repo/gomod/submod2/go.mod
`,
},
golden: "testdata/gomod-skip.json.golden",
},
{
name: "dockerfile with custom file pattern",
args: args{
input: "testdata/fixtures/repo/dockerfile_file_pattern",
flags: map[string]string{
"scanners": "misconfig",
"file-patterns": "dockerfile:Customfile",
"namespaces": "testing",
},
envs: map[string]string{
"TRIVY_SCANNERS": "misconfig",
"TRIVY_FILE_PATTERNS": "dockerfile:Customfile",
"TRIVY_NAMESPACES": "testing",
},
configFile: `---
scan:
scanners:
- misconfig
file-patterns:
- dockerfile:Customfile
rego:
skip-policy-update: true
namespaces:
- testing
`,
},
golden: "testdata/dockerfile_file_pattern.json.golden",
},
{
name: "key alias", // "--scanners" vs "--security-checks"
args: args{
input: "testdata/fixtures/repo/gomod",
flags: map[string]string{
"security-checks": "vuln",
},
envs: map[string]string{
"TRIVY_SECURITY_CHECKS": "vuln",
},
configFile: `---
scan:
security-checks:
- vuln
`,
},
golden: "testdata/gomod.json.golden",
},
{
name: "value alias", // "--scanners vuln" vs "--scanners vulnerability"
args: args{
input: "testdata/fixtures/repo/gomod",
flags: map[string]string{
"scanners": "vulnerability",
},
envs: map[string]string{
"TRIVY_SCANNERS": "vulnerability",
},
configFile: `---
scan:
scanners:
- vulnerability
`,
},
golden: "testdata/gomod.json.golden",
},
{
name: "invalid value",
args: args{
input: "testdata/fixtures/repo/gomod",
flags: map[string]string{
"scanners": "vulnerability",
"severity": "CRITICAL,INVALID",
},
envs: map[string]string{
"TRIVY_SCANNERS": "vulnerability",
"TRIVY_SEVERITY": "CRITICAL,INVALID",
},
configFile: `---
scan:
scanners:
- vulnerability
severity:
- CRITICAL
- INVALID
`,
},
wantErr: `invalid argument "[CRITICAL INVALID]" for "--severity" flag`,
},
}
// Set up testing DB
cacheDir := initDB(t)
// Set a temp dir so that modules will not be loaded
t.Setenv("XDG_DATA_HOME", cacheDir)
for _, tt := range tests {
command := "repo"
t.Run(tt.name+" with CLI flags", func(t *testing.T) {
osArgs := []string{
"--format",
"json",
"--cache-dir",
cacheDir,
"--skip-db-update",
"--skip-policy-update",
command,
tt.args.input,
}
for key, value := range tt.args.flags {
osArgs = append(osArgs, "--"+key, value)
}
// Set up the output file
outputFile := filepath.Join(t.TempDir(), "output.json")
osArgs = append(osArgs, "--output", outputFile)
runTest(t, osArgs, tt.golden, outputFile, types.FormatJSON, runOptions{
wantErr: tt.wantErr,
})
})
t.Run(tt.name+" with environmental variables", func(t *testing.T) {
// Set up the output file
outputFile := filepath.Join(t.TempDir(), "output.json")
t.Setenv("TRIVY_OUTPUT", outputFile)
t.Setenv("TRIVY_FORMAT", "json")
t.Setenv("TRIVY_CACHE_DIR", cacheDir)
t.Setenv("TRIVY_SKIP_DB_UPDATE", "true")
t.Setenv("TRIVY_SKIP_POLICY_UPDATE", "true")
for key, value := range tt.args.envs {
t.Setenv(key, value)
}
osArgs := []string{
command,
tt.args.input,
}
runTest(t, osArgs, tt.golden, outputFile, types.FormatJSON, runOptions{
wantErr: tt.wantErr,
})
})
t.Run(tt.name+" with config file", func(t *testing.T) {
// Set up the output file
outputFile := filepath.Join(t.TempDir(), "output.json")
configFile := tt.args.configFile
configFile = configFile + fmt.Sprintf(`
format: json
output: %s
cache:
dir: %s
db:
skip-update: true
`, outputFile, cacheDir)
configPath := filepath.Join(t.TempDir(), "trivy.yaml")
err := os.WriteFile(configPath, []byte(configFile), 0444)
require.NoError(t, err)
osArgs := []string{
command,
"--config",
configPath,
tt.args.input,
}
runTest(t, osArgs, tt.golden, outputFile, types.FormatJSON, runOptions{
wantErr: tt.wantErr,
})
})
}
}

View File

@@ -5,9 +5,9 @@ package integration
import (
"context"
"github.com/aquasecurity/trivy/pkg/types"
"io"
"os"
"path/filepath"
"strings"
"testing"
@@ -40,18 +40,24 @@ func TestDockerEngine(t *testing.T) {
golden: "testdata/alpine-39.json.golden",
},
{
name: "alpine:3.9, with high and critical severity",
severity: []string{"HIGH", "CRITICAL"},
name: "alpine:3.9, with high and critical severity",
severity: []string{
"HIGH",
"CRITICAL",
},
imageTag: "ghcr.io/aquasecurity/trivy-test-images:alpine-39",
input: "testdata/fixtures/images/alpine-39.tar.gz",
golden: "testdata/alpine-39-high-critical.json.golden",
},
{
name: "alpine:3.9, with .trivyignore",
imageTag: "ghcr.io/aquasecurity/trivy-test-images:alpine-39",
ignoreIDs: []string{"CVE-2019-1549", "CVE-2019-14697"},
input: "testdata/fixtures/images/alpine-39.tar.gz",
golden: "testdata/alpine-39-ignore-cveids.json.golden",
name: "alpine:3.9, with .trivyignore",
imageTag: "ghcr.io/aquasecurity/trivy-test-images:alpine-39",
ignoreIDs: []string{
"CVE-2019-1549",
"CVE-2019-14697",
},
input: "testdata/fixtures/images/alpine-39.tar.gz",
golden: "testdata/alpine-39-ignore-cveids.json.golden",
},
{
name: "alpine:3.10",
@@ -244,13 +250,28 @@ func TestDockerEngine(t *testing.T) {
// tag our image to something unique
err = cli.ImageTag(ctx, tt.imageTag, tt.input)
require.NoError(t, err, tt.name)
// cleanup
t.Cleanup(func() {
_, err = cli.ImageRemove(ctx, tt.input, api.ImageRemoveOptions{
Force: true,
PruneChildren: true,
})
_, err = cli.ImageRemove(ctx, tt.imageTag, api.ImageRemoveOptions{
Force: true,
PruneChildren: true,
})
assert.NoError(t, err, tt.name)
})
}
tmpDir := t.TempDir()
output := filepath.Join(tmpDir, "result.json")
osArgs := []string{"--cache-dir", cacheDir, "image",
"--skip-update", "--format=json", "--output", output}
osArgs := []string{
"--cache-dir",
cacheDir,
"image",
"--skip-update",
"--format=json",
}
if tt.ignoreUnfixed {
osArgs = append(osArgs, "--ignore-unfixed")
@@ -258,12 +279,18 @@ func TestDockerEngine(t *testing.T) {
if len(tt.ignoreStatus) != 0 {
osArgs = append(osArgs,
[]string{"--ignore-status", strings.Join(tt.ignoreStatus, ",")}...,
[]string{
"--ignore-status",
strings.Join(tt.ignoreStatus, ","),
}...,
)
}
if len(tt.severity) != 0 {
osArgs = append(osArgs,
[]string{"--severity", strings.Join(tt.severity, ",")}...,
[]string{
"--severity",
strings.Join(tt.severity, ","),
}...,
)
}
if len(tt.ignoreIDs) != 0 {
@@ -275,28 +302,7 @@ func TestDockerEngine(t *testing.T) {
osArgs = append(osArgs, tt.input)
// Run Trivy
err = execute(osArgs)
if tt.wantErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErr, tt.name)
return
}
assert.NoError(t, err, tt.name)
// check for vulnerability output info
compareReports(t, tt.golden, output, nil)
// cleanup
_, err = cli.ImageRemove(ctx, tt.input, api.ImageRemoveOptions{
Force: true,
PruneChildren: true,
})
_, err = cli.ImageRemove(ctx, tt.imageTag, api.ImageRemoveOptions{
Force: true,
PruneChildren: true,
})
assert.NoError(t, err, tt.name)
runTest(t, osArgs, tt.golden, "", types.FormatJSON, runOptions{wantErr: tt.wantErr})
})
}
}

View File

@@ -7,7 +7,6 @@ import (
"encoding/json"
"flag"
"fmt"
"github.com/aquasecurity/trivy/pkg/clock"
"io"
"net"
"os"
@@ -22,15 +21,18 @@ import (
spdxjson "github.com/spdx/tools-golang/json"
"github.com/spdx/tools-golang/spdx"
"github.com/spdx/tools-golang/spdxlib"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/xeipuuv/gojsonschema"
"github.com/aquasecurity/trivy-db/pkg/db"
"github.com/aquasecurity/trivy-db/pkg/metadata"
"github.com/aquasecurity/trivy/pkg/clock"
"github.com/aquasecurity/trivy/pkg/commands"
"github.com/aquasecurity/trivy/pkg/dbtest"
"github.com/aquasecurity/trivy/pkg/types"
"github.com/aquasecurity/trivy/pkg/uuid"
_ "modernc.org/sqlite"
)
@@ -190,7 +192,56 @@ func readSpdxJson(t *testing.T, filePath string) *spdx.Document {
return bom
}
type runOptions struct {
wantErr string
override func(want, got *types.Report)
fakeUUID string
}
// runTest runs Trivy with the given args and compares the output with the golden file.
// If outputFile is empty, the output file is created in a temporary directory.
// If update is true, the golden file is updated.
func runTest(t *testing.T, osArgs []string, wantFile, outputFile string, format types.Format, opts runOptions) {
if opts.fakeUUID != "" {
uuid.SetFakeUUID(t, opts.fakeUUID)
}
if outputFile == "" {
// Set up the output file
outputFile = filepath.Join(t.TempDir(), "output.json")
if *update && opts.override == nil {
outputFile = wantFile
}
}
osArgs = append(osArgs, "--output", outputFile)
// Run Trivy
err := execute(osArgs)
if opts.wantErr != "" {
require.ErrorContains(t, err, opts.wantErr)
return
}
require.NoError(t, err)
// Compare want and got
switch format {
case types.FormatCycloneDX:
compareCycloneDX(t, wantFile, outputFile)
case types.FormatSPDXJSON:
compareSPDXJson(t, wantFile, outputFile)
case types.FormatJSON:
compareReports(t, wantFile, outputFile, opts.override)
case types.FormatTemplate, types.FormatSarif, types.FormatGitHub:
compareRawFiles(t, wantFile, outputFile)
default:
require.Fail(t, "invalid format", "format: %s", format)
}
}
func execute(osArgs []string) error {
// viper.XXX() (e.g. viper.ReadInConfig()) affects the global state, so we need to reset it after each test.
defer viper.Reset()
// Set a fake time
ctx := clock.With(context.Background(), time.Date(2021, 8, 25, 12, 20, 30, 5, time.UTC))
@@ -203,11 +254,19 @@ func execute(osArgs []string) error {
return app.ExecuteContext(ctx)
}
func compareReports(t *testing.T, wantFile, gotFile string, override func(*types.Report)) {
func compareRawFiles(t *testing.T, wantFile, gotFile string) {
want, err := os.ReadFile(wantFile)
require.NoError(t, err)
got, err := os.ReadFile(gotFile)
require.NoError(t, err)
assert.EqualValues(t, string(want), string(got))
}
func compareReports(t *testing.T, wantFile, gotFile string, override func(want, got *types.Report)) {
want := readReport(t, wantFile)
got := readReport(t, gotFile)
if override != nil {
override(&want)
override(&want, &got)
}
assert.Equal(t, want, got)
}

View File

@@ -3,11 +3,10 @@
package integration
import (
"github.com/aquasecurity/trivy/pkg/types"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
"github.com/aquasecurity/trivy/pkg/scanner/post"
)
@@ -51,27 +50,13 @@ func TestModule(t *testing.T) {
tt.input,
}
// Set up the output file
outputFile := filepath.Join(t.TempDir(), "output.json")
if *update {
outputFile = tt.golden
}
osArgs = append(osArgs, []string{
"--output",
outputFile,
}...)
// Run Trivy
err := execute(osArgs)
require.NoError(t, err)
defer func() {
t.Cleanup(func() {
analyzer.DeregisterAnalyzer("spring4shell")
post.DeregisterPostScanner("spring4shell")
}()
})
// Compare want and got
compareReports(t, tt.golden, outputFile, nil)
// Run Trivy
runTest(t, osArgs, tt.golden, "", types.FormatJSON, runOptions{})
})
}
}

View File

@@ -1,5 +1,4 @@
//go:build integration
// +build integration
package integration
@@ -11,6 +10,7 @@ import (
"crypto/x509"
"encoding/json"
"fmt"
"github.com/aquasecurity/trivy/pkg/types"
"io"
"net/http"
"net/url"
@@ -24,9 +24,8 @@ import (
"github.com/google/go-containerregistry/pkg/name"
"github.com/google/go-containerregistry/pkg/v1/remote"
"github.com/google/go-containerregistry/pkg/v1/tarball"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
testcontainers "github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
)
@@ -62,7 +61,10 @@ func setupRegistry(ctx context.Context, baseDir string, authURL *url.URL) (testc
HostConfigModifier: func(hostConfig *dockercontainer.HostConfig) {
hostConfig.AutoRemove = true
},
WaitingFor: wait.ForLog("listening on [::]:5443"),
WaitingFor: wait.ForHTTP("v2").WithTLS(true).WithAllowInsecure(true).
WithStatusCodeMatcher(func(status int) bool {
return status == http.StatusUnauthorized
}),
}
registryC, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
@@ -191,62 +193,50 @@ func TestRegistry(t *testing.T) {
imageRef, err := name.ParseReference(s)
require.NoError(t, err)
// 1. Load a test image from the tar file, tag it and push to the test registry.
// Load a test image from the tar file, tag it and push to the test registry.
err = replicateImage(imageRef, tc.imageFile, auth)
require.NoError(t, err)
// 2. Scan it
resultFile, err := scan(t, imageRef, baseDir, tc.golden, tc.option)
osArgs, err := scan(t, imageRef, baseDir, tc.golden, tc.option)
if tc.wantErr != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tc.wantErr, err)
return
}
require.NoError(t, err)
// 3. Read want and got
want := readReport(t, tc.golden)
got := readReport(t, resultFile)
// 4 Update some dynamic fields
want.ArtifactName = s
for i := range want.Results {
want.Results[i].Target = fmt.Sprintf("%s (alpine 3.10.2)", s)
}
// 5. Compare want and got
assert.Equal(t, want, got)
// Run Trivy
runTest(t, osArgs, tc.golden, "", types.FormatJSON, runOptions{
wantErr: tc.wantErr,
override: func(_, got *types.Report) {
got.ArtifactName = tc.imageName
for i := range got.Results {
got.Results[i].Target = fmt.Sprintf("%s (alpine 3.10.2)", tc.imageName)
}
},
})
})
}
}
func scan(t *testing.T, imageRef name.Reference, baseDir, goldenFile string, opt registryOption) (string, error) {
func scan(t *testing.T, imageRef name.Reference, baseDir, goldenFile string, opt registryOption) ([]string, error) {
// Set up testing DB
cacheDir := initDB(t)
// Set a temp dir so that modules will not be loaded
t.Setenv("XDG_DATA_HOME", cacheDir)
// Setup the output file
outputFile := filepath.Join(t.TempDir(), "output.json")
if *update {
outputFile = goldenFile
}
// Setup env
if err := setupEnv(t, imageRef, baseDir, opt); err != nil {
return "", err
return nil, err
}
osArgs := []string{"-q", "--cache-dir", cacheDir, "image", "--format", "json", "--skip-update",
"--output", outputFile, imageRef.Name()}
// Run Trivy
if err := execute(osArgs); err != nil {
return "", err
osArgs := []string{
"-q",
"--cache-dir",
cacheDir,
"image",
"--format",
"json",
"--skip-update",
imageRef.Name(),
}
return outputFile, nil
return osArgs, nil
}
func setupEnv(t *testing.T, imageRef name.Reference, baseDir string, opt registryOption) error {

View File

@@ -5,15 +5,12 @@ package integration
import (
"fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"os"
"path/filepath"
"strings"
"testing"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/types"
"github.com/aquasecurity/trivy/pkg/uuid"
)
// TestRepository tests `trivy repo` with the local code repositories
@@ -40,7 +37,7 @@ func TestRepository(t *testing.T) {
name string
args args
golden string
override func(*types.Report)
override func(want, got *types.Report)
}{
{
name: "gomod",
@@ -372,8 +369,8 @@ func TestRepository(t *testing.T) {
skipFiles: []string{"testdata/fixtures/repo/gomod/submod2/go.mod"},
},
golden: "testdata/gomod-skip.json.golden",
override: func(report *types.Report) {
report.ArtifactType = ftypes.ArtifactFilesystem
override: func(want, _ *types.Report) {
want.ArtifactType = ftypes.ArtifactFilesystem
},
},
{
@@ -386,8 +383,8 @@ func TestRepository(t *testing.T) {
input: "testdata/fixtures/repo/custom-policy",
},
golden: "testdata/dockerfile-custom-policies.json.golden",
override: func(report *types.Report) {
report.ArtifactType = ftypes.ArtifactFilesystem
override: func(want, got *types.Report) {
want.ArtifactType = ftypes.ArtifactFilesystem
},
},
}
@@ -400,7 +397,6 @@ func TestRepository(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
command := "repo"
if tt.args.command != "" {
command = tt.args.command
@@ -423,6 +419,7 @@ func TestRepository(t *testing.T) {
"--parallel",
fmt.Sprint(tt.args.parallel),
"--offline-scan",
tt.args.input,
}
if tt.args.scanner != "" {
@@ -478,12 +475,6 @@ func TestRepository(t *testing.T) {
}
}
// Setup the output file
outputFile := filepath.Join(t.TempDir(), "output.json")
if *update && tt.override == nil {
outputFile = tt.golden
}
if tt.args.listAllPkgs {
osArgs = append(osArgs, "--list-all-pkgs")
}
@@ -496,26 +487,10 @@ func TestRepository(t *testing.T) {
osArgs = append(osArgs, "--secret-config", tt.args.secretConfig)
}
osArgs = append(osArgs, "--output", outputFile)
osArgs = append(osArgs, tt.args.input)
uuid.SetFakeUUID(t, "3ff14136-e09f-4df9-80ea-%012d")
// Run "trivy repo"
err := execute(osArgs)
require.NoError(t, err)
// Compare want and got
switch format {
case types.FormatCycloneDX:
compareCycloneDX(t, tt.golden, outputFile)
case types.FormatSPDXJSON:
compareSPDXJson(t, tt.golden, outputFile)
case types.FormatJSON:
compareReports(t, tt.golden, outputFile, tt.override)
default:
require.Fail(t, "invalid format", "format: %s", format)
}
runTest(t, osArgs, tt.golden, "", format, runOptions{
fakeUUID: "3ff14136-e09f-4df9-80ea-%012d",
override: tt.override,
})
})
}
}

View File

@@ -3,6 +3,7 @@
package integration
import (
"github.com/aquasecurity/trivy/pkg/types"
"os"
"path/filepath"
"strings"
@@ -17,28 +18,28 @@ func TestTar(t *testing.T) {
IgnoreUnfixed bool
Severity []string
IgnoreIDs []string
Format string
Format types.Format
Input string
SkipDirs []string
SkipFiles []string
}
tests := []struct {
name string
testArgs args
golden string
name string
args args
golden string
}{
{
name: "alpine 3.9",
testArgs: args{
Format: "json",
args: args{
Format: types.FormatJSON,
Input: "testdata/fixtures/images/alpine-39.tar.gz",
},
golden: "testdata/alpine-39.json.golden",
},
{
name: "alpine 3.9 with skip dirs",
testArgs: args{
Format: "json",
args: args{
Format: types.FormatJSON,
Input: "testdata/fixtures/images/alpine-39.tar.gz",
SkipDirs: []string{
"/etc",
@@ -48,8 +49,8 @@ func TestTar(t *testing.T) {
},
{
name: "alpine 3.9 with skip files",
testArgs: args{
Format: "json",
args: args{
Format: types.FormatJSON,
Input: "testdata/fixtures/images/alpine-39.tar.gz",
SkipFiles: []string{
"/etc",
@@ -132,224 +133,224 @@ func TestTar(t *testing.T) {
},
{
name: "alpine 3.9 with high and critical severity",
testArgs: args{
args: args{
IgnoreUnfixed: true,
Severity: []string{
"HIGH",
"CRITICAL",
},
Format: "json",
Format: types.FormatJSON,
Input: "testdata/fixtures/images/alpine-39.tar.gz",
},
golden: "testdata/alpine-39-high-critical.json.golden",
},
{
name: "alpine 3.9 with .trivyignore",
testArgs: args{
args: args{
IgnoreUnfixed: false,
IgnoreIDs: []string{
"CVE-2019-1549",
"CVE-2019-14697",
},
Format: "json",
Format: types.FormatJSON,
Input: "testdata/fixtures/images/alpine-39.tar.gz",
},
golden: "testdata/alpine-39-ignore-cveids.json.golden",
},
{
name: "alpine 3.10",
testArgs: args{
Format: "json",
args: args{
Format: types.FormatJSON,
Input: "testdata/fixtures/images/alpine-310.tar.gz",
},
golden: "testdata/alpine-310.json.golden",
},
{
name: "alpine distroless",
testArgs: args{
Format: "json",
args: args{
Format: types.FormatJSON,
Input: "testdata/fixtures/images/alpine-distroless.tar.gz",
},
golden: "testdata/alpine-distroless.json.golden",
},
{
name: "amazon linux 1",
testArgs: args{
Format: "json",
args: args{
Format: types.FormatJSON,
Input: "testdata/fixtures/images/amazon-1.tar.gz",
},
golden: "testdata/amazon-1.json.golden",
},
{
name: "amazon linux 2",
testArgs: args{
Format: "json",
args: args{
Format: types.FormatJSON,
Input: "testdata/fixtures/images/amazon-2.tar.gz",
},
golden: "testdata/amazon-2.json.golden",
},
{
name: "debian buster/10",
testArgs: args{
Format: "json",
args: args{
Format: types.FormatJSON,
Input: "testdata/fixtures/images/debian-buster.tar.gz",
},
golden: "testdata/debian-buster.json.golden",
},
{
name: "debian buster/10 with --ignore-unfixed option",
testArgs: args{
args: args{
IgnoreUnfixed: true,
Format: "json",
Format: types.FormatJSON,
Input: "testdata/fixtures/images/debian-buster.tar.gz",
},
golden: "testdata/debian-buster-ignore-unfixed.json.golden",
},
{
name: "debian stretch/9",
testArgs: args{
Format: "json",
args: args{
Format: types.FormatJSON,
Input: "testdata/fixtures/images/debian-stretch.tar.gz",
},
golden: "testdata/debian-stretch.json.golden",
},
{
name: "ubuntu 18.04",
testArgs: args{
Format: "json",
args: args{
Format: types.FormatJSON,
Input: "testdata/fixtures/images/ubuntu-1804.tar.gz",
},
golden: "testdata/ubuntu-1804.json.golden",
},
{
name: "ubuntu 18.04 with --ignore-unfixed option",
testArgs: args{
args: args{
IgnoreUnfixed: true,
Format: "json",
Format: types.FormatJSON,
Input: "testdata/fixtures/images/ubuntu-1804.tar.gz",
},
golden: "testdata/ubuntu-1804-ignore-unfixed.json.golden",
},
{
name: "centos 7",
testArgs: args{
Format: "json",
args: args{
Format: types.FormatJSON,
Input: "testdata/fixtures/images/centos-7.tar.gz",
},
golden: "testdata/centos-7.json.golden",
},
{
name: "centos 7with --ignore-unfixed option",
testArgs: args{
args: args{
IgnoreUnfixed: true,
Format: "json",
Format: types.FormatJSON,
Input: "testdata/fixtures/images/centos-7.tar.gz",
},
golden: "testdata/centos-7-ignore-unfixed.json.golden",
},
{
name: "centos 7 with medium severity",
testArgs: args{
args: args{
IgnoreUnfixed: true,
Severity: []string{"MEDIUM"},
Format: "json",
Format: types.FormatJSON,
Input: "testdata/fixtures/images/centos-7.tar.gz",
},
golden: "testdata/centos-7-medium.json.golden",
},
{
name: "centos 6",
testArgs: args{
Format: "json",
args: args{
Format: types.FormatJSON,
Input: "testdata/fixtures/images/centos-6.tar.gz",
},
golden: "testdata/centos-6.json.golden",
},
{
name: "ubi 7",
testArgs: args{
Format: "json",
args: args{
Format: types.FormatJSON,
Input: "testdata/fixtures/images/ubi-7.tar.gz",
},
golden: "testdata/ubi-7.json.golden",
},
{
name: "almalinux 8",
testArgs: args{
Format: "json",
args: args{
Format: types.FormatJSON,
Input: "testdata/fixtures/images/almalinux-8.tar.gz",
},
golden: "testdata/almalinux-8.json.golden",
},
{
name: "rocky linux 8",
testArgs: args{
Format: "json",
args: args{
Format: types.FormatJSON,
Input: "testdata/fixtures/images/rockylinux-8.tar.gz",
},
golden: "testdata/rockylinux-8.json.golden",
},
{
name: "distroless base",
testArgs: args{
Format: "json",
args: args{
Format: types.FormatJSON,
Input: "testdata/fixtures/images/distroless-base.tar.gz",
},
golden: "testdata/distroless-base.json.golden",
},
{
name: "distroless python27",
testArgs: args{
Format: "json",
args: args{
Format: types.FormatJSON,
Input: "testdata/fixtures/images/distroless-python27.tar.gz",
},
golden: "testdata/distroless-python27.json.golden",
},
{
name: "oracle linux 8",
testArgs: args{
Format: "json",
args: args{
Format: types.FormatJSON,
Input: "testdata/fixtures/images/oraclelinux-8.tar.gz",
},
golden: "testdata/oraclelinux-8.json.golden",
},
{
name: "opensuse leap 15.1",
testArgs: args{
Format: "json",
args: args{
Format: types.FormatJSON,
Input: "testdata/fixtures/images/opensuse-leap-151.tar.gz",
},
golden: "testdata/opensuse-leap-151.json.golden",
},
{
name: "photon 3.0",
testArgs: args{
Format: "json",
args: args{
Format: types.FormatJSON,
Input: "testdata/fixtures/images/photon-30.tar.gz",
},
golden: "testdata/photon-30.json.golden",
},
{
name: "CBL-Mariner 1.0",
testArgs: args{
Format: "json",
args: args{
Format: types.FormatJSON,
Input: "testdata/fixtures/images/mariner-1.0.tar.gz",
},
golden: "testdata/mariner-1.0.json.golden",
},
{
name: "busybox with Cargo.lock integration",
testArgs: args{
Format: "json",
args: args{
Format: types.FormatJSON,
Input: "testdata/fixtures/images/busybox-with-lockfile.tar.gz",
},
golden: "testdata/busybox-with-lockfile.json.golden",
},
{
name: "fluentd with RubyGems",
testArgs: args{
args: args{
IgnoreUnfixed: true,
Format: "json",
Format: types.FormatJSON,
Input: "testdata/fixtures/images/fluentd-multiple-lockfiles.tar.gz",
},
golden: "testdata/fluentd-gems.json.golden",
@@ -370,55 +371,40 @@ func TestTar(t *testing.T) {
"image",
"-q",
"--format",
tt.testArgs.Format,
string(tt.args.Format),
"--skip-update",
}
if tt.testArgs.IgnoreUnfixed {
if tt.args.IgnoreUnfixed {
osArgs = append(osArgs, "--ignore-unfixed")
}
if len(tt.testArgs.Severity) != 0 {
osArgs = append(osArgs, "--severity", strings.Join(tt.testArgs.Severity, ","))
if len(tt.args.Severity) != 0 {
osArgs = append(osArgs, "--severity", strings.Join(tt.args.Severity, ","))
}
if len(tt.testArgs.IgnoreIDs) != 0 {
if len(tt.args.IgnoreIDs) != 0 {
trivyIgnore := ".trivyignore"
err := os.WriteFile(trivyIgnore, []byte(strings.Join(tt.testArgs.IgnoreIDs, "\n")), 0444)
err := os.WriteFile(trivyIgnore, []byte(strings.Join(tt.args.IgnoreIDs, "\n")), 0444)
assert.NoError(t, err, "failed to write .trivyignore")
defer os.Remove(trivyIgnore)
}
if tt.testArgs.Input != "" {
osArgs = append(osArgs, "--input", tt.testArgs.Input)
if tt.args.Input != "" {
osArgs = append(osArgs, "--input", tt.args.Input)
}
if len(tt.testArgs.SkipFiles) != 0 {
for _, skipFile := range tt.testArgs.SkipFiles {
if len(tt.args.SkipFiles) != 0 {
for _, skipFile := range tt.args.SkipFiles {
osArgs = append(osArgs, "--skip-files", skipFile)
}
}
if len(tt.testArgs.SkipDirs) != 0 {
for _, skipDir := range tt.testArgs.SkipDirs {
if len(tt.args.SkipDirs) != 0 {
for _, skipDir := range tt.args.SkipDirs {
osArgs = append(osArgs, "--skip-dirs", skipDir)
}
}
// Set up the output file
outputFile := filepath.Join(t.TempDir(), "output.json")
if *update {
outputFile = tt.golden
}
osArgs = append(osArgs, []string{
"--output",
outputFile,
}...)
// Run Trivy
err := execute(osArgs)
require.NoError(t, err)
// Compare want and got
compareReports(t, tt.golden, outputFile, nil)
runTest(t, osArgs, tt.golden, "", tt.args.Format, runOptions{})
})
}
}
@@ -479,8 +465,6 @@ func TestTarWithEnv(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
osArgs := []string{"image"}
t.Setenv("TRIVY_FORMAT", tt.testArgs.Format)
t.Setenv("TRIVY_CACHE_DIR", cacheDir)
t.Setenv("TRIVY_QUIET", "true")
@@ -493,27 +477,15 @@ func TestTarWithEnv(t *testing.T) {
t.Setenv("TRIVY_SEVERITY", strings.Join(tt.testArgs.Severity, ","))
}
if tt.testArgs.Input != "" {
osArgs = append(osArgs, "--input", tt.testArgs.Input)
t.Setenv("TRIVY_INPUT", tt.testArgs.Input)
}
if len(tt.testArgs.SkipDirs) != 0 {
t.Setenv("TRIVY_SKIP_DIRS", strings.Join(tt.testArgs.SkipDirs, ","))
}
// Set up the output file
outputFile := filepath.Join(t.TempDir(), "output.json")
osArgs = append(osArgs, []string{
"--output",
outputFile,
}...)
// Run Trivy
err := execute(osArgs)
require.NoError(t, err)
// Compare want and got
compareReports(t, tt.golden, outputFile, nil)
runTest(t, []string{"image"}, tt.golden, "", types.FormatJSON, runOptions{})
})
}
}
@@ -531,13 +503,13 @@ func TestTarWithConfigFile(t *testing.T) {
configFile: `quiet: true
format: json
severity:
- HIGH
- CRITICAL
- HIGH
- CRITICAL
vulnerability:
type:
- os
type:
- os
cache:
dir: /should/be/overwritten
dir: /should/be/overwritten
`,
golden: "testdata/alpine-39-high-critical.json.golden",
},
@@ -547,9 +519,9 @@ cache:
configFile: `quiet: true
format: json
vulnerability:
ignore-unfixed: true
ignore-unfixed: true
cache:
dir: /should/be/overwritten
dir: /should/be/overwritten
`,
golden: "testdata/debian-buster-ignore-unfixed.json.golden",
},
@@ -563,10 +535,7 @@ cache:
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir := t.TempDir()
outputFile := filepath.Join(tmpDir, "output.json")
configPath := filepath.Join(tmpDir, "trivy.yaml")
configPath := filepath.Join(t.TempDir(), "trivy.yaml")
err := os.WriteFile(configPath, []byte(tt.configFile), 0600)
require.NoError(t, err)
@@ -579,16 +548,10 @@ cache:
configPath,
"--input",
tt.input,
"--output",
outputFile,
}
// Run Trivy
err = execute(osArgs)
require.NoError(t, err)
// Compare want and got
compareReports(t, tt.golden, outputFile, nil)
runTest(t, osArgs, tt.golden, "", types.FormatJSON, runOptions{})
})
}
}

View File

@@ -1,7 +1,7 @@
{
"SchemaVersion": 2,
"CreatedAt": "2021-08-25T12:20:30.000000005Z",
"ArtifactName": "localhost:53869/alpine:3.10",
"ArtifactName": "alpine:3.10",
"ArtifactType": "container_image",
"Metadata": {
"OS": {
@@ -14,10 +14,10 @@
"sha256:03901b4a2ea88eeaad62dbe59b072b28b6efa00491962b8741081c5df50c65e0"
],
"RepoTags": [
"localhost:53869/alpine:3.10"
"alpine:3.10"
],
"RepoDigests": [
"localhost:53869/alpine@sha256:b1c5a500182b21d0bfa5a584a8526b56d8be316f89e87d951be04abed2446e60"
"alpine@sha256:b1c5a500182b21d0bfa5a584a8526b56d8be316f89e87d951be04abed2446e60"
],
"ImageConfig": {
"architecture": "amd64",
@@ -56,7 +56,7 @@
},
"Results": [
{
"Target": "localhost:53869/alpine:3.10 (alpine 3.10.2)",
"Target": "alpine:3.10 (alpine 3.10.2)",
"Class": "os-pkgs",
"Type": "alpine",
"Vulnerabilities": [

View File

@@ -3,12 +3,10 @@
package integration
import (
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/aquasecurity/trivy/internal/testutil"
"github.com/aquasecurity/trivy/pkg/types"
)
@@ -66,10 +64,6 @@ func TestVM(t *testing.T) {
// Set up testing DB
cacheDir := initDB(t)
// Keep the current working directory
currentDir, err := os.Getwd()
require.NoError(t, err)
const imageFile = "disk.img"
for _, tt := range tests {
@@ -86,34 +80,22 @@ func TestVM(t *testing.T) {
tt.args.format,
}
tmpDir := t.TempDir()
// Set up the output file
outputFile := filepath.Join(tmpDir, "output.json")
if *update {
outputFile = filepath.Join(currentDir, tt.golden)
}
// Get the absolute path of the golden file
goldenFile, err := filepath.Abs(tt.golden)
require.NoError(t, err)
// Decompress the gzipped image file
imagePath := filepath.Join(tmpDir, imageFile)
imagePath := filepath.Join(t.TempDir(), imageFile)
testutil.DecompressSparseGzip(t, tt.args.input, imagePath)
// Change the current working directory so that targets in the result could be the same as golden files.
err = os.Chdir(tmpDir)
require.NoError(t, err)
defer os.Chdir(currentDir)
osArgs = append(osArgs, "--output", outputFile)
osArgs = append(osArgs, imageFile)
osArgs = append(osArgs, imagePath)
// Run "trivy vm"
err = execute(osArgs)
require.NoError(t, err)
compareReports(t, goldenFile, outputFile, nil)
runTest(t, osArgs, tt.golden, "", types.FormatJSON, runOptions{
override: func(_, got *types.Report) {
got.ArtifactName = "disk.img"
for i := range got.Results {
lastIndex := strings.LastIndex(got.Results[i].Target, "/")
got.Results[i].Target = got.Results[i].Target[lastIndex+1:]
}
},
})
})
}
}

View File

@@ -190,7 +190,10 @@ func NewRootCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
return err
}
globalOptions := globalFlags.ToOptions()
globalOptions, err := globalFlags.ToOptions()
if err != nil {
return err
}
// Initialize logger
if err := log.InitLogger(globalOptions.Debug, globalOptions.Quiet); err != nil {
@@ -200,7 +203,11 @@ func NewRootCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
globalOptions := globalFlags.ToOptions()
globalOptions, err := globalFlags.ToOptions()
if err != nil {
return err
}
if globalOptions.ShowVersion {
// Customize version output
return showVersion(globalOptions.CacheDir, versionFormat, cmd.OutOrStdout())
@@ -223,20 +230,21 @@ func NewImageCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
scanFlagGroup.IncludeDevDeps = nil // disable '--include-dev-deps'
reportFlagGroup := flag.NewReportFlagGroup()
report := flag.ReportFormatFlag
report := flag.ReportFormatFlag.Clone()
report.Default = "summary" // override the default value as the summary is preferred for the compliance report
report.Usage = "specify a format for the compliance report." // "--report" works only with "--compliance"
reportFlagGroup.ReportFormat = &report
reportFlagGroup.ReportFormat = report
compliance := flag.ComplianceFlag
compliance := flag.ComplianceFlag.Clone()
compliance.Values = []string{types.ComplianceDockerCIS}
reportFlagGroup.Compliance = &compliance // override usage as the accepted values differ for each subcommand.
reportFlagGroup.Compliance = compliance // override usage as the accepted values differ for each subcommand.
misconfFlagGroup := flag.NewMisconfFlagGroup()
misconfFlagGroup.CloudformationParamVars = nil // disable '--cf-params'
misconfFlagGroup.TerraformTFVars = nil // disable '--tf-vars'
imageFlags := &flag.Flags{
GlobalFlagGroup: globalFlags,
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
ImageFlagGroup: flag.NewImageFlagGroup(), // container image specific
@@ -292,7 +300,7 @@ func NewImageCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
return validateArgs(cmd, args)
},
RunE: func(cmd *cobra.Command, args []string) error {
options, err := imageFlags.ToOptions(args, globalFlags)
options, err := imageFlags.ToOptions(args)
if err != nil {
return xerrors.Errorf("flag error: %w", err)
}
@@ -311,12 +319,13 @@ func NewImageCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
func NewFilesystemCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
reportFlagGroup := flag.NewReportFlagGroup()
reportFormat := flag.ReportFormatFlag
reportFormat := flag.ReportFormatFlag.Clone()
reportFormat.Usage = "specify a compliance report format for the output" // @TODO: support --report summary for non compliance reports
reportFlagGroup.ReportFormat = &reportFormat
reportFlagGroup.ReportFormat = reportFormat
reportFlagGroup.ExitOnEOL = nil // disable '--exit-on-eol'
fsFlags := &flag.Flags{
GlobalFlagGroup: globalFlags,
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
LicenseFlagGroup: flag.NewLicenseFlagGroup(),
@@ -351,7 +360,7 @@ func NewFilesystemCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
if err := fsFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err)
}
options, err := fsFlags.ToOptions(args, globalFlags)
options, err := fsFlags.ToOptions(args)
if err != nil {
return xerrors.Errorf("flag error: %w", err)
}
@@ -370,6 +379,7 @@ func NewFilesystemCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
func NewRootfsCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
rootfsFlags := &flag.Flags{
GlobalFlagGroup: globalFlags,
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
LicenseFlagGroup: flag.NewLicenseFlagGroup(),
@@ -410,7 +420,7 @@ func NewRootfsCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
if err := rootfsFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err)
}
options, err := rootfsFlags.ToOptions(args, globalFlags)
options, err := rootfsFlags.ToOptions(args)
if err != nil {
return xerrors.Errorf("flag error: %w", err)
}
@@ -428,6 +438,7 @@ func NewRootfsCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
func NewRepositoryCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
repoFlags := &flag.Flags{
GlobalFlagGroup: globalFlags,
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
LicenseFlagGroup: flag.NewLicenseFlagGroup(),
@@ -465,7 +476,7 @@ func NewRepositoryCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
if err := repoFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err)
}
options, err := repoFlags.ToOptions(args, globalFlags)
options, err := repoFlags.ToOptions(args)
if err != nil {
return xerrors.Errorf("flag error: %w", err)
}
@@ -483,6 +494,7 @@ func NewRepositoryCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
func NewConvertCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
convertFlags := &flag.Flags{
GlobalFlagGroup: globalFlags,
ScanFlagGroup: &flag.ScanFlagGroup{},
ReportFlagGroup: flag.NewReportFlagGroup(),
}
@@ -505,7 +517,7 @@ func NewConvertCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
if err := convertFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err)
}
opts, err := convertFlags.ToOptions(args, globalFlags)
opts, err := convertFlags.ToOptions(args)
if err != nil {
return xerrors.Errorf("flag error: %w", err)
}
@@ -525,7 +537,7 @@ func NewConvertCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
// NewClientCommand returns the 'client' subcommand that is deprecated
func NewClientCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
remoteFlags := flag.NewClientFlags()
remoteAddr := flag.Flag{
remoteAddr := flag.Flag[string]{
Name: "remote",
ConfigName: "server.addr",
Shorthand: "",
@@ -535,6 +547,7 @@ func NewClientCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
remoteFlags.ServerAddr = &remoteAddr // disable '--server' and enable '--remote' instead.
clientFlags := &flag.Flags{
GlobalFlagGroup: globalFlags,
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
@@ -562,7 +575,7 @@ func NewClientCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
if err := clientFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err)
}
options, err := clientFlags.ToOptions(args, globalFlags)
options, err := clientFlags.ToOptions(args)
if err != nil {
return xerrors.Errorf("flag error: %w", err)
}
@@ -580,6 +593,7 @@ func NewClientCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
func NewServerCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
serverFlags := &flag.Flags{
GlobalFlagGroup: globalFlags,
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
ModuleFlagGroup: flag.NewModuleFlagGroup(),
@@ -608,7 +622,7 @@ func NewServerCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
if err := serverFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err)
}
options, err := serverFlags.ToOptions(args, globalFlags)
options, err := serverFlags.ToOptions(args)
if err != nil {
return xerrors.Errorf("flag error: %w", err)
}
@@ -629,18 +643,19 @@ func NewConfigCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
reportFlagGroup.DependencyTree = nil // disable '--dependency-tree'
reportFlagGroup.ListAllPkgs = nil // disable '--list-all-pkgs'
reportFlagGroup.ExitOnEOL = nil // disable '--exit-on-eol'
reportFormat := flag.ReportFormatFlag
reportFormat := flag.ReportFormatFlag.Clone()
reportFormat.Usage = "specify a compliance report format for the output" // @TODO: support --report summary for non compliance reports
reportFlagGroup.ReportFormat = &reportFormat
reportFlagGroup.ReportFormat = reportFormat
scanFlags := &flag.ScanFlagGroup{
// Enable only '--skip-dirs' and '--skip-files' and disable other flags
SkipDirs: &flag.SkipDirsFlag,
SkipFiles: &flag.SkipFilesFlag,
FilePatterns: &flag.FilePatternsFlag,
SkipDirs: flag.SkipDirsFlag.Clone(),
SkipFiles: flag.SkipFilesFlag.Clone(),
FilePatterns: flag.FilePatternsFlag.Clone(),
}
configFlags := &flag.Flags{
GlobalFlagGroup: globalFlags,
CacheFlagGroup: flag.NewCacheFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
ModuleFlagGroup: flag.NewModuleFlagGroup(),
@@ -648,7 +663,7 @@ func NewConfigCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
RegoFlagGroup: flag.NewRegoFlagGroup(),
K8sFlagGroup: &flag.K8sFlagGroup{
// disable unneeded flags
K8sVersion: &flag.K8sVersionFlag,
K8sVersion: flag.K8sVersionFlag.Clone(),
},
ReportFlagGroup: reportFlagGroup,
ScanFlagGroup: scanFlags,
@@ -669,7 +684,7 @@ func NewConfigCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
if err := configFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err)
}
options, err := configFlags.ToOptions(args, globalFlags)
options, err := configFlags.ToOptions(args)
if err != nil {
return xerrors.Errorf("flag error: %w", err)
}
@@ -796,6 +811,7 @@ func NewPluginCommand() *cobra.Command {
func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
moduleFlags := &flag.Flags{
GlobalFlagGroup: globalFlags,
ModuleFlagGroup: flag.NewModuleFlagGroup(),
}
@@ -827,7 +843,7 @@ func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
}
repo := args[0]
opts, err := moduleFlags.ToOptions(args, globalFlags)
opts, err := moduleFlags.ToOptions(args)
if err != nil {
return xerrors.Errorf("flag error: %w", err)
}
@@ -851,7 +867,7 @@ func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
}
repo := args[0]
opts, err := moduleFlags.ToOptions(args, globalFlags)
opts, err := moduleFlags.ToOptions(args)
if err != nil {
return xerrors.Errorf("flag error: %w", err)
}
@@ -866,7 +882,7 @@ func NewModuleCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
func NewKubernetesCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
scanFlags := flag.NewScanFlagGroup()
scanners := flag.ScannersFlag
scanners := flag.ScannersFlag.Clone()
// overwrite the default scanners
scanners.Values = xstrings.ToStringSlice(types.Scanners{
types.VulnerabilityScanner,
@@ -875,36 +891,37 @@ func NewKubernetesCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
types.RBACScanner,
})
scanners.Default = scanners.Values
scanFlags.Scanners = &scanners
scanFlags.Scanners = scanners
scanFlags.IncludeDevDeps = nil // disable '--include-dev-deps'
// required only SourceFlag
imageFlags := &flag.ImageFlagGroup{ImageSources: &flag.SourceFlag}
imageFlags := &flag.ImageFlagGroup{ImageSources: flag.SourceFlag.Clone()}
reportFlagGroup := flag.NewReportFlagGroup()
compliance := flag.ComplianceFlag
compliance := flag.ComplianceFlag.Clone()
compliance.Values = []string{
types.ComplianceK8sNsa,
types.ComplianceK8sCIS,
types.ComplianceK8sPSSBaseline,
types.ComplianceK8sPSSRestricted,
}
reportFlagGroup.Compliance = &compliance // override usage as the accepted values differ for each subcommand.
reportFlagGroup.ExitOnEOL = nil // disable '--exit-on-eol'
reportFlagGroup.Compliance = compliance // override usage as the accepted values differ for each subcommand.
reportFlagGroup.ExitOnEOL = nil // disable '--exit-on-eol'
formatFlag := flag.FormatFlag
formatFlag := flag.FormatFlag.Clone()
formatFlag.Values = xstrings.ToStringSlice([]types.Format{
types.FormatTable,
types.FormatJSON,
types.FormatCycloneDX,
})
reportFlagGroup.Format = &formatFlag
reportFlagGroup.Format = formatFlag
misconfFlagGroup := flag.NewMisconfFlagGroup()
misconfFlagGroup.CloudformationParamVars = nil // disable '--cf-params'
misconfFlagGroup.TerraformTFVars = nil // disable '--tf-vars'
k8sFlags := &flag.Flags{
GlobalFlagGroup: globalFlags,
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
ImageFlagGroup: imageFlags,
@@ -945,7 +962,7 @@ func NewKubernetesCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
if err := k8sFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err)
}
opts, err := k8sFlags.ToOptions(args, globalFlags)
opts, err := k8sFlags.ToOptions(args)
if err != nil {
return xerrors.Errorf("flag error: %w", err)
}
@@ -973,6 +990,7 @@ func NewAWSCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
reportFlagGroup.ExitOnEOL = nil // disable '--exit-on-eol'
awsFlags := &flag.Flags{
GlobalFlagGroup: globalFlags,
AWSFlagGroup: flag.NewAWSFlagGroup(),
CloudFlagGroup: flag.NewCloudFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
@@ -1014,7 +1032,7 @@ The following services are supported:
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
opts, err := awsFlags.ToOptions(args, globalFlags)
opts, err := awsFlags.ToOptions(args)
if err != nil {
return xerrors.Errorf("flag error: %w", err)
}
@@ -1036,6 +1054,7 @@ The following services are supported:
func NewVMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
vmFlags := &flag.Flags{
GlobalFlagGroup: globalFlags,
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
@@ -1046,10 +1065,9 @@ func NewVMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
SecretFlagGroup: flag.NewSecretFlagGroup(),
VulnerabilityFlagGroup: flag.NewVulnerabilityFlagGroup(),
AWSFlagGroup: &flag.AWSFlagGroup{
Region: &flag.Flag{
Region: &flag.Flag[string]{
Name: "aws-region",
ConfigName: "aws.region",
Default: "",
Usage: "AWS region to scan",
},
},
@@ -1080,7 +1098,7 @@ func NewVMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
if err := vmFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err)
}
options, err := vmFlags.ToOptions(args, globalFlags)
options, err := vmFlags.ToOptions(args)
if err != nil {
return xerrors.Errorf("flag error: %w", err)
}
@@ -1111,6 +1129,7 @@ func NewSBOMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
scanFlagGroup.Parallel = nil // disable '--parallel'
sbomFlags := &flag.Flags{
GlobalFlagGroup: globalFlags,
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
@@ -1140,7 +1159,7 @@ func NewSBOMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
if err := sbomFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err)
}
options, err := sbomFlags.ToOptions(args, globalFlags)
options, err := sbomFlags.ToOptions(args)
if err != nil {
return xerrors.Errorf("flag error: %w", err)
}
@@ -1168,7 +1187,10 @@ func NewVersionCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
GroupID: groupUtility,
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
options := globalFlags.ToOptions()
options, err := globalFlags.ToOptions()
if err != nil {
return err
}
return showVersion(options.CacheDir, versionFormat, cmd.OutOrStdout())
},
SilenceErrors: true,

View File

@@ -250,7 +250,7 @@ func TestFlags(t *testing.T) {
"--format",
"foo",
},
wantErr: `invalid argument "foo" for "-f, --format" flag`,
wantErr: `invalid argument "foo" for "--format" flag`,
},
}
@@ -262,16 +262,21 @@ func TestFlags(t *testing.T) {
rootCmd.SetOut(io.Discard)
flags := &flag.Flags{
GlobalFlagGroup: globalFlags,
ReportFlagGroup: flag.NewReportFlagGroup(),
}
cmd := &cobra.Command{
Use: "test",
RunE: func(cmd *cobra.Command, args []string) error {
// Bind
require.NoError(t, flags.Bind(cmd))
if err := flags.Bind(cmd); err != nil {
return err
}
options, err := flags.ToOptions(args, globalFlags)
require.NoError(t, err)
options, err := flags.ToOptions(args)
if err != nil {
return err
}
assert.Equal(t, tt.want.format, options.Format)
assert.Equal(t, tt.want.severities, options.Severities)

View File

@@ -1,51 +1,45 @@
package flag
var (
awsRegionFlag = Flag{
awsRegionFlag = Flag[string]{
Name: "region",
ConfigName: "cloud.aws.region",
Default: "",
Usage: "AWS Region to scan",
}
awsEndpointFlag = Flag{
awsEndpointFlag = Flag[string]{
Name: "endpoint",
ConfigName: "cloud.aws.endpoint",
Default: "",
Usage: "AWS Endpoint override",
}
awsServiceFlag = Flag{
awsServiceFlag = Flag[[]string]{
Name: "service",
ConfigName: "cloud.aws.service",
Default: []string{},
Usage: "Only scan AWS Service(s) specified with this flag. Can specify multiple services using --service A --service B etc.",
}
awsSkipServicesFlag = Flag{
awsSkipServicesFlag = Flag[[]string]{
Name: "skip-service",
ConfigName: "cloud.aws.skip-service",
Default: []string{},
Usage: "Skip selected AWS Service(s) specified with this flag. Can specify multiple services using --skip-service A --skip-service B etc.",
}
awsAccountFlag = Flag{
awsAccountFlag = Flag[string]{
Name: "account",
ConfigName: "cloud.aws.account",
Default: "",
Usage: "The AWS account to scan. It's useful to specify this when reviewing cached results for multiple accounts.",
}
awsARNFlag = Flag{
awsARNFlag = Flag[string]{
Name: "arn",
ConfigName: "cloud.aws.arn",
Default: "",
Usage: "The AWS ARN to show results for. Useful to filter results once a scan is cached.",
}
)
type AWSFlagGroup struct {
Region *Flag
Endpoint *Flag
Services *Flag
SkipServices *Flag
Account *Flag
ARN *Flag
Region *Flag[string]
Endpoint *Flag[string]
Services *Flag[[]string]
SkipServices *Flag[[]string]
Account *Flag[string]
ARN *Flag[string]
}
type AWSOptions struct {
@@ -59,12 +53,12 @@ type AWSOptions struct {
func NewAWSFlagGroup() *AWSFlagGroup {
return &AWSFlagGroup{
Region: &awsRegionFlag,
Endpoint: &awsEndpointFlag,
Services: &awsServiceFlag,
SkipServices: &awsSkipServicesFlag,
Account: &awsAccountFlag,
ARN: &awsARNFlag,
Region: awsRegionFlag.Clone(),
Endpoint: awsEndpointFlag.Clone(),
Services: awsServiceFlag.Clone(),
SkipServices: awsSkipServicesFlag.Clone(),
Account: awsAccountFlag.Clone(),
ARN: awsARNFlag.Clone(),
}
}
@@ -72,17 +66,27 @@ func (f *AWSFlagGroup) Name() string {
return "AWS"
}
func (f *AWSFlagGroup) Flags() []*Flag {
return []*Flag{f.Region, f.Endpoint, f.Services, f.SkipServices, f.Account, f.ARN}
}
func (f *AWSFlagGroup) ToOptions() AWSOptions {
return AWSOptions{
Region: getString(f.Region),
Endpoint: getString(f.Endpoint),
Services: getStringSlice(f.Services),
SkipServices: getStringSlice(f.SkipServices),
Account: getString(f.Account),
ARN: getString(f.ARN),
func (f *AWSFlagGroup) Flags() []Flagger {
return []Flagger{
f.Region,
f.Endpoint,
f.Services,
f.SkipServices,
f.Account,
f.ARN,
}
}
func (f *AWSFlagGroup) ToOptions() (AWSOptions, error) {
if err := parseFlags(f); err != nil {
return AWSOptions{}, err
}
return AWSOptions{
Region: f.Region.Value(),
Endpoint: f.Endpoint.Value(),
Services: f.Services.Value(),
SkipServices: f.SkipServices.Value(),
Account: f.Account.Value(),
ARN: f.ARN.Value(),
}, nil
}

View File

@@ -19,60 +19,54 @@ import (
// cert: cert.pem
// key: key.pem
var (
ClearCacheFlag = Flag{
ClearCacheFlag = Flag[bool]{
Name: "clear-cache",
ConfigName: "cache.clear",
Default: false,
Usage: "clear image caches without scanning",
}
CacheBackendFlag = Flag{
CacheBackendFlag = Flag[string]{
Name: "cache-backend",
ConfigName: "cache.backend",
Default: "fs",
Usage: "cache backend (e.g. redis://localhost:6379)",
}
CacheTTLFlag = Flag{
CacheTTLFlag = Flag[time.Duration]{
Name: "cache-ttl",
ConfigName: "cache.ttl",
Default: time.Duration(0),
Usage: "cache TTL when using redis as cache backend",
}
RedisTLSFlag = Flag{
RedisTLSFlag = Flag[bool]{
Name: "redis-tls",
ConfigName: "cache.redis.tls",
Default: false,
Usage: "enable redis TLS with public certificates, if using redis as cache backend",
}
RedisCACertFlag = Flag{
RedisCACertFlag = Flag[string]{
Name: "redis-ca",
ConfigName: "cache.redis.ca",
Default: "",
Usage: "redis ca file location, if using redis as cache backend",
}
RedisCertFlag = Flag{
RedisCertFlag = Flag[string]{
Name: "redis-cert",
ConfigName: "cache.redis.cert",
Default: "",
Usage: "redis certificate file location, if using redis as cache backend",
}
RedisKeyFlag = Flag{
RedisKeyFlag = Flag[string]{
Name: "redis-key",
ConfigName: "cache.redis.key",
Default: "",
Usage: "redis key file location, if using redis as cache backend",
}
)
// CacheFlagGroup composes common printer flag structs used for commands requiring cache logic.
type CacheFlagGroup struct {
ClearCache *Flag
CacheBackend *Flag
CacheTTL *Flag
ClearCache *Flag[bool]
CacheBackend *Flag[string]
CacheTTL *Flag[time.Duration]
RedisTLS *Flag
RedisCACert *Flag
RedisCert *Flag
RedisKey *Flag
RedisTLS *Flag[bool]
RedisCACert *Flag[string]
RedisCert *Flag[string]
RedisKey *Flag[string]
}
type CacheOptions struct {
@@ -93,13 +87,13 @@ type RedisOptions struct {
// NewCacheFlagGroup returns a default CacheFlagGroup
func NewCacheFlagGroup() *CacheFlagGroup {
return &CacheFlagGroup{
ClearCache: &ClearCacheFlag,
CacheBackend: &CacheBackendFlag,
CacheTTL: &CacheTTLFlag,
RedisTLS: &RedisTLSFlag,
RedisCACert: &RedisCACertFlag,
RedisCert: &RedisCertFlag,
RedisKey: &RedisKeyFlag,
ClearCache: ClearCacheFlag.Clone(),
CacheBackend: CacheBackendFlag.Clone(),
CacheTTL: CacheTTLFlag.Clone(),
RedisTLS: RedisTLSFlag.Clone(),
RedisCACert: RedisCACertFlag.Clone(),
RedisCert: RedisCertFlag.Clone(),
RedisKey: RedisKeyFlag.Clone(),
}
}
@@ -107,16 +101,28 @@ func (fg *CacheFlagGroup) Name() string {
return "Cache"
}
func (fg *CacheFlagGroup) Flags() []*Flag {
return []*Flag{fg.ClearCache, fg.CacheBackend, fg.CacheTTL, fg.RedisTLS, fg.RedisCACert, fg.RedisCert, fg.RedisKey}
func (fg *CacheFlagGroup) Flags() []Flagger {
return []Flagger{
fg.ClearCache,
fg.CacheBackend,
fg.CacheTTL,
fg.RedisTLS,
fg.RedisCACert,
fg.RedisCert,
fg.RedisKey,
}
}
func (fg *CacheFlagGroup) ToOptions() (CacheOptions, error) {
cacheBackend := getString(fg.CacheBackend)
if err := parseFlags(fg); err != nil {
return CacheOptions{}, err
}
cacheBackend := fg.CacheBackend.Value()
redisOptions := RedisOptions{
RedisCACert: getString(fg.RedisCACert),
RedisCert: getString(fg.RedisCert),
RedisKey: getString(fg.RedisKey),
RedisCACert: fg.RedisCACert.Value(),
RedisCert: fg.RedisCert.Value(),
RedisKey: fg.RedisKey.Value(),
}
// "redis://" or "fs" are allowed for now
@@ -133,10 +139,10 @@ func (fg *CacheFlagGroup) ToOptions() (CacheOptions, error) {
}
return CacheOptions{
ClearCache: getBool(fg.ClearCache),
ClearCache: fg.ClearCache.Value(),
CacheBackend: cacheBackend,
CacheTTL: getDuration(fg.CacheTTL),
RedisTLS: getBool(fg.RedisTLS),
CacheTTL: fg.CacheTTL.Value(),
RedisTLS: fg.RedisTLS.Value(),
RedisOptions: redisOptions,
}, nil
}

View File

@@ -108,13 +108,13 @@ func TestCacheFlagGroup_ToOptions(t *testing.T) {
viper.Set(flag.RedisKeyFlag.ConfigName, tt.fields.RedisKey)
f := &flag.CacheFlagGroup{
ClearCache: &flag.ClearCacheFlag,
CacheBackend: &flag.CacheBackendFlag,
CacheTTL: &flag.CacheTTLFlag,
RedisTLS: &flag.RedisTLSFlag,
RedisCACert: &flag.RedisCACertFlag,
RedisCert: &flag.RedisCertFlag,
RedisKey: &flag.RedisKeyFlag,
ClearCache: flag.ClearCacheFlag.Clone(),
CacheBackend: flag.CacheBackendFlag.Clone(),
CacheTTL: flag.CacheTTLFlag.Clone(),
RedisTLS: flag.RedisTLSFlag.Clone(),
RedisCACert: flag.RedisCACertFlag.Clone(),
RedisCert: flag.RedisCertFlag.Clone(),
RedisKey: flag.RedisKeyFlag.Clone(),
}
got, err := f.ToOptions()

View File

@@ -3,13 +3,12 @@ package flag
import "time"
var (
cloudUpdateCacheFlag = Flag{
cloudUpdateCacheFlag = Flag[bool]{
Name: "update-cache",
ConfigName: "cloud.update-cache",
Default: false,
Usage: "Update the cache for the applicable cloud provider instead of using cached results.",
}
cloudMaxCacheAgeFlag = Flag{
cloudMaxCacheAgeFlag = Flag[time.Duration]{
Name: "max-cache-age",
ConfigName: "cloud.max-cache-age",
Default: time.Hour * 24,
@@ -18,8 +17,8 @@ var (
)
type CloudFlagGroup struct {
UpdateCache *Flag
MaxCacheAge *Flag
UpdateCache *Flag[bool]
MaxCacheAge *Flag[time.Duration]
}
type CloudOptions struct {
@@ -29,8 +28,8 @@ type CloudOptions struct {
func NewCloudFlagGroup() *CloudFlagGroup {
return &CloudFlagGroup{
UpdateCache: &cloudUpdateCacheFlag,
MaxCacheAge: &cloudMaxCacheAgeFlag,
UpdateCache: cloudUpdateCacheFlag.Clone(),
MaxCacheAge: cloudMaxCacheAgeFlag.Clone(),
}
}
@@ -38,13 +37,19 @@ func (f *CloudFlagGroup) Name() string {
return "Cloud"
}
func (f *CloudFlagGroup) Flags() []*Flag {
return []*Flag{f.UpdateCache, f.MaxCacheAge}
}
func (f *CloudFlagGroup) ToOptions() CloudOptions {
return CloudOptions{
UpdateCache: getBool(f.UpdateCache),
MaxCacheAge: getDuration(f.MaxCacheAge),
func (f *CloudFlagGroup) Flags() []Flagger {
return []Flagger{
f.UpdateCache,
f.MaxCacheAge,
}
}
func (f *CloudFlagGroup) ToOptions() (CloudOptions, error) {
if err := parseFlags(f); err != nil {
return CloudOptions{}, err
}
return CloudOptions{
UpdateCache: f.UpdateCache.Value(),
MaxCacheAge: f.MaxCacheAge.Value(),
}, nil
}

View File

@@ -10,22 +10,19 @@ const defaultDBRepository = "ghcr.io/aquasecurity/trivy-db"
const defaultJavaDBRepository = "ghcr.io/aquasecurity/trivy-java-db"
var (
ResetFlag = Flag{
ResetFlag = Flag[bool]{
Name: "reset",
ConfigName: "reset",
Default: false,
Usage: "remove all caches and database",
}
DownloadDBOnlyFlag = Flag{
DownloadDBOnlyFlag = Flag[bool]{
Name: "download-db-only",
ConfigName: "db.download-only",
Default: false,
Usage: "download/update vulnerability database but don't run a scan",
}
SkipDBUpdateFlag = Flag{
SkipDBUpdateFlag = Flag[bool]{
Name: "skip-db-update",
ConfigName: "db.skip-update",
Default: false,
Usage: "skip updating vulnerability database",
Aliases: []Alias{
{
@@ -34,40 +31,36 @@ var (
},
},
}
DownloadJavaDBOnlyFlag = Flag{
DownloadJavaDBOnlyFlag = Flag[bool]{
Name: "download-java-db-only",
ConfigName: "db.download-java-only",
Default: false,
Usage: "download/update Java index database but don't run a scan",
}
SkipJavaDBUpdateFlag = Flag{
SkipJavaDBUpdateFlag = Flag[bool]{
Name: "skip-java-db-update",
ConfigName: "db.java-skip-update",
Default: false,
Usage: "skip updating Java index database",
}
NoProgressFlag = Flag{
NoProgressFlag = Flag[bool]{
Name: "no-progress",
ConfigName: "db.no-progress",
Default: false,
Usage: "suppress progress bar",
}
DBRepositoryFlag = Flag{
DBRepositoryFlag = Flag[string]{
Name: "db-repository",
ConfigName: "db.repository",
Default: defaultDBRepository,
Usage: "OCI repository to retrieve trivy-db from",
}
JavaDBRepositoryFlag = Flag{
JavaDBRepositoryFlag = Flag[string]{
Name: "java-db-repository",
ConfigName: "db.java-repository",
Default: defaultJavaDBRepository,
Usage: "OCI repository to retrieve trivy-java-db from",
}
LightFlag = Flag{
LightFlag = Flag[bool]{
Name: "light",
ConfigName: "db.light",
Default: false,
Usage: "deprecated",
Deprecated: true,
}
@@ -75,15 +68,15 @@ var (
// DBFlagGroup composes common printer flag structs used for commands requiring DB logic.
type DBFlagGroup struct {
Reset *Flag
DownloadDBOnly *Flag
SkipDBUpdate *Flag
DownloadJavaDBOnly *Flag
SkipJavaDBUpdate *Flag
NoProgress *Flag
DBRepository *Flag
JavaDBRepository *Flag
Light *Flag // deprecated
Reset *Flag[bool]
DownloadDBOnly *Flag[bool]
SkipDBUpdate *Flag[bool]
DownloadJavaDBOnly *Flag[bool]
SkipJavaDBUpdate *Flag[bool]
NoProgress *Flag[bool]
DBRepository *Flag[string]
JavaDBRepository *Flag[string]
Light *Flag[bool] // deprecated
}
type DBOptions struct {
@@ -101,15 +94,15 @@ type DBOptions struct {
// NewDBFlagGroup returns a default DBFlagGroup
func NewDBFlagGroup() *DBFlagGroup {
return &DBFlagGroup{
Reset: &ResetFlag,
DownloadDBOnly: &DownloadDBOnlyFlag,
SkipDBUpdate: &SkipDBUpdateFlag,
DownloadJavaDBOnly: &DownloadJavaDBOnlyFlag,
SkipJavaDBUpdate: &SkipJavaDBUpdateFlag,
Light: &LightFlag,
NoProgress: &NoProgressFlag,
DBRepository: &DBRepositoryFlag,
JavaDBRepository: &JavaDBRepositoryFlag,
Reset: ResetFlag.Clone(),
DownloadDBOnly: DownloadDBOnlyFlag.Clone(),
SkipDBUpdate: SkipDBUpdateFlag.Clone(),
DownloadJavaDBOnly: DownloadJavaDBOnlyFlag.Clone(),
SkipJavaDBUpdate: SkipJavaDBUpdateFlag.Clone(),
Light: LightFlag.Clone(),
NoProgress: NoProgressFlag.Clone(),
DBRepository: DBRepositoryFlag.Clone(),
JavaDBRepository: JavaDBRepositoryFlag.Clone(),
}
}
@@ -117,8 +110,8 @@ func (f *DBFlagGroup) Name() string {
return "DB"
}
func (f *DBFlagGroup) Flags() []*Flag {
return []*Flag{
func (f *DBFlagGroup) Flags() []Flagger {
return []Flagger{
f.Reset,
f.DownloadDBOnly,
f.SkipDBUpdate,
@@ -132,11 +125,15 @@ func (f *DBFlagGroup) Flags() []*Flag {
}
func (f *DBFlagGroup) ToOptions() (DBOptions, error) {
skipDBUpdate := getBool(f.SkipDBUpdate)
skipJavaDBUpdate := getBool(f.SkipJavaDBUpdate)
downloadDBOnly := getBool(f.DownloadDBOnly)
downloadJavaDBOnly := getBool(f.DownloadJavaDBOnly)
light := getBool(f.Light)
if err := parseFlags(f); err != nil {
return DBOptions{}, err
}
skipDBUpdate := f.SkipDBUpdate.Value()
skipJavaDBUpdate := f.SkipJavaDBUpdate.Value()
downloadDBOnly := f.DownloadDBOnly.Value()
downloadJavaDBOnly := f.DownloadJavaDBOnly.Value()
light := f.Light.Value()
if downloadDBOnly && skipDBUpdate {
return DBOptions{}, xerrors.New("--skip-db-update and --download-db-only options can not be specified both")
@@ -149,14 +146,14 @@ func (f *DBFlagGroup) ToOptions() (DBOptions, error) {
}
return DBOptions{
Reset: getBool(f.Reset),
Reset: f.Reset.Value(),
DownloadDBOnly: downloadDBOnly,
SkipDBUpdate: skipDBUpdate,
DownloadJavaDBOnly: downloadJavaDBOnly,
SkipJavaDBUpdate: skipJavaDBUpdate,
Light: light,
NoProgress: getBool(f.NoProgress),
DBRepository: getString(f.DBRepository),
JavaDBRepository: getString(f.JavaDBRepository),
NoProgress: f.NoProgress.Value(),
DBRepository: f.DBRepository.Value(),
JavaDBRepository: f.JavaDBRepository.Value(),
}, nil
}

View File

@@ -74,9 +74,9 @@ func TestDBFlagGroup_ToOptions(t *testing.T) {
// Assert options
f := &flag.DBFlagGroup{
DownloadDBOnly: &flag.DownloadDBOnlyFlag,
SkipDBUpdate: &flag.SkipDBUpdateFlag,
Light: &flag.LightFlag,
DownloadDBOnly: flag.DownloadDBOnlyFlag.Clone(),
SkipDBUpdate: flag.SkipDBUpdateFlag.Clone(),
Light: flag.LightFlag.Clone(),
}
got, err := f.ToOptions()
tt.assertion(t, err)

View File

@@ -10,7 +10,7 @@ import (
)
var (
ConfigFileFlag = Flag{
ConfigFileFlag = Flag[string]{
Name: "config",
ConfigName: "config",
Shorthand: "c",
@@ -18,55 +18,50 @@ var (
Usage: "config path",
Persistent: true,
}
ShowVersionFlag = Flag{
ShowVersionFlag = Flag[bool]{
Name: "version",
ConfigName: "version",
Shorthand: "v",
Default: false,
Usage: "show version",
Persistent: true,
}
QuietFlag = Flag{
QuietFlag = Flag[bool]{
Name: "quiet",
ConfigName: "quiet",
Shorthand: "q",
Default: false,
Usage: "suppress progress bar and log output",
Persistent: true,
}
DebugFlag = Flag{
DebugFlag = Flag[bool]{
Name: "debug",
ConfigName: "debug",
Shorthand: "d",
Default: false,
Usage: "debug mode",
Persistent: true,
}
InsecureFlag = Flag{
InsecureFlag = Flag[bool]{
Name: "insecure",
ConfigName: "insecure",
Default: false,
Usage: "allow insecure server connections",
Persistent: true,
}
TimeoutFlag = Flag{
TimeoutFlag = Flag[time.Duration]{
Name: "timeout",
ConfigName: "timeout",
Default: time.Second * 300, // 5 mins
Usage: "timeout",
Persistent: true,
}
CacheDirFlag = Flag{
CacheDirFlag = Flag[string]{
Name: "cache-dir",
ConfigName: "cache.dir",
Default: fsutils.CacheDir(),
Usage: "cache directory",
Persistent: true,
}
GenerateDefaultConfigFlag = Flag{
GenerateDefaultConfigFlag = Flag[bool]{
Name: "generate-default-config",
ConfigName: "generate-default-config",
Default: false,
Usage: "write the default config to trivy-default.yaml",
Persistent: true,
}
@@ -74,14 +69,14 @@ var (
// GlobalFlagGroup composes global flags
type GlobalFlagGroup struct {
ConfigFile *Flag
ShowVersion *Flag // spf13/cobra can't override the logic of version printing like VersionPrinter in urfave/cli. -v needs to be defined ourselves.
Quiet *Flag
Debug *Flag
Insecure *Flag
Timeout *Flag
CacheDir *Flag
GenerateDefaultConfig *Flag
ConfigFile *Flag[string]
ShowVersion *Flag[bool] // spf13/cobra can't override the logic of version printing like VersionPrinter in urfave/cli. -v needs to be defined ourselves.
Quiet *Flag[bool]
Debug *Flag[bool]
Insecure *Flag[bool]
Timeout *Flag[time.Duration]
CacheDir *Flag[string]
GenerateDefaultConfig *Flag[bool]
}
// GlobalOptions defines flags and other configuration parameters for all the subcommands
@@ -98,19 +93,23 @@ type GlobalOptions struct {
func NewGlobalFlagGroup() *GlobalFlagGroup {
return &GlobalFlagGroup{
ConfigFile: &ConfigFileFlag,
ShowVersion: &ShowVersionFlag,
Quiet: &QuietFlag,
Debug: &DebugFlag,
Insecure: &InsecureFlag,
Timeout: &TimeoutFlag,
CacheDir: &CacheDirFlag,
GenerateDefaultConfig: &GenerateDefaultConfigFlag,
ConfigFile: ConfigFileFlag.Clone(),
ShowVersion: ShowVersionFlag.Clone(),
Quiet: QuietFlag.Clone(),
Debug: DebugFlag.Clone(),
Insecure: InsecureFlag.Clone(),
Timeout: TimeoutFlag.Clone(),
CacheDir: CacheDirFlag.Clone(),
GenerateDefaultConfig: GenerateDefaultConfigFlag.Clone(),
}
}
func (f *GlobalFlagGroup) flags() []*Flag {
return []*Flag{
func (f *GlobalFlagGroup) Name() string {
return "global"
}
func (f *GlobalFlagGroup) Flags() []Flagger {
return []Flagger{
f.ConfigFile,
f.ShowVersion,
f.Quiet,
@@ -123,32 +122,36 @@ func (f *GlobalFlagGroup) flags() []*Flag {
}
func (f *GlobalFlagGroup) AddFlags(cmd *cobra.Command) {
for _, flag := range f.flags() {
addFlag(cmd, flag)
for _, flag := range f.Flags() {
flag.Add(cmd)
}
}
func (f *GlobalFlagGroup) Bind(cmd *cobra.Command) error {
for _, flag := range f.flags() {
if err := bind(cmd, flag); err != nil {
for _, flag := range f.Flags() {
if err := flag.Bind(cmd); err != nil {
return err
}
}
return nil
}
func (f *GlobalFlagGroup) ToOptions() GlobalOptions {
func (f *GlobalFlagGroup) ToOptions() (GlobalOptions, error) {
if err := parseFlags(f); err != nil {
return GlobalOptions{}, err
}
// Keep TRIVY_NON_SSL for backward compatibility
insecure := getBool(f.Insecure) || os.Getenv("TRIVY_NON_SSL") != ""
insecure := f.Insecure.Value() || os.Getenv("TRIVY_NON_SSL") != ""
return GlobalOptions{
ConfigFile: getString(f.ConfigFile),
ShowVersion: getBool(f.ShowVersion),
Quiet: getBool(f.Quiet),
Debug: getBool(f.Debug),
ConfigFile: f.ConfigFile.Value(),
ShowVersion: f.ShowVersion.Value(),
Quiet: f.Quiet.Value(),
Debug: f.Debug.Value(),
Insecure: insecure,
Timeout: getDuration(f.Timeout),
CacheDir: getString(f.CacheDir),
GenerateDefaultConfig: getBool(f.GenerateDefaultConfig),
}
Timeout: f.Timeout.Value(),
CacheDir: f.CacheDir.Value(),
GenerateDefaultConfig: f.GenerateDefaultConfig.Value(),
}, nil
}

View File

@@ -15,41 +15,37 @@ import (
// input: "/path/to/alpine"
var (
ImageConfigScannersFlag = Flag{
ImageConfigScannersFlag = Flag[[]string]{
Name: "image-config-scanners",
ConfigName: "image.image-config-scanners",
Default: []string{},
Values: xstrings.ToStringSlice(types.Scanners{
types.MisconfigScanner,
types.SecretScanner,
}),
Usage: "comma-separated list of what security issues to detect on container image configurations",
}
ScanRemovedPkgsFlag = Flag{
ScanRemovedPkgsFlag = Flag[bool]{
Name: "removed-pkgs",
ConfigName: "image.removed-pkgs",
Default: false,
Usage: "detect vulnerabilities of removed packages (only for Alpine)",
}
InputFlag = Flag{
InputFlag = Flag[string]{
Name: "input",
ConfigName: "image.input",
Default: "",
Usage: "input file path instead of image name",
}
PlatformFlag = Flag{
PlatformFlag = Flag[string]{
Name: "platform",
ConfigName: "image.platform",
Default: "",
Usage: "set platform in the form os/arch if image is multi-platform capable",
}
DockerHostFlag = Flag{
DockerHostFlag = Flag[string]{
Name: "docker-host",
ConfigName: "image.docker.host",
Default: "",
Usage: "unix domain socket path to use for docker scanning",
}
SourceFlag = Flag{
SourceFlag = Flag[[]string]{
Name: "image-src",
ConfigName: "image.source",
Default: xstrings.ToStringSlice(ftypes.AllImageSources),
@@ -59,12 +55,12 @@ var (
)
type ImageFlagGroup struct {
Input *Flag // local image archive
ImageConfigScanners *Flag
ScanRemovedPkgs *Flag
Platform *Flag
DockerHost *Flag
ImageSources *Flag
Input *Flag[string] // local image archive
ImageConfigScanners *Flag[[]string]
ScanRemovedPkgs *Flag[bool]
Platform *Flag[string]
DockerHost *Flag[string]
ImageSources *Flag[[]string]
}
type ImageOptions struct {
@@ -78,12 +74,12 @@ type ImageOptions struct {
func NewImageFlagGroup() *ImageFlagGroup {
return &ImageFlagGroup{
Input: &InputFlag,
ImageConfigScanners: &ImageConfigScannersFlag,
ScanRemovedPkgs: &ScanRemovedPkgsFlag,
Platform: &PlatformFlag,
DockerHost: &DockerHostFlag,
ImageSources: &SourceFlag,
Input: InputFlag.Clone(),
ImageConfigScanners: ImageConfigScannersFlag.Clone(),
ScanRemovedPkgs: ScanRemovedPkgsFlag.Clone(),
Platform: PlatformFlag.Clone(),
DockerHost: DockerHostFlag.Clone(),
ImageSources: SourceFlag.Clone(),
}
}
@@ -91,8 +87,8 @@ func (f *ImageFlagGroup) Name() string {
return "Image"
}
func (f *ImageFlagGroup) Flags() []*Flag {
return []*Flag{
func (f *ImageFlagGroup) Flags() []Flagger {
return []Flagger{
f.Input,
f.ImageConfigScanners,
f.ScanRemovedPkgs,
@@ -103,8 +99,12 @@ func (f *ImageFlagGroup) Flags() []*Flag {
}
func (f *ImageFlagGroup) ToOptions() (ImageOptions, error) {
if err := parseFlags(f); err != nil {
return ImageOptions{}, err
}
var platform ftypes.Platform
if p := getString(f.Platform); p != "" {
if p := f.Platform.Value(); p != "" {
pl, err := v1.ParsePlatform(p)
if err != nil {
return ImageOptions{}, xerrors.Errorf("unable to parse platform: %w", err)
@@ -116,11 +116,11 @@ func (f *ImageFlagGroup) ToOptions() (ImageOptions, error) {
}
return ImageOptions{
Input: getString(f.Input),
ImageConfigScanners: getUnderlyingStringSlice[types.Scanner](f.ImageConfigScanners),
ScanRemovedPkgs: getBool(f.ScanRemovedPkgs),
Input: f.Input.Value(),
ImageConfigScanners: xstrings.ToTSlice[types.Scanner](f.ImageConfigScanners.Value()),
ScanRemovedPkgs: f.ScanRemovedPkgs.Value(),
Platform: platform,
DockerHost: getString(f.DockerHost),
ImageSources: getUnderlyingStringSlice[ftypes.ImageSource](f.ImageSources),
DockerHost: f.DockerHost.Value(),
ImageSources: xstrings.ToTSlice[ftypes.ImageSource](f.ImageSources.Value()),
}, nil
}

View File

@@ -10,29 +10,26 @@ import (
)
var (
ClusterContextFlag = Flag{
ClusterContextFlag = Flag[string]{
Name: "context",
ConfigName: "kubernetes.context",
Default: "",
Usage: "specify a context to scan",
Aliases: []Alias{
{Name: "ctx"},
},
}
K8sNamespaceFlag = Flag{
K8sNamespaceFlag = Flag[string]{
Name: "namespace",
ConfigName: "kubernetes.namespace",
Shorthand: "n",
Default: "",
Usage: "specify a namespace to scan",
}
KubeConfigFlag = Flag{
KubeConfigFlag = Flag[string]{
Name: "kubeconfig",
ConfigName: "kubernetes.kubeconfig",
Default: "",
Usage: "specify the kubeconfig file path to use",
}
ComponentsFlag = Flag{
ComponentsFlag = Flag[[]string]{
Name: "components",
ConfigName: "kubernetes.components",
Default: []string{
@@ -45,56 +42,51 @@ var (
},
Usage: "specify which components to scan",
}
K8sVersionFlag = Flag{
K8sVersionFlag = Flag[string]{
Name: "k8s-version",
ConfigName: "kubernetes.k8s.version",
Default: "",
Usage: "specify k8s version to validate outdated api by it (example: 1.21.0)",
}
TolerationsFlag = Flag{
TolerationsFlag = Flag[[]string]{
Name: "tolerations",
ConfigName: "kubernetes.tolerations",
Default: []string{},
Usage: "specify node-collector job tolerations (example: key1=value1:NoExecute,key2=value2:NoSchedule)",
}
AllNamespaces = Flag{
AllNamespaces = Flag[bool]{
Name: "all-namespaces",
ConfigName: "kubernetes.all.namespaces",
Shorthand: "A",
Default: false,
Usage: "fetch resources from all cluster namespaces",
}
NodeCollectorNamespace = Flag{
NodeCollectorNamespace = Flag[string]{
Name: "node-collector-namespace",
ConfigName: "node.collector.namespace",
Default: "trivy-temp",
Usage: "specify the namespace in which the node-collector job should be deployed",
}
ExcludeOwned = Flag{
ExcludeOwned = Flag[bool]{
Name: "exclude-owned",
ConfigName: "kubernetes.exclude.owned",
Default: false,
Usage: "exclude resources that have an owner reference",
}
ExcludeNodes = Flag{
ExcludeNodes = Flag[[]string]{
Name: "exclude-nodes",
ConfigName: "exclude.nodes",
Default: []string{},
ConfigName: "kubernetes.exclude.nodes",
Usage: "indicate the node labels that the node-collector job should exclude from scanning (example: kubernetes.io/arch:arm64,team:dev)",
}
NodeCollectorImageRef = Flag{
NodeCollectorImageRef = Flag[string]{
Name: "node-collector-imageref",
ConfigName: "node.collector.imageref",
ConfigName: "kubernetes.node.collector.imageref",
Default: "ghcr.io/aquasecurity/node-collector:0.0.9",
Usage: "indicate the image reference for the node-collector scan job",
}
QPS = Flag{
QPS = Flag[float64]{
Name: "qps",
ConfigName: "kubernetes.qps",
Default: 5.0,
Usage: "specify the maximum QPS to the master from this client",
}
Burst = Flag{
Burst = Flag[int]{
Name: "burst",
ConfigName: "kubernetes.burst",
Default: 10,
@@ -103,19 +95,19 @@ var (
)
type K8sFlagGroup struct {
ClusterContext *Flag
Namespace *Flag
KubeConfig *Flag
Components *Flag
K8sVersion *Flag
Tolerations *Flag
NodeCollectorImageRef *Flag
AllNamespaces *Flag
NodeCollectorNamespace *Flag
ExcludeOwned *Flag
ExcludeNodes *Flag
QPS *Flag
Burst *Flag
ClusterContext *Flag[string]
Namespace *Flag[string]
KubeConfig *Flag[string]
Components *Flag[[]string]
K8sVersion *Flag[string]
Tolerations *Flag[[]string]
NodeCollectorImageRef *Flag[string]
AllNamespaces *Flag[bool]
NodeCollectorNamespace *Flag[string]
ExcludeOwned *Flag[bool]
ExcludeNodes *Flag[[]string]
QPS *Flag[float64]
Burst *Flag[int]
}
type K8sOptions struct {
@@ -136,19 +128,19 @@ type K8sOptions struct {
func NewK8sFlagGroup() *K8sFlagGroup {
return &K8sFlagGroup{
ClusterContext: &ClusterContextFlag,
Namespace: &K8sNamespaceFlag,
KubeConfig: &KubeConfigFlag,
Components: &ComponentsFlag,
K8sVersion: &K8sVersionFlag,
Tolerations: &TolerationsFlag,
AllNamespaces: &AllNamespaces,
NodeCollectorNamespace: &NodeCollectorNamespace,
ExcludeOwned: &ExcludeOwned,
ExcludeNodes: &ExcludeNodes,
NodeCollectorImageRef: &NodeCollectorImageRef,
QPS: &QPS,
Burst: &Burst,
ClusterContext: ClusterContextFlag.Clone(),
Namespace: K8sNamespaceFlag.Clone(),
KubeConfig: KubeConfigFlag.Clone(),
Components: ComponentsFlag.Clone(),
K8sVersion: K8sVersionFlag.Clone(),
Tolerations: TolerationsFlag.Clone(),
AllNamespaces: AllNamespaces.Clone(),
NodeCollectorNamespace: NodeCollectorNamespace.Clone(),
ExcludeOwned: ExcludeOwned.Clone(),
ExcludeNodes: ExcludeNodes.Clone(),
NodeCollectorImageRef: NodeCollectorImageRef.Clone(),
QPS: QPS.Clone(),
Burst: Burst.Clone(),
}
}
@@ -156,8 +148,8 @@ func (f *K8sFlagGroup) Name() string {
return "Kubernetes"
}
func (f *K8sFlagGroup) Flags() []*Flag {
return []*Flag{
func (f *K8sFlagGroup) Flags() []Flagger {
return []Flagger{
f.ClusterContext,
f.Namespace,
f.KubeConfig,
@@ -175,13 +167,17 @@ func (f *K8sFlagGroup) Flags() []*Flag {
}
func (f *K8sFlagGroup) ToOptions() (K8sOptions, error) {
tolerations, err := optionToTolerations(getStringSlice(f.Tolerations))
if err := parseFlags(f); err != nil {
return K8sOptions{}, err
}
tolerations, err := optionToTolerations(f.Tolerations.Value())
if err != nil {
return K8sOptions{}, err
}
exludeNodeLabels := make(map[string]string)
exludeNodes := getStringSlice(f.ExcludeNodes)
exludeNodes := f.ExcludeNodes.Value()
for _, exludeNodeValue := range exludeNodes {
excludeNodeParts := strings.Split(exludeNodeValue, ":")
if len(excludeNodeParts) != 2 {
@@ -191,17 +187,19 @@ func (f *K8sFlagGroup) ToOptions() (K8sOptions, error) {
}
return K8sOptions{
ClusterContext: getString(f.ClusterContext),
Namespace: getString(f.Namespace),
KubeConfig: getString(f.KubeConfig),
Components: getStringSlice(f.Components),
K8sVersion: getString(f.K8sVersion),
ClusterContext: f.ClusterContext.Value(),
Namespace: f.Namespace.Value(),
KubeConfig: f.KubeConfig.Value(),
Components: f.Components.Value(),
K8sVersion: f.K8sVersion.Value(),
Tolerations: tolerations,
AllNamespaces: getBool(f.AllNamespaces),
NodeCollectorNamespace: getString(f.NodeCollectorNamespace),
ExcludeOwned: getBool(f.ExcludeOwned),
AllNamespaces: f.AllNamespaces.Value(),
NodeCollectorNamespace: f.NodeCollectorNamespace.Value(),
ExcludeOwned: f.ExcludeOwned.Value(),
ExcludeNodes: exludeNodeLabels,
NodeCollectorImageRef: getString(f.NodeCollectorImageRef),
NodeCollectorImageRef: f.NodeCollectorImageRef.Value(),
QPS: float32(f.QPS.Value()),
Burst: f.Burst.Value(),
}, nil
}

View File

@@ -6,19 +6,17 @@ import (
)
var (
LicenseFull = Flag{
LicenseFull = Flag[bool]{
Name: "license-full",
ConfigName: "license.full",
Default: false,
Usage: "eagerly look for licenses in source code headers and license files",
}
IgnoredLicenses = Flag{
IgnoredLicenses = Flag[[]string]{
Name: "ignored-licenses",
ConfigName: "license.ignored",
Default: []string{},
Usage: "specify a list of license to ignore",
}
LicenseConfidenceLevel = Flag{
LicenseConfidenceLevel = Flag[float64]{
Name: "license-confidence-level",
ConfigName: "license.confidenceLevel",
Default: 0.9,
@@ -26,37 +24,37 @@ var (
}
// LicenseForbidden is an option only in a config file
LicenseForbidden = Flag{
LicenseForbidden = Flag[[]string]{
ConfigName: "license.forbidden",
Default: licensing.ForbiddenLicenses,
Usage: "forbidden licenses",
}
// LicenseRestricted is an option only in a config file
LicenseRestricted = Flag{
LicenseRestricted = Flag[[]string]{
ConfigName: "license.restricted",
Default: licensing.RestrictedLicenses,
Usage: "restricted licenses",
}
// LicenseReciprocal is an option only in a config file
LicenseReciprocal = Flag{
LicenseReciprocal = Flag[[]string]{
ConfigName: "license.reciprocal",
Default: licensing.ReciprocalLicenses,
Usage: "reciprocal licenses",
}
// LicenseNotice is an option only in a config file
LicenseNotice = Flag{
LicenseNotice = Flag[[]string]{
ConfigName: "license.notice",
Default: licensing.NoticeLicenses,
Usage: "notice licenses",
}
// LicensePermissive is an option only in a config file
LicensePermissive = Flag{
LicensePermissive = Flag[[]string]{
ConfigName: "license.permissive",
Default: licensing.PermissiveLicenses,
Usage: "permissive licenses",
}
// LicenseUnencumbered is an option only in a config file
LicenseUnencumbered = Flag{
LicenseUnencumbered = Flag[[]string]{
ConfigName: "license.unencumbered",
Default: licensing.UnencumberedLicenses,
Usage: "unencumbered licenses",
@@ -64,17 +62,17 @@ var (
)
type LicenseFlagGroup struct {
LicenseFull *Flag
IgnoredLicenses *Flag
LicenseConfidenceLevel *Flag
LicenseFull *Flag[bool]
IgnoredLicenses *Flag[[]string]
LicenseConfidenceLevel *Flag[float64]
// License Categories
LicenseForbidden *Flag // mapped to CRITICAL
LicenseRestricted *Flag // mapped to HIGH
LicenseReciprocal *Flag // mapped to MEDIUM
LicenseNotice *Flag // mapped to LOW
LicensePermissive *Flag // mapped to LOW
LicenseUnencumbered *Flag // mapped to LOW
LicenseForbidden *Flag[[]string] // mapped to CRITICAL
LicenseRestricted *Flag[[]string] // mapped to HIGH
LicenseReciprocal *Flag[[]string] // mapped to MEDIUM
LicenseNotice *Flag[[]string] // mapped to LOW
LicensePermissive *Flag[[]string] // mapped to LOW
LicenseUnencumbered *Flag[[]string] // mapped to LOW
}
type LicenseOptions struct {
@@ -87,15 +85,15 @@ type LicenseOptions struct {
func NewLicenseFlagGroup() *LicenseFlagGroup {
return &LicenseFlagGroup{
LicenseFull: &LicenseFull,
IgnoredLicenses: &IgnoredLicenses,
LicenseConfidenceLevel: &LicenseConfidenceLevel,
LicenseForbidden: &LicenseForbidden,
LicenseRestricted: &LicenseRestricted,
LicenseReciprocal: &LicenseReciprocal,
LicenseNotice: &LicenseNotice,
LicensePermissive: &LicensePermissive,
LicenseUnencumbered: &LicenseUnencumbered,
LicenseFull: LicenseFull.Clone(),
IgnoredLicenses: IgnoredLicenses.Clone(),
LicenseConfidenceLevel: LicenseConfidenceLevel.Clone(),
LicenseForbidden: LicenseForbidden.Clone(),
LicenseRestricted: LicenseRestricted.Clone(),
LicenseReciprocal: LicenseReciprocal.Clone(),
LicenseNotice: LicenseNotice.Clone(),
LicensePermissive: LicensePermissive.Clone(),
LicenseUnencumbered: LicenseUnencumbered.Clone(),
}
}
@@ -103,24 +101,37 @@ func (f *LicenseFlagGroup) Name() string {
return "License"
}
func (f *LicenseFlagGroup) Flags() []*Flag {
return []*Flag{f.LicenseFull, f.IgnoredLicenses, f.LicenseForbidden, f.LicenseRestricted, f.LicenseReciprocal,
f.LicenseNotice, f.LicensePermissive, f.LicenseUnencumbered, f.LicenseConfidenceLevel}
}
func (f *LicenseFlagGroup) ToOptions() LicenseOptions {
licenseCategories := make(map[types.LicenseCategory][]string)
licenseCategories[types.CategoryForbidden] = getStringSlice(f.LicenseForbidden)
licenseCategories[types.CategoryRestricted] = getStringSlice(f.LicenseRestricted)
licenseCategories[types.CategoryReciprocal] = getStringSlice(f.LicenseReciprocal)
licenseCategories[types.CategoryNotice] = getStringSlice(f.LicenseNotice)
licenseCategories[types.CategoryPermissive] = getStringSlice(f.LicensePermissive)
licenseCategories[types.CategoryUnencumbered] = getStringSlice(f.LicenseUnencumbered)
return LicenseOptions{
LicenseFull: getBool(f.LicenseFull),
IgnoredLicenses: getStringSlice(f.IgnoredLicenses),
LicenseConfidenceLevel: getFloat(f.LicenseConfidenceLevel),
LicenseCategories: licenseCategories,
func (f *LicenseFlagGroup) Flags() []Flagger {
return []Flagger{
f.LicenseFull,
f.IgnoredLicenses,
f.LicenseForbidden,
f.LicenseRestricted,
f.LicenseReciprocal,
f.LicenseNotice,
f.LicensePermissive,
f.LicenseUnencumbered,
f.LicenseConfidenceLevel,
}
}
func (f *LicenseFlagGroup) ToOptions() (LicenseOptions, error) {
if err := parseFlags(f); err != nil {
return LicenseOptions{}, err
}
licenseCategories := make(map[types.LicenseCategory][]string)
licenseCategories[types.CategoryForbidden] = f.LicenseForbidden.Value()
licenseCategories[types.CategoryRestricted] = f.LicenseRestricted.Value()
licenseCategories[types.CategoryReciprocal] = f.LicenseReciprocal.Value()
licenseCategories[types.CategoryNotice] = f.LicenseNotice.Value()
licenseCategories[types.CategoryPermissive] = f.LicensePermissive.Value()
licenseCategories[types.CategoryUnencumbered] = f.LicenseUnencumbered.Value()
return LicenseOptions{
LicenseFull: f.LicenseFull.Value(),
IgnoredLicenses: f.IgnoredLicenses.Value(),
LicenseConfidenceLevel: f.LicenseConfidenceLevel.Value(),
LicenseCategories: licenseCategories,
}, nil
}

View File

@@ -15,67 +15,59 @@ import (
// config-policy: "custom-policy/policy"
// policy-namespaces: "user"
var (
ResetPolicyBundleFlag = Flag{
ResetPolicyBundleFlag = Flag[bool]{
Name: "reset-policy-bundle",
ConfigName: "misconfiguration.reset-policy-bundle",
Default: false,
Usage: "remove policy bundle",
}
IncludeNonFailuresFlag = Flag{
IncludeNonFailuresFlag = Flag[bool]{
Name: "include-non-failures",
ConfigName: "misconfiguration.include-non-failures",
Default: false,
Usage: "include successes and exceptions, available with '--scanners misconfig'",
}
HelmValuesFileFlag = Flag{
HelmValuesFileFlag = Flag[[]string]{
Name: "helm-values",
ConfigName: "misconfiguration.helm.values",
Default: []string{},
Usage: "specify paths to override the Helm values.yaml files",
}
HelmSetFlag = Flag{
HelmSetFlag = Flag[[]string]{
Name: "helm-set",
ConfigName: "misconfiguration.helm.set",
Default: []string{},
Usage: "specify Helm values on the command line (can specify multiple or separate values with commas: key1=val1,key2=val2)",
}
HelmSetFileFlag = Flag{
HelmSetFileFlag = Flag[[]string]{
Name: "helm-set-file",
ConfigName: "misconfiguration.helm.set-file",
Default: []string{},
Usage: "specify Helm values from respective files specified via the command line (can specify multiple or separate values with commas: key1=path1,key2=path2)",
}
HelmSetStringFlag = Flag{
HelmSetStringFlag = Flag[[]string]{
Name: "helm-set-string",
ConfigName: "misconfiguration.helm.set-string",
Default: []string{},
Usage: "specify Helm string values on the command line (can specify multiple or separate values with commas: key1=val1,key2=val2)",
}
TfVarsFlag = Flag{
TfVarsFlag = Flag[[]string]{
Name: "tf-vars",
ConfigName: "misconfiguration.terraform.vars",
Default: []string{},
Usage: "specify paths to override the Terraform tfvars files",
}
CfParamsFlag = Flag{
CfParamsFlag = Flag[[]string]{
Name: "cf-params",
ConfigName: "misconfiguration.cloudformation.params",
Default: []string{},
Usage: "specify paths to override the CloudFormation parameters files",
}
TerraformExcludeDownloaded = Flag{
TerraformExcludeDownloaded = Flag[bool]{
Name: "tf-exclude-downloaded-modules",
ConfigName: "misconfiguration.terraform.exclude-downloaded-modules",
Default: false,
Usage: "exclude misconfigurations for downloaded terraform modules",
}
PolicyBundleRepositoryFlag = Flag{
PolicyBundleRepositoryFlag = Flag[string]{
Name: "policy-bundle-repository",
ConfigName: "misconfiguration.policy-bundle-repository",
Default: fmt.Sprintf("%s:%d", policy.BundleRepository, policy.BundleVersion),
Usage: "OCI registry URL to retrieve policy bundle from",
}
MisconfigScannersFlag = Flag{
MisconfigScannersFlag = Flag[[]string]{
Name: "misconfig-scanners",
ConfigName: "misconfiguration.scanners",
Default: xstrings.ToStringSlice(analyzer.TypeConfigFiles),
@@ -85,19 +77,19 @@ var (
// MisconfFlagGroup composes common printer flag structs used for commands providing misconfiguration scanning.
type MisconfFlagGroup struct {
IncludeNonFailures *Flag
ResetPolicyBundle *Flag
PolicyBundleRepository *Flag
IncludeNonFailures *Flag[bool]
ResetPolicyBundle *Flag[bool]
PolicyBundleRepository *Flag[string]
// Values Files
HelmValues *Flag
HelmValueFiles *Flag
HelmFileValues *Flag
HelmStringValues *Flag
TerraformTFVars *Flag
CloudformationParamVars *Flag
TerraformExcludeDownloaded *Flag
MisconfigScanners *Flag
HelmValues *Flag[[]string]
HelmValueFiles *Flag[[]string]
HelmFileValues *Flag[[]string]
HelmStringValues *Flag[[]string]
TerraformTFVars *Flag[[]string]
CloudformationParamVars *Flag[[]string]
TerraformExcludeDownloaded *Flag[bool]
MisconfigScanners *Flag[[]string]
}
type MisconfOptions struct {
@@ -118,18 +110,18 @@ type MisconfOptions struct {
func NewMisconfFlagGroup() *MisconfFlagGroup {
return &MisconfFlagGroup{
IncludeNonFailures: &IncludeNonFailuresFlag,
ResetPolicyBundle: &ResetPolicyBundleFlag,
PolicyBundleRepository: &PolicyBundleRepositoryFlag,
IncludeNonFailures: IncludeNonFailuresFlag.Clone(),
ResetPolicyBundle: ResetPolicyBundleFlag.Clone(),
PolicyBundleRepository: PolicyBundleRepositoryFlag.Clone(),
HelmValues: &HelmSetFlag,
HelmFileValues: &HelmSetFileFlag,
HelmStringValues: &HelmSetStringFlag,
HelmValueFiles: &HelmValuesFileFlag,
TerraformTFVars: &TfVarsFlag,
CloudformationParamVars: &CfParamsFlag,
TerraformExcludeDownloaded: &TerraformExcludeDownloaded,
MisconfigScanners: &MisconfigScannersFlag,
HelmValues: HelmSetFlag.Clone(),
HelmFileValues: HelmSetFileFlag.Clone(),
HelmStringValues: HelmSetStringFlag.Clone(),
HelmValueFiles: HelmValuesFileFlag.Clone(),
TerraformTFVars: TfVarsFlag.Clone(),
CloudformationParamVars: CfParamsFlag.Clone(),
TerraformExcludeDownloaded: TerraformExcludeDownloaded.Clone(),
MisconfigScanners: MisconfigScannersFlag.Clone(),
}
}
@@ -137,8 +129,8 @@ func (f *MisconfFlagGroup) Name() string {
return "Misconfiguration"
}
func (f *MisconfFlagGroup) Flags() []*Flag {
return []*Flag{
func (f *MisconfFlagGroup) Flags() []Flagger {
return []Flagger{
f.IncludeNonFailures,
f.ResetPolicyBundle,
f.PolicyBundleRepository,
@@ -154,17 +146,21 @@ func (f *MisconfFlagGroup) Flags() []*Flag {
}
func (f *MisconfFlagGroup) ToOptions() (MisconfOptions, error) {
if err := parseFlags(f); err != nil {
return MisconfOptions{}, err
}
return MisconfOptions{
IncludeNonFailures: getBool(f.IncludeNonFailures),
ResetPolicyBundle: getBool(f.ResetPolicyBundle),
PolicyBundleRepository: getString(f.PolicyBundleRepository),
HelmValues: getStringSlice(f.HelmValues),
HelmValueFiles: getStringSlice(f.HelmValueFiles),
HelmFileValues: getStringSlice(f.HelmFileValues),
HelmStringValues: getStringSlice(f.HelmStringValues),
TerraformTFVars: getStringSlice(f.TerraformTFVars),
CloudFormationParamVars: getStringSlice(f.CloudformationParamVars),
TfExcludeDownloaded: getBool(f.TerraformExcludeDownloaded),
MisconfigScanners: getUnderlyingStringSlice[analyzer.Type](f.MisconfigScanners),
IncludeNonFailures: f.IncludeNonFailures.Value(),
ResetPolicyBundle: f.ResetPolicyBundle.Value(),
PolicyBundleRepository: f.PolicyBundleRepository.Value(),
HelmValues: f.HelmValues.Value(),
HelmValueFiles: f.HelmValueFiles.Value(),
HelmFileValues: f.HelmFileValues.Value(),
HelmStringValues: f.HelmStringValues.Value(),
TerraformTFVars: f.TerraformTFVars.Value(),
CloudFormationParamVars: f.CloudformationParamVars.Value(),
TfExcludeDownloaded: f.TerraformExcludeDownloaded.Value(),
MisconfigScanners: xstrings.ToTSlice[analyzer.Type](f.MisconfigScanners.Value()),
}, nil
}

View File

@@ -11,14 +11,14 @@ import (
// - spring4shell
var (
ModuleDirFlag = Flag{
ModuleDirFlag = Flag[string]{
Name: "module-dir",
ConfigName: "module.dir",
Default: module.DefaultDir,
Usage: "specify directory to the wasm modules that will be loaded",
Persistent: true,
}
EnableModulesFlag = Flag{
EnableModulesFlag = Flag[[]string]{
Name: "enable-modules",
ConfigName: "module.enable-modules",
Default: []string{},
@@ -29,8 +29,8 @@ var (
// ModuleFlagGroup defines flags for modules
type ModuleFlagGroup struct {
Dir *Flag
EnabledModules *Flag
Dir *Flag[string]
EnabledModules *Flag[[]string]
}
type ModuleOptions struct {
@@ -40,8 +40,8 @@ type ModuleOptions struct {
func NewModuleFlagGroup() *ModuleFlagGroup {
return &ModuleFlagGroup{
Dir: &ModuleDirFlag,
EnabledModules: &EnableModulesFlag,
Dir: ModuleDirFlag.Clone(),
EnabledModules: EnableModulesFlag.Clone(),
}
}
@@ -49,16 +49,20 @@ func (f *ModuleFlagGroup) Name() string {
return "Module"
}
func (f *ModuleFlagGroup) Flags() []*Flag {
return []*Flag{
func (f *ModuleFlagGroup) Flags() []Flagger {
return []Flagger{
f.Dir,
f.EnabledModules,
}
}
func (f *ModuleFlagGroup) ToOptions() ModuleOptions {
return ModuleOptions{
ModuleDir: getString(f.Dir),
EnabledModules: getStringSlice(f.EnabledModules),
func (f *ModuleFlagGroup) ToOptions() (ModuleOptions, error) {
if err := parseFlags(f); err != nil {
return ModuleOptions{}, err
}
return ModuleOptions{
ModuleDir: f.Dir.Value(),
EnabledModules: f.EnabledModules.Value(),
}, nil
}

View File

@@ -9,10 +9,12 @@ import (
"sync"
"time"
"github.com/samber/lo"
"github.com/spf13/cast"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"golang.org/x/exp/slices"
"golang.org/x/xerrors"
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
@@ -22,10 +24,13 @@ import (
"github.com/aquasecurity/trivy/pkg/result"
"github.com/aquasecurity/trivy/pkg/types"
"github.com/aquasecurity/trivy/pkg/version"
xstrings "github.com/aquasecurity/trivy/pkg/x/strings"
)
type Flag struct {
type FlagType interface {
int | string | []string | bool | time.Duration | float64
}
type Flag[T FlagType] struct {
// Name is for CLI flag and environment variable.
// If this field is empty, it will be available only in config file.
Name string
@@ -36,8 +41,8 @@ type Flag struct {
// Shorthand is a shorthand letter.
Shorthand string
// Default is the default value. It must be filled to determine the flag type.
Default any
// Default is the default value. It should be defined when the value is different from the zero value.
Default T
// Values is a list of allowed values.
// It currently supports string flags and string slice flags only.
@@ -45,7 +50,7 @@ type Flag struct {
// ValueNormalize is a function to normalize the value.
// It can be used for aliases, etc.
ValueNormalize func(string) string
ValueNormalize func(T) T
// Usage explains how to use the flag.
Usage string
@@ -58,6 +63,10 @@ type Flag struct {
// Aliases represents aliases
Aliases []Alias
// value is the value passed through CLI flag, env, or config file.
// It is populated after flag.Parse() is called.
value T
}
type Alias struct {
@@ -66,12 +75,230 @@ type Alias struct {
Deprecated bool
}
func (f *Flag[T]) Clone() *Flag[T] {
var t T
ff := *f
ff.value = t
fff := &ff
return fff
}
func (f *Flag[T]) Parse() error {
if f == nil {
return nil
}
v := f.parse()
if v == nil {
f.value = lo.Empty[T]()
return nil
}
value, ok := f.cast(v).(T)
if !ok {
return xerrors.Errorf("failed to parse flag %s", f.Name)
}
if f.ValueNormalize != nil {
value = f.ValueNormalize(value)
}
if f.isSet() && !f.allowedValue(value) {
return xerrors.Errorf(`invalid argument "%s" for "--%s" flag: must be one of %q`, value, f.Name, f.Values)
}
f.value = value
return nil
}
func (f *Flag[T]) parse() any {
// First, looks for aliases in config file (trivy.yaml).
// Note that viper.RegisterAlias cannot be used for this purpose.
var v any
for _, alias := range f.Aliases {
if alias.ConfigName == "" {
continue
}
v = viper.Get(alias.ConfigName)
if v != nil {
log.Logger.Warnf("'%s' in config file is deprecated. Use '%s' instead.", alias.ConfigName, f.ConfigName)
return v
}
}
return viper.Get(f.ConfigName)
}
// cast converts the value to the type of the flag.
func (f *Flag[T]) cast(val any) any {
switch any(f.Default).(type) {
case bool:
return cast.ToBool(val)
case string:
return cast.ToString(val)
case int:
return cast.ToInt(val)
case float64, float32:
return cast.ToFloat64(val)
case time.Duration:
return cast.ToDuration(val)
case []string:
if s, ok := val.(string); ok && strings.Contains(s, ",") {
// Split environmental variables by comma as it is not done by viper.
// cf. https://github.com/spf13/viper/issues/380
// It is split by spaces only.
// https://github.com/spf13/cast/blob/48ddde5701366ade1d3aba346e09bb58430d37c6/caste.go#L1296-L1297
val = strings.Split(s, ",")
}
return cast.ToStringSlice(val)
}
return val
}
func (f *Flag[T]) isSet() bool {
configNames := lo.FilterMap(f.Aliases, func(alias Alias, _ int) (string, bool) {
return alias.ConfigName, alias.ConfigName != ""
})
configNames = append(configNames, f.ConfigName)
return lo.SomeBy(configNames, viper.IsSet)
}
func (f *Flag[T]) allowedValue(v any) bool {
if len(f.Values) == 0 {
return true
}
switch value := v.(type) {
case string:
return slices.Contains(f.Values, value)
case []string:
for _, v := range value {
if !slices.Contains(f.Values, v) {
return false
}
}
}
return true
}
func (f *Flag[T]) GetName() string {
return f.Name
}
func (f *Flag[T]) GetAliases() []Alias {
return f.Aliases
}
func (f *Flag[T]) Value() (t T) {
if f == nil {
return t
}
return f.value
}
func (f *Flag[T]) Add(cmd *cobra.Command) {
if f == nil || f.Name == "" {
return
}
var flags *pflag.FlagSet
if f.Persistent {
flags = cmd.PersistentFlags()
} else {
flags = cmd.Flags()
}
switch v := any(f.Default).(type) {
case int:
flags.IntP(f.Name, f.Shorthand, v, f.Usage)
case string:
usage := f.Usage
if len(f.Values) > 0 {
usage += fmt.Sprintf(" (%s)", strings.Join(f.Values, ","))
}
flags.StringP(f.Name, f.Shorthand, v, usage)
case []string:
usage := f.Usage
if len(f.Values) > 0 {
usage += fmt.Sprintf(" (%s)", strings.Join(f.Values, ","))
}
flags.StringSliceP(f.Name, f.Shorthand, v, usage)
case bool:
flags.BoolP(f.Name, f.Shorthand, v, f.Usage)
case time.Duration:
flags.DurationP(f.Name, f.Shorthand, v, f.Usage)
case float64:
flags.Float64P(f.Name, f.Shorthand, v, f.Usage)
}
if f.Deprecated {
flags.MarkHidden(f.Name) // nolint: gosec
}
}
func (f *Flag[T]) Bind(cmd *cobra.Command) error {
if f == nil {
return nil
} else if f.Name == "" {
// This flag is available only in trivy.yaml
viper.SetDefault(f.ConfigName, f.Default)
return nil
}
// Bind CLI flags
flag := cmd.Flags().Lookup(f.Name)
if f == nil {
// Lookup local persistent flags
flag = cmd.PersistentFlags().Lookup(f.Name)
}
if err := viper.BindPFlag(f.ConfigName, flag); err != nil {
return xerrors.Errorf("bind flag error: %w", err)
}
// Bind environmental variable
if err := f.BindEnv(); err != nil {
return err
}
return nil
}
func (f *Flag[T]) BindEnv() error {
// We don't use viper.AutomaticEnv, so we need to add a prefix manually here.
envName := strings.ToUpper("trivy_" + strings.ReplaceAll(f.Name, "-", "_"))
if err := viper.BindEnv(f.ConfigName, envName); err != nil {
return xerrors.Errorf("bind env error: %w", err)
}
// Bind env aliases
for _, alias := range f.Aliases {
envAlias := strings.ToUpper("trivy_" + strings.ReplaceAll(alias.Name, "-", "_"))
if err := viper.BindEnv(f.ConfigName, envAlias); err != nil {
return xerrors.Errorf("bind env error: %w", err)
}
if alias.Deprecated {
if _, ok := os.LookupEnv(envAlias); ok {
log.Logger.Warnf("'%s' is deprecated. Use '%s' instead.", envAlias, envName)
}
}
}
return nil
}
type FlagGroup interface {
Name() string
Flags() []*Flag
Flags() []Flagger
}
type Flagger interface {
GetName() string
GetAliases() []Alias
Parse() error
Add(cmd *cobra.Command)
Bind(cmd *cobra.Command) error
}
type Flags struct {
GlobalFlagGroup *GlobalFlagGroup
AWSFlagGroup *AWSFlagGroup
CacheFlagGroup *CacheFlagGroup
CloudFlagGroup *CloudFlagGroup
@@ -217,163 +444,7 @@ func (o *Options) outputPluginWriter(ctx context.Context) (io.Writer, func() err
return pw, cleanup, nil
}
func addFlag(cmd *cobra.Command, flag *Flag) {
if flag == nil || flag.Name == "" {
return
}
var flags *pflag.FlagSet
if flag.Persistent {
flags = cmd.PersistentFlags()
} else {
flags = cmd.Flags()
}
switch v := flag.Default.(type) {
case int:
flags.IntP(flag.Name, flag.Shorthand, v, flag.Usage)
case string:
usage := flag.Usage
if len(flag.Values) > 0 {
usage += fmt.Sprintf(" (%s)", strings.Join(flag.Values, ","))
}
flags.VarP(newCustomStringValue(v, flag.Values, flag.ValueNormalize), flag.Name, flag.Shorthand, usage)
case []string:
usage := flag.Usage
if len(flag.Values) > 0 {
usage += fmt.Sprintf(" (%s)", strings.Join(flag.Values, ","))
}
flags.VarP(newCustomStringSliceValue(v, flag.Values, flag.ValueNormalize), flag.Name, flag.Shorthand, usage)
case bool:
flags.BoolP(flag.Name, flag.Shorthand, v, flag.Usage)
case time.Duration:
flags.DurationP(flag.Name, flag.Shorthand, v, flag.Usage)
case float64:
flags.Float64P(flag.Name, flag.Shorthand, v, flag.Usage)
}
if flag.Deprecated {
flags.MarkHidden(flag.Name) // nolint: gosec
}
}
func bind(cmd *cobra.Command, flag *Flag) error {
if flag == nil {
return nil
} else if flag.Name == "" {
// This flag is available only in trivy.yaml
viper.SetDefault(flag.ConfigName, flag.Default)
return nil
}
// Bind CLI flags
f := cmd.Flags().Lookup(flag.Name)
if f == nil {
// Lookup local persistent flags
f = cmd.PersistentFlags().Lookup(flag.Name)
}
if err := viper.BindPFlag(flag.ConfigName, f); err != nil {
return xerrors.Errorf("bind flag error: %w", err)
}
// Bind environmental variable
if err := bindEnv(flag); err != nil {
return err
}
return nil
}
func bindEnv(flag *Flag) error {
// We don't use viper.AutomaticEnv, so we need to add a prefix manually here.
envName := strings.ToUpper("trivy_" + strings.ReplaceAll(flag.Name, "-", "_"))
if err := viper.BindEnv(flag.ConfigName, envName); err != nil {
return xerrors.Errorf("bind env error: %w", err)
}
// Bind env aliases
for _, alias := range flag.Aliases {
envAlias := strings.ToUpper("trivy_" + strings.ReplaceAll(alias.Name, "-", "_"))
if err := viper.BindEnv(flag.ConfigName, envAlias); err != nil {
return xerrors.Errorf("bind env error: %w", err)
}
if alias.Deprecated {
if _, ok := os.LookupEnv(envAlias); ok {
log.Logger.Warnf("'%s' is deprecated. Use '%s' instead.", envAlias, envName)
}
}
}
return nil
}
func getString(flag *Flag) string {
return cast.ToString(getValue(flag))
}
func getUnderlyingString[T xstrings.String](flag *Flag) T {
s := getString(flag)
return T(s)
}
func getStringSlice(flag *Flag) []string {
// viper always returns a string for ENV
// https://github.com/spf13/viper/blob/419fd86e49ef061d0d33f4d1d56d5e2a480df5bb/viper.go#L545-L553
// and uses strings.Field to separate values (whitespace only)
// we need to separate env values with ','
v := cast.ToStringSlice(getValue(flag))
switch {
case len(v) == 0: // no strings
return nil
case len(v) == 1 && strings.Contains(v[0], ","): // unseparated string
v = strings.Split(v[0], ",")
}
return v
}
func getUnderlyingStringSlice[T xstrings.String](flag *Flag) []T {
ss := getStringSlice(flag)
if len(ss) == 0 {
return nil
}
return xstrings.ToTSlice[T](ss)
}
func getInt(flag *Flag) int {
return cast.ToInt(getValue(flag))
}
func getFloat(flag *Flag) float64 {
return cast.ToFloat64(getValue(flag))
}
func getBool(flag *Flag) bool {
return cast.ToBool(getValue(flag))
}
func getDuration(flag *Flag) time.Duration {
return cast.ToDuration(getValue(flag))
}
func getValue(flag *Flag) any {
if flag == nil {
return nil
}
// First, looks for aliases in config file (trivy.yaml).
// Note that viper.RegisterAlias cannot be used for this purpose.
var v any
for _, alias := range flag.Aliases {
if alias.ConfigName == "" {
continue
}
v = viper.Get(alias.ConfigName)
if v != nil {
log.Logger.Warnf("'%s' in config file is deprecated. Use '%s' instead.", alias.ConfigName, flag.ConfigName)
return v
}
}
return viper.Get(flag.ConfigName)
}
// groups returns all the flag groups other than global flags
func (f *Flags) groups() []FlagGroup {
var groups []FlagGroup
// This order affects the usage message, so they are sorted by frequency of use.
@@ -438,7 +509,11 @@ func (f *Flags) AddFlags(cmd *cobra.Command) {
aliases := make(flagAliases)
for _, group := range f.groups() {
for _, flag := range group.Flags() {
addFlag(cmd, flag)
if lo.IsNil(flag) || flag.GetName() == "" {
continue
}
// Register the CLI flag
flag.Add(cmd)
// Register flag aliases
aliases.Add(flag)
@@ -451,14 +526,13 @@ func (f *Flags) AddFlags(cmd *cobra.Command) {
func (f *Flags) Usages(cmd *cobra.Command) string {
var usages string
for _, group := range f.groups() {
flags := pflag.NewFlagSet(cmd.Name(), pflag.ContinueOnError)
lflags := cmd.LocalFlags()
for _, flag := range group.Flags() {
if flag == nil || flag.Name == "" {
if lo.IsNil(flag) || flag.GetName() == "" {
continue
}
flags.AddFlag(lflags.Lookup(flag.Name))
flags.AddFlag(lflags.Lookup(flag.GetName()))
}
if !flags.HasAvailableFlags() {
continue
@@ -476,7 +550,7 @@ func (f *Flags) Bind(cmd *cobra.Command) error {
continue
}
for _, flag := range group.Flags() {
if err := bind(cmd, flag); err != nil {
if err := flag.Bind(cmd); err != nil {
return xerrors.Errorf("flag groups: %w", err)
}
}
@@ -485,19 +559,31 @@ func (f *Flags) Bind(cmd *cobra.Command) error {
}
// nolint: gocyclo
func (f *Flags) ToOptions(args []string, globalFlags *GlobalFlagGroup) (Options, error) {
func (f *Flags) ToOptions(args []string) (Options, error) {
var err error
opts := Options{
AppVersion: version.AppVersion(),
GlobalOptions: globalFlags.ToOptions(),
AppVersion: version.AppVersion(),
}
if f.GlobalFlagGroup != nil {
opts.GlobalOptions, err = f.GlobalFlagGroup.ToOptions()
if err != nil {
return Options{}, xerrors.Errorf("global flag error: %w", err)
}
}
if f.AWSFlagGroup != nil {
opts.AWSOptions = f.AWSFlagGroup.ToOptions()
opts.AWSOptions, err = f.AWSFlagGroup.ToOptions()
if err != nil {
return Options{}, xerrors.Errorf("aws flag error: %w", err)
}
}
if f.CloudFlagGroup != nil {
opts.CloudOptions = f.CloudFlagGroup.ToOptions()
opts.CloudOptions, err = f.CloudFlagGroup.ToOptions()
if err != nil {
return Options{}, xerrors.Errorf("cloud flag error: %w", err)
}
}
if f.CacheFlagGroup != nil {
@@ -510,7 +596,7 @@ func (f *Flags) ToOptions(args []string, globalFlags *GlobalFlagGroup) (Options,
if f.DBFlagGroup != nil {
opts.DBOptions, err = f.DBFlagGroup.ToOptions()
if err != nil {
return Options{}, xerrors.Errorf("flag error: %w", err)
return Options{}, xerrors.Errorf("db flag error: %w", err)
}
}
@@ -529,7 +615,10 @@ func (f *Flags) ToOptions(args []string, globalFlags *GlobalFlagGroup) (Options,
}
if f.LicenseFlagGroup != nil {
opts.LicenseOptions = f.LicenseFlagGroup.ToOptions()
opts.LicenseOptions, err = f.LicenseFlagGroup.ToOptions()
if err != nil {
return Options{}, xerrors.Errorf("license flag error: %w", err)
}
}
if f.MisconfFlagGroup != nil {
@@ -540,7 +629,10 @@ func (f *Flags) ToOptions(args []string, globalFlags *GlobalFlagGroup) (Options,
}
if f.ModuleFlagGroup != nil {
opts.ModuleOptions = f.ModuleFlagGroup.ToOptions()
opts.ModuleOptions, err = f.ModuleFlagGroup.ToOptions()
if err != nil {
return Options{}, xerrors.Errorf("module flag error: %w", err)
}
}
if f.RegoFlagGroup != nil {
@@ -551,7 +643,10 @@ func (f *Flags) ToOptions(args []string, globalFlags *GlobalFlagGroup) (Options,
}
if f.RemoteFlagGroup != nil {
opts.RemoteOptions = f.RemoteFlagGroup.ToOptions()
opts.RemoteOptions, err = f.RemoteFlagGroup.ToOptions()
if err != nil {
return Options{}, xerrors.Errorf("remote flag error: %w", err)
}
}
if f.RegistryFlagGroup != nil {
@@ -562,7 +657,10 @@ func (f *Flags) ToOptions(args []string, globalFlags *GlobalFlagGroup) (Options,
}
if f.RepoFlagGroup != nil {
opts.RepoOptions = f.RepoFlagGroup.ToOptions()
opts.RepoOptions, err = f.RepoFlagGroup.ToOptions()
if err != nil {
return Options{}, xerrors.Errorf("rego flag error: %w", err)
}
}
if f.ReportFlagGroup != nil {
@@ -587,11 +685,17 @@ func (f *Flags) ToOptions(args []string, globalFlags *GlobalFlagGroup) (Options,
}
if f.SecretFlagGroup != nil {
opts.SecretOptions = f.SecretFlagGroup.ToOptions()
opts.SecretOptions, err = f.SecretFlagGroup.ToOptions()
if err != nil {
return Options{}, xerrors.Errorf("secret flag error: %w", err)
}
}
if f.VulnerabilityFlagGroup != nil {
opts.VulnerabilityOptions = f.VulnerabilityFlagGroup.ToOptions()
opts.VulnerabilityOptions, err = f.VulnerabilityFlagGroup.ToOptions()
if err != nil {
return Options{}, xerrors.Errorf("vulnerability flag error: %w", err)
}
}
opts.Align()
@@ -599,6 +703,15 @@ func (f *Flags) ToOptions(args []string, globalFlags *GlobalFlagGroup) (Options,
return opts, nil
}
func parseFlags(fg FlagGroup) error {
for _, flag := range fg.Flags() {
if err := flag.Parse(); err != nil {
return xerrors.Errorf("unable to parse flag: %w", err)
}
}
return nil
}
type flagAlias struct {
formalName string
deprecated bool
@@ -608,13 +721,10 @@ type flagAlias struct {
// flagAliases have aliases for CLI flags
type flagAliases map[string]*flagAlias
func (a flagAliases) Add(flag *Flag) {
if flag == nil {
return
}
for _, alias := range flag.Aliases {
func (a flagAliases) Add(flag Flagger) {
for _, alias := range flag.GetAliases() {
a[alias.Name] = &flagAlias{
formalName: flag.Name,
formalName: flag.GetName(),
deprecated: alias.Deprecated,
}
}

View File

@@ -1,82 +1,127 @@
package flag
package flag_test
import (
"github.com/aquasecurity/trivy/pkg/flag"
"github.com/aquasecurity/trivy/pkg/types"
"github.com/samber/lo"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
"testing"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/aquasecurity/trivy/pkg/types"
)
func Test_getStringSlice(t *testing.T) {
type env struct {
func TestFlag_Parse(t *testing.T) {
type kv struct {
key string
value string
value any
}
tests := []struct {
name string
flag *Flag
flagValue interface{}
env env
want []string
name string
flag *kv
env *kv
want []string
wantErr string
}{
{
name: "happy path. Empty value",
flag: &ScannersFlag,
flagValue: "",
want: nil,
},
{
name: "happy path. String value",
flag: &ScannersFlag,
flagValue: "license,vuln",
name: "flag, string slice",
flag: &kv{
key: "scan.scanners",
value: []string{
"vuln",
"misconfig",
},
},
want: []string{
string(types.LicenseScanner),
string(types.VulnerabilityScanner),
},
},
{
name: "happy path. Slice value",
flag: &ScannersFlag,
flagValue: []string{
"license",
"secret",
},
want: []string{
string(types.LicenseScanner),
string(types.SecretScanner),
},
},
{
name: "happy path. Env value",
flag: &ScannersFlag,
env: env{
key: "TRIVY_SECURITY_CHECKS",
value: "rbac,misconfig",
},
want: []string{
string(types.RBACScanner),
string(types.MisconfigScanner),
},
},
{
name: "env, string",
env: &kv{
key: "TRIVY_SCANNERS",
value: "vuln,misconfig",
},
want: []string{
string(types.VulnerabilityScanner),
string(types.MisconfigScanner),
},
},
{
name: "flag, alias",
flag: &kv{
key: "scan.security-checks",
value: "vulnerability,config",
},
want: []string{
string(types.VulnerabilityScanner),
string(types.MisconfigScanner),
},
},
{
name: "env, alias",
env: &kv{
key: "TRIVY_SECURITY_CHECKS",
value: "vulnerability,config",
},
want: []string{
string(types.VulnerabilityScanner),
string(types.MisconfigScanner),
},
},
{
name: "flag, invalid value",
flag: &kv{
key: "scan.scanners",
value: "vuln,invalid",
},
wantErr: `invalid argument "[vuln invalid]" for "--scanners" flag`,
},
{
name: "env, invalid value",
env: &kv{
key: "TRIVY_SCANNERS",
value: "vuln,invalid",
},
wantErr: `invalid argument "[vuln invalid]" for "--scanners" flag`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.env.key == "" {
viper.Set(tt.flag.ConfigName, tt.flagValue)
} else {
err := viper.BindEnv(tt.flag.ConfigName, tt.env.key)
assert.NoError(t, err)
t.Cleanup(viper.Reset)
t.Setenv(tt.env.key, tt.env.value)
if tt.flag != nil {
viper.Set(tt.flag.key, tt.flag.value)
} else {
t.Setenv(tt.env.key, tt.env.value.(string))
}
sl := getStringSlice(tt.flag)
assert.Equal(t, tt.want, sl)
app := &cobra.Command{}
f := flag.ScannersFlag.Clone()
f.Add(app)
require.NoError(t, f.Bind(app))
viper.Reset()
err := f.Parse()
if tt.wantErr != "" {
require.ErrorContains(t, err, tt.wantErr)
return
}
require.NoError(t, err)
require.Equal(t, tt.want, f.Value())
})
}
}
func setValue[T comparable](key string, value T) {
if !lo.IsEmpty(value) {
viper.Set(key, value)
}
}
func setSliceValue[T any](key string, value []T) {
if len(value) > 0 {
viper.Set(key, value)
}
}

View File

@@ -9,30 +9,27 @@ import (
)
var (
UsernameFlag = Flag{
UsernameFlag = Flag[[]string]{
Name: "username",
ConfigName: "registry.username",
Default: []string{},
Usage: "username. Comma-separated usernames allowed.",
}
PasswordFlag = Flag{
PasswordFlag = Flag[[]string]{
Name: "password",
ConfigName: "registry.password",
Default: []string{},
Usage: "password. Comma-separated passwords allowed. TRIVY_PASSWORD should be used for security reasons.",
}
RegistryTokenFlag = Flag{
RegistryTokenFlag = Flag[string]{
Name: "registry-token",
ConfigName: "registry.token",
Default: "",
Usage: "registry token",
}
)
type RegistryFlagGroup struct {
Username *Flag
Password *Flag
RegistryToken *Flag
Username *Flag[[]string]
Password *Flag[[]string]
RegistryToken *Flag[string]
}
type RegistryOptions struct {
@@ -42,9 +39,9 @@ type RegistryOptions struct {
func NewRegistryFlagGroup() *RegistryFlagGroup {
return &RegistryFlagGroup{
Username: &UsernameFlag,
Password: &PasswordFlag,
RegistryToken: &RegistryTokenFlag,
Username: UsernameFlag.Clone(),
Password: PasswordFlag.Clone(),
RegistryToken: RegistryTokenFlag.Clone(),
}
}
@@ -52,8 +49,8 @@ func (f *RegistryFlagGroup) Name() string {
return "Registry"
}
func (f *RegistryFlagGroup) Flags() []*Flag {
return []*Flag{
func (f *RegistryFlagGroup) Flags() []Flagger {
return []Flagger{
f.Username,
f.Password,
f.RegistryToken,
@@ -61,9 +58,13 @@ func (f *RegistryFlagGroup) Flags() []*Flag {
}
func (f *RegistryFlagGroup) ToOptions() (RegistryOptions, error) {
if err := parseFlags(f); err != nil {
return RegistryOptions{}, err
}
var credentials []types.Credential
users := getStringSlice(f.Username)
passwords := getStringSlice(f.Password)
users := f.Username.Value()
passwords := f.Password.Value()
if len(users) != len(passwords) {
return RegistryOptions{}, xerrors.New("the length of usernames and passwords must match")
}
@@ -76,6 +77,6 @@ func (f *RegistryFlagGroup) ToOptions() (RegistryOptions, error) {
return RegistryOptions{
Credentials: credentials,
RegistryToken: getString(f.RegistryToken),
RegistryToken: f.RegistryToken.Value(),
}, nil
}

View File

@@ -7,40 +7,35 @@ package flag
// config-policy: "custom-policy/policy"
// policy-namespaces: "user"
var (
SkipPolicyUpdateFlag = Flag{
SkipPolicyUpdateFlag = Flag[bool]{
Name: "skip-policy-update",
ConfigName: "rego.skip-policy-update",
Default: false,
Usage: "skip fetching rego policy updates",
}
TraceFlag = Flag{
TraceFlag = Flag[bool]{
Name: "trace",
ConfigName: "rego.trace",
Default: false,
Usage: "enable more verbose trace output for custom queries",
}
ConfigPolicyFlag = Flag{
ConfigPolicyFlag = Flag[[]string]{
Name: "config-policy",
ConfigName: "rego.policy",
Default: []string{},
Usage: "specify the paths to the Rego policy files or to the directories containing them, applying config files",
Aliases: []Alias{
{Name: "policy"},
},
}
ConfigDataFlag = Flag{
ConfigDataFlag = Flag[[]string]{
Name: "config-data",
ConfigName: "rego.data",
Default: []string{},
Usage: "specify paths from which data for the Rego policies will be recursively loaded",
Aliases: []Alias{
{Name: "data"},
},
}
PolicyNamespaceFlag = Flag{
PolicyNamespaceFlag = Flag[[]string]{
Name: "policy-namespaces",
ConfigName: "rego.namespaces",
Default: []string{},
Usage: "Rego namespaces",
Aliases: []Alias{
{Name: "namespaces"},
@@ -50,11 +45,11 @@ var (
// RegoFlagGroup composes common printer flag structs used for commands providing misconfinguration scanning.
type RegoFlagGroup struct {
SkipPolicyUpdate *Flag
Trace *Flag
PolicyPaths *Flag
DataPaths *Flag
PolicyNamespaces *Flag
SkipPolicyUpdate *Flag[bool]
Trace *Flag[bool]
PolicyPaths *Flag[[]string]
DataPaths *Flag[[]string]
PolicyNamespaces *Flag[[]string]
}
type RegoOptions struct {
@@ -67,11 +62,11 @@ type RegoOptions struct {
func NewRegoFlagGroup() *RegoFlagGroup {
return &RegoFlagGroup{
SkipPolicyUpdate: &SkipPolicyUpdateFlag,
Trace: &TraceFlag,
PolicyPaths: &ConfigPolicyFlag,
DataPaths: &ConfigDataFlag,
PolicyNamespaces: &PolicyNamespaceFlag,
SkipPolicyUpdate: SkipPolicyUpdateFlag.Clone(),
Trace: TraceFlag.Clone(),
PolicyPaths: ConfigPolicyFlag.Clone(),
DataPaths: ConfigDataFlag.Clone(),
PolicyNamespaces: PolicyNamespaceFlag.Clone(),
}
}
@@ -79,8 +74,8 @@ func (f *RegoFlagGroup) Name() string {
return "Rego"
}
func (f *RegoFlagGroup) Flags() []*Flag {
return []*Flag{
func (f *RegoFlagGroup) Flags() []Flagger {
return []Flagger{
f.SkipPolicyUpdate,
f.Trace,
f.PolicyPaths,
@@ -90,11 +85,15 @@ func (f *RegoFlagGroup) Flags() []*Flag {
}
func (f *RegoFlagGroup) ToOptions() (RegoOptions, error) {
if err := parseFlags(f); err != nil {
return RegoOptions{}, err
}
return RegoOptions{
SkipPolicyUpdate: getBool(f.SkipPolicyUpdate),
Trace: getBool(f.Trace),
PolicyPaths: getStringSlice(f.PolicyPaths),
DataPaths: getStringSlice(f.DataPaths),
PolicyNamespaces: getStringSlice(f.PolicyNamespaces),
SkipPolicyUpdate: f.SkipPolicyUpdate.Value(),
Trace: f.Trace.Value(),
PolicyPaths: f.PolicyPaths.Value(),
DataPaths: f.DataPaths.Value(),
PolicyNamespaces: f.PolicyNamespaces.Value(),
}, nil
}

View File

@@ -12,31 +12,28 @@ const (
)
var (
ServerTokenFlag = Flag{
ServerTokenFlag = Flag[string]{
Name: "token",
ConfigName: "server.token",
Default: "",
Usage: "for authentication in client/server mode",
}
ServerTokenHeaderFlag = Flag{
ServerTokenHeaderFlag = Flag[string]{
Name: "token-header",
ConfigName: "server.token-header",
Default: DefaultTokenHeader,
Usage: "specify a header name for token in client/server mode",
}
ServerAddrFlag = Flag{
ServerAddrFlag = Flag[string]{
Name: "server",
ConfigName: "server.addr",
Default: "",
Usage: "server address in client mode",
}
ServerCustomHeadersFlag = Flag{
ServerCustomHeadersFlag = Flag[[]string]{
Name: "custom-headers",
ConfigName: "server.custom-headers",
Default: []string{},
Usage: "custom headers in client mode",
}
ServerListenFlag = Flag{
ServerListenFlag = Flag[string]{
Name: "listen",
ConfigName: "server.listen",
Default: "localhost:4954",
@@ -48,15 +45,15 @@ var (
// used for commands requiring reporting logic.
type RemoteFlagGroup struct {
// for client/server
Token *Flag
TokenHeader *Flag
Token *Flag[string]
TokenHeader *Flag[string]
// for client
ServerAddr *Flag
CustomHeaders *Flag
ServerAddr *Flag[string]
CustomHeaders *Flag[[]string]
// for server
Listen *Flag
Listen *Flag[string]
}
type RemoteOptions struct {
@@ -70,10 +67,10 @@ type RemoteOptions struct {
func NewClientFlags() *RemoteFlagGroup {
return &RemoteFlagGroup{
Token: &ServerTokenFlag,
TokenHeader: &ServerTokenHeaderFlag,
ServerAddr: &ServerAddrFlag,
CustomHeaders: &ServerCustomHeadersFlag,
Token: ServerTokenFlag.Clone(),
TokenHeader: ServerTokenHeaderFlag.Clone(),
ServerAddr: ServerAddrFlag.Clone(),
CustomHeaders: ServerCustomHeadersFlag.Clone(),
}
}
@@ -89,16 +86,26 @@ func (f *RemoteFlagGroup) Name() string {
return "Client/Server"
}
func (f *RemoteFlagGroup) Flags() []*Flag {
return []*Flag{f.Token, f.TokenHeader, f.ServerAddr, f.CustomHeaders, f.Listen}
func (f *RemoteFlagGroup) Flags() []Flagger {
return []Flagger{
f.Token,
f.TokenHeader,
f.ServerAddr,
f.CustomHeaders,
f.Listen,
}
}
func (f *RemoteFlagGroup) ToOptions() RemoteOptions {
serverAddr := getString(f.ServerAddr)
customHeaders := splitCustomHeaders(getStringSlice(f.CustomHeaders))
listen := getString(f.Listen)
token := getString(f.Token)
tokenHeader := getString(f.TokenHeader)
func (f *RemoteFlagGroup) ToOptions() (RemoteOptions, error) {
if err := parseFlags(f); err != nil {
return RemoteOptions{}, err
}
serverAddr := f.ServerAddr.Value()
customHeaders := splitCustomHeaders(f.CustomHeaders.Value())
listen := f.Listen.Value()
token := f.Token.Value()
tokenHeader := f.TokenHeader.Value()
if serverAddr == "" && listen == "" {
switch {
@@ -125,7 +132,7 @@ func (f *RemoteFlagGroup) ToOptions() RemoteOptions {
ServerAddr: serverAddr,
CustomHeaders: customHeaders,
Listen: listen,
}
}, nil
}
func splitCustomHeaders(headers []string) http.Header {

View File

@@ -1,6 +1,7 @@
package flag_test
import (
"github.com/stretchr/testify/require"
"net/http"
"testing"
@@ -108,12 +109,13 @@ func TestRemoteFlagGroup_ToOptions(t *testing.T) {
// Assert options
f := &flag.RemoteFlagGroup{
ServerAddr: &flag.ServerAddrFlag,
CustomHeaders: &flag.ServerCustomHeadersFlag,
Token: &flag.ServerTokenFlag,
TokenHeader: &flag.ServerTokenHeaderFlag,
ServerAddr: flag.ServerAddrFlag.Clone(),
CustomHeaders: flag.ServerCustomHeadersFlag.Clone(),
Token: flag.ServerTokenFlag.Clone(),
TokenHeader: flag.ServerTokenHeaderFlag.Clone(),
}
got := f.ToOptions()
got, err := f.ToOptions()
require.NoError(t, err)
assert.Equalf(t, tt.want, got, "ToOptions()")
// Assert log messages

View File

@@ -1,30 +1,27 @@
package flag
var (
FetchBranchFlag = Flag{
FetchBranchFlag = Flag[string]{
Name: "branch",
ConfigName: "repository.branch",
Default: "",
Usage: "pass the branch name to be scanned",
}
FetchCommitFlag = Flag{
FetchCommitFlag = Flag[string]{
Name: "commit",
ConfigName: "repository.commit",
Default: "",
Usage: "pass the commit hash to be scanned",
}
FetchTagFlag = Flag{
FetchTagFlag = Flag[string]{
Name: "tag",
ConfigName: "repository.tag",
Default: "",
Usage: "pass the tag name to be scanned",
}
)
type RepoFlagGroup struct {
Branch *Flag
Commit *Flag
Tag *Flag
Branch *Flag[string]
Commit *Flag[string]
Tag *Flag[string]
}
type RepoOptions struct {
@@ -35,9 +32,9 @@ type RepoOptions struct {
func NewRepoFlagGroup() *RepoFlagGroup {
return &RepoFlagGroup{
Branch: &FetchBranchFlag,
Commit: &FetchCommitFlag,
Tag: &FetchTagFlag,
Branch: FetchBranchFlag.Clone(),
Commit: FetchCommitFlag.Clone(),
Tag: FetchTagFlag.Clone(),
}
}
@@ -45,14 +42,22 @@ func (f *RepoFlagGroup) Name() string {
return "Repository"
}
func (f *RepoFlagGroup) Flags() []*Flag {
return []*Flag{f.Branch, f.Commit, f.Tag}
}
func (f *RepoFlagGroup) ToOptions() RepoOptions {
return RepoOptions{
RepoBranch: getString(f.Branch),
RepoCommit: getString(f.Commit),
RepoTag: getString(f.Tag),
func (f *RepoFlagGroup) Flags() []Flagger {
return []Flagger{
f.Branch,
f.Commit,
f.Tag,
}
}
func (f *RepoFlagGroup) ToOptions() (RepoOptions, error) {
if err := parseFlags(f); err != nil {
return RepoOptions{}, err
}
return RepoOptions{
RepoBranch: f.Branch.Value(),
RepoCommit: f.Commit.Value(),
RepoTag: f.Tag.Value(),
}, nil
}

View File

@@ -22,7 +22,7 @@ import (
// dependency-tree: true
// severity: HIGH,CRITICAL
var (
FormatFlag = Flag{
FormatFlag = Flag[string]{
Name: "format",
ConfigName: "format",
Shorthand: "f",
@@ -30,70 +30,65 @@ var (
Values: xstrings.ToStringSlice(types.SupportedFormats),
Usage: "format",
}
ReportFormatFlag = Flag{
ReportFormatFlag = Flag[string]{
Name: "report",
ConfigName: "report",
Default: "all",
Values: []string{"all", "summary"},
Usage: "specify a report format for the output",
Values: []string{
"all",
"summary",
},
Usage: "specify a report format for the output",
}
TemplateFlag = Flag{
TemplateFlag = Flag[string]{
Name: "template",
ConfigName: "template",
Shorthand: "t",
Default: "",
Usage: "output template",
}
DependencyTreeFlag = Flag{
DependencyTreeFlag = Flag[bool]{
Name: "dependency-tree",
ConfigName: "dependency-tree",
Default: false,
Usage: "[EXPERIMENTAL] show dependency origin tree of vulnerable packages",
}
ListAllPkgsFlag = Flag{
ListAllPkgsFlag = Flag[bool]{
Name: "list-all-pkgs",
ConfigName: "list-all-pkgs",
Default: false,
Usage: "enabling the option will output all packages regardless of vulnerability",
}
IgnoreFileFlag = Flag{
IgnoreFileFlag = Flag[string]{
Name: "ignorefile",
ConfigName: "ignorefile",
Default: result.DefaultIgnoreFile,
Usage: "specify .trivyignore file",
}
IgnorePolicyFlag = Flag{
IgnorePolicyFlag = Flag[string]{
Name: "ignore-policy",
ConfigName: "ignore-policy",
Default: "",
Usage: "specify the Rego file path to evaluate each vulnerability",
}
ExitCodeFlag = Flag{
ExitCodeFlag = Flag[int]{
Name: "exit-code",
ConfigName: "exit-code",
Default: 0,
Usage: "specify exit code when any security issues are found",
}
ExitOnEOLFlag = Flag{
ExitOnEOLFlag = Flag[int]{
Name: "exit-on-eol",
ConfigName: "exit-on-eol",
Default: 0,
Usage: "exit with the specified code when the OS reaches end of service/life",
}
OutputFlag = Flag{
OutputFlag = Flag[string]{
Name: "output",
ConfigName: "output",
Shorthand: "o",
Default: "",
Usage: "output file name",
}
OutputPluginArgFlag = Flag{
OutputPluginArgFlag = Flag[string]{
Name: "output-plugin-arg",
ConfigName: "output-plugin-arg",
Default: "",
Usage: "[EXPERIMENTAL] output plugin arguments",
}
SeverityFlag = Flag{
SeverityFlag = Flag[[]string]{
Name: "severity",
ConfigName: "severity",
Shorthand: "s",
@@ -101,10 +96,9 @@ var (
Values: dbTypes.SeverityNames,
Usage: "severities of security issues to be displayed",
}
ComplianceFlag = Flag{
ComplianceFlag = Flag[string]{
Name: "compliance",
ConfigName: "scan.compliance",
Default: "",
Usage: "compliance report to generate",
}
)
@@ -112,19 +106,19 @@ var (
// ReportFlagGroup composes common printer flag structs
// used for commands requiring reporting logic.
type ReportFlagGroup struct {
Format *Flag
ReportFormat *Flag
Template *Flag
DependencyTree *Flag
ListAllPkgs *Flag
IgnoreFile *Flag
IgnorePolicy *Flag
ExitCode *Flag
ExitOnEOL *Flag
Output *Flag
OutputPluginArg *Flag
Severity *Flag
Compliance *Flag
Format *Flag[string]
ReportFormat *Flag[string]
Template *Flag[string]
DependencyTree *Flag[bool]
ListAllPkgs *Flag[bool]
IgnoreFile *Flag[string]
IgnorePolicy *Flag[string]
ExitCode *Flag[int]
ExitOnEOL *Flag[int]
Output *Flag[string]
OutputPluginArg *Flag[string]
Severity *Flag[[]string]
Compliance *Flag[string]
}
type ReportOptions struct {
@@ -145,19 +139,19 @@ type ReportOptions struct {
func NewReportFlagGroup() *ReportFlagGroup {
return &ReportFlagGroup{
Format: &FormatFlag,
ReportFormat: &ReportFormatFlag,
Template: &TemplateFlag,
DependencyTree: &DependencyTreeFlag,
ListAllPkgs: &ListAllPkgsFlag,
IgnoreFile: &IgnoreFileFlag,
IgnorePolicy: &IgnorePolicyFlag,
ExitCode: &ExitCodeFlag,
ExitOnEOL: &ExitOnEOLFlag,
Output: &OutputFlag,
OutputPluginArg: &OutputPluginArgFlag,
Severity: &SeverityFlag,
Compliance: &ComplianceFlag,
Format: FormatFlag.Clone(),
ReportFormat: ReportFormatFlag.Clone(),
Template: TemplateFlag.Clone(),
DependencyTree: DependencyTreeFlag.Clone(),
ListAllPkgs: ListAllPkgsFlag.Clone(),
IgnoreFile: IgnoreFileFlag.Clone(),
IgnorePolicy: IgnorePolicyFlag.Clone(),
ExitCode: ExitCodeFlag.Clone(),
ExitOnEOL: ExitOnEOLFlag.Clone(),
Output: OutputFlag.Clone(),
OutputPluginArg: OutputPluginArgFlag.Clone(),
Severity: SeverityFlag.Clone(),
Compliance: ComplianceFlag.Clone(),
}
}
@@ -165,8 +159,8 @@ func (f *ReportFlagGroup) Name() string {
return "Report"
}
func (f *ReportFlagGroup) Flags() []*Flag {
return []*Flag{
func (f *ReportFlagGroup) Flags() []Flagger {
return []Flagger{
f.Format,
f.ReportFormat,
f.Template,
@@ -184,10 +178,14 @@ func (f *ReportFlagGroup) Flags() []*Flag {
}
func (f *ReportFlagGroup) ToOptions() (ReportOptions, error) {
format := getUnderlyingString[types.Format](f.Format)
template := getString(f.Template)
dependencyTree := getBool(f.DependencyTree)
listAllPkgs := getBool(f.ListAllPkgs)
if err := parseFlags(f); err != nil {
return ReportOptions{}, err
}
format := types.Format(f.Format.Value())
template := f.Template.Value()
dependencyTree := f.DependencyTree.Value()
listAllPkgs := f.ListAllPkgs.Value()
if template != "" {
if format == "" {
@@ -222,13 +220,13 @@ func (f *ReportFlagGroup) ToOptions() (ReportOptions, error) {
listAllPkgs = true
}
cs, err := loadComplianceTypes(getString(f.Compliance))
cs, err := loadComplianceTypes(f.Compliance.Value())
if err != nil {
return ReportOptions{}, xerrors.Errorf("unable to load compliance spec: %w", err)
}
var outputPluginArgs []string
if arg := getString(f.OutputPluginArg); arg != "" {
if arg := f.OutputPluginArg.Value(); arg != "" {
outputPluginArgs, err = shellwords.Parse(arg)
if err != nil {
return ReportOptions{}, xerrors.Errorf("unable to parse output plugin argument: %w", err)
@@ -237,17 +235,17 @@ func (f *ReportFlagGroup) ToOptions() (ReportOptions, error) {
return ReportOptions{
Format: format,
ReportFormat: getString(f.ReportFormat),
ReportFormat: f.ReportFormat.Value(),
Template: template,
DependencyTree: dependencyTree,
ListAllPkgs: listAllPkgs,
IgnoreFile: getString(f.IgnoreFile),
ExitCode: getInt(f.ExitCode),
ExitOnEOL: getInt(f.ExitOnEOL),
IgnorePolicy: getString(f.IgnorePolicy),
Output: getString(f.Output),
IgnoreFile: f.IgnoreFile.Value(),
ExitCode: f.ExitCode.Value(),
ExitOnEOL: f.ExitOnEOL.Value(),
IgnorePolicy: f.IgnorePolicy.Value(),
Output: f.Output.Value(),
OutputPluginArgs: outputPluginArgs,
Severities: toSeverity(getStringSlice(f.Severity)),
Severities: toSeverity(f.Severity.Value()),
Compliance: cs,
}, nil
}

View File

@@ -185,6 +185,8 @@ func TestReportFlagGroup_ToOptions(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Cleanup(viper.Reset)
level := zap.WarnLevel
if tt.fields.debug {
level = zap.DebugLevel
@@ -192,34 +194,34 @@ func TestReportFlagGroup_ToOptions(t *testing.T) {
core, obs := observer.New(level)
log.Logger = zap.New(core).Sugar()
viper.Set(flag.FormatFlag.ConfigName, string(tt.fields.format))
viper.Set(flag.TemplateFlag.ConfigName, tt.fields.template)
viper.Set(flag.DependencyTreeFlag.ConfigName, tt.fields.dependencyTree)
viper.Set(flag.ListAllPkgsFlag.ConfigName, tt.fields.listAllPkgs)
viper.Set(flag.IgnoreFileFlag.ConfigName, tt.fields.ignoreFile)
viper.Set(flag.IgnoreUnfixedFlag.ConfigName, tt.fields.ignoreUnfixed)
viper.Set(flag.IgnorePolicyFlag.ConfigName, tt.fields.ignorePolicy)
viper.Set(flag.ExitCodeFlag.ConfigName, tt.fields.exitCode)
viper.Set(flag.ExitOnEOLFlag.ConfigName, tt.fields.exitOnEOSL)
viper.Set(flag.OutputFlag.ConfigName, tt.fields.output)
viper.Set(flag.OutputPluginArgFlag.ConfigName, tt.fields.outputPluginArgs)
viper.Set(flag.SeverityFlag.ConfigName, tt.fields.severities)
viper.Set(flag.ComplianceFlag.ConfigName, tt.fields.compliance)
setValue(flag.FormatFlag.ConfigName, string(tt.fields.format))
setValue(flag.TemplateFlag.ConfigName, tt.fields.template)
setValue(flag.DependencyTreeFlag.ConfigName, tt.fields.dependencyTree)
setValue(flag.ListAllPkgsFlag.ConfigName, tt.fields.listAllPkgs)
setValue(flag.IgnoreFileFlag.ConfigName, tt.fields.ignoreFile)
setValue(flag.IgnoreUnfixedFlag.ConfigName, tt.fields.ignoreUnfixed)
setValue(flag.IgnorePolicyFlag.ConfigName, tt.fields.ignorePolicy)
setValue(flag.ExitCodeFlag.ConfigName, tt.fields.exitCode)
setValue(flag.ExitOnEOLFlag.ConfigName, tt.fields.exitOnEOSL)
setValue(flag.OutputFlag.ConfigName, tt.fields.output)
setValue(flag.OutputPluginArgFlag.ConfigName, tt.fields.outputPluginArgs)
setValue(flag.SeverityFlag.ConfigName, tt.fields.severities)
setValue(flag.ComplianceFlag.ConfigName, tt.fields.compliance)
// Assert options
f := &flag.ReportFlagGroup{
Format: &flag.FormatFlag,
Template: &flag.TemplateFlag,
DependencyTree: &flag.DependencyTreeFlag,
ListAllPkgs: &flag.ListAllPkgsFlag,
IgnoreFile: &flag.IgnoreFileFlag,
IgnorePolicy: &flag.IgnorePolicyFlag,
ExitCode: &flag.ExitCodeFlag,
ExitOnEOL: &flag.ExitOnEOLFlag,
Output: &flag.OutputFlag,
OutputPluginArg: &flag.OutputPluginArgFlag,
Severity: &flag.SeverityFlag,
Compliance: &flag.ComplianceFlag,
Format: flag.FormatFlag.Clone(),
Template: flag.TemplateFlag.Clone(),
DependencyTree: flag.DependencyTreeFlag.Clone(),
ListAllPkgs: flag.ListAllPkgsFlag.Clone(),
IgnoreFile: flag.IgnoreFileFlag.Clone(),
IgnorePolicy: flag.IgnorePolicyFlag.Clone(),
ExitCode: flag.ExitCodeFlag.Clone(),
ExitOnEOL: flag.ExitOnEOLFlag.Clone(),
Output: flag.OutputFlag.Clone(),
OutputPluginArg: flag.OutputPluginArgFlag.Clone(),
Severity: flag.SeverityFlag.Clone(),
Compliance: flag.ComplianceFlag.Clone(),
}
got, err := f.ToOptions()

View File

@@ -7,25 +7,23 @@ import (
)
var (
ArtifactTypeFlag = Flag{
ArtifactTypeFlag = Flag[string]{
Name: "artifact-type",
ConfigName: "sbom.artifact-type",
Default: "",
Usage: "deprecated",
Deprecated: true,
}
SBOMFormatFlag = Flag{
SBOMFormatFlag = Flag[string]{
Name: "sbom-format",
ConfigName: "sbom.format",
Default: "",
Usage: "deprecated",
Deprecated: true,
}
)
type SBOMFlagGroup struct {
ArtifactType *Flag // deprecated
SBOMFormat *Flag // deprecated
ArtifactType *Flag[string] // deprecated
SBOMFormat *Flag[string] // deprecated
}
type SBOMOptions struct {
@@ -33,8 +31,8 @@ type SBOMOptions struct {
func NewSBOMFlagGroup() *SBOMFlagGroup {
return &SBOMFlagGroup{
ArtifactType: &ArtifactTypeFlag,
SBOMFormat: &SBOMFormatFlag,
ArtifactType: ArtifactTypeFlag.Clone(),
SBOMFormat: SBOMFormatFlag.Clone(),
}
}
@@ -42,16 +40,20 @@ func (f *SBOMFlagGroup) Name() string {
return "SBOM"
}
func (f *SBOMFlagGroup) Flags() []*Flag {
return []*Flag{
func (f *SBOMFlagGroup) Flags() []Flagger {
return []Flagger{
f.ArtifactType,
f.SBOMFormat,
}
}
func (f *SBOMFlagGroup) ToOptions() (SBOMOptions, error) {
artifactType := getString(f.ArtifactType)
sbomFormat := getString(f.SBOMFormat)
if err := parseFlags(f); err != nil {
return SBOMOptions{}, err
}
artifactType := f.ArtifactType.Value()
sbomFormat := f.SBOMFormat.Value()
if artifactType != "" || sbomFormat != "" {
log.Logger.Error("'trivy sbom' is now for scanning SBOM. " +

View File

@@ -3,31 +3,31 @@ package flag
import (
"runtime"
"github.com/samber/lo"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/types"
xstrings "github.com/aquasecurity/trivy/pkg/x/strings"
)
var (
SkipDirsFlag = Flag{
SkipDirsFlag = Flag[[]string]{
Name: "skip-dirs",
ConfigName: "scan.skip-dirs",
Default: []string{},
Usage: "specify the directories or glob patterns to skip",
}
SkipFilesFlag = Flag{
SkipFilesFlag = Flag[[]string]{
Name: "skip-files",
ConfigName: "scan.skip-files",
Default: []string{},
Usage: "specify the files or glob patterns to skip",
}
OfflineScanFlag = Flag{
OfflineScanFlag = Flag[bool]{
Name: "offline-scan",
ConfigName: "scan.offline",
Default: false,
Usage: "do not issue API requests to identify dependencies",
}
ScannersFlag = Flag{
ScannersFlag = Flag[[]string]{
Name: "scanners",
ConfigName: "scan.scanners",
Default: xstrings.ToStringSlice(types.Scanners{
@@ -40,17 +40,19 @@ var (
types.SecretScanner,
types.LicenseScanner,
}),
ValueNormalize: func(s string) string {
switch s {
case "vulnerability":
return string(types.VulnerabilityScanner)
case "misconf", "misconfiguration":
return string(types.MisconfigScanner)
case "config":
log.Logger.Warn("'--scanner config' is deprecated. Use '--scanner misconfig' instead. See https://github.com/aquasecurity/trivy/discussions/5586 for the detail.")
return string(types.MisconfigScanner)
}
return s
ValueNormalize: func(ss []string) []string {
return lo.Map(ss, func(s string, _ int) string {
switch s {
case "vulnerability":
return string(types.VulnerabilityScanner)
case "misconf", "misconfiguration":
return string(types.MisconfigScanner)
case "config":
log.Logger.Warn("'--scanners config' is deprecated. Use '--scanners misconfig' instead. See https://github.com/aquasecurity/trivy/discussions/5586 for the detail.")
return string(types.MisconfigScanner)
}
return s
})
},
Aliases: []Alias{
{
@@ -61,57 +63,57 @@ var (
},
Usage: "comma-separated list of what security issues to detect",
}
FilePatternsFlag = Flag{
FilePatternsFlag = Flag[[]string]{
Name: "file-patterns",
ConfigName: "scan.file-patterns",
Default: []string{},
Usage: "specify config file patterns",
}
SlowFlag = Flag{
SlowFlag = Flag[bool]{
Name: "slow",
ConfigName: "scan.slow",
Default: false,
Usage: "scan over time with lower CPU and memory utilization",
Deprecated: true,
}
ParallelFlag = Flag{
ParallelFlag = Flag[int]{
Name: "parallel",
ConfigName: "scan.parallel",
Default: 5,
Usage: "number of goroutines enabled for parallel scanning, set 0 to auto-detect parallelism",
}
SBOMSourcesFlag = Flag{
SBOMSourcesFlag = Flag[[]string]{
Name: "sbom-sources",
ConfigName: "scan.sbom-sources",
Default: []string{},
Values: []string{"oci", "rekor"},
Usage: "[EXPERIMENTAL] try to retrieve SBOM from the specified sources",
Values: []string{
"oci",
"rekor",
},
Usage: "[EXPERIMENTAL] try to retrieve SBOM from the specified sources",
}
RekorURLFlag = Flag{
RekorURLFlag = Flag[string]{
Name: "rekor-url",
ConfigName: "scan.rekor-url",
Default: "https://rekor.sigstore.dev",
Usage: "[EXPERIMENTAL] address of rekor STL server",
}
IncludeDevDepsFlag = Flag{
IncludeDevDepsFlag = Flag[bool]{
Name: "include-dev-deps",
ConfigName: "include-dev-deps",
Default: false,
Usage: "include development dependencies in the report (supported: npm, yarn)",
}
)
type ScanFlagGroup struct {
SkipDirs *Flag
SkipFiles *Flag
OfflineScan *Flag
Scanners *Flag
FilePatterns *Flag
Slow *Flag // deprecated
Parallel *Flag
SBOMSources *Flag
RekorURL *Flag
IncludeDevDeps *Flag
SkipDirs *Flag[[]string]
SkipFiles *Flag[[]string]
OfflineScan *Flag[bool]
Scanners *Flag[[]string]
FilePatterns *Flag[[]string]
Slow *Flag[bool] // deprecated
Parallel *Flag[int]
SBOMSources *Flag[[]string]
RekorURL *Flag[string]
IncludeDevDeps *Flag[bool]
}
type ScanOptions struct {
@@ -129,16 +131,16 @@ type ScanOptions struct {
func NewScanFlagGroup() *ScanFlagGroup {
return &ScanFlagGroup{
SkipDirs: &SkipDirsFlag,
SkipFiles: &SkipFilesFlag,
OfflineScan: &OfflineScanFlag,
Scanners: &ScannersFlag,
FilePatterns: &FilePatternsFlag,
Parallel: &ParallelFlag,
SBOMSources: &SBOMSourcesFlag,
RekorURL: &RekorURLFlag,
IncludeDevDeps: &IncludeDevDepsFlag,
Slow: &SlowFlag,
SkipDirs: SkipDirsFlag.Clone(),
SkipFiles: SkipFilesFlag.Clone(),
OfflineScan: OfflineScanFlag.Clone(),
Scanners: ScannersFlag.Clone(),
FilePatterns: FilePatternsFlag.Clone(),
Parallel: ParallelFlag.Clone(),
SBOMSources: SBOMSourcesFlag.Clone(),
RekorURL: RekorURLFlag.Clone(),
IncludeDevDeps: IncludeDevDepsFlag.Clone(),
Slow: SlowFlag.Clone(),
}
}
@@ -146,8 +148,8 @@ func (f *ScanFlagGroup) Name() string {
return "Scan"
}
func (f *ScanFlagGroup) Flags() []*Flag {
return []*Flag{
func (f *ScanFlagGroup) Flags() []Flagger {
return []Flagger{
f.SkipDirs,
f.SkipFiles,
f.OfflineScan,
@@ -162,12 +164,16 @@ func (f *ScanFlagGroup) Flags() []*Flag {
}
func (f *ScanFlagGroup) ToOptions(args []string) (ScanOptions, error) {
if err := parseFlags(f); err != nil {
return ScanOptions{}, err
}
var target string
if len(args) == 1 {
target = args[0]
}
parallel := getInt(f.Parallel)
parallel := f.Parallel.Value()
if f.Parallel != nil && parallel == 0 {
log.Logger.Infof("Set '--parallel' to the number of CPUs (%d)", runtime.NumCPU())
parallel = runtime.NumCPU()
@@ -175,14 +181,14 @@ func (f *ScanFlagGroup) ToOptions(args []string) (ScanOptions, error) {
return ScanOptions{
Target: target,
SkipDirs: getStringSlice(f.SkipDirs),
SkipFiles: getStringSlice(f.SkipFiles),
OfflineScan: getBool(f.OfflineScan),
Scanners: getUnderlyingStringSlice[types.Scanner](f.Scanners),
FilePatterns: getStringSlice(f.FilePatterns),
SkipDirs: f.SkipDirs.Value(),
SkipFiles: f.SkipFiles.Value(),
OfflineScan: f.OfflineScan.Value(),
Scanners: xstrings.ToTSlice[types.Scanner](f.Scanners.Value()),
FilePatterns: f.FilePatterns.Value(),
Parallel: parallel,
SBOMSources: getStringSlice(f.SBOMSources),
RekorURL: getString(f.RekorURL),
IncludeDevDeps: getBool(f.IncludeDevDeps),
SBOMSources: f.SBOMSources.Value(),
RekorURL: f.RekorURL.Value(),
IncludeDevDeps: f.IncludeDevDeps.Value(),
}, nil
}

View File

@@ -1,9 +1,9 @@
package flag_test
import (
"github.com/spf13/viper"
"testing"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -109,23 +109,23 @@ func TestScanFlagGroup_ToOptions(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
viper.Set(flag.SkipDirsFlag.ConfigName, tt.fields.skipDirs)
viper.Set(flag.SkipFilesFlag.ConfigName, tt.fields.skipFiles)
viper.Set(flag.OfflineScanFlag.ConfigName, tt.fields.offlineScan)
viper.Set(flag.ScannersFlag.ConfigName, tt.fields.scanners)
t.Cleanup(viper.Reset)
setSliceValue(flag.SkipDirsFlag.ConfigName, tt.fields.skipDirs)
setSliceValue(flag.SkipFilesFlag.ConfigName, tt.fields.skipFiles)
setValue(flag.OfflineScanFlag.ConfigName, tt.fields.offlineScan)
setValue(flag.ScannersFlag.ConfigName, tt.fields.scanners)
// Assert options
f := &flag.ScanFlagGroup{
SkipDirs: &flag.SkipDirsFlag,
SkipFiles: &flag.SkipFilesFlag,
OfflineScan: &flag.OfflineScanFlag,
Scanners: &flag.ScannersFlag,
SkipDirs: flag.SkipDirsFlag.Clone(),
SkipFiles: flag.SkipFilesFlag.Clone(),
OfflineScan: flag.OfflineScanFlag.Clone(),
Scanners: flag.ScannersFlag.Clone(),
}
got, err := f.ToOptions(tt.args)
tt.assertion(t, err)
assert.Equalf(t, tt.want, got, "ToOptions()")
})
}
}

View File

@@ -1,7 +1,7 @@
package flag
var (
SecretConfigFlag = Flag{
SecretConfigFlag = Flag[string]{
Name: "secret-config",
ConfigName: "secret.config",
Default: "trivy-secret.yaml",
@@ -10,7 +10,7 @@ var (
)
type SecretFlagGroup struct {
SecretConfig *Flag
SecretConfig *Flag[string]
}
type SecretOptions struct {
@@ -19,7 +19,7 @@ type SecretOptions struct {
func NewSecretFlagGroup() *SecretFlagGroup {
return &SecretFlagGroup{
SecretConfig: &SecretConfigFlag,
SecretConfig: SecretConfigFlag.Clone(),
}
}
@@ -27,12 +27,16 @@ func (f *SecretFlagGroup) Name() string {
return "Secret"
}
func (f *SecretFlagGroup) Flags() []*Flag {
return []*Flag{f.SecretConfig}
func (f *SecretFlagGroup) Flags() []Flagger {
return []Flagger{f.SecretConfig}
}
func (f *SecretFlagGroup) ToOptions() SecretOptions {
return SecretOptions{
SecretConfigPath: getString(f.SecretConfig),
func (f *SecretFlagGroup) ToOptions() (SecretOptions, error) {
if err := parseFlags(f); err != nil {
return SecretOptions{}, err
}
return SecretOptions{
SecretConfigPath: f.SecretConfig.Value(),
}, nil
}

View File

@@ -1,104 +0,0 @@
package flag
import (
"strings"
"github.com/samber/lo"
"golang.org/x/exp/slices"
"golang.org/x/xerrors"
)
type ValueNormalizeFunc func(string) string
// -- string Value
type customStringValue struct {
value *string
allowed []string
normalize ValueNormalizeFunc
}
func newCustomStringValue(val string, allowed []string, fn ValueNormalizeFunc) *customStringValue {
return &customStringValue{
value: &val,
allowed: allowed,
normalize: fn,
}
}
func (s *customStringValue) Set(val string) error {
if s.normalize != nil {
val = s.normalize(val)
}
if len(s.allowed) > 0 && !slices.Contains(s.allowed, val) {
return xerrors.Errorf("must be one of %q", s.allowed)
}
s.value = &val
return nil
}
func (s *customStringValue) Type() string {
return "string"
}
func (s *customStringValue) String() string { return *s.value }
// -- stringSlice Value
type customStringSliceValue struct {
value *[]string
allowed []string
normalize ValueNormalizeFunc
changed bool
}
func newCustomStringSliceValue(val, allowed []string, fn ValueNormalizeFunc) *customStringSliceValue {
return &customStringSliceValue{
value: &val,
allowed: allowed,
normalize: fn,
}
}
func (s *customStringSliceValue) Set(val string) error {
values := strings.Split(val, ",")
if s.normalize != nil {
values = lo.Map(values, func(item string, _ int) string { return s.normalize(item) })
}
for _, v := range values {
if len(s.allowed) > 0 && !slices.Contains(s.allowed, v) {
return xerrors.Errorf("must be one of %q", s.allowed)
}
}
if !s.changed {
*s.value = values
} else {
*s.value = append(*s.value, values...)
}
s.changed = true
return nil
}
func (s *customStringSliceValue) Type() string {
return "stringSlice"
}
func (s *customStringSliceValue) String() string {
if len(*s.value) == 0 {
// "[]" is not recognized as a zero value
// cf. https://github.com/spf13/pflag/blob/d5e0c0615acee7028e1e2740a11102313be88de1/flag.go#L553-L565
return ""
}
return "[" + strings.Join(*s.value, ",") + "]"
}
func (s *customStringSliceValue) Append(val string) error {
s.changed = true
return s.Set(val)
}
func (s *customStringSliceValue) Replace(val []string) error {
*s.value = val
return nil
}
func (s *customStringSliceValue) GetSlice() []string {
return *s.value
}

View File

@@ -9,7 +9,7 @@ import (
)
var (
VulnTypeFlag = Flag{
VulnTypeFlag = Flag[[]string]{
Name: "vuln-type",
ConfigName: "vulnerability.type",
Default: []string{
@@ -22,20 +22,18 @@ var (
},
Usage: "comma-separated list of vulnerability types",
}
IgnoreUnfixedFlag = Flag{
IgnoreUnfixedFlag = Flag[bool]{
Name: "ignore-unfixed",
ConfigName: "vulnerability.ignore-unfixed",
Default: false,
Usage: "display only fixed vulnerabilities",
}
IgnoreStatusFlag = Flag{
IgnoreStatusFlag = Flag[[]string]{
Name: "ignore-status",
ConfigName: "vulnerability.ignore-status",
Default: []string{},
Values: dbTypes.Statuses,
Usage: "comma-separated list of vulnerability status to ignore",
}
VEXFlag = Flag{
VEXFlag = Flag[string]{
Name: "vex",
ConfigName: "vulnerability.vex",
Default: "",
@@ -44,10 +42,10 @@ var (
)
type VulnerabilityFlagGroup struct {
VulnType *Flag
IgnoreUnfixed *Flag
IgnoreStatus *Flag
VEXPath *Flag
VulnType *Flag[[]string]
IgnoreUnfixed *Flag[bool]
IgnoreStatus *Flag[[]string]
VEXPath *Flag[string]
}
type VulnerabilityOptions struct {
@@ -58,10 +56,10 @@ type VulnerabilityOptions struct {
func NewVulnerabilityFlagGroup() *VulnerabilityFlagGroup {
return &VulnerabilityFlagGroup{
VulnType: &VulnTypeFlag,
IgnoreUnfixed: &IgnoreUnfixedFlag,
IgnoreStatus: &IgnoreStatusFlag,
VEXPath: &VEXFlag,
VulnType: VulnTypeFlag.Clone(),
IgnoreUnfixed: IgnoreUnfixedFlag.Clone(),
IgnoreStatus: IgnoreStatusFlag.Clone(),
VEXPath: VEXFlag.Clone(),
}
}
@@ -69,8 +67,8 @@ func (f *VulnerabilityFlagGroup) Name() string {
return "Vulnerability"
}
func (f *VulnerabilityFlagGroup) Flags() []*Flag {
return []*Flag{
func (f *VulnerabilityFlagGroup) Flags() []Flagger {
return []Flagger{
f.VulnType,
f.IgnoreUnfixed,
f.IgnoreStatus,
@@ -78,12 +76,16 @@ func (f *VulnerabilityFlagGroup) Flags() []*Flag {
}
}
func (f *VulnerabilityFlagGroup) ToOptions() VulnerabilityOptions {
func (f *VulnerabilityFlagGroup) ToOptions() (VulnerabilityOptions, error) {
if err := parseFlags(f); err != nil {
return VulnerabilityOptions{}, err
}
// Just convert string to dbTypes.Status as the validated values are passed here.
ignoreStatuses := lo.Map(getStringSlice(f.IgnoreStatus), func(s string, _ int) dbTypes.Status {
ignoreStatuses := lo.Map(f.IgnoreStatus.Value(), func(s string, _ int) dbTypes.Status {
return dbTypes.NewStatus(s)
})
ignoreUnfixed := getBool(f.IgnoreUnfixed)
ignoreUnfixed := f.IgnoreUnfixed.Value()
switch {
case ignoreUnfixed && len(ignoreStatuses) > 0:
@@ -103,8 +105,8 @@ func (f *VulnerabilityFlagGroup) ToOptions() VulnerabilityOptions {
log.Logger.Debugw("Ignore statuses", "statuses", ignoreStatuses)
return VulnerabilityOptions{
VulnType: getStringSlice(f.VulnType),
VulnType: f.VulnType.Value(),
IgnoreStatuses: ignoreStatuses,
VEXPath: getString(f.VEXPath),
}
VEXPath: f.VEXPath.Value(),
}, nil
}

View File

@@ -1,6 +1,7 @@
package flag_test
import (
"github.com/stretchr/testify/require"
"testing"
"github.com/spf13/viper"
@@ -57,10 +58,11 @@ func TestVulnerabilityFlagGroup_ToOptions(t *testing.T) {
// Assert options
f := &flag.VulnerabilityFlagGroup{
VulnType: &flag.VulnTypeFlag,
VulnType: flag.VulnTypeFlag.Clone(),
}
got := f.ToOptions()
got, err := f.ToOptions()
require.NoError(t, err)
assert.Equalf(t, tt.want, got, "ToOptions()")
// Assert log messages

View File

@@ -230,7 +230,7 @@ func (c *CycloneDX) Vulnerabilities(uniq map[string]*cdx.Vulnerability) *[]cdx.V
return *value
})
sort.Slice(vulns, func(i, j int) bool {
return vulns[i].BOMRef < vulns[j].BOMRef
return vulns[i].ID < vulns[j].ID
})
return &vulns
}

View File

@@ -7,12 +7,18 @@ type String interface {
}
func ToStringSlice[T String](ss []T) []string {
if ss == nil {
return nil
}
return lo.Map(ss, func(s T, _ int) string {
return string(s)
})
}
func ToTSlice[T String](ss []string) []T {
if ss == nil {
return nil
}
return lo.Map(ss, func(s string, _ int) T {
return T(s)
})