refactor: reintroduce output writer (#5564)

Signed-off-by: knqyf263 <knqyf263@gmail.com>
This commit is contained in:
Teppei Fukuda
2023-11-14 21:32:23 +09:00
committed by GitHub
parent 2310f0dd69
commit 950e431f0f
10 changed files with 42 additions and 56 deletions

View File

@@ -1,6 +1,7 @@
package commands package commands
import ( import (
"bytes"
"context" "context"
"os" "os"
"path/filepath" "path/filepath"
@@ -1135,8 +1136,8 @@ Summary Report for compliance: my-custom-spec
}() }()
} }
output := filepath.Join(t.TempDir(), "output") output := bytes.NewBuffer(nil)
test.options.Output = output test.options.SetOutputWriter(output)
test.options.Debug = true test.options.Debug = true
test.options.GlobalOptions.Timeout = time.Minute test.options.GlobalOptions.Timeout = time.Minute
if test.options.Format == "" { if test.options.Format == "" {
@@ -1178,10 +1179,7 @@ Summary Report for compliance: my-custom-spec
return return
} }
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, test.want, output.String())
b, err := os.ReadFile(output)
require.NoError(t, err)
assert.Equal(t, test.want, string(b))
}) })
} }
} }

View File

@@ -59,11 +59,11 @@ func (r *Report) Failed() bool {
// Write writes the results in the give format // Write writes the results in the give format
func Write(rep *Report, opt flag.Options, fromCache bool) error { func Write(rep *Report, opt flag.Options, fromCache bool) error {
output, err := opt.OutputWriter() output, cleanup, err := opt.OutputWriter()
if err != nil { if err != nil {
return xerrors.Errorf("failed to create output file: %w", err) return xerrors.Errorf("failed to create output file: %w", err)
} }
defer output.Close() defer cleanup()
if opt.Compliance.Spec.ID != "" { if opt.Compliance.Spec.ID != "" {
return writeCompliance(rep, opt, output) return writeCompliance(rep, opt, output)
@@ -104,7 +104,7 @@ func Write(rep *Report, opt flag.Options, fromCache bool) error {
// ensure color/formatting is disabled for pipes/non-pty // ensure color/formatting is disabled for pipes/non-pty
var useANSI bool var useANSI bool
if opt.Output == "" { if output == os.Stdout {
if o, err := os.Stdout.Stat(); err == nil { if o, err := os.Stdout.Stat(); err == nil {
useANSI = (o.Mode() & os.ModeCharDevice) == os.ModeCharDevice useANSI = (o.Mode() & os.ModeCharDevice) == os.ModeCharDevice
} }

View File

@@ -1,8 +1,7 @@
package report package report
import ( import (
"os" "bytes"
"path/filepath"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -110,18 +109,15 @@ No problems detected.
tt.options.AWSOptions.Services, tt.options.AWSOptions.Services,
) )
output := filepath.Join(t.TempDir(), "output") output := bytes.NewBuffer(nil)
tt.options.Output = output tt.options.SetOutputWriter(output)
require.NoError(t, Write(report, tt.options, tt.fromCache)) require.NoError(t, Write(report, tt.options, tt.fromCache))
assert.Equal(t, "AWS", report.Provider) assert.Equal(t, "AWS", report.Provider)
assert.Equal(t, tt.options.AWSOptions.Account, report.AccountID) assert.Equal(t, tt.options.AWSOptions.Account, report.AccountID)
assert.Equal(t, tt.options.AWSOptions.Region, report.Region) assert.Equal(t, tt.options.AWSOptions.Region, report.Region)
assert.ElementsMatch(t, tt.options.AWSOptions.Services, report.ServicesInScope) assert.ElementsMatch(t, tt.options.AWSOptions.Services, report.ServicesInScope)
assert.Equal(t, tt.expected, output.String())
b, err := os.ReadFile(output)
require.NoError(t, err)
assert.Equal(t, tt.expected, string(b))
}) })
} }
} }

View File

