Skip to content

Commit 87f6128

Browse files
committed
Graceful shutdown
1 parent 5adf06b commit 87f6128

File tree

5 files changed

+217
-108
lines changed

5 files changed

+217
-108
lines changed

cmd/localstack/main.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"strconv"
1313
"strings"
1414
"sync"
15+
"time"
1516

1617
"github.com/aws/aws-sdk-go-v2/config"
1718
"github.com/localstack/lambda-runtime-init/internal/aws/lambda"
@@ -201,7 +202,7 @@ func main() {
201202
eventsListener := events.NewLocalStackEventsAPI(lsClient)
202203

203204
defaultSupv := supv.NewLocalSupervisor()
204-
wrappedSupv := supervisor.NewLocalStackSupervisor(ctx, defaultSupv, eventsListener, interopServer.InternalState)
205+
localStackSupv := supervisor.NewLocalStackSupervisor(ctx, defaultSupv, eventsListener)
205206

206207
// build sandbox
207208
exitChan := make(chan struct{})
@@ -219,11 +220,8 @@ func main() {
219220
SetLogsEgressAPI(localStackLogsEgressApi).
220221
SetTracer(tracer).
221222
SetInteropServer(interopServer).
222-
SetSupervisor(wrappedSupv).
223+
SetSupervisor(localStackSupv).
223224
SetHandler(handler)
224-
sandbox.AddShutdownFunc(func() {
225-
exitChan <- struct{}{}
226-
})
227225

228226
// Start daemons
229227

@@ -248,7 +246,9 @@ func main() {
248246
interopServer.SetSandboxContext(sandboxContext)
249247
interopServer.SetInternalStateGetter(internalStateFn)
250248

251-
localStackService := server.NewLocalStackService(interopServer, logCollector, lsClient, xrayConfig.Endpoint, lsOpts, functionConf, awsEnvConf)
249+
localStackService := server.NewLocalStackService(
250+
interopServer, logCollector, lsClient, localStackSupv, xrayConfig.Endpoint, lsOpts, functionConf, awsEnvConf,
251+
)
252252

253253
// start runtime init. It is important to start `InitHandler` synchronously because we need to ensure the
254254
// notification channels and status fields are properly initialized before `AwaitInitialized`
@@ -297,4 +297,11 @@ func main() {
297297
case <-exitChan:
298298
}
299299

300+
gracefulCtx, cancel := context.WithTimeout(ctx, time.Millisecond*500)
301+
defer cancel()
302+
303+
if err := localStackService.AwaitCompleted(gracefulCtx); err != nil {
304+
log.Warnf("Did not gracefully complete: %w", err)
305+
}
306+
300307
}

internal/server/handler.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,8 @@ func InvokeHandler(api *LocalStackService) http.HandlerFunc {
4141
log.WithError(err).Error("Failed to decode invoke request")
4242
}
4343

44-
response, err := api.InvokeForward(req)
44+
response, err := api.InvokeForward(r.Context(), req)
4545
switch {
46-
// case errors.Is(err, rapidcore.ErrInvokeDoneFailed) || err == nil:
47-
// we can actually just continue here, error message is sent below
4846
case errors.Is(err, rapidcore.ErrInvokeTimeout):
4947
log.Debugf("Got invoke timeout")
5048
errorResponse := localstack.ErrorResponse{

internal/server/interop.go

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,63 +3,68 @@ package server
33
import (
44
"context"
55
"encoding/json"
6+
"errors"
67
"fmt"
78
"net/http"
8-
"sync"
99

1010
"github.com/aws/aws-sdk-go/aws"
1111
"github.com/localstack/lambda-runtime-init/internal/localstack"
1212
log "github.com/sirupsen/logrus"
1313
"go.amzn.com/lambda/core/directinvoke"
14+
"go.amzn.com/lambda/core/statejson"
1415
"go.amzn.com/lambda/interop"
1516
"go.amzn.com/lambda/metering"
1617
"go.amzn.com/lambda/rapi/model"
1718
"go.amzn.com/lambda/rapidcore"
1819
"golang.org/x/sync/errgroup"
1920
)
2021

21-
type CustomInteropServer struct {
22+
type LocalStackInteropsServer struct {
2223
*rapidcore.Server
2324
localStackAdapter *localstack.LocalStackClient
24-
mutex *sync.Mutex
2525
}
2626

27-
func NewInteropServer(server *rapidcore.Server, ls *localstack.LocalStackClient) *CustomInteropServer {
28-
return &CustomInteropServer{
27+
func NewInteropServer(server *rapidcore.Server, ls *localstack.LocalStackClient) *LocalStackInteropsServer {
28+
return &LocalStackInteropsServer{
2929
Server: server,
3030
localStackAdapter: ls,
31-
mutex: &sync.Mutex{},
3231
}
3332
}
3433

35-
func (c *CustomInteropServer) Invoke(responseWriter http.ResponseWriter, invoke *interop.Invoke) error {
34+
func (c *LocalStackInteropsServer) Execute(ctx context.Context, responseWriter http.ResponseWriter, invoke *interop.Invoke) error {
3635
ctx, cancel := context.WithTimeout(context.Background(), c.Server.GetInvokeTimeout())
3736
defer cancel()
3837

39-
if err := c.reserveForInvoke(ctx, invoke); err != nil {
38+
if err := c.reserve(ctx, invoke); err != nil {
4039
return err
4140
}
4241

43-
return c.executeInvoke(ctx, responseWriter, invoke)
42+
if err := c.executeInvoke(ctx, responseWriter, invoke); err != nil {
43+
return err
44+
}
45+
46+
return nil
4447
}
4548

46-
func (c *CustomInteropServer) executeInvoke(ctx context.Context, responseWriter http.ResponseWriter, invoke *interop.Invoke) error {
49+
func (c *LocalStackInteropsServer) Invoke(responseWriter http.ResponseWriter, invoke *interop.Invoke) error {
50+
return c.Execute(context.Background(), responseWriter, invoke)
51+
}
52+
53+
func (c *LocalStackInteropsServer) executeInvoke(ctx context.Context, responseWriter http.ResponseWriter, invoke *interop.Invoke) error {
4754
g, gCtx := errgroup.WithContext(ctx)
4855

4956
g.Go(func() error {
5057
isDirect := directinvoke.MaxDirectResponseSize > interop.MaxPayloadSize
51-
if err := c.Server.FastInvoke(responseWriter, invoke, isDirect); err != nil {
52-
log.Debugf("FastInvoke() error: %s", err)
58+
err := c.Server.FastInvoke(responseWriter, invoke, isDirect)
59+
if err != nil {
60+
log.WithError(err).Debug("FastInvoke() failed")
5361
}
54-
return nil
62+
return err
5563
})
5664

5765
g.Go(func() error {
58-
_, err := c.Server.AwaitRelease()
59-
if err != nil {
60-
return c.handleReleaseError(err)
61-
}
62-
return nil
66+
_, err := c.AwaitRelease()
67+
return err
6368
})
6469

6570
done := make(chan error, 1)
@@ -71,11 +76,17 @@ func (c *CustomInteropServer) executeInvoke(ctx context.Context, responseWriter
7176
case err := <-done:
7277
return err
7378
case <-gCtx.Done():
74-
return c.handleTimeout()
79+
if errors.Is(gCtx.Err(), context.DeadlineExceeded) {
80+
if _, resetErr := c.Server.Reset("Timeout", 2000); resetErr != nil {
81+
log.WithError(resetErr).Errorf("Reset failed")
82+
}
83+
return rapidcore.ErrInvokeTimeout
84+
}
85+
return nil
7586
}
7687
}
7788

78-
func (c *CustomInteropServer) reserveForInvoke(ctx context.Context, invoke *interop.Invoke) error {
89+
func (c *LocalStackInteropsServer) reserve(ctx context.Context, invoke *interop.Invoke) error {
7990
reserveResp, err := c.Server.Reserve(invoke.ID, invoke.TraceID, invoke.LambdaSegmentID)
8091
if err != nil {
8192
return err
@@ -89,12 +100,18 @@ func (c *CustomInteropServer) reserveForInvoke(ctx context.Context, invoke *inte
89100
switch err {
90101
case rapidcore.ErrInitDoneFailed:
91102
if _, resetErr := c.Server.Reset("InitFailed", 2000); resetErr != nil {
92-
log.Errorf("Reset failed: %v", resetErr)
103+
log.WithError(resetErr).Debug("Reset failed")
93104
}
94105

95-
if _, reserveErr := c.Server.Reserve(invoke.ID, invoke.TraceID, invoke.LambdaSegmentID); reserveErr != nil {
96-
return reserveErr
106+
if _, err := c.Server.Reserve(invoke.ID, invoke.TraceID, invoke.LambdaSegmentID); err != nil {
107+
return err
97108
}
109+
110+
// If the original INIT failed, let's do another wait since we've triggered a RESERVE
111+
if err := c.Server.AwaitInitialized(); err != nil {
112+
return err
113+
}
114+
98115
return nil
99116
default:
100117
return err
@@ -104,31 +121,25 @@ func (c *CustomInteropServer) reserveForInvoke(ctx context.Context, invoke *inte
104121
return nil
105122
}
106123

107-
func (c *CustomInteropServer) handleReleaseError(err error) error {
124+
func (c *LocalStackInteropsServer) AwaitRelease() (*statejson.ReleaseResponse, error) {
125+
resp, err := c.Server.AwaitRelease()
108126
switch err {
109-
case rapidcore.ErrReleaseReservationDone:
110-
return nil
127+
case rapidcore.ErrReleaseReservationDone, nil:
128+
return resp, nil
111129
case rapidcore.ErrInitDoneFailed, rapidcore.ErrInvokeDoneFailed:
112130
if _, resetErr := c.Server.Reset("ReleaseFail", 2000); resetErr != nil {
113131
log.Errorf("Reset failed: %v", resetErr)
114132
}
115-
return err
133+
return nil, err
116134
default:
117135
if _, resetErr := c.Server.Reset("UnexpectedError", 2000); resetErr != nil {
118136
log.Errorf("Reset failed: %v", resetErr)
119137
}
120-
return err
121-
}
122-
}
123-
124-
func (c *CustomInteropServer) handleTimeout() error {
125-
if _, resetErr := c.Server.Reset("Timeout", 2000); resetErr != nil {
126-
log.Errorf("Reset failed: %v", resetErr)
138+
return nil, err
127139
}
128-
return rapidcore.ErrInvokeTimeout
129140
}
130141

131-
func (c *CustomInteropServer) SendInitErrorResponse(resp *interop.ErrorInvokeResponse) error {
142+
func (c *LocalStackInteropsServer) SendInitErrorResponse(resp *interop.ErrorInvokeResponse) error {
132143
errResp := &model.ErrorResponse{}
133144
err := json.Unmarshal(resp.Payload, errResp)
134145
if err != nil {

0 commit comments

Comments
 (0)