feat: add graceful shutdown with signal handling (#9242)

Signed-off-by: knqyf263 <knqyf263@gmail.com>
This commit is contained in:
Teppei Fukuda
2025-07-24 15:05:27 +04:00
committed by GitHub
parent b5da1b8d61
commit 2c05882f45
6 changed files with 167 additions and 7 deletions

View File

@@ -41,6 +41,10 @@ func run() error {
return nil
}
// Set up signal handling for graceful shutdown
ctx, stop := commands.NotifyContext(context.Background())
defer stop()
app := commands.NewApp()
return app.Execute()
return app.ExecuteContext(ctx)
}

37
pkg/commands/signal.go Normal file
View File

@@ -0,0 +1,37 @@
package commands
import (
"context"
"os"
"os/signal"
"syscall"
"github.com/aquasecurity/trivy/pkg/log"
)
// NotifyContext returns a context that is canceled when SIGINT or SIGTERM is received.
// It also ensures cleanup of temporary files when the signal is received.
//
// When a signal is received, Trivy will attempt to gracefully shut down by canceling
// the context and waiting for all operations to complete. If users want to force an
// immediate exit, they can send a second SIGINT or SIGTERM signal.
func NotifyContext(parent context.Context) (context.Context, context.CancelFunc) {
ctx, stop := signal.NotifyContext(parent, os.Interrupt, syscall.SIGTERM)
// Start a goroutine to handle cleanup when context is done
go func() {
<-ctx.Done()
// Log that we're shutting down gracefully
log.Info("Received signal, attempting graceful shutdown...")
log.Info("Press Ctrl+C again to force exit")
// TODO: Add any necessary cleanup logic here
// Clean up signal handling
// After calling stop(), a second signal will cause immediate termination
stop()
}()
return ctx, stop
}

View File

@@ -21,6 +21,7 @@ import (
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/remote"
"github.com/aquasecurity/trivy/pkg/version/doc"
xio "github.com/aquasecurity/trivy/pkg/x/io"
)
const (
@@ -188,7 +189,7 @@ func (a *Artifact) download(ctx context.Context, layer v1.Layer, fileName, dir s
}()
// Download the layer content into a temporal file
if _, err = io.Copy(f, pr); err != nil {
if _, err = xio.Copy(ctx, f, pr); err != nil {
return xerrors.Errorf("copy error: %w", err)
}

View File

@@ -3,6 +3,7 @@ package server
import (
"context"
"encoding/json"
"errors"
"net/http"
"os"
"strings"
@@ -62,20 +63,46 @@ func (s Server) ListenAndServe(ctx context.Context, serverCache cache.Cache, ski
requestWg := &sync.WaitGroup{}
dbUpdateWg := &sync.WaitGroup{}
server := &http.Server{
Addr: s.addr,
Handler: s.NewServeMux(ctx, serverCache, dbUpdateWg, requestWg),
ReadHeaderTimeout: 10 * time.Second,
}
// Start DB update worker
go func() {
worker := newDBWorker(db.NewClient(s.dbDir, true, db.WithDBRepository(s.dbRepositories)))
ticker := time.NewTicker(updateInterval)
defer ticker.Stop()
for {
time.Sleep(updateInterval)
if err := worker.update(ctx, s.appVersion, s.dbDir, skipDBUpdate, dbUpdateWg, requestWg, s.RegistryOptions); err != nil {
log.Errorf("%+v\n", err)
select {
case <-ctx.Done():
log.Debug("Server shutting down gracefully...")
// Give active requests time to complete
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
if err := server.Shutdown(shutdownCtx); err != nil {
log.Errorf("Server shutdown error: %v", err)
}
cancel()
return
case <-ticker.C:
if err := worker.update(ctx, s.appVersion, s.dbDir, skipDBUpdate, dbUpdateWg, requestWg, s.RegistryOptions); err != nil {
log.Errorf("%+v\n", err)
}
}
}
}()
mux := s.NewServeMux(ctx, serverCache, dbUpdateWg, requestWg)
log.Infof("Listening %s...", s.addr)
return http.ListenAndServe(s.addr, mux)
// This will block until Shutdown is called
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return xerrors.Errorf("server error: %w", err)
}
return nil
}
func (s Server) NewServeMux(ctx context.Context, serverCache cache.Cache, dbUpdateWg, requestWg *sync.WaitGroup) *http.ServeMux {

View File

@@ -2,6 +2,7 @@ package io
import (
"bytes"
"context"
"io"
"golang.org/x/xerrors"
@@ -71,3 +72,27 @@ type nopCloser struct {
}
func (nopCloser) Close() error { return nil }
// readerFunc is a function that implements io.Reader
type readerFunc func([]byte) (int, error)
func (f readerFunc) Read(p []byte) (int, error) {
return f(p)
}
// Copy copies from src to dst until either EOF is reached on src or the context is canceled.
// It returns the number of bytes copied and the first error encountered while copying, if any.
//
// Note: This implementation wraps the reader with a context check, which means it won't
// benefit from WriterTo optimization in io.Copy if the source implements it. This is a trade-off
// for being able to cancel the operation on context cancellation.
func Copy(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
return io.Copy(dst, readerFunc(func(p []byte) (int, error) {
select {
case <-ctx.Done():
return 0, ctx.Err()
default:
return src.Read(p)
}
}))
}

66
pkg/x/io/io_test.go Normal file
View File

@@ -0,0 +1,66 @@
package io
import (
"bytes"
"context"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCopy(t *testing.T) {
t.Run("successful copy", func(t *testing.T) {
ctx := t.Context()
src := strings.NewReader("hello world")
dst := &bytes.Buffer{}
n, err := Copy(ctx, dst, src)
require.NoError(t, err)
assert.Equal(t, int64(11), n)
assert.Equal(t, "hello world", dst.String())
})
t.Run("context canceled before read", func(t *testing.T) {
ctx, cancel := context.WithCancel(t.Context())
cancel() // Cancel immediately
src := strings.NewReader("hello world")
dst := &bytes.Buffer{}
n, err := Copy(ctx, dst, src)
require.ErrorIs(t, err, context.Canceled)
assert.Equal(t, int64(0), n)
assert.Empty(t, dst.String())
})
t.Run("context canceled during read", func(t *testing.T) {
ctx, cancel := context.WithCancel(t.Context())
// Create a reader that will be canceled after first read
reader := &dummyReader{
cancel: cancel, // Cancel after first read
}
dst := &bytes.Buffer{}
n, err := Copy(ctx, dst, reader)
require.ErrorIs(t, err, context.Canceled)
// Should have written first chunk before cancellation
assert.Equal(t, int64(5), n)
assert.Equal(t, "dummy", dst.String())
})
}
// dummyReader returns the same data on every Read call
type dummyReader struct {
cancel context.CancelFunc
}
func (r *dummyReader) Read(p []byte) (int, error) {
n := copy(p, "dummy")
if r.cancel != nil {
r.cancel() // Simulate cancellation after first read
}
return n, nil
}