@@ -1,8 +1,7 @@
package report package report
import ( import (
"os" "bytes"
"path/filepath"
"strings" "strings"
"testing" "testing"
@@ -69,18 +68,15 @@ See https://avd.aquasec.com/misconfig/avd-aws-9999
tt.options.AWSOptions.Services, tt.options.AWSOptions.Services,
) )
output := filepath.Join(t.TempDir(), "output") output := bytes.NewBuffer(nil)
tt.options.Output = output tt.options.SetOutputWriter(output)
require.NoError(t, Write(report, tt.options, tt.fromCache)) require.NoError(t, Write(report, tt.options, tt.fromCache))
b, err := os.ReadFile(output)
require.NoError(t, err)
assert.Equal(t, "AWS", report.Provider) assert.Equal(t, "AWS", report.Provider)
assert.Equal(t, tt.options.AWSOptions.Account, report.AccountID) assert.Equal(t, tt.options.AWSOptions.Account, report.AccountID)
assert.Equal(t, tt.options.AWSOptions.Region, report.Region) assert.Equal(t, tt.options.AWSOptions.Region, report.Region)
assert.ElementsMatch(t, tt.options.AWSOptions.Services, report.ServicesInScope) assert.ElementsMatch(t, tt.options.AWSOptions.Services, report.ServicesInScope)
assert.Equal(t, tt.expected, strings.ReplaceAll(string(b), "\r\n", "\n")) assert.Equal(t, tt.expected, strings.ReplaceAll(output.String(), "\r\n", "\n"))
}) })
} }
} }

View File

@@ -1,8 +1,7 @@
package report package report
import ( import (
"os" "bytes"
"path/filepath"
"testing" "testing"
"github.com/aws/aws-sdk-go-v2/aws/arn" "github.com/aws/aws-sdk-go-v2/aws/arn"
@@ -317,8 +316,8 @@ Scan Overview for AWS Account
tt.options.AWSOptions.Services, tt.options.AWSOptions.Services,
) )
output := filepath.Join(t.TempDir(), "output") output := bytes.NewBuffer(nil)
tt.options.Output = output tt.options.SetOutputWriter(output)
require.NoError(t, Write(report, tt.options, tt.fromCache)) require.NoError(t, Write(report, tt.options, tt.fromCache))
assert.Equal(t, "AWS", report.Provider) assert.Equal(t, "AWS", report.Provider)
@@ -326,13 +325,11 @@ Scan Overview for AWS Account
assert.Equal(t, tt.options.AWSOptions.Region, report.Region) assert.Equal(t, tt.options.AWSOptions.Region, report.Region)
assert.ElementsMatch(t, tt.options.AWSOptions.Services, report.ServicesInScope) assert.ElementsMatch(t, tt.options.AWSOptions.Services, report.ServicesInScope)
b, err := os.ReadFile(output)
require.NoError(t, err)
if tt.options.Format == "json" { if tt.options.Format == "json" {
// json output can be formatted/ordered differently - we just care that the data matches // json output can be formatted/ordered differently - we just care that the data matches
assert.JSONEq(t, tt.expected, string(b)) assert.JSONEq(t, tt.expected, output.String())
} else { } else {
assert.Equal(t, tt.expected, string(b)) assert.Equal(t, tt.expected, output.String())
} }
}) })
} }

View File

