From 2c05882f45071928c14d8212ef6c4f0f7048245d Mon Sep 17 00:00:00 2001 From: Teppei Fukuda Date: Thu, 24 Jul 2025 15:05:27 +0400 Subject: [PATCH] feat: add graceful shutdown with signal handling (#9242) Signed-off-by: knqyf263 --- cmd/trivy/main.go | 6 +++- pkg/commands/signal.go | 37 ++++++++++++++++++++++ pkg/oci/artifact.go | 3 +- pkg/rpc/server/listen.go | 37 +++++++++++++++++++--- pkg/x/io/io.go | 25 +++++++++++++++ pkg/x/io/io_test.go | 66 ++++++++++++++++++++++++++++++++++++++++ 6 files changed, 167 insertions(+), 7 deletions(-) create mode 100644 pkg/commands/signal.go create mode 100644 pkg/x/io/io_test.go diff --git a/cmd/trivy/main.go b/cmd/trivy/main.go index 80d7ab8f71..db01bfa042 100644 --- a/cmd/trivy/main.go +++ b/cmd/trivy/main.go @@ -41,6 +41,10 @@ func run() error { return nil } + // Set up signal handling for graceful shutdown + ctx, stop := commands.NotifyContext(context.Background()) + defer stop() + app := commands.NewApp() - return app.Execute() + return app.ExecuteContext(ctx) } diff --git a/pkg/commands/signal.go b/pkg/commands/signal.go new file mode 100644 index 0000000000..52e0e762fc --- /dev/null +++ b/pkg/commands/signal.go @@ -0,0 +1,37 @@ +package commands + +import ( + "context" + "os" + "os/signal" + "syscall" + + "github.com/aquasecurity/trivy/pkg/log" +) + +// NotifyContext returns a context that is canceled when SIGINT or SIGTERM is received. +// It also ensures cleanup of temporary files when the signal is received. +// +// When a signal is received, Trivy will attempt to gracefully shut down by canceling +// the context and waiting for all operations to complete. If users want to force an +// immediate exit, they can send a second SIGINT or SIGTERM signal. +func NotifyContext(parent context.Context) (context.Context, context.CancelFunc) { + ctx, stop := signal.NotifyContext(parent, os.Interrupt, syscall.SIGTERM) + + // Start a goroutine to handle cleanup when context is done + go func() { + <-ctx.Done() + + // Log that we're shutting down gracefully + log.Info("Received signal, attempting graceful shutdown...") + log.Info("Press Ctrl+C again to force exit") + + // TODO: Add any necessary cleanup logic here + + // Clean up signal handling + // After calling stop(), a second signal will cause immediate termination + stop() + }() + + return ctx, stop +} diff --git a/pkg/oci/artifact.go b/pkg/oci/artifact.go index 1abb2aa99e..09e9ff8d68 100644 --- a/pkg/oci/artifact.go +++ b/pkg/oci/artifact.go @@ -21,6 +21,7 @@ import ( "github.com/aquasecurity/trivy/pkg/log" "github.com/aquasecurity/trivy/pkg/remote" "github.com/aquasecurity/trivy/pkg/version/doc" + xio "github.com/aquasecurity/trivy/pkg/x/io" ) const ( @@ -188,7 +189,7 @@ func (a *Artifact) download(ctx context.Context, layer v1.Layer, fileName, dir s }() // Download the layer content into a temporal file - if _, err = io.Copy(f, pr); err != nil { + if _, err = xio.Copy(ctx, f, pr); err != nil { return xerrors.Errorf("copy error: %w", err) } diff --git a/pkg/rpc/server/listen.go b/pkg/rpc/server/listen.go index 37ac101018..10e6d814ff 100644 --- a/pkg/rpc/server/listen.go +++ b/pkg/rpc/server/listen.go @@ -3,6 +3,7 @@ package server import ( "context" "encoding/json" + "errors" "net/http" "os" "strings" @@ -62,20 +63,46 @@ func (s Server) ListenAndServe(ctx context.Context, serverCache cache.Cache, ski requestWg := &sync.WaitGroup{} dbUpdateWg := &sync.WaitGroup{} + server := &http.Server{ + Addr: s.addr, + Handler: s.NewServeMux(ctx, serverCache, dbUpdateWg, requestWg), + ReadHeaderTimeout: 10 * time.Second, + } + + // Start DB update worker go func() { worker := newDBWorker(db.NewClient(s.dbDir, true, db.WithDBRepository(s.dbRepositories))) + ticker := time.NewTicker(updateInterval) + defer ticker.Stop() + for { - time.Sleep(updateInterval) - if err := worker.update(ctx, s.appVersion, s.dbDir, skipDBUpdate, dbUpdateWg, requestWg, s.RegistryOptions); err != nil { - log.Errorf("%+v\n", err) + select { + case <-ctx.Done(): + log.Debug("Server shutting down gracefully...") + + // Give active requests time to complete + shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + if err := server.Shutdown(shutdownCtx); err != nil { + log.Errorf("Server shutdown error: %v", err) + } + cancel() + return + case <-ticker.C: + if err := worker.update(ctx, s.appVersion, s.dbDir, skipDBUpdate, dbUpdateWg, requestWg, s.RegistryOptions); err != nil { + log.Errorf("%+v\n", err) + } } } }() - mux := s.NewServeMux(ctx, serverCache, dbUpdateWg, requestWg) log.Infof("Listening %s...", s.addr) - return http.ListenAndServe(s.addr, mux) + // This will block until Shutdown is called + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + return xerrors.Errorf("server error: %w", err) + } + + return nil } func (s Server) NewServeMux(ctx context.Context, serverCache cache.Cache, dbUpdateWg, requestWg *sync.WaitGroup) *http.ServeMux { diff --git a/pkg/x/io/io.go b/pkg/x/io/io.go index 1fe3d6c0c2..8612e53b8b 100644 --- a/pkg/x/io/io.go +++ b/pkg/x/io/io.go @@ -2,6 +2,7 @@ package io import ( "bytes" + "context" "io" "golang.org/x/xerrors" @@ -71,3 +72,27 @@ type nopCloser struct { } func (nopCloser) Close() error { return nil } + +// readerFunc is a function that implements io.Reader +type readerFunc func([]byte) (int, error) + +func (f readerFunc) Read(p []byte) (int, error) { + return f(p) +} + +// Copy copies from src to dst until either EOF is reached on src or the context is canceled. +// It returns the number of bytes copied and the first error encountered while copying, if any. +// +// Note: This implementation wraps the reader with a context check, which means it won't +// benefit from WriterTo optimization in io.Copy if the source implements it. This is a trade-off +// for being able to cancel the operation on context cancellation. +func Copy(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) { + return io.Copy(dst, readerFunc(func(p []byte) (int, error) { + select { + case <-ctx.Done(): + return 0, ctx.Err() + default: + return src.Read(p) + } + })) +} diff --git a/pkg/x/io/io_test.go b/pkg/x/io/io_test.go new file mode 100644 index 0000000000..b9c5b7f672 --- /dev/null +++ b/pkg/x/io/io_test.go @@ -0,0 +1,66 @@ +package io + +import ( + "bytes" + "context" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCopy(t *testing.T) { + t.Run("successful copy", func(t *testing.T) { + ctx := t.Context() + src := strings.NewReader("hello world") + dst := &bytes.Buffer{} + + n, err := Copy(ctx, dst, src) + require.NoError(t, err) + assert.Equal(t, int64(11), n) + assert.Equal(t, "hello world", dst.String()) + }) + + t.Run("context canceled before read", func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + cancel() // Cancel immediately + + src := strings.NewReader("hello world") + dst := &bytes.Buffer{} + + n, err := Copy(ctx, dst, src) + require.ErrorIs(t, err, context.Canceled) + assert.Equal(t, int64(0), n) + assert.Empty(t, dst.String()) + }) + + t.Run("context canceled during read", func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + + // Create a reader that will be canceled after first read + reader := &dummyReader{ + cancel: cancel, // Cancel after first read + } + dst := &bytes.Buffer{} + + n, err := Copy(ctx, dst, reader) + require.ErrorIs(t, err, context.Canceled) + // Should have written first chunk before cancellation + assert.Equal(t, int64(5), n) + assert.Equal(t, "dummy", dst.String()) + }) +} + +// dummyReader returns the same data on every Read call +type dummyReader struct { + cancel context.CancelFunc +} + +func (r *dummyReader) Read(p []byte) (int, error) { + n := copy(p, "dummy") + if r.cancel != nil { + r.cancel() // Simulate cancellation after first read + } + return n, nil +}