mirror of
https://github.com/aquasecurity/trivy.git
synced 2025-12-12 15:50:15 -08:00
feat: add graceful shutdown with signal handling (#9242)
Signed-off-by: knqyf263 <knqyf263@gmail.com>
This commit is contained in:
@@ -41,6 +41,10 @@ func run() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set up signal handling for graceful shutdown
|
||||||
|
ctx, stop := commands.NotifyContext(context.Background())
|
||||||
|
defer stop()
|
||||||
|
|
||||||
app := commands.NewApp()
|
app := commands.NewApp()
|
||||||
return app.Execute()
|
return app.ExecuteContext(ctx)
|
||||||
}
|
}
|
||||||
|
|||||||
37
pkg/commands/signal.go
Normal file
37
pkg/commands/signal.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package commands
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/aquasecurity/trivy/pkg/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NotifyContext returns a context that is canceled when SIGINT or SIGTERM is received.
|
||||||
|
// It also ensures cleanup of temporary files when the signal is received.
|
||||||
|
//
|
||||||
|
// When a signal is received, Trivy will attempt to gracefully shut down by canceling
|
||||||
|
// the context and waiting for all operations to complete. If users want to force an
|
||||||
|
// immediate exit, they can send a second SIGINT or SIGTERM signal.
|
||||||
|
func NotifyContext(parent context.Context) (context.Context, context.CancelFunc) {
|
||||||
|
ctx, stop := signal.NotifyContext(parent, os.Interrupt, syscall.SIGTERM)
|
||||||
|
|
||||||
|
// Start a goroutine to handle cleanup when context is done
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
// Log that we're shutting down gracefully
|
||||||
|
log.Info("Received signal, attempting graceful shutdown...")
|
||||||
|
log.Info("Press Ctrl+C again to force exit")
|
||||||
|
|
||||||
|
// TODO: Add any necessary cleanup logic here
|
||||||
|
|
||||||
|
// Clean up signal handling
|
||||||
|
// After calling stop(), a second signal will cause immediate termination
|
||||||
|
stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
return ctx, stop
|
||||||
|
}
|
||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"github.com/aquasecurity/trivy/pkg/log"
|
"github.com/aquasecurity/trivy/pkg/log"
|
||||||
"github.com/aquasecurity/trivy/pkg/remote"
|
"github.com/aquasecurity/trivy/pkg/remote"
|
||||||
"github.com/aquasecurity/trivy/pkg/version/doc"
|
"github.com/aquasecurity/trivy/pkg/version/doc"
|
||||||
|
xio "github.com/aquasecurity/trivy/pkg/x/io"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -188,7 +189,7 @@ func (a *Artifact) download(ctx context.Context, layer v1.Layer, fileName, dir s
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// Download the layer content into a temporal file
|
// Download the layer content into a temporal file
|
||||||
if _, err = io.Copy(f, pr); err != nil {
|
if _, err = xio.Copy(ctx, f, pr); err != nil {
|
||||||
return xerrors.Errorf("copy error: %w", err)
|
return xerrors.Errorf("copy error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -62,20 +63,46 @@ func (s Server) ListenAndServe(ctx context.Context, serverCache cache.Cache, ski
|
|||||||
requestWg := &sync.WaitGroup{}
|
requestWg := &sync.WaitGroup{}
|
||||||
dbUpdateWg := &sync.WaitGroup{}
|
dbUpdateWg := &sync.WaitGroup{}
|
||||||
|
|
||||||
|
server := &http.Server{
|
||||||
|
Addr: s.addr,
|
||||||
|
Handler: s.NewServeMux(ctx, serverCache, dbUpdateWg, requestWg),
|
||||||
|
ReadHeaderTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start DB update worker
|
||||||
go func() {
|
go func() {
|
||||||
worker := newDBWorker(db.NewClient(s.dbDir, true, db.WithDBRepository(s.dbRepositories)))
|
worker := newDBWorker(db.NewClient(s.dbDir, true, db.WithDBRepository(s.dbRepositories)))
|
||||||
|
ticker := time.NewTicker(updateInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
time.Sleep(updateInterval)
|
select {
|
||||||
if err := worker.update(ctx, s.appVersion, s.dbDir, skipDBUpdate, dbUpdateWg, requestWg, s.RegistryOptions); err != nil {
|
case <-ctx.Done():
|
||||||
log.Errorf("%+v\n", err)
|
log.Debug("Server shutting down gracefully...")
|
||||||
|
|
||||||
|
// Give active requests time to complete
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||||
|
log.Errorf("Server shutdown error: %v", err)
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := worker.update(ctx, s.appVersion, s.dbDir, skipDBUpdate, dbUpdateWg, requestWg, s.RegistryOptions); err != nil {
|
||||||
|
log.Errorf("%+v\n", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
mux := s.NewServeMux(ctx, serverCache, dbUpdateWg, requestWg)
|
|
||||||
log.Infof("Listening %s...", s.addr)
|
log.Infof("Listening %s...", s.addr)
|
||||||
|
|
||||||
return http.ListenAndServe(s.addr, mux)
|
// This will block until Shutdown is called
|
||||||
|
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
return xerrors.Errorf("server error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s Server) NewServeMux(ctx context.Context, serverCache cache.Cache, dbUpdateWg, requestWg *sync.WaitGroup) *http.ServeMux {
|
func (s Server) NewServeMux(ctx context.Context, serverCache cache.Cache, dbUpdateWg, requestWg *sync.WaitGroup) *http.ServeMux {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package io
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
@@ -71,3 +72,27 @@ type nopCloser struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (nopCloser) Close() error { return nil }
|
func (nopCloser) Close() error { return nil }
|
||||||
|
|
||||||
|
// readerFunc is a function that implements io.Reader
|
||||||
|
type readerFunc func([]byte) (int, error)
|
||||||
|
|
||||||
|
func (f readerFunc) Read(p []byte) (int, error) {
|
||||||
|
return f(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy copies from src to dst until either EOF is reached on src or the context is canceled.
|
||||||
|
// It returns the number of bytes copied and the first error encountered while copying, if any.
|
||||||
|
//
|
||||||
|
// Note: This implementation wraps the reader with a context check, which means it won't
|
||||||
|
// benefit from WriterTo optimization in io.Copy if the source implements it. This is a trade-off
|
||||||
|
// for being able to cancel the operation on context cancellation.
|
||||||
|
func Copy(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
|
||||||
|
return io.Copy(dst, readerFunc(func(p []byte) (int, error) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return 0, ctx.Err()
|
||||||
|
default:
|
||||||
|
return src.Read(p)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|||||||
66
pkg/x/io/io_test.go
Normal file
66
pkg/x/io/io_test.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package io
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCopy(t *testing.T) {
|
||||||
|
t.Run("successful copy", func(t *testing.T) {
|
||||||
|
ctx := t.Context()
|
||||||
|
src := strings.NewReader("hello world")
|
||||||
|
dst := &bytes.Buffer{}
|
||||||
|
|
||||||
|
n, err := Copy(ctx, dst, src)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(11), n)
|
||||||
|
assert.Equal(t, "hello world", dst.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("context canceled before read", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
|
cancel() // Cancel immediately
|
||||||
|
|
||||||
|
src := strings.NewReader("hello world")
|
||||||
|
dst := &bytes.Buffer{}
|
||||||
|
|
||||||
|
n, err := Copy(ctx, dst, src)
|
||||||
|
require.ErrorIs(t, err, context.Canceled)
|
||||||
|
assert.Equal(t, int64(0), n)
|
||||||
|
assert.Empty(t, dst.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("context canceled during read", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
|
|
||||||
|
// Create a reader that will be canceled after first read
|
||||||
|
reader := &dummyReader{
|
||||||
|
cancel: cancel, // Cancel after first read
|
||||||
|
}
|
||||||
|
dst := &bytes.Buffer{}
|
||||||
|
|
||||||
|
n, err := Copy(ctx, dst, reader)
|
||||||
|
require.ErrorIs(t, err, context.Canceled)
|
||||||
|
// Should have written first chunk before cancellation
|
||||||
|
assert.Equal(t, int64(5), n)
|
||||||
|
assert.Equal(t, "dummy", dst.String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// dummyReader returns the same data on every Read call
|
||||||
|
type dummyReader struct {
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dummyReader) Read(p []byte) (int, error) {
|
||||||
|
n := copy(p, "dummy")
|
||||||
|
if r.cancel != nil {
|
||||||
|
r.cancel() // Simulate cancellation after first read
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user