@@ -20,7 +20,6 @@ import (
"github.com/aquasecurity/trivy/pkg/result" "github.com/aquasecurity/trivy/pkg/result"
"github.com/aquasecurity/trivy/pkg/types" "github.com/aquasecurity/trivy/pkg/types"
"github.com/aquasecurity/trivy/pkg/version" "github.com/aquasecurity/trivy/pkg/version"
xio "github.com/aquasecurity/trivy/pkg/x/io"
xstrings "github.com/aquasecurity/trivy/pkg/x/strings" xstrings "github.com/aquasecurity/trivy/pkg/x/strings"
) )
@@ -114,6 +113,10 @@ type Options struct {
// We don't want to allow disabled analyzers to be passed by users, but it is necessary for internal use. // We don't want to allow disabled analyzers to be passed by users, but it is necessary for internal use.
DisabledAnalyzers []analyzer.Type DisabledAnalyzers []analyzer.Type
// outputWriter is not initialized via the CLI.
// It is mainly used for testing purposes or by tools that use Trivy as a library.
outputWriter io.Writer
} }
// Align takes consistency of options // Align takes consistency of options
@@ -159,17 +162,26 @@ func (o *Options) FilterOpts() result.FilterOption {
} }
} }
// SetOutputWriter sets an output writer.
func (o *Options) SetOutputWriter(w io.Writer) {
o.outputWriter = w
}
// OutputWriter returns an output writer. // OutputWriter returns an output writer.
// If the output file is not specified, it returns os.Stdout. // If the output file is not specified, it returns os.Stdout.
func (o *Options) OutputWriter() (io.WriteCloser, error) { func (o *Options) OutputWriter() (io.Writer, func(), error) {
if o.outputWriter != nil {
return o.outputWriter, func() {}, nil
}
if o.Output != "" { if o.Output != "" {
f, err := os.Create(o.Output) f, err := os.Create(o.Output)
if err != nil { if err != nil {
return nil, xerrors.Errorf("failed to create output file: %w", err) return nil, nil, xerrors.Errorf("failed to create output file: %w", err)
} }
return f, nil return f, func() { _ = f.Close() }, nil
} }
return xio.NopCloser(os.Stdout), nil return os.Stdout, func() {}, nil
} }
func addFlag(cmd *cobra.Command, flag *Flag) { func addFlag(cmd *cobra.Command, flag *Flag) {

View File

@@ -95,11 +95,11 @@ func (r *runner) run(ctx context.Context, artifacts []*k8sArtifacts.Artifact) er
return xerrors.Errorf("k8s scan error: %w", err) return xerrors.Errorf("k8s scan error: %w", err)
} }
output, err := r.flagOpts.OutputWriter() output, cleanup, err := r.flagOpts.OutputWriter()
if err != nil { if err != nil {
return xerrors.Errorf("failed to create output file: %w", err) return xerrors.Errorf("failed to create output file: %w", err)
} }
defer output.Close() defer cleanup()
if r.flagOpts.Compliance.Spec.ID != "" { if r.flagOpts.Compliance.Spec.ID != "" {
var scanResults []types.Results var scanResults []types.Results

View File

@@ -15,7 +15,6 @@ import (
"github.com/aquasecurity/tml" "github.com/aquasecurity/tml"
dbTypes "github.com/aquasecurity/trivy-db/pkg/types" dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
"github.com/aquasecurity/trivy/pkg/types" "github.com/aquasecurity/trivy/pkg/types"
xio "github.com/aquasecurity/trivy/pkg/x/io"
) )
var ( var (
@@ -137,7 +136,7 @@ func IsOutputToTerminal(output io.Writer) bool {
return false return false
} }
if output != xio.NopCloser(os.Stdout) { if output != os.Stdout {
return false return false
} }
o, err := os.Stdout.Stat() o, err := os.Stdout.Stat()

View File

@@ -25,11 +25,11 @@ const (
// Write writes the result to output, format as passed in argument // Write writes the result to output, format as passed in argument
func Write(report types.Report, option flag.Options) error { func Write(report types.Report, option flag.Options) error {
output, err := option.OutputWriter() output, cleanup, err := option.OutputWriter()
if err != nil { if err != nil {
return xerrors.Errorf("failed to create a file: %w", err) return xerrors.Errorf("failed to create a file: %w", err)
} }
defer output.Close() defer cleanup()
// Compliance report // Compliance report
if option.Compliance.Spec.ID != "" { if option.Compliance.Spec.ID != "" {

View File

@@ -9,18 +9,6 @@ import (
dio "github.com/aquasecurity/go-dep-parser/pkg/io" dio "github.com/aquasecurity/go-dep-parser/pkg/io"
) )
// NopCloser returns a WriteCloser with a no-op Close method wrapping
// the provided Writer w.
func NopCloser(w io.Writer) io.WriteCloser {
return nopCloser{w}
}
type nopCloser struct {
io.Writer
}
func (nopCloser) Close() error { return nil }
func NewReadSeekerAt(r io.Reader) (dio.ReadSeekerAt, error) { func NewReadSeekerAt(r io.Reader) (dio.ReadSeekerAt, error) {
if rr, ok := r.(dio.ReadSeekerAt); ok { if rr, ok := r.(dio.ReadSeekerAt); ok {
return rr, nil return rr, nil