diff --git a/cmd/aws-lambda-rie/handlers.go b/cmd/aws-lambda-rie/handlers.go deleted file mode 100644 index 2cca12d..0000000 --- a/cmd/aws-lambda-rie/handlers.go +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package main - -import ( - "bytes" - "encoding/base64" - "fmt" - "io/ioutil" - "math" - "net/http" - "os" - "strconv" - "strings" - "time" - - "go.amzn.com/lambda/core/statejson" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapidcore" - "go.amzn.com/lambda/rapidcore/env" - - "github.com/google/uuid" - - log "github.com/sirupsen/logrus" -) - -type Sandbox interface { - Init(i *interop.Init, invokeTimeoutMs int64) - Invoke(responseWriter http.ResponseWriter, invoke *interop.Invoke) error -} - -type InteropServer interface { - Init(i *interop.Init, invokeTimeoutMs int64) error - AwaitInitialized() error - FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error - Reserve(id string, traceID, lambdaSegmentID string) (*rapidcore.ReserveResponse, error) - Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) - AwaitRelease() (*statejson.InternalStateDescription, error) - Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription - InternalState() (*statejson.InternalStateDescription, error) - CurrentToken() *interop.Token - Restore(restore *interop.Restore) error -} - -var initDone bool - -func GetenvWithDefault(key string, defaultValue string) string { - envValue := os.Getenv(key) - - if envValue == "" { - return defaultValue - } - - return envValue -} - -func printEndReports(invokeId string, initDuration string, memorySize string, invokeStart time.Time, timeoutDuration time.Duration) { - // Calcuation invoke duration - invokeDuration := math.Min(float64(time.Now().Sub(invokeStart).Nanoseconds()), - float64(timeoutDuration.Nanoseconds())) / float64(time.Millisecond) - - fmt.Println("END RequestId: " + invokeId) - // We set the Max Memory Used and Memory Size to be the same (whatever it is set to) since there is - // not a clean way to get this information from rapidcore - fmt.Printf( - "REPORT RequestId: %s\t"+ - initDuration+ - "Duration: %.2f ms\t"+ - "Billed Duration: %.f ms\t"+ - "Memory Size: %s MB\t"+ - "Max Memory Used: %s MB\t\n", - invokeId, invokeDuration, math.Ceil(invokeDuration), memorySize, memorySize) -} - -func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox, bs interop.Bootstrap) { - log.Debugf("invoke: -> %s %s %v", r.Method, r.URL, r.Header) - bodyBytes, err := ioutil.ReadAll(r.Body) - if err != nil { - log.Errorf("Failed to read invoke body: %s", err) - w.WriteHeader(500) - return - } - - rawClientContext, err := base64.StdEncoding.DecodeString(r.Header.Get("X-Amz-Client-Context")) - if err != nil { - log.Errorf("Failed to decode X-Amz-Client-Context: %s", err) - w.WriteHeader(500) - return - } - - initDuration := "" - inv := GetenvWithDefault("AWS_LAMBDA_FUNCTION_TIMEOUT", "300") - timeoutDuration, _ := time.ParseDuration(inv + "s") - // Default - timeout, err := strconv.ParseInt(inv, 10, 64) - if err != nil { - panic(err) - } - - functionVersion := GetenvWithDefault("AWS_LAMBDA_FUNCTION_VERSION", "$LATEST") - memorySize := GetenvWithDefault("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", "3008") - - if !initDone { - - initStart, initEnd := InitHandler(sandbox, functionVersion, timeout, bs) - - // Calculate InitDuration - initTimeMS := math.Min(float64(initEnd.Sub(initStart).Nanoseconds()), - float64(timeoutDuration.Nanoseconds())) / float64(time.Millisecond) - - initDuration = fmt.Sprintf("Init Duration: %.2f ms\t", initTimeMS) - - // Set initDone so next invokes do not try to Init the function again - initDone = true - } - - invokeStart := time.Now() - invokePayload := &interop.Invoke{ - ID: uuid.New().String(), - InvokedFunctionArn: fmt.Sprintf("arn:aws:lambda:us-east-1:012345678912:function:%s", GetenvWithDefault("AWS_LAMBDA_FUNCTION_NAME", "test_function")), - TraceID: r.Header.Get("X-Amzn-Trace-Id"), - LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), - Payload: bytes.NewReader(bodyBytes), - ClientContext: string(rawClientContext), - } - fmt.Println("START RequestId: " + invokePayload.ID + " Version: " + functionVersion) - - // If we write to 'w' directly and waitUntilRelease fails, we won't be able to propagate error anymore - invokeResp := &ResponseWriterProxy{} - if err := sandbox.Invoke(invokeResp, invokePayload); err != nil { - switch err { - - // Reserve errors: - case rapidcore.ErrAlreadyReserved: - log.Errorf("Failed to reserve: %s", err) - w.WriteHeader(http.StatusBadRequest) - return - case rapidcore.ErrInternalServerError: - w.WriteHeader(http.StatusInternalServerError) - return - case rapidcore.ErrInitDoneFailed: - w.WriteHeader(http.StatusBadGateway) - w.Write(invokeResp.Body) - return - case rapidcore.ErrReserveReservationDone: - // TODO use http.StatusBadGateway - w.WriteHeader(http.StatusGatewayTimeout) - return - - // Invoke errors: - case rapidcore.ErrNotReserved: - case rapidcore.ErrAlreadyReplied: - case rapidcore.ErrAlreadyInvocating: - log.Errorf("Failed to set reply stream: %s", err) - w.WriteHeader(http.StatusBadRequest) - return - case rapidcore.ErrInvokeReservationDone: - // TODO use http.StatusBadGateway - w.WriteHeader(http.StatusGatewayTimeout) - return - case rapidcore.ErrInvokeResponseAlreadyWritten: - return - // AwaitRelease errors: - case rapidcore.ErrInvokeDoneFailed: - w.WriteHeader(http.StatusBadGateway) - w.Write(invokeResp.Body) - return - case rapidcore.ErrReleaseReservationDone: - // TODO return sandbox status when we implement async reset handling - // TODO use http.StatusOK - w.WriteHeader(http.StatusGatewayTimeout) - return - case rapidcore.ErrInvokeTimeout: - printEndReports(invokePayload.ID, initDuration, memorySize, invokeStart, timeoutDuration) - - w.Write([]byte(fmt.Sprintf("Task timed out after %d.00 seconds", timeout))) - time.Sleep(100 * time.Millisecond) - //initDone = false - return - } - } - - printEndReports(invokePayload.ID, initDuration, memorySize, invokeStart, timeoutDuration) - - if invokeResp.StatusCode != 0 { - w.WriteHeader(invokeResp.StatusCode) - } - w.Write(invokeResp.Body) -} - -func InitHandler(sandbox Sandbox, functionVersion string, timeout int64, bs interop.Bootstrap) (time.Time, time.Time) { - additionalFunctionEnvironmentVariables := map[string]string{} - - // Add default Env Vars if they were not defined. This is a required otherwise 1p Python2.7, Python3.6, and - // possibly others pre runtime API runtimes will fail. This will be overwritten if they are defined on the system. - additionalFunctionEnvironmentVariables["AWS_LAMBDA_LOG_GROUP_NAME"] = "/aws/lambda/Functions" - additionalFunctionEnvironmentVariables["AWS_LAMBDA_LOG_STREAM_NAME"] = "$LATEST" - additionalFunctionEnvironmentVariables["AWS_LAMBDA_FUNCTION_VERSION"] = "$LATEST" - additionalFunctionEnvironmentVariables["AWS_LAMBDA_FUNCTION_MEMORY_SIZE"] = "3008" - additionalFunctionEnvironmentVariables["AWS_LAMBDA_FUNCTION_NAME"] = "test_function" - - // Forward Env Vars from the running system (container) to what the function can view. Without this, Env Vars will - // not be viewable when the function runs. - for _, env := range os.Environ() { - // Split the env into by the first "=". This will account for if the env var's value has a '=' in it - envVar := strings.SplitN(env, "=", 2) - additionalFunctionEnvironmentVariables[envVar[0]] = envVar[1] - } - - initStart := time.Now() - // pass to rapid - sandbox.Init(&interop.Init{ - Handler: GetenvWithDefault("AWS_LAMBDA_FUNCTION_HANDLER", os.Getenv("_HANDLER")), - AwsKey: os.Getenv("AWS_ACCESS_KEY_ID"), - AwsSecret: os.Getenv("AWS_SECRET_ACCESS_KEY"), - AwsSession: os.Getenv("AWS_SESSION_TOKEN"), - XRayDaemonAddress: "0.0.0.0:0", // TODO - FunctionName: GetenvWithDefault("AWS_LAMBDA_FUNCTION_NAME", "test_function"), - FunctionVersion: functionVersion, - RuntimeInfo: interop.RuntimeInfo{ - ImageJSON: "{}", - Arn: "", - Version: ""}, - CustomerEnvironmentVariables: additionalFunctionEnvironmentVariables, - SandboxType: interop.SandboxClassic, - Bootstrap: bs, - EnvironmentVariables: env.NewEnvironment(), - }, timeout*1000) - initEnd := time.Now() - return initStart, initEnd -} diff --git a/cmd/aws-lambda-rie/http.go b/cmd/aws-lambda-rie/http.go deleted file mode 100644 index 88bd39b..0000000 --- a/cmd/aws-lambda-rie/http.go +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package main - -import ( - "net/http" - - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapidcore" -) - -func startHTTPServer(ipport string, sandbox *rapidcore.SandboxBuilder, bs interop.Bootstrap) { - srv := &http.Server{ - Addr: ipport, - } - - // Pass a channel - http.HandleFunc("/2015-03-31/functions/function/invocations", func(w http.ResponseWriter, r *http.Request) { - InvokeHandler(w, r, sandbox.LambdaInvokeAPI(), bs) - }) - - // go routine (main thread waits) - if err := srv.ListenAndServe(); err != nil { - log.Panic(err) - } - - log.Warnf("Listening on %s", ipport) -} diff --git a/cmd/aws-lambda-rie/main.go b/cmd/aws-lambda-rie/main.go deleted file mode 100644 index bd15402..0000000 --- a/cmd/aws-lambda-rie/main.go +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package main - -import ( - "context" - "fmt" - "net" - "os" - "runtime/debug" - - "github.com/jessevdk/go-flags" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapidcore" - - log "github.com/sirupsen/logrus" -) - -const ( - optBootstrap = "/opt/bootstrap" - runtimeBootstrap = "/var/runtime/bootstrap" -) - -type options struct { - LogLevel string `long:"log-level" description:"The level of AWS Lambda Runtime Interface Emulator logs to display. Can also be set by the environment variable 'LOG_LEVEL'. Defaults to the value 'info'."` - InitCachingEnabled bool `long:"enable-init-caching" description:"Enable support for Init Caching"` - // Do not have a default value so we do not need to keep it in sync with the default value in lambda/rapidcore/sandbox_builder.go - RuntimeAPIAddress string `long:"runtime-api-address" description:"The address of the AWS Lambda Runtime API to communicate with the Lambda execution environment."` - RuntimeInterfaceEmulatorAddress string `long:"runtime-interface-emulator-address" default:"0.0.0.0:8080" description:"The address for the AWS Lambda Runtime Interface Emulator to accept HTTP request upon."` -} - -func main() { - // More frequent GC reduces the tail latencies, equivalent to export GOGC=33 - debug.SetGCPercent(33) - - opts, args := getCLIArgs() - - logLevel := "info" - - // If you specify an option by using a parameter on the CLI command line, it overrides any value from either the corresponding environment variable. - // - // https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-envvars.html - if opts.LogLevel != "" { - logLevel = opts.LogLevel - } else if envLogLevel, envLogLevelSet := os.LookupEnv("LOG_LEVEL"); envLogLevelSet { - logLevel = envLogLevel - } - - rapidcore.SetLogLevel(logLevel) - - if opts.RuntimeAPIAddress != "" { - _, _, err := net.SplitHostPort(opts.RuntimeAPIAddress) - - if err != nil { - log.WithError(err).Fatalf("The command line value for \"--runtime-api-address\" is not a valid network address %q.", opts.RuntimeAPIAddress) - } - } - - _, _, err := net.SplitHostPort(opts.RuntimeInterfaceEmulatorAddress) - - if err != nil { - log.WithError(err).Fatalf("The command line value for \"--runtime-interface-emulator-address\" is not a valid network address %q.", opts.RuntimeInterfaceEmulatorAddress) - } - - bootstrap, handler := getBootstrap(args, opts) - sandbox := rapidcore. - NewSandboxBuilder(). - AddShutdownFunc(context.CancelFunc(func() { os.Exit(0) })). - SetExtensionsFlag(true). - SetInitCachingFlag(opts.InitCachingEnabled) - - if len(handler) > 0 { - sandbox.SetHandler(handler) - } - - if opts.RuntimeAPIAddress != "" { - sandbox.SetRuntimeAPIAddress(opts.RuntimeAPIAddress) - } - - sandboxContext, internalStateFn := sandbox.Create() - // Since we have not specified a custom interop server for standalone, we can - // directly reference the default interop server, which is a concrete type - sandbox.DefaultInteropServer().SetSandboxContext(sandboxContext) - sandbox.DefaultInteropServer().SetInternalStateGetter(internalStateFn) - - startHTTPServer(opts.RuntimeInterfaceEmulatorAddress, sandbox, bootstrap) -} - -func getCLIArgs() (options, []string) { - var opts options - parser := flags.NewParser(&opts, flags.IgnoreUnknown) - args, err := parser.ParseArgs(os.Args) - - if err != nil { - log.WithError(err).Fatal("Failed to parse command line arguments:", os.Args) - } - - return opts, args -} - -func isBootstrapFileExist(filePath string) bool { - file, err := os.Stat(filePath) - return !os.IsNotExist(err) && !file.IsDir() -} - -func getBootstrap(args []string, opts options) (interop.Bootstrap, string) { - var bootstrapLookupCmd []string - var handler string - currentWorkingDir := "/var/task" // default value - - if len(args) <= 1 { - // set default value to /var/task/bootstrap, but switch to the other options if it doesn't exist - bootstrapLookupCmd = []string{ - fmt.Sprintf("%s/bootstrap", currentWorkingDir), - } - - if !isBootstrapFileExist(bootstrapLookupCmd[0]) { - var bootstrapCmdCandidates = []string{ - optBootstrap, - runtimeBootstrap, - } - - for i, bootstrapCandidate := range bootstrapCmdCandidates { - if isBootstrapFileExist(bootstrapCandidate) { - bootstrapLookupCmd = []string{bootstrapCmdCandidates[i]} - break - } - } - } - - // handler is used later to set an env var for Lambda Image support - handler = "" - } else if len(args) > 1 { - - bootstrapLookupCmd = args[1:] - - if cwd, err := os.Getwd(); err == nil { - currentWorkingDir = cwd - } - - if len(args) > 2 { - // Assume last arg is the handler - handler = args[len(args)-1] - } - - log.Infof("exec '%s' (cwd=%s, handler=%s)", args[1], currentWorkingDir, handler) - - } else { - log.Panic("insufficient arguments: bootstrap not provided") - } - - return NewSimpleBootstrap(bootstrapLookupCmd, currentWorkingDir), handler -} diff --git a/cmd/aws-lambda-rie/simple_bootstrap.go b/cmd/aws-lambda-rie/simple_bootstrap.go deleted file mode 100644 index c9111a2..0000000 --- a/cmd/aws-lambda-rie/simple_bootstrap.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package main - -import ( - "fmt" - "os" - "path/filepath" - - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapidcore/env" -) - -// the type implement a simpler version of the Bootstrap -// this is useful in the Standalone Core implementation. -type simpleBootstrap struct { - cmd []string - workingDir string -} - -func NewSimpleBootstrap(cmd []string, currentWorkingDir string) interop.Bootstrap { - if currentWorkingDir == "" { - // use the root directory as the default working directory - currentWorkingDir = "/" - } - - // a single candidate command makes it automatically valid - return &simpleBootstrap{ - cmd: cmd, - workingDir: currentWorkingDir, - } -} - -func (b *simpleBootstrap) Cmd() ([]string, error) { - return b.cmd, nil -} - -// Cwd returns the working directory of the bootstrap process -// The path is validated against the chroot identified by `root` -func (b *simpleBootstrap) Cwd() (string, error) { - if !filepath.IsAbs(b.workingDir) { - return "", fmt.Errorf("the working directory '%s' is invalid, it needs to be an absolute path", b.workingDir) - } - - // evaluate the path relatively to the domain's mnt namespace root - if _, err := os.Stat(b.workingDir); os.IsNotExist(err) { - return "", fmt.Errorf("the working directory doesn't exist: %s", b.workingDir) - } - - return b.workingDir, nil -} - -// Env returns the environment variables available to -// the bootstrap process -func (b *simpleBootstrap) Env(e *env.Environment) map[string]string { - return e.RuntimeExecEnv() -} - -// ExtraFiles returns the extra file descriptors apart from 1 & 2 to be passed to runtime -func (b *simpleBootstrap) ExtraFiles() []*os.File { - return make([]*os.File, 0) -} - -func (b *simpleBootstrap) CachedFatalError(err error) (fatalerror.ErrorType, string, bool) { - // not implemented as it is not needed in Core but we need to fullfil the interface anyway - return fatalerror.ErrorType(""), "", false -} diff --git a/cmd/aws-lambda-rie/simple_bootstrap_test.go b/cmd/aws-lambda-rie/simple_bootstrap_test.go deleted file mode 100644 index de00ee2..0000000 --- a/cmd/aws-lambda-rie/simple_bootstrap_test.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package main - -import ( - "os" - "reflect" - "testing" - - "go.amzn.com/lambda/rapidcore/env" - - "github.com/stretchr/testify/assert" -) - -func TestSimpleBootstrap(t *testing.T) { - tmpFile, err := os.CreateTemp("", "oci-test-bootstrap") - assert.NoError(t, err) - defer os.Remove(tmpFile.Name()) - - // Setup single cmd candidate - file := []string{tmpFile.Name(), "--arg1 s", "foo"} - cmdCandidate := file - - // Setup working dir - cwd, err := os.Getwd() - assert.NoError(t, err) - - // Setup environment - environment := env.NewEnvironment() - environment.StoreRuntimeAPIEnvironmentVariable("host:port") - environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") - - // Test - b := NewSimpleBootstrap(cmdCandidate, cwd) - bCwd, err := b.Cwd() - assert.NoError(t, err) - assert.Equal(t, cwd, bCwd) - assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) - - cmd, err := b.Cmd() - assert.NoError(t, err) - assert.Equal(t, file, cmd) -} - -func TestSimpleBootstrapCmdNonExistingCandidate(t *testing.T) { - // Setup inexistent single cmd candidate - file := []string{"/foo/bar", "--arg1 s", "foo"} - cmdCandidate := file - - // Setup working dir - cwd, err := os.Getwd() - assert.NoError(t, err) - - // Setup environment - environment := env.NewEnvironment() - environment.StoreRuntimeAPIEnvironmentVariable("host:port") - environment.StoreEnvironmentVariablesFromInit(map[string]string{}, "", "", "", "", "", "") - - // Test - b := NewSimpleBootstrap(cmdCandidate, cwd) - bCwd, err := b.Cwd() - assert.NoError(t, err) - assert.Equal(t, cwd, bCwd) - assert.True(t, reflect.DeepEqual(environment.RuntimeExecEnv(), b.Env(environment))) - - // No validations run against single candidates - cmd, err := b.Cmd() - assert.NoError(t, err) - assert.Equal(t, file, cmd) -} - -func TestSimpleBootstrapCmdDefaultWorkingDir(t *testing.T) { - b := NewSimpleBootstrap([]string{}, "") - bCwd, err := b.Cwd() - assert.NoError(t, err) - assert.Equal(t, "/", bCwd) -} diff --git a/cmd/aws-lambda-rie/util.go b/cmd/aws-lambda-rie/util.go deleted file mode 100644 index 1d539ec..0000000 --- a/cmd/aws-lambda-rie/util.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package main - -import ( - "fmt" - "net/http") - -type ErrorType int - -const ( - ClientInvalidRequest ErrorType = iota -) - -func (t ErrorType) String() string { - switch t { - case ClientInvalidRequest: - return "Client.InvalidRequest" - } - return fmt.Sprintf("Cannot stringify standalone.ErrorType.%d", int(t)) -} - -type ResponseWriterProxy struct { - Body []byte - StatusCode int -} - -func (w *ResponseWriterProxy) Header() http.Header { - return http.Header{} -} - -func (w *ResponseWriterProxy) Write(b []byte) (int, error) { - w.Body = b - return 0, nil -} - -func (w *ResponseWriterProxy) WriteHeader(statusCode int) { - w.StatusCode = statusCode -} - -func (w *ResponseWriterProxy) IsError() bool { - return w.StatusCode != 0 && w.StatusCode/100 != 2 -} diff --git a/cmd/localstack/awsutil.go b/cmd/localstack/awsutil.go deleted file mode 100644 index c7fcbc4..0000000 --- a/cmd/localstack/awsutil.go +++ /dev/null @@ -1,234 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 -// LOCALSTACK CHANGES 2022-03-10: modified/collected file from /cmd/aws-lambda-rie/* into this util -// LOCALSTACK CHANGES 2022-03-10: minor refactoring of PrintEndReports -// LOCALSTACK CHANGES 2023-10-06: reflect getBootstrap and InitHandler API updates - -package main - -import ( - "context" - "fmt" - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapidcore/env" - "golang.org/x/sys/unix" - "io" - "io/fs" - "math" - "net/http" - "os" - "path/filepath" - "strings" - "time" -) - -const ( - optBootstrap = "/opt/bootstrap" - runtimeBootstrap = "/var/runtime/bootstrap" -) - -func isBootstrapFileExist(filePath string) bool { - file, err := os.Stat(filePath) - return !os.IsNotExist(err) && !file.IsDir() -} - -func getBootstrap(args []string) (interop.Bootstrap, string) { - var bootstrapLookupCmd []string - var handler string - currentWorkingDir := "/var/task" // default value - - if len(args) <= 1 { - // set default value to /var/task/bootstrap, but switch to the other options if it doesn't exist - bootstrapLookupCmd = []string{ - fmt.Sprintf("%s/bootstrap", currentWorkingDir), - } - - if !isBootstrapFileExist(bootstrapLookupCmd[0]) { - var bootstrapCmdCandidates = []string{ - optBootstrap, - runtimeBootstrap, - } - - for i, bootstrapCandidate := range bootstrapCmdCandidates { - if isBootstrapFileExist(bootstrapCandidate) { - bootstrapLookupCmd = []string{bootstrapCmdCandidates[i]} - break - } - } - } - - // handler is used later to set an env var for Lambda Image support - handler = "" - } else if len(args) > 1 { - - bootstrapLookupCmd = args[1:] - - if cwd, err := os.Getwd(); err == nil { - currentWorkingDir = cwd - } - - if len(args) > 2 { - // Assume last arg is the handler - handler = args[len(args)-1] - } - - log.Infof("exec '%s' (cwd=%s, handler=%s)", args[1], currentWorkingDir, handler) - - } else { - log.Panic("insufficient arguments: bootstrap not provided") - } - - err := unix.Access(bootstrapLookupCmd[0], unix.X_OK) - if err != nil { - log.Debug("Bootstrap not executable, setting permissions to 0755...", bootstrapLookupCmd[0]) - err = os.Chmod(bootstrapLookupCmd[0], 0755) - if err != nil { - log.Warn("Error setting bootstrap to 0755 permissions: ", bootstrapLookupCmd[0], err) - } - } - - return NewSimpleBootstrap(bootstrapLookupCmd, currentWorkingDir), handler -} - -func PrintEndReports(invokeId string, initDuration string, memorySize string, invokeStart time.Time, timeoutDuration time.Duration, w io.Writer) { - // Calculate invoke duration - invokeDuration := math.Min(float64(time.Now().Sub(invokeStart).Nanoseconds()), - float64(timeoutDuration.Nanoseconds())) / float64(time.Millisecond) - - _, _ = fmt.Fprintln(w, "END RequestId: "+invokeId) - // We set the Max Memory Used and Memory Size to be the same (whatever it is set to) since there is - // not a clean way to get this information from rapidcore - _, _ = fmt.Fprintf(w, - "REPORT RequestId: %s\t"+ - initDuration+ - "Duration: %.2f ms\t"+ - "Billed Duration: %.f ms\t"+ - "Memory Size: %s MB\t"+ - "Max Memory Used: %s MB\t\n", - invokeId, invokeDuration, math.Ceil(invokeDuration), memorySize, memorySize) -} - -type Sandbox interface { - Init(i *interop.Init, invokeTimeoutMs int64) - Invoke(responseWriter http.ResponseWriter, invoke *interop.Invoke) error -} - -// GetenvWithDefault returns the value of the environment variable key or the defaultValue if key is not set -func GetenvWithDefault(key string, defaultValue string) string { - envValue, ok := os.LookupEnv(key) - if !ok { - return defaultValue - } - - return envValue -} - -func resetListener(changeChannel <-chan bool, server *CustomInteropServer) { - for { - _, more := <-changeChannel - if !more { - return - } - log.Println("Resetting environment...") - _, err := server.Reset("HotReload", 2000) - if err != nil { - log.Warnln("Error resetting server: ", err) - } - } - -} - -func RunHotReloadingListener(server *CustomInteropServer, targetPaths []string, ctx context.Context, fileWatcherStrategy string) { - if len(targetPaths) == 1 && targetPaths[0] == "" { - log.Debugln("Hot reloading disabled.") - return - } - defaultDebouncingDuration := 500 * time.Millisecond - log.Infoln("Hot reloading enabled, starting filewatcher.", targetPaths) - changeListener, err := NewChangeListener(defaultDebouncingDuration, fileWatcherStrategy) - if err != nil { - log.Errorln("Hot reloading disabled due to change listener error.", err) - return - } - defer changeListener.Close() - go changeListener.Start() - changeListener.AddTargetPaths(targetPaths) - go resetListener(changeListener.debouncedChannel, server) - - <-ctx.Done() - log.Infoln("Closing down filewatcher.") - -} - -func getSubFolders(dirPath string) []string { - var subfolders []string - err := filepath.WalkDir(dirPath, func(path string, d fs.DirEntry, err error) error { - if err == nil && d.IsDir() { - subfolders = append(subfolders, path) - } - return err - }) - if err != nil { - log.Errorln("Error listing directory contents: ", err) - return subfolders - } - return subfolders -} - -func getSubFoldersInList(prefix string, pathList []string) (oldFolders []string, newFolders []string) { - for _, pathItem := range pathList { - if strings.HasPrefix(pathItem, prefix) { - oldFolders = append(oldFolders, pathItem) - } else { - newFolders = append(newFolders, pathItem) - } - } - return -} - -func InitHandler(sandbox Sandbox, functionVersion string, timeout int64, bs interop.Bootstrap, accountId string) (time.Time, time.Time) { - additionalFunctionEnvironmentVariables := map[string]string{} - - // Add default Env Vars if they were not defined. This is a required otherwise 1p Python2.7, Python3.6, and - // possibly others pre runtime API runtimes will fail. This will be overwritten if they are defined on the system. - additionalFunctionEnvironmentVariables["AWS_LAMBDA_LOG_GROUP_NAME"] = "/aws/lambda/Functions" - additionalFunctionEnvironmentVariables["AWS_LAMBDA_LOG_STREAM_NAME"] = "$LATEST" - additionalFunctionEnvironmentVariables["AWS_LAMBDA_FUNCTION_VERSION"] = "$LATEST" - additionalFunctionEnvironmentVariables["AWS_LAMBDA_FUNCTION_MEMORY_SIZE"] = "3008" - additionalFunctionEnvironmentVariables["AWS_LAMBDA_FUNCTION_NAME"] = "test_function" - - // Forward Env Vars from the running system (container) to what the function can view. Without this, Env Vars will - // not be viewable when the function runs. - for _, env := range os.Environ() { - // Split the env into by the first "=". This will account for if the env var's value has a '=' in it - envVar := strings.SplitN(env, "=", 2) - additionalFunctionEnvironmentVariables[envVar[0]] = envVar[1] - } - - initStart := time.Now() - // pass to rapid - sandbox.Init(&interop.Init{ - Handler: GetenvWithDefault("AWS_LAMBDA_FUNCTION_HANDLER", os.Getenv("_HANDLER")), - AwsKey: os.Getenv("AWS_ACCESS_KEY_ID"), - AwsSecret: os.Getenv("AWS_SECRET_ACCESS_KEY"), - AwsSession: os.Getenv("AWS_SESSION_TOKEN"), - AccountID: accountId, - XRayDaemonAddress: GetenvWithDefault("AWS_XRAY_DAEMON_ADDRESS", "127.0.0.1:2000"), - FunctionName: GetenvWithDefault("AWS_LAMBDA_FUNCTION_NAME", "test_function"), - FunctionVersion: functionVersion, - - // TODO: Implement runtime management controls - // https://aws.amazon.com/blogs/compute/introducing-aws-lambda-runtime-management-controls/ - RuntimeInfo: interop.RuntimeInfo{ - ImageJSON: "{}", - Arn: "", - Version: ""}, - CustomerEnvironmentVariables: additionalFunctionEnvironmentVariables, - SandboxType: interop.SandboxClassic, - Bootstrap: bs, - EnvironmentVariables: env.NewEnvironment(), - }, timeout*1000) - initEnd := time.Now() - return initStart, initEnd -} diff --git a/cmd/localstack/main.go b/cmd/localstack/main.go index 064a174..94eb46a 100644 --- a/cmd/localstack/main.go +++ b/cmd/localstack/main.go @@ -4,63 +4,44 @@ package main import ( "context" - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapidcore" "os" "runtime/debug" "strconv" "strings" -) - -type LsOpts struct { - InteropPort string - RuntimeEndpoint string - RuntimeId string - AccountId string - InitTracingPort string - User string - CodeArchives string - HotReloadingPaths []string - FileWatcherStrategy string - ChmodPaths string - LocalstackIP string - InitLogLevel string - EdgePort string - EnableXRayTelemetry string - PostInvokeWaitMS string - MaxPayloadSize string -} -func GetEnvOrDie(env string) string { - result, found := os.LookupEnv(env) - if !found { - panic("Could not find environment variable for: " + env) - } - return result -} + "github.com/localstack/lambda-runtime-init/internal/aws/xray" + "github.com/localstack/lambda-runtime-init/internal/bootstrap" + "github.com/localstack/lambda-runtime-init/internal/hotreloading" + "github.com/localstack/lambda-runtime-init/internal/logging" + "github.com/localstack/lambda-runtime-init/internal/server" + "github.com/localstack/lambda-runtime-init/internal/tracing" + "github.com/localstack/lambda-runtime-init/internal/utils" + log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/core/directinvoke" + "go.amzn.com/lambda/rapidcore" +) -func InitLsOpts() *LsOpts { - return &LsOpts{ +func InitLsOpts() *server.LsOpts { + return &server.LsOpts{ // required - RuntimeEndpoint: GetEnvOrDie("LOCALSTACK_RUNTIME_ENDPOINT"), - RuntimeId: GetEnvOrDie("LOCALSTACK_RUNTIME_ID"), - AccountId: GetenvWithDefault("LOCALSTACK_FUNCTION_ACCOUNT_ID", "000000000000"), + RuntimeEndpoint: utils.GetEnvOrDie("LOCALSTACK_RUNTIME_ENDPOINT"), + RuntimeId: utils.GetEnvOrDie("LOCALSTACK_RUNTIME_ID"), + AccountId: utils.GetEnvWithDefault("LOCALSTACK_FUNCTION_ACCOUNT_ID", "000000000000"), // optional with default - InteropPort: GetenvWithDefault("LOCALSTACK_INTEROP_PORT", "9563"), - InitTracingPort: GetenvWithDefault("LOCALSTACK_RUNTIME_TRACING_PORT", "9564"), - User: GetenvWithDefault("LOCALSTACK_USER", "sbx_user1051"), - InitLogLevel: GetenvWithDefault("LOCALSTACK_INIT_LOG_LEVEL", "warn"), - EdgePort: GetenvWithDefault("EDGE_PORT", "4566"), - MaxPayloadSize: GetenvWithDefault("LOCALSTACK_MAX_PAYLOAD_SIZE", "6291556"), + InteropPort: utils.GetEnvWithDefault("LOCALSTACK_INTEROP_PORT", "9563"), + InitTracingPort: utils.GetEnvWithDefault("LOCALSTACK_RUNTIME_TRACING_PORT", "9564"), + User: utils.GetEnvWithDefault("LOCALSTACK_USER", "sbx_user1051"), + InitLogLevel: utils.GetEnvWithDefault("LOCALSTACK_INIT_LOG_LEVEL", "warn"), + EdgePort: utils.GetEnvWithDefault("EDGE_PORT", "4566"), + MaxPayloadSize: utils.GetEnvWithDefault("LOCALSTACK_MAX_PAYLOAD_SIZE", "6291556"), // optional or empty CodeArchives: os.Getenv("LOCALSTACK_CODE_ARCHIVES"), - HotReloadingPaths: strings.Split(GetenvWithDefault("LOCALSTACK_HOT_RELOADING_PATHS", ""), ","), + HotReloadingPaths: strings.Split(utils.GetEnvWithDefault("LOCALSTACK_HOT_RELOADING_PATHS", ""), ","), FileWatcherStrategy: os.Getenv("LOCALSTACK_FILE_WATCHER_STRATEGY"), EnableXRayTelemetry: os.Getenv("LOCALSTACK_ENABLE_XRAY_TELEMETRY"), LocalstackIP: os.Getenv("LOCALSTACK_HOSTNAME"), PostInvokeWaitMS: os.Getenv("LOCALSTACK_POST_INVOKE_WAIT_MS"), - ChmodPaths: GetenvWithDefault("LOCALSTACK_CHMOD_PATHS", "[]"), + ChmodPaths: utils.GetEnvWithDefault("LOCALSTACK_CHMOD_PATHS", "[]"), } } @@ -137,43 +118,43 @@ func main() { if err != nil { log.Panicln("Please specify a number for LOCALSTACK_MAX_PAYLOAD_SIZE") } - interop.MaxPayloadSize = payloadSize + directinvoke.MaxDirectResponseSize = int64(payloadSize) // download code archive if env variable is set - if err := DownloadCodeArchives(lsOpts.CodeArchives); err != nil { + if err := utils.DownloadCodeArchives(lsOpts.CodeArchives); err != nil { log.Fatal("Failed to download code archives: " + err.Error()) } - if err := AdaptFilesystemPermissions(lsOpts.ChmodPaths); err != nil { + if err := utils.AdaptFilesystemPermissions(lsOpts.ChmodPaths); err != nil { log.Warnln("Could not change file mode of code directories:", err) } // parse CLI args - bootstrap, handler := getBootstrap(os.Args) + bootstrap, handler := bootstrap.GetBootstrap(os.Args) // Switch to non-root user and drop root privileges - if IsRootUser() && lsOpts.User != "" && lsOpts.User != "root" { + if utils.IsRootUser() && lsOpts.User != "" && lsOpts.User != "root" { uid := 993 gid := 990 - AddUser(lsOpts.User, uid, gid) + utils.AddUser(lsOpts.User, uid, gid) if err := os.Chown("/tmp", uid, gid); err != nil { log.Warnln("Could not change owner of directory /tmp:", err) } - UserLogger().Debugln("Process running as root user.") - err := DropPrivileges(lsOpts.User) + utils.UserLogger().Debugln("Process running as root user.") + err := utils.DropPrivileges(lsOpts.User) if err != nil { log.Warnln("Could not drop root privileges.", err) } else { - UserLogger().Debugln("Process running as non-root user.") + utils.UserLogger().Debugln("Process running as non-root user.") } } // file watcher for hot-reloading fileWatcherContext, cancelFileWatcher := context.WithCancel(context.Background()) - logCollector := NewLogCollector() - localStackLogsEgressApi := NewLocalStackLogsEgressAPI(logCollector) - tracer := NewLocalStackTracer() + logCollector := logging.NewLogCollector() + localStackLogsEgressApi := logging.NewLocalStackLogsEgressAPI(logCollector) + tracer := tracing.NewLocalStackTracer() // build sandbox sandbox := rapidcore. @@ -190,18 +171,18 @@ func main() { // xray daemon endpoint := "http://" + lsOpts.LocalstackIP + ":" + lsOpts.EdgePort - xrayConfig := initConfig(endpoint, xRayLogLevel) - d := initDaemon(xrayConfig, lsOpts.EnableXRayTelemetry == "1") + xrayConfig := xray.NewConfig(endpoint, xRayLogLevel) + d := xray.NewDaemon(xrayConfig, lsOpts.EnableXRayTelemetry == "1") sandbox.AddShutdownFunc(func() { log.Debugln("Shutting down xray daemon") - d.stop() + d.Stop() log.Debugln("Flushing segments in xray daemon") - d.close() + d.Close() }) - runDaemon(d) // async + d.Run() // async defaultInterop := sandbox.DefaultInteropServer() - interopServer := NewCustomInteropServer(lsOpts, defaultInterop, logCollector) + interopServer := server.NewCustomInteropServer(lsOpts, defaultInterop, logCollector) sandbox.SetInteropServer(interopServer) if len(handler) > 0 { sandbox.SetHandler(handler) @@ -218,20 +199,20 @@ func main() { interopServer.SetInternalStateGetter(internalStateFn) // get timeout - invokeTimeoutEnv := GetEnvOrDie("AWS_LAMBDA_FUNCTION_TIMEOUT") // TODO: collect all AWS_* env parsing + invokeTimeoutEnv := utils.GetEnvOrDie("AWS_LAMBDA_FUNCTION_TIMEOUT") // TODO: collect all AWS_* env parsing invokeTimeoutSeconds, err := strconv.Atoi(invokeTimeoutEnv) if err != nil { log.Fatalln(err) } - go RunHotReloadingListener(interopServer, lsOpts.HotReloadingPaths, fileWatcherContext, lsOpts.FileWatcherStrategy) + go hotreloading.RunHotReloadingListener(interopServer, lsOpts.HotReloadingPaths, fileWatcherContext, lsOpts.FileWatcherStrategy) // start runtime init. It is important to start `InitHandler` synchronously because we need to ensure the // notification channels and status fields are properly initialized before `AwaitInitialized` log.Debugln("Starting runtime init.") - InitHandler(sandbox.LambdaInvokeAPI(), GetEnvOrDie("AWS_LAMBDA_FUNCTION_VERSION"), int64(invokeTimeoutSeconds), bootstrap, lsOpts.AccountId) // TODO: replace this with a custom init + server.InitHandler(sandbox.LambdaInvokeAPI(), utils.GetEnvOrDie("AWS_LAMBDA_FUNCTION_VERSION"), int64(invokeTimeoutSeconds), bootstrap, lsOpts.AccountId) // TODO: replace this with a custom init log.Debugln("Awaiting initialization of runtime init.") - if err := interopServer.delegate.AwaitInitialized(); err != nil { + if err := interopServer.AwaitInitialized(); err != nil { // Error cases: ErrInitDoneFailed or ErrInitResetReceived log.Errorln("Runtime init failed to initialize: " + err.Error() + ". Exiting.") // NOTE: Sending the error status to LocalStack is handled beforehand in the custom_interop.go through the @@ -240,7 +221,7 @@ func main() { } log.Debugln("Completed initialization of runtime init. Sending status ready to LocalStack.") - if err := interopServer.localStackAdapter.SendStatus(Ready, []byte{}); err != nil { + if err := interopServer.SendStatus(server.Ready, []byte{}); err != nil { log.Fatalln("Failed to send status ready to LocalStack " + err.Error() + ". Exiting.") } diff --git a/debugging/Makefile b/debugging/Makefile index fe3a68f..fd04f70 100644 --- a/debugging/Makefile +++ b/debugging/Makefile @@ -1,5 +1,5 @@ # Golang EOL overview: https://endoflife.date/go -DOCKER_GOLANG_IMAGE ?= golang:1.20-bullseye +DOCKER_GOLANG_IMAGE ?= golang:1.24-bullseye # On ARM hosts, use: make ARCH=arm64 build-init # Check host architecture: uname -m diff --git a/go.mod b/go.mod index 02da1b1..c65604a 100644 --- a/go.mod +++ b/go.mod @@ -1,32 +1,27 @@ -module go.amzn.com +module github.com/localstack/lambda-runtime-init go 1.24 require ( - github.com/aws/aws-lambda-go v1.46.0 github.com/aws/aws-sdk-go v1.44.298 github.com/aws/aws-xray-daemon v0.0.0-20250212175715-5defe1b8d61b github.com/cihub/seelog v0.0.0-20170130134532-f561c5e57575 github.com/fsnotify/fsnotify v1.6.0 github.com/go-chi/chi v1.5.5 - github.com/google/uuid v1.6.0 - github.com/jessevdk/go-flags v1.5.0 github.com/shirou/gopsutil v2.19.10+incompatible github.com/sirupsen/logrus v1.9.3 - github.com/stretchr/testify v1.9.0 - golang.org/x/sync v0.12.0 + go.amzn.com v0.0.0-00010101000000-000000000000 golang.org/x/sys v0.31.0 ) require ( github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d // indirect - github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-ole/go-ole v1.2.4 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/stretchr/objx v0.5.2 // indirect golang.org/x/net v0.38.0 // indirect golang.org/x/text v0.23.0 // indirect gopkg.in/yaml.v2 v2.2.8 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace go.amzn.com => github.com/aws/aws-lambda-runtime-interface-emulator v0.0.0-20250423173140-3a0772eae98d diff --git a/go.sum b/go.sum index 3f2f234..7a2908d 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d h1:G0m3OIz70MZUW github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg= github.com/aws/aws-lambda-go v1.46.0 h1:UWVnvh2h2gecOlFhHQfIPQcD8pL/f7pVCutmFl+oXU8= github.com/aws/aws-lambda-go v1.46.0/go.mod h1:dpMpZgvWx5vuQJfBt0zqBha60q7Dd7RfgJv23DymV8A= +github.com/aws/aws-lambda-runtime-interface-emulator v0.0.0-20250423173140-3a0772eae98d h1:M21cruQn5Kdpwiz8aekZaRwXNpEUCOHUwaCMqdx7bl0= +github.com/aws/aws-lambda-runtime-interface-emulator v0.0.0-20250423173140-3a0772eae98d/go.mod h1:J6YoZd6buk2amNl7zFx1jMXq/3qh9zqsFkzNvbQLkDY= github.com/aws/aws-sdk-go v1.44.298 h1:5qTxdubgV7PptZJmp/2qDwD2JL187ePL7VOxsSh1i3g= github.com/aws/aws-sdk-go v1.44.298/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= github.com/aws/aws-xray-daemon v0.0.0-20250212175715-5defe1b8d61b h1:hiV1SQDGCUECdYdKRvfBmIZnoCWggTDauTintGTkIFU= @@ -19,8 +21,6 @@ github.com/go-ole/go-ole v1.2.4 h1:nNBDSCOigTSiarFpYE9J/KtEA1IOW4CNeqT9TQDqCxI= github.com/go-ole/go-ole v1.2.4/go.mod h1:XCwSNxSkXRo4vlyPy93sltvi/qJq0jqQhjqQNIwKuxM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc= -github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= @@ -54,7 +54,6 @@ golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/cmd/localstack/xraydaemon.go b/internal/aws/xray/daemon.go similarity index 94% rename from cmd/localstack/xraydaemon.go rename to internal/aws/xray/daemon.go index 4af83e8..467993f 100644 --- a/cmd/localstack/xraydaemon.go +++ b/internal/aws/xray/daemon.go @@ -2,7 +2,7 @@ // It has been adapted for the use as a library and not as a separate executable. // The config is set directly in code instead of loading it from a config file -package main +package xray import ( "encoding/json" @@ -25,6 +25,7 @@ import ( "github.com/aws/aws-xray-daemon/pkg/telemetry" "github.com/aws/aws-xray-daemon/pkg/tracesegment" "github.com/aws/aws-xray-daemon/pkg/util" + "github.com/localstack/lambda-runtime-init/internal/utils" "github.com/aws/aws-sdk-go/aws" log "github.com/cihub/seelog" @@ -72,14 +73,14 @@ type Daemon struct { } // https://docs.aws.amazon.com/xray/latest/devguide/xray-daemon-configuration.html -func initConfig(endpoint string, logLevel string) *cfg.Config { +func NewConfig(endpoint string, logLevel string) *cfg.Config { xrayConfig := cfg.DefaultConfig() xrayConfig.Socket.UDPAddress = "127.0.0.1:2000" xrayConfig.Socket.TCPAddress = "127.0.0.1:2000" xrayConfig.Endpoint = endpoint xrayConfig.NoVerifySSL = util.Bool(true) // obvious xrayConfig.LocalMode = util.Bool(true) // skip EC2 metadata check - xrayConfig.Region = GetEnvOrDie("AWS_REGION") + xrayConfig.Region = utils.GetEnvOrDie("AWS_REGION") xrayConfig.Logging.LogLevel = logLevel //xrayConfig.TotalBufferSizeMB //xrayConfig.RoleARN = roleARN @@ -87,7 +88,7 @@ func initConfig(endpoint string, logLevel string) *cfg.Config { return xrayConfig } -func initDaemon(config *cfg.Config, enableTelemetry bool) *Daemon { +func NewDaemon(config *cfg.Config, enableTelemetry bool) *Daemon { if logFile != "" { var fileWriter io.Writer if *config.Logging.LogRotation { @@ -172,16 +173,16 @@ func initDaemon(config *cfg.Config, enableTelemetry bool) *Daemon { return daemon } -func runDaemon(daemon *Daemon) { +func (d *Daemon) Run() { // Start http server for proxying requests to xray - go daemon.server.Serve() + go d.server.Serve() - for i := 0; i < daemon.receiverCount; i++ { - go daemon.poll() + for i := 0; i < d.receiverCount; i++ { + go d.Poll() } } -func (d *Daemon) close() { +func (d *Daemon) Close() { for i := 0; i < d.receiverCount; i++ { <-d.done } @@ -201,13 +202,13 @@ func (d *Daemon) close() { log.Debugf("Shutdown finished. Current epoch in nanoseconds: %v", time.Now().UnixNano()) } -func (d *Daemon) stop() { +func (d *Daemon) Stop() { d.sock.Close() d.server.Close() } // Returns number of bytes read from socket connection. -func (d *Daemon) read(buf *[]byte) int { +func (d *Daemon) Read(buf *[]byte) int { bufVal := *buf rlen, err := d.sock.Read(bufVal) switch err := err.(type) { @@ -225,7 +226,7 @@ func (d *Daemon) read(buf *[]byte) int { return rlen } -func (d *Daemon) poll() { +func (d *Daemon) Poll() { separator := []byte(protocolSeparator) fallBackBuffer := make([]byte, d.receiveBufferSize) splitBuf := make([][]byte, 2) @@ -238,7 +239,7 @@ func (d *Daemon) poll() { bufPointer = &fallBackBuffer fallbackPointerUsed = true } - rlen := d.read(bufPointer) + rlen := d.Read(bufPointer) if rlen > 0 && d.enableTelemetry { telemetry.T.SegmentReceived(1) } diff --git a/cmd/localstack/simple_bootstrap.go b/internal/bootstrap/simple.go similarity index 99% rename from cmd/localstack/simple_bootstrap.go rename to internal/bootstrap/simple.go index c9111a2..4a65cfb 100644 --- a/cmd/localstack/simple_bootstrap.go +++ b/internal/bootstrap/simple.go @@ -1,7 +1,7 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package main +package bootstrap import ( "fmt" diff --git a/internal/bootstrap/util.go b/internal/bootstrap/util.go new file mode 100644 index 0000000..06de122 --- /dev/null +++ b/internal/bootstrap/util.go @@ -0,0 +1,78 @@ +package bootstrap + +import ( + "fmt" + "os" + + log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/interop" + "golang.org/x/sys/unix" +) + +const ( + optBootstrap = "/opt/bootstrap" + runtimeBootstrap = "/var/runtime/bootstrap" +) + +func IsBootstrapFileExist(filePath string) bool { + file, err := os.Stat(filePath) + return !os.IsNotExist(err) && !file.IsDir() +} + +func GetBootstrap(args []string) (interop.Bootstrap, string) { + var bootstrapLookupCmd []string + var handler string + currentWorkingDir := "/var/task" // default value + + if len(args) <= 1 { + // set default value to /var/task/bootstrap, but switch to the other options if it doesn't exist + bootstrapLookupCmd = []string{ + fmt.Sprintf("%s/bootstrap", currentWorkingDir), + } + + if !IsBootstrapFileExist(bootstrapLookupCmd[0]) { + var bootstrapCmdCandidates = []string{ + optBootstrap, + runtimeBootstrap, + } + + for i, bootstrapCandidate := range bootstrapCmdCandidates { + if IsBootstrapFileExist(bootstrapCandidate) { + bootstrapLookupCmd = []string{bootstrapCmdCandidates[i]} + break + } + } + } + + // handler is used later to set an env var for Lambda Image support + handler = "" + } else if len(args) > 1 { + + bootstrapLookupCmd = args[1:] + + if cwd, err := os.Getwd(); err == nil { + currentWorkingDir = cwd + } + + if len(args) > 2 { + // Assume last arg is the handler + handler = args[len(args)-1] + } + + log.Infof("exec '%s' (cwd=%s, handler=%s)", args[1], currentWorkingDir, handler) + + } else { + log.Panic("insufficient arguments: bootstrap not provided") + } + + err := unix.Access(bootstrapLookupCmd[0], unix.X_OK) + if err != nil { + log.Debug("Bootstrap not executable, setting permissions to 0755...", bootstrapLookupCmd[0]) + err = os.Chmod(bootstrapLookupCmd[0], 0755) + if err != nil { + log.Warn("Error setting bootstrap to 0755 permissions: ", bootstrapLookupCmd[0], err) + } + } + + return NewSimpleBootstrap(bootstrapLookupCmd, currentWorkingDir), handler +} diff --git a/cmd/localstack/filenotify/filenotify.go b/internal/filenotify/filenotify.go similarity index 100% rename from cmd/localstack/filenotify/filenotify.go rename to internal/filenotify/filenotify.go diff --git a/cmd/localstack/filenotify/fsnotify.go b/internal/filenotify/fsnotify.go similarity index 100% rename from cmd/localstack/filenotify/fsnotify.go rename to internal/filenotify/fsnotify.go diff --git a/cmd/localstack/filenotify/poller.go b/internal/filenotify/poller.go similarity index 100% rename from cmd/localstack/filenotify/poller.go rename to internal/filenotify/poller.go diff --git a/cmd/localstack/hotreloading.go b/internal/hotreloading/listener.go similarity index 81% rename from cmd/localstack/hotreloading.go rename to internal/hotreloading/listener.go index a0c4467..3a3083e 100644 --- a/cmd/localstack/hotreloading.go +++ b/internal/hotreloading/listener.go @@ -1,11 +1,14 @@ -package main +package hotreloading import ( - "github.com/fsnotify/fsnotify" - log "github.com/sirupsen/logrus" - "go.amzn.com/cmd/localstack/filenotify" "os" "time" + + "github.com/fsnotify/fsnotify" + "github.com/localstack/lambda-runtime-init/internal/filenotify" + "github.com/localstack/lambda-runtime-init/internal/utils" + log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/rapidcore/standalone" ) type ChangeListener struct { @@ -50,7 +53,7 @@ func (c *ChangeListener) Watch() { if err != nil { log.Errorln("Error stating event file: ", event.Name, err) } else if stat.IsDir() { - subfolders := getSubFolders(event.Name) + subfolders := utils.GetSubFolders(event.Name) for _, folder := range subfolders { err = c.watcher.Add(folder) c.watchedFolders = append(c.watchedFolders, folder) @@ -62,7 +65,7 @@ func (c *ChangeListener) Watch() { // remove in case of remove / rename (rename within the folder will trigger a separate create event) } else if event.Has(fsnotify.Remove) || event.Has(fsnotify.Rename) { // remove all file watchers if it is in our folders list - toBeRemovedDirs, newWatchedFolders := getSubFoldersInList(event.Name, c.watchedFolders) + toBeRemovedDirs, newWatchedFolders := utils.GetSubFoldersInList(event.Name, c.watchedFolders) c.watchedFolders = newWatchedFolders for _, dir := range toBeRemovedDirs { err := c.watcher.Remove(dir) @@ -85,7 +88,7 @@ func (c *ChangeListener) Watch() { func (c *ChangeListener) AddTargetPaths(targetPaths []string) { // Add all target paths and subfolders for _, targetPath := range targetPaths { - subfolders := getSubFolders(targetPath) + subfolders := utils.GetSubFolders(targetPath) log.Infoln("Subfolders: ", subfolders) for _, target := range subfolders { err := c.watcher.Add(target) @@ -124,3 +127,18 @@ func (c *ChangeListener) debounceChannel() { func (c *ChangeListener) Close() error { return c.watcher.Close() } + +func ResetListener(changeChannel <-chan bool, server standalone.InteropServer) { + for { + _, more := <-changeChannel + if !more { + return + } + log.Println("Resetting environment...") + _, err := server.Reset("HotReload", 2000) + if err != nil { + log.Warnln("Error resetting server: ", err) + } + } + +} diff --git a/internal/hotreloading/reloader.go b/internal/hotreloading/reloader.go new file mode 100644 index 0000000..06a2f35 --- /dev/null +++ b/internal/hotreloading/reloader.go @@ -0,0 +1,31 @@ +package hotreloading + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/rapidcore/standalone" +) + +func RunHotReloadingListener(server standalone.InteropServer, targetPaths []string, ctx context.Context, fileWatcherStrategy string) { + if len(targetPaths) == 1 && targetPaths[0] == "" { + log.Debugln("Hot reloading disabled.") + return + } + defaultDebouncingDuration := 500 * time.Millisecond + log.Infoln("Hot reloading enabled, starting filewatcher.", targetPaths) + changeListener, err := NewChangeListener(defaultDebouncingDuration, fileWatcherStrategy) + if err != nil { + log.Errorln("Hot reloading disabled due to change listener error.", err) + return + } + defer changeListener.Close() + go changeListener.Start() + changeListener.AddTargetPaths(targetPaths) + go ResetListener(changeListener.debouncedChannel, server) + + <-ctx.Done() + log.Infoln("Closing down filewatcher.") + +} diff --git a/cmd/localstack/logs.go b/internal/logging/collector.go similarity index 92% rename from cmd/localstack/logs.go rename to internal/logging/collector.go index 0a9a5a7..7231510 100644 --- a/cmd/localstack/logs.go +++ b/internal/logging/collector.go @@ -1,4 +1,4 @@ -package main +package logging import ( "strings" @@ -37,7 +37,7 @@ func (lc *LogCollector) reset() { lc.RuntimeLogs = []string{} } -func (lc *LogCollector) getLogs() LogResponse { +func (lc *LogCollector) GetLogs() LogResponse { lc.mutex.Lock() defer lc.mutex.Unlock() response := LogResponse{ diff --git a/cmd/localstack/logs_egress_api.go b/internal/logging/egress_api.go similarity index 98% rename from cmd/localstack/logs_egress_api.go rename to internal/logging/egress_api.go index ec567d0..b63b92f 100644 --- a/cmd/localstack/logs_egress_api.go +++ b/internal/logging/egress_api.go @@ -1,4 +1,4 @@ -package main +package logging import ( "io" diff --git a/internal/server/handler.go b/internal/server/handler.go new file mode 100644 index 0000000..3437ec7 --- /dev/null +++ b/internal/server/handler.go @@ -0,0 +1,58 @@ +package server + +import ( + "os" + "strings" + "time" + + "github.com/localstack/lambda-runtime-init/internal/utils" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/rapidcore" + "go.amzn.com/lambda/rapidcore/env" +) + +func InitHandler(sandbox rapidcore.LambdaInvokeAPI, functionVersion string, timeout int64, bs interop.Bootstrap, accountId string) (time.Time, time.Time) { + additionalFunctionEnvironmentVariables := map[string]string{} + + // Add default Env Vars if they were not defined. This is a required otherwise 1p Python2.7, Python3.6, and + // possibly others pre runtime API runtimes will fail. This will be overwritten if they are defined on the system. + additionalFunctionEnvironmentVariables["AWS_LAMBDA_LOG_GROUP_NAME"] = "/aws/lambda/Functions" + additionalFunctionEnvironmentVariables["AWS_LAMBDA_LOG_STREAM_NAME"] = "$LATEST" + additionalFunctionEnvironmentVariables["AWS_LAMBDA_FUNCTION_VERSION"] = "$LATEST" + additionalFunctionEnvironmentVariables["AWS_LAMBDA_FUNCTION_MEMORY_SIZE"] = "3008" + additionalFunctionEnvironmentVariables["AWS_LAMBDA_FUNCTION_NAME"] = "test_function" + + // Forward Env Vars from the running system (container) to what the function can view. Without this, Env Vars will + // not be viewable when the function runs. + for _, env := range os.Environ() { + // Split the env into by the first "=". This will account for if the env var's value has a '=' in it + envVar := strings.SplitN(env, "=", 2) + additionalFunctionEnvironmentVariables[envVar[0]] = envVar[1] + } + + initStart := time.Now() + // pass to rapid + sandbox.Init(&interop.Init{ + Handler: utils.GetEnvWithDefault("AWS_LAMBDA_FUNCTION_HANDLER", os.Getenv("_HANDLER")), + AwsKey: os.Getenv("AWS_ACCESS_KEY_ID"), + AwsSecret: os.Getenv("AWS_SECRET_ACCESS_KEY"), + AwsSession: os.Getenv("AWS_SESSION_TOKEN"), + AccountID: accountId, + XRayDaemonAddress: utils.GetEnvWithDefault("AWS_XRAY_DAEMON_ADDRESS", "127.0.0.1:2000"), + FunctionName: utils.GetEnvWithDefault("AWS_LAMBDA_FUNCTION_NAME", "test_function"), + FunctionVersion: functionVersion, + + // TODO: Implement runtime management controls + // https://aws.amazon.com/blogs/compute/introducing-aws-lambda-runtime-management-controls/ + RuntimeInfo: interop.RuntimeInfo{ + ImageJSON: "{}", + Arn: "", + Version: ""}, + CustomerEnvironmentVariables: additionalFunctionEnvironmentVariables, + SandboxType: interop.SandboxClassic, + Bootstrap: bs, + EnvironmentVariables: env.NewEnvironment(), + }, timeout*1000) + initEnd := time.Now() + return initStart, initEnd +} diff --git a/cmd/localstack/custom_interop.go b/internal/server/interop.go similarity index 66% rename from cmd/localstack/custom_interop.go rename to internal/server/interop.go index 1941668..9dea9c7 100644 --- a/cmd/localstack/custom_interop.go +++ b/internal/server/interop.go @@ -1,4 +1,4 @@ -package main +package server // Original implementation: lambda/rapidcore/server.go includes Server struct with state // Server interface between Runtime API and this init: lambda/interop/model.go:Server @@ -8,29 +8,52 @@ import ( "encoding/json" "errors" "fmt" - "github.com/go-chi/chi" - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/core/statejson" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapidcore" - "go.amzn.com/lambda/rapidcore/standalone" "io" "net/http" "strconv" "strings" "time" + + "github.com/go-chi/chi" + "github.com/localstack/lambda-runtime-init/internal/logging" + "github.com/localstack/lambda-runtime-init/internal/utils" + log "github.com/sirupsen/logrus" + "go.amzn.com/lambda/interop" + "go.amzn.com/lambda/rapidcore" + "go.amzn.com/lambda/rapidcore/standalone" ) +type LsOpts struct { + InteropPort string + RuntimeEndpoint string + RuntimeId string + AccountId string + InitTracingPort string + User string + CodeArchives string + HotReloadingPaths []string + FileWatcherStrategy string + ChmodPaths string + LocalstackIP string + InitLogLevel string + EdgePort string + EnableXRayTelemetry string + PostInvokeWaitMS string + MaxPayloadSize string +} + +// Create a type-alias to allow the rapidcore.Server to be easier embedded +// into the current implementation. +// TODO: Rename this from delegate. +type delegate = rapidcore.Server + type CustomInteropServer struct { - delegate *rapidcore.Server - localStackAdapter *LocalStackAdapter - port string - upstreamEndpoint string -} + // Embed the rapidcore.Server in the custom implementation. + *delegate -type LocalStackAdapter struct { - UpstreamEndpoint string - RuntimeId string + port string + upstreamEndpoint string + runtimeId string } type LocalStackStatus string @@ -40,15 +63,6 @@ const ( Error LocalStackStatus = "error" ) -func (l *LocalStackAdapter) SendStatus(status LocalStackStatus, payload []byte) error { - statusUrl := fmt.Sprintf("%s/status/%s/%s", l.UpstreamEndpoint, l.RuntimeId, status) - _, err := http.Post(statusUrl, "application/json", bytes.NewReader(payload)) - if err != nil { - return err - } - return nil -} - // The InvokeRequest is sent by LocalStack to trigger an invocation type InvokeRequest struct { InvokeId string `json:"invoke-id"` @@ -65,15 +79,12 @@ type ErrorResponse struct { StackTrace []string `json:"stackTrace,omitempty"` } -func NewCustomInteropServer(lsOpts *LsOpts, delegate interop.Server, logCollector *LogCollector) (server *CustomInteropServer) { +func NewCustomInteropServer(lsOpts *LsOpts, delegate interop.Server, logCollector *logging.LogCollector) (server *CustomInteropServer) { server = &CustomInteropServer{ delegate: delegate.(*rapidcore.Server), port: lsOpts.InteropPort, upstreamEndpoint: lsOpts.RuntimeEndpoint, - localStackAdapter: &LocalStackAdapter{ - UpstreamEndpoint: lsOpts.RuntimeEndpoint, - RuntimeId: lsOpts.RuntimeId, - }, + runtimeId: lsOpts.RuntimeId, } // TODO: extract this @@ -93,7 +104,7 @@ func NewCustomInteropServer(lsOpts *LsOpts, delegate interop.Server, logCollecto } invokeResp := &standalone.ResponseWriterProxy{} - functionVersion := GetEnvOrDie("AWS_LAMBDA_FUNCTION_VERSION") // default $LATEST + functionVersion := utils.GetEnvOrDie("AWS_LAMBDA_FUNCTION_VERSION") // default $LATEST _, _ = fmt.Fprintf(logCollector, "START RequestId: %s Version: %s\n", invokeR.InvokeId, functionVersion) invokeStart := time.Now() @@ -153,10 +164,10 @@ func NewCustomInteropServer(lsOpts *LsOpts, delegate interop.Server, logCollecto time.Sleep(time.Duration(waitMS) * time.Millisecond) } timeoutDuration := time.Duration(timeout) * time.Second - memorySize := GetEnvOrDie("AWS_LAMBDA_FUNCTION_MEMORY_SIZE") + memorySize := utils.GetEnvOrDie("AWS_LAMBDA_FUNCTION_MEMORY_SIZE") PrintEndReports(invokeR.InvokeId, "", memorySize, invokeStart, timeoutDuration, logCollector) - serializedLogs, err2 := json.Marshal(logCollector.getLogs()) + serializedLogs, err2 := json.Marshal(logCollector.GetLogs()) if err2 == nil { _, err2 = http.Post(server.upstreamEndpoint+"/invocations/"+invokeR.InvokeId+"/logs", "application/json", bytes.NewReader(serializedLogs)) // TODO: handle err @@ -197,81 +208,35 @@ func NewCustomInteropServer(lsOpts *LsOpts, delegate interop.Server, logCollecto return server } -func (c *CustomInteropServer) SendResponse(invokeID string, resp *interop.StreamableInvokeResponse) error { - log.Traceln("SendResponse called") - return c.delegate.SendResponse(invokeID, resp) -} - -func (c *CustomInteropServer) SendErrorResponse(invokeID string, resp *interop.ErrorInvokeResponse) error { - log.Traceln("SendErrorResponse called") - return c.delegate.SendErrorResponse(invokeID, resp) +func (c *CustomInteropServer) SendStatus(status LocalStackStatus, payload []byte) error { + statusUrl := fmt.Sprintf("%s/status/%s/%s", c.upstreamEndpoint, c.runtimeId, status) + _, err := http.Post(statusUrl, "application/json", bytes.NewReader(payload)) + if err != nil { + return err + } + return nil } // SendInitErrorResponse writes error response during init to a shared memory and sends GIRD FAULT. func (c *CustomInteropServer) SendInitErrorResponse(resp *interop.ErrorInvokeResponse) error { - log.Traceln("SendInitErrorResponse called") - if err := c.localStackAdapter.SendStatus(Error, resp.Payload); err != nil { + log.Debug("Forwarding SendInitErrorResponse status to LocalStack at %s.", c.upstreamEndpoint) + if err := c.SendStatus(Error, resp.Payload); err != nil { log.Fatalln("Failed to send init error to LocalStack " + err.Error() + ". Exiting.") } return c.delegate.SendInitErrorResponse(resp) } -func (c *CustomInteropServer) GetCurrentInvokeID() string { - log.Traceln("GetCurrentInvokeID called") - return c.delegate.GetCurrentInvokeID() -} - -func (c *CustomInteropServer) SendRuntimeReady() error { - log.Traceln("SendRuntimeReady called") - return c.delegate.SendRuntimeReady() -} - -func (c *CustomInteropServer) Init(i *interop.Init, invokeTimeoutMs int64) error { - log.Traceln("Init called") - return c.delegate.Init(i, invokeTimeoutMs) -} - -func (c *CustomInteropServer) Invoke(responseWriter http.ResponseWriter, invoke *interop.Invoke) error { - log.Traceln("Invoke called") - return c.delegate.Invoke(responseWriter, invoke) -} - -func (c *CustomInteropServer) FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error { - log.Traceln("FastInvoke called") - return c.delegate.FastInvoke(w, i, direct) -} - -func (c *CustomInteropServer) Reserve(id string, traceID, lambdaSegmentID string) (*rapidcore.ReserveResponse, error) { - log.Traceln("Reserve called") - return c.delegate.Reserve(id, traceID, lambdaSegmentID) -} - -func (c *CustomInteropServer) Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) { - log.Traceln("Reset called") - return c.delegate.Reset(reason, timeoutMs) -} - -func (c *CustomInteropServer) AwaitRelease() (*statejson.ReleaseResponse, error) { - log.Traceln("AwaitRelease called") - return c.delegate.AwaitRelease() -} - -func (c *CustomInteropServer) InternalState() (*statejson.InternalStateDescription, error) { - log.Traceln("InternalState called") - return c.delegate.InternalState() -} - -func (c *CustomInteropServer) CurrentToken() *interop.Token { - log.Traceln("CurrentToken called") - return c.delegate.CurrentToken() +func (c *CustomInteropServer) SendResponse(invokeID string, resp *interop.StreamableInvokeResponse) error { + // c.ToggleDirectInvoke() + return c.delegate.SendResponse(invokeID, resp) } -func (c *CustomInteropServer) SetSandboxContext(sbCtx interop.SandboxContext) { - log.Traceln("SetSandboxContext called") - c.delegate.SetSandboxContext(sbCtx) +func (c *CustomInteropServer) SendErrorResponse(invokeID string, resp *interop.ErrorInvokeResponse) error { + // c.ToggleDirectInvoke() + return c.delegate.SendErrorResponse(invokeID, resp) } -func (c *CustomInteropServer) SetInternalStateGetter(cb interop.InternalStateGetter) { - log.Traceln("SetInternalStateGetter called") - c.delegate.InternalStateGetter = cb +func (c *CustomInteropServer) ToggleDirectInvoke() { + ctx := c.delegate.GetInvokeContext() + ctx.Direct = true } diff --git a/internal/server/reporter.go b/internal/server/reporter.go new file mode 100644 index 0000000..7addbde --- /dev/null +++ b/internal/server/reporter.go @@ -0,0 +1,28 @@ +package server + +import ( + "fmt" + "io" + "math" + "time" +) + +// TODO: This should be creating a Report struct which, in turn, should be Stringified into the "REPORT RequestId: ..." syntax + +func PrintEndReports(invokeId string, initDuration string, memorySize string, invokeStart time.Time, timeoutDuration time.Duration, w io.Writer) { + // Calculate invoke duration + invokeDuration := math.Min(float64(time.Now().Sub(invokeStart).Nanoseconds()), + float64(timeoutDuration.Nanoseconds())) / float64(time.Millisecond) + + _, _ = fmt.Fprintln(w, "END RequestId: "+invokeId) + // We set the Max Memory Used and Memory Size to be the same (whatever it is set to) since there is + // not a clean way to get this information from rapidcore + _, _ = fmt.Fprintf(w, + "REPORT RequestId: %s\t"+ + initDuration+ + "Duration: %.2f ms\t"+ + "Billed Duration: %.f ms\t"+ + "Memory Size: %s MB\t"+ + "Max Memory Used: %s MB\t\n", + invokeId, invokeDuration, math.Ceil(invokeDuration), memorySize, memorySize) +} diff --git a/cmd/localstack/tracer.go b/internal/tracing/tracer.go similarity index 99% rename from cmd/localstack/tracer.go rename to internal/tracing/tracer.go index 8506a9a..37bba58 100644 --- a/cmd/localstack/tracer.go +++ b/internal/tracing/tracer.go @@ -1,8 +1,9 @@ -package main +package tracing import ( "context" "encoding/json" + "go.amzn.com/lambda/appctx" "go.amzn.com/lambda/interop" ) diff --git a/cmd/localstack/codearchive.go b/internal/utils/codearchive.go similarity index 99% rename from cmd/localstack/codearchive.go rename to internal/utils/codearchive.go index ceb6dee..a056874 100644 --- a/cmd/localstack/codearchive.go +++ b/internal/utils/codearchive.go @@ -1,14 +1,15 @@ -package main +package utils import ( "archive/zip" "encoding/json" - log "github.com/sirupsen/logrus" "io" "net/http" "os" "path" "path/filepath" + + log "github.com/sirupsen/logrus" ) type ArchiveDownload struct { diff --git a/internal/utils/env.go b/internal/utils/env.go new file mode 100644 index 0000000..10049c3 --- /dev/null +++ b/internal/utils/env.go @@ -0,0 +1,22 @@ +package utils + +import "os" + +// GetEnvWithDefault returns the value of the environment variable key or the defaultValue if key is not set +func GetEnvWithDefault(key string, defaultValue string) string { + envValue, ok := os.LookupEnv(key) + if !ok { + return defaultValue + } + + return envValue +} + +// GetEnvOrDie returns the value of the environment variable key or panics if the key is not set +func GetEnvOrDie(env string) string { + result, found := os.LookupEnv(env) + if !found { + panic("Could not find environment variable for: " + env) + } + return result +} diff --git a/cmd/localstack/file_utils.go b/internal/utils/file.go similarity index 69% rename from cmd/localstack/file_utils.go rename to internal/utils/file.go index 0de9519..e110cad 100644 --- a/cmd/localstack/file_utils.go +++ b/internal/utils/file.go @@ -1,12 +1,15 @@ -package main +package utils import ( "encoding/json" - log "github.com/sirupsen/logrus" "io" + "io/fs" "os" "path/filepath" "strconv" + "strings" + + log "github.com/sirupsen/logrus" ) type Chmod struct { @@ -66,3 +69,34 @@ func IsDirEmpty(name string) (bool, error) { } return false, err // Either not empty or error, suits both cases } + +func IsFileExist(filePath string) bool { + file, err := os.Stat(filePath) + return !os.IsNotExist(err) && !file.IsDir() +} + +func GetSubFolders(dirPath string) []string { + var subfolders []string + err := filepath.WalkDir(dirPath, func(path string, d fs.DirEntry, err error) error { + if err == nil && d.IsDir() { + subfolders = append(subfolders, path) + } + return err + }) + if err != nil { + log.Errorln("Error listing directory contents: ", err) + return subfolders + } + return subfolders +} + +func GetSubFoldersInList(prefix string, pathList []string) (oldFolders []string, newFolders []string) { + for _, pathItem := range pathList { + if strings.HasPrefix(pathItem, prefix) { + oldFolders = append(oldFolders, pathItem) + } else { + newFolders = append(newFolders, pathItem) + } + } + return +} diff --git a/cmd/localstack/user.go b/internal/utils/user.go similarity index 99% rename from cmd/localstack/user.go rename to internal/utils/user.go index 3e6da42..ed4d56a 100644 --- a/cmd/localstack/user.go +++ b/internal/utils/user.go @@ -1,14 +1,15 @@ // User utilities to create UNIX users and drop root privileges -package main +package utils import ( "fmt" - log "github.com/sirupsen/logrus" "os" "os/user" "strconv" "strings" "syscall" + + log "github.com/sirupsen/logrus" ) // AddUser adds a UNIX user (e.g., sbx_user1051) to the passwd and shadow files if not already present diff --git a/lambda/agents/agent.go b/lambda/agents/agent.go deleted file mode 100644 index cabe1fa..0000000 --- a/lambda/agents/agent.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package agents - -import ( - "os" - "path" - "path/filepath" - - log "github.com/sirupsen/logrus" -) - -// ListExternalAgentPaths return a list of external agents found in a given directory -func ListExternalAgentPaths(dir string, root string) []string { - var agentPaths []string - if !isCanonical(dir) || !isCanonical(root) { - log.Warningf("Agents base paths are not absolute and in canonical form: %s, %s", dir, root) - return agentPaths - } - fullDir := path.Join(root, dir) - files, err := os.ReadDir(fullDir) - - if err != nil { - if os.IsNotExist(err) { - log.Infof("The extension's directory %q does not exist, assuming no extensions to be loaded.", fullDir) - } else { - // TODO - Should this return an error rather than ignore failing to load? - log.WithError(err).Error("Cannot list external agents") - } - - return agentPaths - } - - for _, file := range files { - if !file.IsDir() { - // The returned path is absolute wrt to `root`. This allows - // to exec the agents in their own mount namespace - p := path.Join("/", dir, file.Name()) - agentPaths = append(agentPaths, p) - } - } - return agentPaths -} - -func isCanonical(path string) bool { - absPath, err := filepath.Abs(path) - return err == nil && absPath == path -} diff --git a/lambda/agents/agent_test.go b/lambda/agents/agent_test.go deleted file mode 100644 index e6732ff..0000000 --- a/lambda/agents/agent_test.go +++ /dev/null @@ -1,233 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package agents - -import ( - "os" - "path" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// - Test utilities - -// a small struct to hold file metadata -type fileInfo struct { - name string - mode os.FileMode - size int64 - target string // for symlinks -} - -func mkFile(name string, size int64, perm os.FileMode) fileInfo { - return fileInfo{ - name: name, - mode: perm, - size: size, - target: "", - } -} - -func mkDir(name string, perm os.FileMode) fileInfo { - return fileInfo{ - name: name, - mode: perm | os.ModeDir, - size: 0, - target: "", - } -} - -func mkLink(name, target string) fileInfo { - return fileInfo{ - name: name, - mode: os.ModeSymlink, - size: 0, - target: target, - } -} - -// populate a directory with a list of files and directories -func createFileTree(root string, fs []fileInfo) error { - - for _, info := range fs { - filename := info.name - dir := path.Join(root, path.Dir(filename)) - name := path.Base(filename) - err := os.MkdirAll(dir, 0775) - if err != nil && !os.IsExist(err) { - return err - } - if os.ModeDir == info.mode&os.ModeDir { - err := os.Mkdir(path.Join(dir, name), info.mode&os.ModePerm) - if err != nil { - return err - } - } else if os.ModeSymlink == info.mode&os.ModeSymlink { - target := path.Join(root, info.target) - _, err = os.Stat(target) - if err != nil { - return err - } - err := os.Symlink(target, path.Join(dir, name)) - if err != nil { - return err - } - } else { - file, err := os.OpenFile(path.Join(dir, name), os.O_RDWR|os.O_CREATE, info.mode&os.ModePerm) - if err != nil { - return err - } - file.Truncate(info.size) - file.Close() - } - } - - return nil -} - -// - Actual tests - -// If the agents folder is empty it is not an error -func TestBaseEmpty(t *testing.T) { - - assert := assert.New(t) - - fs := []fileInfo{ - mkDir("/opt/extensions", 0777), - } - - tmpDir, err := os.MkdirTemp("", "ext-") - require.NoError(t, err) - - createFileTree(tmpDir, fs) - defer os.RemoveAll(tmpDir) - - agents := ListExternalAgentPaths(path.Join(tmpDir, "/opt/extensions"), "/") - assert.Equal(0, len(agents)) -} - -// Test that non-existant /opt/extensions is treated as if no agents were found -func TestBaseNotExist(t *testing.T) { - - assert := assert.New(t) - - agents := ListExternalAgentPaths("/path/which/does/not/exist", "/") - assert.Equal(0, len(agents)) -} - -// Test that non-existant root dir is teaded as if no agents were found -func TestChrootNotExist(t *testing.T) { - - assert := assert.New(t) - - agents := ListExternalAgentPaths("/bin", "/does/not/exist") - assert.Equal(0, len(agents)) -} - -// Test that non-directory /opt/extensions is treated as if no agents were found -func TestBaseNotDir(t *testing.T) { - - assert := assert.New(t) - - fs := []fileInfo{ - mkFile("/opt/extensions", 1, 0777), - } - tmpDir, err := os.MkdirTemp("", "ext-") - require.NoError(t, err) - - createFileTree(tmpDir, fs) - defer os.RemoveAll(tmpDir) - - path := path.Join(tmpDir, "/opt/extensions") - agents := ListExternalAgentPaths(path, "/") - assert.Equal(0, len(agents)) -} - -// Test our ability to find agent bootstraps in the FS and return them sorted. -// Even if not all files are valid as executable agents, -// ListExternalAgentPaths() is expected to return all of them. -func TestFindAgentMixed(t *testing.T) { - - assert := assert.New(t) - - listed := []fileInfo{ - mkFile("/opt/extensions/ok2", 1, 0777), // this is ok - mkFile("/opt/extensions/ok1", 1, 0777), // this is ok as well - mkFile("/opt/extensions/not_exec", 1, 0666), // this is not executable - mkFile("/opt/extensions/not_read", 1, 0333), // this is not readable - mkFile("/opt/extensions/empty_file", 0, 0777), // this is empty - mkLink("/opt/extensions/link", "/opt/extensions/ok1"), // symlink must be ignored - } - - unlisted := []fileInfo{ - mkDir("/opt/extensions/empty_dir", 0777), // this is an empty directory - mkDir("/opt/extensions/nonempty_dir", 0777), // subdirs should not be listed - mkFile("/opt/extensions/nonempty_dir/notok", 1, 0777), // files in subdirs should not be listed - } - - fs := append([]fileInfo{}, listed...) - fs = append(fs, unlisted...) - - tmpDir, err := os.MkdirTemp("", "ext-") - require.NoError(t, err) - - createFileTree(tmpDir, fs) - defer os.RemoveAll(tmpDir) - - path := path.Join(tmpDir, "/opt/extensions") - agentPaths := ListExternalAgentPaths(path, "/") - assert.Equal(len(listed), len(agentPaths)) - last := "" - for index := range listed { - if len(last) > 0 { - assert.GreaterOrEqual(agentPaths[index], last) - } - last = agentPaths[index] - } -} - -// Test our ability to find agent bootstraps in the FS and return them sorted, -// when using a different mount namespace root for the extensiosn domain. -// Even if not all files are valid as executable agents, -// ListExternalAgentPaths() is expected to return all of them. -func TestFindAgentMixedInChroot(t *testing.T) { - - assert := assert.New(t) - - listed := []fileInfo{ - mkFile("/opt/extensions/ok2", 1, 0777), // this is ok - mkFile("/opt/extensions/ok1", 1, 0777), // this is ok as well - mkFile("/opt/extensions/not_exec", 1, 0666), // this is not executable - mkFile("/opt/extensions/not_read", 1, 0333), // this is not readable - mkFile("/opt/extensions/empty_file", 0, 0777), // this is empty - mkLink("/opt/extensions/link", "/opt/extensions/ok1"), // symlink must be ignored - } - - unlisted := []fileInfo{ - mkDir("/opt/extensions/empty_dir", 0777), // this is an empty directory - mkDir("/opt/extensions/nonempty_dir", 0777), // subdirs should not be listed - mkFile("/opt/extensions/nonempty_dir/notok", 1, 0777), // files in subdirs should not be listed - } - - fs := append([]fileInfo{}, listed...) - fs = append(fs, unlisted...) - - rootDir, err := os.MkdirTemp("", "rootfs") - require.NoError(t, err) - - createFileTree(rootDir, fs) - defer os.RemoveAll(rootDir) - - agentPaths := ListExternalAgentPaths("/opt/extensions", rootDir) - assert.Equal(len(listed), len(agentPaths)) - last := "" - for index := range listed { - if len(last) > 0 { - assert.GreaterOrEqual(agentPaths[index], last) - } - last = agentPaths[index] - } -} diff --git a/lambda/appctx/appctx.go b/lambda/appctx/appctx.go deleted file mode 100644 index 931a2ec..0000000 --- a/lambda/appctx/appctx.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package appctx - -import ( - "sync" -) - -// A Key type is used as a key for storing values in the application context. -type Key int - -type InitType int - -const ( - // AppCtxInvokeErrorTraceDataKey is used for storing deferred invoke error cause header value. - // Only used by xray. TODO refactor xray interface so it doesn't use appctx - AppCtxInvokeErrorTraceDataKey Key = iota - - // AppCtxRuntimeReleaseKey is used for storing runtime release information (parsed from User_Agent Http header string). - AppCtxRuntimeReleaseKey - - // AppCtxInteropServerKey is used to store a reference to the interop server. - AppCtxInteropServerKey - - // AppCtxResponseSenderKey is used to store a reference to the response sender - AppCtxResponseSenderKey - - // AppCtxFirstFatalErrorKey is used to store first unrecoverable error message encountered to propagate it to slicer with DONE(errortype) or DONEFAIL(errortype) - AppCtxFirstFatalErrorKey - - // AppCtxInitType is used to store the init type (init caching or plain INIT) - AppCtxInitType - - // AppCtxSandbox type is used to store the sandbox type (SandboxClassic or SandboxPreWarmed) - AppCtxSandboxType -) - -// Possible values for AppCtxInitType key -const ( - Init InitType = iota - InitCaching -) - -// ApplicationContext is an application scope context. -type ApplicationContext interface { - Store(key Key, value interface{}) - Load(key Key) (value interface{}, ok bool) - Delete(key Key) - GetOrDefault(key Key, defaultValue interface{}) interface{} - StoreIfNotExists(key Key, value interface{}) interface{} -} - -type applicationContext struct { - mux *sync.Mutex - m map[Key]interface{} -} - -func (appCtx *applicationContext) Store(key Key, value interface{}) { - appCtx.mux.Lock() - defer appCtx.mux.Unlock() - appCtx.m[key] = value -} - -func (appCtx *applicationContext) StoreIfNotExists(key Key, value interface{}) interface{} { - appCtx.mux.Lock() - defer appCtx.mux.Unlock() - existing, found := appCtx.m[key] - if found { - return existing - } - appCtx.m[key] = value - return nil -} - -func (appCtx *applicationContext) Load(key Key) (value interface{}, ok bool) { - appCtx.mux.Lock() - defer appCtx.mux.Unlock() - value, ok = appCtx.m[key] - return -} - -func (appCtx *applicationContext) Delete(key Key) { - appCtx.mux.Lock() - defer appCtx.mux.Unlock() - delete(appCtx.m, key) -} - -func (appCtx *applicationContext) GetOrDefault(key Key, defaultValue interface{}) interface{} { - if value, ok := appCtx.Load(key); ok { - return value - } - return defaultValue -} - -// NewApplicationContext returns a new instance of application context. -func NewApplicationContext() ApplicationContext { - return &applicationContext{ - mux: &sync.Mutex{}, - m: make(map[Key]interface{}), - } -} diff --git a/lambda/appctx/appctxutil.go b/lambda/appctx/appctxutil.go deleted file mode 100644 index cd6e6d3..0000000 --- a/lambda/appctx/appctxutil.go +++ /dev/null @@ -1,202 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package appctx - -import ( - "context" - "net/http" - "strings" - - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/interop" - - log "github.com/sirupsen/logrus" -) - -// This package contains a set of utility methods for accessing application -// context and application context data. - -// A ReqCtxKey type is used as a key for storing values in the request context. -type ReqCtxKey int - -// ReqCtxApplicationContextKey is used for injecting application -// context object into request context. -const ReqCtxApplicationContextKey ReqCtxKey = iota - -// MaxRuntimeReleaseLength Max length for user agent string. -const MaxRuntimeReleaseLength = 128 - -// FromRequest retrieves application context from the request context. -func FromRequest(request *http.Request) ApplicationContext { - return request.Context().Value(ReqCtxApplicationContextKey).(ApplicationContext) -} - -// RequestWithAppCtx places application context into request context. -func RequestWithAppCtx(request *http.Request, appCtx ApplicationContext) *http.Request { - return request.WithContext(context.WithValue(request.Context(), ReqCtxApplicationContextKey, appCtx)) -} - -// GetRuntimeRelease returns runtime_release str extracted from app context. -func GetRuntimeRelease(appCtx ApplicationContext) string { - return appCtx.GetOrDefault(AppCtxRuntimeReleaseKey, "").(string) -} - -// GetUserAgentFromRequest Returns the first token -seperated by a space- -// from request header 'User-Agent'. -func GetUserAgentFromRequest(request *http.Request) string { - runtimeRelease := "" - userAgent := request.Header.Get("User-Agent") - // Split around spaces and use only the first token. - if fields := strings.Fields(userAgent); len(fields) > 0 && len(fields[0]) > 0 { - runtimeRelease = fields[0] - } - return runtimeRelease -} - -// CreateRuntimeReleaseFromRequest Gets runtime features from request header -// 'Lambda-Runtime-Features', and append it to the given runtime release. -func CreateRuntimeReleaseFromRequest(request *http.Request, runtimeRelease string) string { - lambdaRuntimeFeaturesHeader := request.Header.Get("Lambda-Runtime-Features") - - // "(", ")" are not valid token characters, and potentially could invalidate runtime_release - lambdaRuntimeFeaturesHeader = strings.ReplaceAll(lambdaRuntimeFeaturesHeader, "(", "") - lambdaRuntimeFeaturesHeader = strings.ReplaceAll(lambdaRuntimeFeaturesHeader, ")", "") - - numberOfAppendedFeatures := 0 - // Available length is a maximum length available for runtime features (including delimiters). From maximal runtime - // release length we subtract what we already have plus 3 additional bytes for a space and a pair of brackets for - // list of runtime features that is added later. - runtimeReleaseLength := len(runtimeRelease) - if runtimeReleaseLength == 0 { - runtimeReleaseLength = len("Unknown") - } - availableLength := MaxRuntimeReleaseLength - runtimeReleaseLength - 3 - var lambdaRuntimeFeatures []string - - for _, feature := range strings.Fields(lambdaRuntimeFeaturesHeader) { - featureLength := len(feature) - // If featureLength <= availableLength - numberOfAppendedFeatures - // (where numberOfAppendedFeatures is equal to number of delimiters needed). - if featureLength <= availableLength-numberOfAppendedFeatures { - availableLength -= featureLength - lambdaRuntimeFeatures = append(lambdaRuntimeFeatures, feature) - numberOfAppendedFeatures++ - } - } - // Append valid features to runtime release. - if len(lambdaRuntimeFeatures) > 0 { - if runtimeRelease == "" { - runtimeRelease = "Unknown" - } - runtimeRelease += " (" + strings.Join(lambdaRuntimeFeatures, " ") + ")" - } - - return runtimeRelease -} - -// UpdateAppCtxWithRuntimeRelease extracts runtime release info from user agent & lambda runtime features -// headers and update it into appCtx. -// Sample UA: -// Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:47.0) Gecko/20100101 Firefox/47.0 -func UpdateAppCtxWithRuntimeRelease(request *http.Request, appCtx ApplicationContext) bool { - // If appCtx has runtime release value already, just append the runtime features. - if appCtxRuntimeRelease := GetRuntimeRelease(appCtx); len(appCtxRuntimeRelease) > 0 { - // if the runtime features are not appended before append them, otherwise ignore - if runtimeReleaseWithFeatures := CreateRuntimeReleaseFromRequest(request, appCtxRuntimeRelease); len(runtimeReleaseWithFeatures) > len(appCtxRuntimeRelease) && - appCtxRuntimeRelease[len(appCtxRuntimeRelease)-1] != ')' { - appCtx.Store(AppCtxRuntimeReleaseKey, runtimeReleaseWithFeatures) - return true - } - return false - } - // If appCtx doesn't have runtime release value, update it with user agent and runtime features. - if runtimeReleaseWithFeatures := CreateRuntimeReleaseFromRequest(request, - GetUserAgentFromRequest(request)); runtimeReleaseWithFeatures != "" { - appCtx.Store(AppCtxRuntimeReleaseKey, runtimeReleaseWithFeatures) - return true - } - return false -} - -// StoreInvokeErrorTraceData stores invocation error x-ray cause header in the applicaton context. -func StoreInvokeErrorTraceData(appCtx ApplicationContext, invokeError *interop.InvokeErrorTraceData) { - appCtx.Store(AppCtxInvokeErrorTraceDataKey, invokeError) -} - -// LoadInvokeErrorTraceData retrieves invocation error x-ray cause header from the application context. -func LoadInvokeErrorTraceData(appCtx ApplicationContext) *interop.InvokeErrorTraceData { - v, ok := appCtx.Load(AppCtxInvokeErrorTraceDataKey) - if ok { - return v.(*interop.InvokeErrorTraceData) - } - return nil -} - -// StoreInteropServer stores a reference to the interop server. -func StoreInteropServer(appCtx ApplicationContext, server interop.Server) { - appCtx.Store(AppCtxInteropServerKey, server) -} - -// LoadInteropServer retrieves the interop server. -func LoadInteropServer(appCtx ApplicationContext) interop.Server { - v, ok := appCtx.Load(AppCtxInteropServerKey) - if ok { - return v.(interop.Server) - } - return nil -} - -// StoreResponseSender stores a reference to the response sender -func StoreResponseSender(appCtx ApplicationContext, server interop.InvokeResponseSender) { - appCtx.Store(AppCtxResponseSenderKey, server) -} - -// LoadResponseSender retrieves the response sender -func LoadResponseSender(appCtx ApplicationContext) interop.InvokeResponseSender { - v, ok := appCtx.Load(AppCtxResponseSenderKey) - if ok { - return v.(interop.InvokeResponseSender) - } - return nil -} - -// StoreFirstFatalError stores unrecoverable error code in appctx once. This error is considered to be the rootcause of failure -func StoreFirstFatalError(appCtx ApplicationContext, err fatalerror.ErrorType) { - if existing := appCtx.StoreIfNotExists(AppCtxFirstFatalErrorKey, err); existing != nil { - log.Warnf("Omitting fatal error %s: %s already stored", err, existing.(fatalerror.ErrorType)) - return - } - - log.Warnf("First fatal error stored in appctx: %s", err) -} - -// LoadFirstFatalError returns stored error if found -func LoadFirstFatalError(appCtx ApplicationContext) (errorType fatalerror.ErrorType, found bool) { - v, found := appCtx.Load(AppCtxFirstFatalErrorKey) - if !found { - return "", false - } - return v.(fatalerror.ErrorType), true -} - -func StoreInitType(appCtx ApplicationContext, initCachingEnabled bool) { - if initCachingEnabled { - appCtx.Store(AppCtxInitType, InitCaching) - } else { - appCtx.Store(AppCtxInitType, Init) - } -} - -// Default Init Type is Init unless it's explicitly stored in ApplicationContext -func LoadInitType(appCtx ApplicationContext) InitType { - return appCtx.GetOrDefault(AppCtxInitType, Init).(InitType) -} - -func StoreSandboxType(appCtx ApplicationContext, sandboxType interop.SandboxType) { - appCtx.Store(AppCtxSandboxType, sandboxType) -} - -func LoadSandboxType(appCtx ApplicationContext) interop.SandboxType { - return appCtx.GetOrDefault(AppCtxSandboxType, interop.SandboxClassic).(interop.SandboxType) -} diff --git a/lambda/appctx/appctxutil_test.go b/lambda/appctx/appctxutil_test.go deleted file mode 100644 index b6df9aa..0000000 --- a/lambda/appctx/appctxutil_test.go +++ /dev/null @@ -1,227 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package appctx - -import ( - "net/http/httptest" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.amzn.com/lambda/fatalerror" - - "go.amzn.com/lambda/interop" -) - -func runTestRequestWithUserAgent(t *testing.T, userAgent string, expectedRuntimeRelease string) { - // Simple User_Agent passed. - // GIVEN - req := httptest.NewRequest("", "/", nil) - req.Header.Set("User-Agent", userAgent) - request := RequestWithAppCtx(req, NewApplicationContext()) - appCtx := request.Context().Value(ReqCtxApplicationContextKey).(ApplicationContext) - - // DO - ok := UpdateAppCtxWithRuntimeRelease(request, appCtx) - - //ASSERT - assert.True(t, ok) - ctxRuntimeRelease, ok := appCtx.Load(AppCtxRuntimeReleaseKey) - assert.True(t, ok) - assert.Equal(t, expectedRuntimeRelease, ctxRuntimeRelease, "failed to extract runtime_release token") -} - -func TestCreateRuntimeReleaseFromRequest(t *testing.T) { - tests := map[string]struct { - userAgentHeader string - lambdaRuntimeFeaturesHeader string - expectedRuntimeRelease string - }{ - "No User-Agent header": { - userAgentHeader: "", - lambdaRuntimeFeaturesHeader: "httpcl/2.0 execwr", - expectedRuntimeRelease: "Unknown (httpcl/2.0 execwr)", - }, - "No Lambda-Runtime-Features header": { - userAgentHeader: "Node.js/14.16.0", - lambdaRuntimeFeaturesHeader: "", - expectedRuntimeRelease: "Node.js/14.16.0", - }, - "Lambda-Runtime-Features header with additional spaces": { - userAgentHeader: "Node.js/14.16.0", - lambdaRuntimeFeaturesHeader: "httpcl/2.0 execwr", - expectedRuntimeRelease: "Node.js/14.16.0 (httpcl/2.0 execwr)", - }, - "Lambda-Runtime-Features header with special characters": { - userAgentHeader: "Node.js/14.16.0", - lambdaRuntimeFeaturesHeader: "httpcl/2.0@execwr-1 abcd?efg nodewr/(4.33)) nodewr/4.3", - expectedRuntimeRelease: "Node.js/14.16.0 (httpcl/2.0@execwr-1 abcd?efg nodewr/4.33 nodewr/4.3)", - }, - "Lambda-Runtime-Features header with long Lambda-Runtime-Features header": { - userAgentHeader: "Node.js/14.16.0", - lambdaRuntimeFeaturesHeader: strings.Repeat("abcdef ", MaxRuntimeReleaseLength/7), - expectedRuntimeRelease: "Node.js/14.16.0 (" + strings.Repeat("abcdef ", (MaxRuntimeReleaseLength-18-6)/7) + "abcdef)", - }, - "Lambda-Runtime-Features header with long Lambda-Runtime-Features header with UTF-8 characters": { - userAgentHeader: "Node.js/14.16.0", - lambdaRuntimeFeaturesHeader: strings.Repeat("我爱亚马逊 ", MaxRuntimeReleaseLength/16), - expectedRuntimeRelease: "Node.js/14.16.0 (" + strings.Repeat("我爱亚马逊 ", (MaxRuntimeReleaseLength-18-15)/16) + "我爱亚马逊)", - }, - } - - for _, tc := range tests { - req := httptest.NewRequest("", "/", nil) - if tc.userAgentHeader != "" { - req.Header.Set("User-Agent", tc.userAgentHeader) - } - if tc.lambdaRuntimeFeaturesHeader != "" { - req.Header.Set("Lambda-Runtime-Features", tc.lambdaRuntimeFeaturesHeader) - } - appCtx := NewApplicationContext() - request := RequestWithAppCtx(req, appCtx) - - UpdateAppCtxWithRuntimeRelease(request, appCtx) - runtimeRelease := GetRuntimeRelease(appCtx) - - assert.LessOrEqual(t, len(runtimeRelease), MaxRuntimeReleaseLength) - assert.Equal(t, tc.expectedRuntimeRelease, runtimeRelease) - } -} - -func TestUpdateAppCtxWithRuntimeRelease(t *testing.T) { - type pair struct { - in, wanted string - } - pairs := []pair{ - {"Mozilla/5.0", "Mozilla/5.0"}, - {"Mozilla/6.0 (Windows NT 6.1; Win64; x64; rv:47.0) Gecko/20100101 Firefox/47.0", "Mozilla/6.0"}, - } - for _, p := range pairs { - runTestRequestWithUserAgent(t, p.in, p.wanted) - } -} - -func TestUpdateAppCtxWithRuntimeReleaseWithoutUserAgent(t *testing.T) { - // GIVEN - // No User_Agent passed. - request := RequestWithAppCtx(httptest.NewRequest("", "/", nil), NewApplicationContext()) - appCtx := request.Context().Value(ReqCtxApplicationContextKey).(ApplicationContext) - - // DO - ok := UpdateAppCtxWithRuntimeRelease(request, appCtx) - - // ASSERT - assert.False(t, ok) - _, ok = appCtx.Load(AppCtxRuntimeReleaseKey) - assert.False(t, ok) -} - -func TestUpdateAppCtxWithRuntimeReleaseWithBlankUserAgent(t *testing.T) { - // GIVEN - req := httptest.NewRequest("", "/", nil) - req.Header.Set("User-Agent", " ") - request := RequestWithAppCtx(req, NewApplicationContext()) - appCtx := request.Context().Value(ReqCtxApplicationContextKey).(ApplicationContext) - - // DO - ok := UpdateAppCtxWithRuntimeRelease(request, appCtx) - - // ASSERT - assert.False(t, ok) - _, ok = appCtx.Load(AppCtxRuntimeReleaseKey) - assert.False(t, ok) -} - -func TestUpdateAppCtxWithRuntimeReleaseWithLambdaRuntimeFeatures(t *testing.T) { - // GIVEN - // Simple LambdaRuntimeFeatures passed. - req := httptest.NewRequest("", "/", nil) - req.Header.Set("User-Agent", "Node.js/14.16.0") - req.Header.Set("Lambda-Runtime-Features", "httpcl/2.0 execwr nodewr/4.3") - request := RequestWithAppCtx(req, NewApplicationContext()) - appCtx := request.Context().Value(ReqCtxApplicationContextKey).(ApplicationContext) - - // DO - ok := UpdateAppCtxWithRuntimeRelease(request, appCtx) - - //ASSERT - assert.True(t, ok, "runtime_release updated based only on User-Agent and valid features") - ctxRuntimeRelease, ok := appCtx.Load(AppCtxRuntimeReleaseKey) - assert.True(t, ok) - assert.Equal(t, "Node.js/14.16.0 (httpcl/2.0 execwr nodewr/4.3)", ctxRuntimeRelease) -} - -// Test that RAPID allows updating runtime_release only once -func TestUpdateAppCtxWithRuntimeReleaseMultipleTimes(t *testing.T) { - // GIVEN - firstValue := "Value1" - secondValue := "Value2" - - req := httptest.NewRequest("", "/", nil) - req.Header.Set("User-Agent", firstValue) - request := RequestWithAppCtx(req, NewApplicationContext()) - appCtx := request.Context().Value(ReqCtxApplicationContextKey).(ApplicationContext) - - // DO - ok := UpdateAppCtxWithRuntimeRelease(request, appCtx) - - // ASSERT - assert.True(t, ok) - ctxRuntimeRelease, ok := appCtx.Load(AppCtxRuntimeReleaseKey) - assert.True(t, ok) - assert.Equal(t, firstValue, ctxRuntimeRelease) - - // GIVEN - req.Header.Set("User-Agent", secondValue) - - // DO - ok = UpdateAppCtxWithRuntimeRelease(request, appCtx) - - // ASSERT - assert.False(t, ok, "failed to prevent second update of runtime_release") - ctxRuntimeRelease, ok = appCtx.Load(AppCtxRuntimeReleaseKey) - assert.True(t, ok) - assert.Equal(t, firstValue, ctxRuntimeRelease, "failed to prevent second update of runtime_release") -} - -func TestFirstFatalError(t *testing.T) { - appCtx := NewApplicationContext() - - _, found := LoadFirstFatalError(appCtx) - require.False(t, found) - - StoreFirstFatalError(appCtx, fatalerror.AgentCrash) - v, found := LoadFirstFatalError(appCtx) - require.True(t, found) - require.Equal(t, fatalerror.AgentCrash, v) - - StoreFirstFatalError(appCtx, fatalerror.AgentExitError) - v, found = LoadFirstFatalError(appCtx) - require.True(t, found) - require.Equal(t, fatalerror.AgentCrash, v) -} - -func TestStoreLoadInitType(t *testing.T) { - appCtx := NewApplicationContext() - - initType := LoadInitType(appCtx) - assert.Equal(t, Init, initType) - - StoreInitType(appCtx, true) - initType = LoadInitType(appCtx) - assert.Equal(t, InitCaching, initType) -} - -func TestStoreLoadSandboxType(t *testing.T) { - appCtx := NewApplicationContext() - - sandboxType := LoadSandboxType(appCtx) - assert.Equal(t, interop.SandboxClassic, sandboxType) - - StoreSandboxType(appCtx, interop.SandboxPreWarmed) - - sandboxType = LoadSandboxType(appCtx) - assert.Equal(t, interop.SandboxPreWarmed, sandboxType) -} diff --git a/lambda/core/agent_state_names.go b/lambda/core/agent_state_names.go deleted file mode 100644 index a5e61f3..0000000 --- a/lambda/core/agent_state_names.go +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -// String values of possibles agent states -const ( - AgentStartedStateName = "Started" - AgentRegisteredStateName = "Registered" - AgentReadyStateName = "Ready" - AgentRunningStateName = "Running" - AgentInitErrorStateName = "InitError" - AgentExitErrorStateName = "ExitError" - AgentShutdownFailedStateName = "ShutdownFailed" - AgentExitedStateName = "Exited" - AgentLaunchErrorName = "LaunchError" -) diff --git a/lambda/core/agentsmap.go b/lambda/core/agentsmap.go deleted file mode 100644 index 4ade8a9..0000000 --- a/lambda/core/agentsmap.go +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -import ( - "errors" - "github.com/google/uuid" -) - -// ErrAgentNameCollision means that agent with the same name already exists in AgentsMap -var ErrAgentNameCollision = errors.New("ErrAgentNameCollision") - -// ErrAgentIDCollision means that agent with the same ID already exists in AgentsMap -var ErrAgentIDCollision = errors.New("ErrAgentIDCollision") - -// ExternalAgentsMap stores Agents indexed by Name and ID -type ExternalAgentsMap struct { - byName map[string]*ExternalAgent - byID map[string]*ExternalAgent -} - -// NewExternalAgentsMap creates empty ExternalAgentsMap -func NewExternalAgentsMap() ExternalAgentsMap { - return ExternalAgentsMap{ - byName: make(map[string]*ExternalAgent), - byID: make(map[string]*ExternalAgent), - } -} - -// Insert places agent into ExternalAgentsMap. Error is returned if agent with this ID or name already exists -func (m *ExternalAgentsMap) Insert(a *ExternalAgent) error { - if _, nameCollision := m.FindByName(a.Name); nameCollision { - return ErrAgentNameCollision - } - - if _, idCollision := m.FindByID(a.ID); idCollision { - return ErrAgentIDCollision - } - - m.byName[a.Name] = a - m.byID[a.ID.String()] = a - - return nil -} - -// FindByName finds agent by name -func (m *ExternalAgentsMap) FindByName(name string) (agent *ExternalAgent, found bool) { - agent, found = m.byName[name] - return -} - -// FindByID finds agent by ID -func (m *ExternalAgentsMap) FindByID(id uuid.UUID) (agent *ExternalAgent, found bool) { - agent, found = m.byID[id.String()] - return -} - -// Visit iterates through agents, calling cb for each of them -func (m *ExternalAgentsMap) Visit(cb func(*ExternalAgent)) { - for _, a := range m.byName { - cb(a) - } -} - -// Size returns the number of agents contained in the datastructure -func (m *ExternalAgentsMap) Size() int { - return len(m.byName) -} - -// AsArray returns shallow copy of all agents as a single array. The order of agents is unspecified. -func (m *ExternalAgentsMap) AsArray() []*ExternalAgent { - agents := make([]*ExternalAgent, 0, len(m.byName)) - - m.Visit(func(a *ExternalAgent) { - agents = append(agents, a) - }) - - return agents -} - -func (m *ExternalAgentsMap) Clear() { - m.byName = make(map[string]*ExternalAgent) - m.byID = make(map[string]*ExternalAgent) -} - -// InternalAgentsMap stores Agents indexed by Name and ID -type InternalAgentsMap struct { - byName map[string]*InternalAgent - byID map[string]*InternalAgent -} - -// NewInternalAgentsMap creates empty InternalAgentsMap -func NewInternalAgentsMap() InternalAgentsMap { - return InternalAgentsMap{ - byName: make(map[string]*InternalAgent), - byID: make(map[string]*InternalAgent), - } -} - -// Insert places agent into InternalAgentsMap. Error is returned if agent with this ID or name already exists -func (m *InternalAgentsMap) Insert(a *InternalAgent) error { - if _, nameCollision := m.FindByName(a.Name); nameCollision { - return ErrAgentNameCollision - } - - if _, idCollision := m.FindByID(a.ID); idCollision { - return ErrAgentIDCollision - } - - m.byName[a.Name] = a - m.byID[a.ID.String()] = a - - return nil -} - -// FindByName finds agent by name -func (m *InternalAgentsMap) FindByName(name string) (agent *InternalAgent, found bool) { - agent, found = m.byName[name] - return -} - -// FindByID finds agent by ID -func (m *InternalAgentsMap) FindByID(id uuid.UUID) (agent *InternalAgent, found bool) { - agent, found = m.byID[id.String()] - return -} - -// Visit iterates through agents, calling cb for each of them -func (m *InternalAgentsMap) Visit(cb func(*InternalAgent)) { - for _, a := range m.byName { - cb(a) - } -} - -// Size returns the number of agents contained in the datastructure -func (m *InternalAgentsMap) Size() int { - return len(m.byName) -} - -// AsArray returns shallow copy of all agents as a single array. The order of agents is unspecified. -func (m *InternalAgentsMap) AsArray() []*InternalAgent { - agents := make([]*InternalAgent, 0, len(m.byName)) - - m.Visit(func(a *InternalAgent) { - agents = append(agents, a) - }) - - return agents -} - -func (m *InternalAgentsMap) Clear() { - m.byName = make(map[string]*InternalAgent) - m.byID = make(map[string]*InternalAgent) -} diff --git a/lambda/core/agentsmap_test.go b/lambda/core/agentsmap_test.go deleted file mode 100644 index 0c4b9e2..0000000 --- a/lambda/core/agentsmap_test.go +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -import ( - "github.com/google/uuid" - "github.com/stretchr/testify/require" - "testing" -) - -func TestExternalAgentsMapLookupByName(t *testing.T) { - m := NewExternalAgentsMap() - - err := m.Insert(&ExternalAgent{Name: "a", ID: uuid.New()}) - require.NoError(t, err) - agentIn := &ExternalAgent{Name: "b", ID: uuid.New()} - err = m.Insert(agentIn) - require.NoError(t, err) - err = m.Insert(&ExternalAgent{Name: "c", ID: uuid.New()}) - require.NoError(t, err) - - agentOut, found := m.FindByName(agentIn.Name) - require.True(t, found) - require.Equal(t, agentIn, agentOut) -} - -func TestExternalAgentsMapLookupByID(t *testing.T) { - m := NewExternalAgentsMap() - - err := m.Insert(&ExternalAgent{Name: "a", ID: uuid.New()}) - require.NoError(t, err) - agentIn := &ExternalAgent{Name: "b", ID: uuid.New()} - err = m.Insert(agentIn) - require.NoError(t, err) - err = m.Insert(&ExternalAgent{Name: "c", ID: uuid.New()}) - require.NoError(t, err) - - agentOut, found := m.FindByID(agentIn.ID) - require.True(t, found) - require.Equal(t, agentIn, agentOut) -} - -func TestExternalAgentsMapInsertNameCollision(t *testing.T) { - m := NewExternalAgentsMap() - - err := m.Insert(&ExternalAgent{Name: "a", ID: uuid.New()}) - require.NoError(t, err) - - err = m.Insert(&ExternalAgent{Name: "a", ID: uuid.New()}) - require.Equal(t, err, ErrAgentNameCollision) -} - -func TestExternalAgentsMapInsertIDCollision(t *testing.T) { - m := NewExternalAgentsMap() - - id := uuid.New() - - err := m.Insert(&ExternalAgent{Name: "a", ID: id}) - require.NoError(t, err) - - err = m.Insert(&ExternalAgent{Name: "b", ID: id}) - require.Equal(t, err, ErrAgentIDCollision) -} - -func TestInternalAgentsMapLookupByName(t *testing.T) { - m := NewInternalAgentsMap() - - err := m.Insert(&InternalAgent{Name: "a", ID: uuid.New()}) - require.NoError(t, err) - agentIn := &InternalAgent{Name: "b", ID: uuid.New()} - err = m.Insert(agentIn) - require.NoError(t, err) - err = m.Insert(&InternalAgent{Name: "c", ID: uuid.New()}) - require.NoError(t, err) - - agentOut, found := m.FindByName(agentIn.Name) - require.True(t, found) - require.Equal(t, agentIn, agentOut) -} - -func TestInternalAgentsMapLookupByID(t *testing.T) { - m := NewInternalAgentsMap() - - err := m.Insert(&InternalAgent{Name: "a", ID: uuid.New()}) - require.NoError(t, err) - agentIn := &InternalAgent{Name: "b", ID: uuid.New()} - err = m.Insert(agentIn) - require.NoError(t, err) - err = m.Insert(&InternalAgent{Name: "c", ID: uuid.New()}) - require.NoError(t, err) - - agentOut, found := m.FindByID(agentIn.ID) - require.True(t, found) - require.Equal(t, agentIn, agentOut) -} - -func TestInternalAgentsMapInsertNameCollision(t *testing.T) { - m := NewInternalAgentsMap() - - err := m.Insert(&InternalAgent{Name: "a", ID: uuid.New()}) - require.NoError(t, err) - - err = m.Insert(&InternalAgent{Name: "a", ID: uuid.New()}) - require.Equal(t, err, ErrAgentNameCollision) -} - -func TestInternalAgentsMapInsertIDCollision(t *testing.T) { - m := NewInternalAgentsMap() - - id := uuid.New() - - err := m.Insert(&InternalAgent{Name: "a", ID: id}) - require.NoError(t, err) - - err = m.Insert(&InternalAgent{Name: "b", ID: id}) - require.Equal(t, err, ErrAgentIDCollision) -} diff --git a/lambda/core/agentutil.go b/lambda/core/agentutil.go deleted file mode 100644 index 6024209..0000000 --- a/lambda/core/agentutil.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -import ( - "errors" -) - -var errInvalidEventType = errors.New("ErrorInvalidEventType") -var errEventNotSupportedForInternalAgent = errors.New("ShutdownEventNotSupportedForInternalExtension") - -type disallowEverything struct { -} - -// Register -func (s *disallowEverything) Register(events []Event) error { return ErrNotAllowed } - -// Ready -func (s *disallowEverything) Ready() error { return ErrNotAllowed } - -// InitError -func (s *disallowEverything) InitError(errorType string) error { return ErrNotAllowed } - -// ExitError -func (s *disallowEverything) ExitError(errorType string) error { return ErrNotAllowed } - -// ShutdownFailed -func (s *disallowEverything) ShutdownFailed() error { return ErrNotAllowed } - -// Exited -func (s *disallowEverything) Exited() error { return ErrNotAllowed } - -// LaunchError -func (s *disallowEverything) LaunchError(error) error { return ErrNotAllowed } diff --git a/lambda/core/bandwidthlimiter/bandwidthlimiter.go b/lambda/core/bandwidthlimiter/bandwidthlimiter.go deleted file mode 100644 index 05c600a..0000000 --- a/lambda/core/bandwidthlimiter/bandwidthlimiter.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package bandwidthlimiter - -import ( - "io" - - "go.amzn.com/lambda/interop" -) - -func BandwidthLimitingCopy(dst *BandwidthLimitingWriter, src io.Reader) (written int64, err error) { - written, err = io.Copy(dst, src) - _ = dst.Close() - return -} - -func NewBandwidthLimitingWriter(w io.Writer, bucket *Bucket) (*BandwidthLimitingWriter, error) { - throttler, err := NewThrottler(bucket) - if err != nil { - return nil, err - } - return &BandwidthLimitingWriter{w: w, th: throttler}, nil -} - -type BandwidthLimitingWriter struct { - w io.Writer - th *Throttler -} - -func (w *BandwidthLimitingWriter) ChunkedWrite(p []byte) (n int, err error) { - i := NewChunkIterator(p, int(w.th.b.capacity)) - for { - buf := i.Next() - if buf == nil { - return - } - written, writeErr := w.th.bandwidthLimitingWrite(w.w, buf) - n += written - if writeErr != nil { - return n, writeErr - } - } -} - -func (w *BandwidthLimitingWriter) Write(p []byte) (n int, err error) { - w.th.start() - if int64(len(p)) > w.th.b.capacity { - return w.ChunkedWrite(p) - } - return w.th.bandwidthLimitingWrite(w.w, p) -} - -func (w *BandwidthLimitingWriter) Close() (err error) { - w.th.stop() - return -} - -func (w *BandwidthLimitingWriter) GetMetrics() (metrics *interop.InvokeResponseMetrics) { - return w.th.metrics -} diff --git a/lambda/core/bandwidthlimiter/bandwidthlimiter_test.go b/lambda/core/bandwidthlimiter/bandwidthlimiter_test.go deleted file mode 100644 index 7ede24b..0000000 --- a/lambda/core/bandwidthlimiter/bandwidthlimiter_test.go +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package bandwidthlimiter - -import ( - "bytes" - "io" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestBandwidthLimitingCopy(t *testing.T) { - var size10mb int64 = 10 * 1024 * 1024 - - inputBuffer := []byte(strings.Repeat("a", int(size10mb))) - reader := bytes.NewReader(inputBuffer) - - bucket, err := NewBucket(size10mb/2, size10mb/4, size10mb/2, time.Millisecond/2) - assert.NoError(t, err) - - internalWriter := bytes.NewBuffer(make([]byte, 0, size10mb)) - writer, err := NewBandwidthLimitingWriter(internalWriter, bucket) - assert.NoError(t, err) - - n, err := BandwidthLimitingCopy(writer, reader) - assert.Equal(t, size10mb, n) - assert.Equal(t, nil, err) - assert.Equal(t, inputBuffer, internalWriter.Bytes()) -} - -type ErrorBufferWriter struct { - w ByteBufferWriter - failAfter int -} - -func (w *ErrorBufferWriter) Write(p []byte) (n int, err error) { - if w.failAfter >= 1 { - w.failAfter-- - } - n, err = w.w.Write(p) - if w.failAfter == 0 { - return n, io.ErrUnexpectedEOF - } - return n, err -} - -func (w *ErrorBufferWriter) Bytes() []byte { - return w.w.Bytes() -} - -func TestNewBandwidthLimitingWriter(t *testing.T) { - type testCase struct { - refillNumber int64 - internalWriter ByteBufferWriter - inputBuffer []byte - expectedN int - expectedError error - } - testCases := []testCase{ - { - refillNumber: 2, - internalWriter: bytes.NewBuffer(make([]byte, 0, 36)), // buffer size greater than bucket size - inputBuffer: []byte(strings.Repeat("a", 36)), - expectedN: 36, - expectedError: nil, - }, - { - refillNumber: 2, - internalWriter: bytes.NewBuffer(make([]byte, 0, 12)), // buffer size lesser than bucket size - inputBuffer: []byte(strings.Repeat("a", 12)), - expectedN: 12, - expectedError: nil, - }, - { - // buffer size greater than bucket size and error after two Write() invocations - refillNumber: 2, - internalWriter: &ErrorBufferWriter{w: bytes.NewBuffer(make([]byte, 0, 36)), failAfter: 2}, - inputBuffer: []byte(strings.Repeat("a", 36)), - expectedN: 32, - expectedError: io.ErrUnexpectedEOF, - }, - } - - for _, test := range testCases { - bucket, err := NewBucket(16, 8, test.refillNumber, 100*time.Millisecond) - assert.NoError(t, err) - - writer, err := NewBandwidthLimitingWriter(test.internalWriter, bucket) - assert.NoError(t, err) - assert.False(t, writer.th.running) - - n, err := writer.Write(test.inputBuffer) - assert.True(t, writer.th.running) - assert.Equal(t, test.expectedN, n) - assert.Equal(t, test.expectedError, err) - assert.Equal(t, test.inputBuffer[:n], test.internalWriter.Bytes()) - - err = writer.Close() - assert.Nil(t, err) - assert.False(t, writer.th.running) - } -} diff --git a/lambda/core/bandwidthlimiter/throttler.go b/lambda/core/bandwidthlimiter/throttler.go deleted file mode 100644 index b3b57dd..0000000 --- a/lambda/core/bandwidthlimiter/throttler.go +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package bandwidthlimiter - -import ( - "errors" - "fmt" - "io" - "sync" - "time" - - log "github.com/sirupsen/logrus" - - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/metering" -) - -var ErrBufferSizeTooLarge = errors.New("buffer size cannot be greater than bucket size") - -func NewBucket(capacity int64, initialTokenCount int64, refillNumber int64, refillInterval time.Duration) (*Bucket, error) { - if capacity <= 0 || initialTokenCount < 0 || refillNumber <= 0 || refillInterval <= 0 || - capacity < initialTokenCount { - errorMsg := fmt.Sprintf("invalid bucket parameters (capacity: %d, initialTokenCount: %d, refillNumber: %d,"+ - "refillInterval: %d)", capacity, initialTokenCount, refillInterval, refillInterval) - log.Error(errorMsg) - return nil, errors.New(errorMsg) - } - return &Bucket{ - capacity: capacity, - tokenCount: initialTokenCount, - refillNumber: refillNumber, - refillInterval: refillInterval, - mutex: sync.Mutex{}, - }, nil -} - -type Bucket struct { - capacity int64 - tokenCount int64 - refillNumber int64 - refillInterval time.Duration - mutex sync.Mutex -} - -func (b *Bucket) produceTokens() { - b.mutex.Lock() - defer b.mutex.Unlock() - if b.tokenCount < b.capacity { - b.tokenCount = min64(b.tokenCount+b.refillNumber, b.capacity) - } -} - -func (b *Bucket) consumeTokens(n int64) bool { - b.mutex.Lock() - defer b.mutex.Unlock() - if n <= b.tokenCount { - b.tokenCount -= n - return true - } - return false -} - -func (b *Bucket) getTokenCount() int64 { - b.mutex.Lock() - defer b.mutex.Unlock() - return b.tokenCount -} - -func NewThrottler(bucket *Bucket) (*Throttler, error) { - if bucket == nil { - errorMsg := "cannot create a throttler with nil bucket" - log.Error(errorMsg) - return nil, errors.New(errorMsg) - } - return &Throttler{ - b: bucket, - running: false, - produced: make(chan int64), - done: make(chan struct{}), - // FIXME: - // The runtime tells whether the function response mode is streaming or not. - // Ideally, we would want to use that value here. Since I'm just rebasing, I will leave - // as-is, but we should use that instead of relying on our memory to set this here - // because we "know" it's a streaming code path. - metrics: &interop.InvokeResponseMetrics{FunctionResponseMode: interop.FunctionResponseModeStreaming}, - }, nil -} - -type Throttler struct { - b *Bucket - running bool - produced chan int64 - done chan struct{} - metrics *interop.InvokeResponseMetrics -} - -func (th *Throttler) start() { - if th.running { - return - } - th.running = true - th.metrics.StartReadingResponseMonoTimeMs = metering.Monotime() - go func() { - ticker := time.NewTicker(th.b.refillInterval) - for { - select { - case <-ticker.C: - th.b.produceTokens() - select { - case th.produced <- metering.Monotime(): - default: - } - case <-th.done: - ticker.Stop() - return - } - } - }() -} - -func (th *Throttler) stop() { - if !th.running { - return - } - th.running = false - th.metrics.FinishReadingResponseMonoTimeMs = metering.Monotime() - durationMs := (th.metrics.FinishReadingResponseMonoTimeMs - th.metrics.StartReadingResponseMonoTimeMs) / int64(time.Millisecond) - if durationMs > 0 { - th.metrics.OutboundThroughputBps = (th.metrics.ProducedBytes / durationMs) * int64(time.Second/time.Millisecond) - } else { - th.metrics.OutboundThroughputBps = -1 - } - th.done <- struct{}{} -} - -func (th *Throttler) bandwidthLimitingWrite(w io.Writer, p []byte) (written int, err error) { - n := int64(len(p)) - if n > th.b.capacity { - return 0, ErrBufferSizeTooLarge - } - for { - if th.b.consumeTokens(n) { - written, err = w.Write(p) - th.metrics.ProducedBytes += int64(written) - return - } - waitStart := metering.Monotime() - elapsed := <-th.produced - waitStart - if elapsed > 0 { - th.metrics.TimeShapedNs += elapsed - } - } -} diff --git a/lambda/core/bandwidthlimiter/throttler_test.go b/lambda/core/bandwidthlimiter/throttler_test.go deleted file mode 100644 index a88a14d..0000000 --- a/lambda/core/bandwidthlimiter/throttler_test.go +++ /dev/null @@ -1,215 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package bandwidthlimiter - -import ( - "bytes" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestNewBucket(t *testing.T) { - type testCase struct { - capacity int64 - initialTokenCount int64 - refillNumber int64 - refillInterval time.Duration - bucketCreated bool - } - testCases := []testCase{ - {capacity: 8, initialTokenCount: 6, refillNumber: 2, refillInterval: 100 * time.Millisecond, bucketCreated: true}, - {capacity: 8, initialTokenCount: 6, refillNumber: 2, refillInterval: -100 * time.Millisecond, bucketCreated: false}, - {capacity: 8, initialTokenCount: 6, refillNumber: -5, refillInterval: 100 * time.Millisecond, bucketCreated: false}, - {capacity: 8, initialTokenCount: -2, refillNumber: 2, refillInterval: 100 * time.Millisecond, bucketCreated: false}, - {capacity: -2, initialTokenCount: 6, refillNumber: 2, refillInterval: 100 * time.Millisecond, bucketCreated: false}, - {capacity: 8, initialTokenCount: 10, refillNumber: 2, refillInterval: 100 * time.Millisecond, bucketCreated: false}, - } - - for _, test := range testCases { - bucket, err := NewBucket(test.capacity, test.initialTokenCount, test.refillNumber, test.refillInterval) - if test.bucketCreated { - assert.NoError(t, err) - assert.NotNil(t, bucket) - } else { - assert.Error(t, err) - assert.Nil(t, bucket) - } - } -} - -func TestBucket_produceTokens_consumeTokens(t *testing.T) { - var consumed bool - bucket, err := NewBucket(16, 8, 6, 100*time.Millisecond) - assert.NoError(t, err) - assert.Equal(t, int64(8), bucket.getTokenCount()) - - consumed = bucket.consumeTokens(5) - assert.Equal(t, int64(3), bucket.getTokenCount()) - assert.True(t, consumed) - - bucket.produceTokens() - assert.Equal(t, int64(9), bucket.getTokenCount()) - - bucket.produceTokens() - assert.Equal(t, int64(15), bucket.getTokenCount()) - - bucket.produceTokens() - assert.Equal(t, int64(16), bucket.getTokenCount()) - - bucket.produceTokens() - assert.Equal(t, int64(16), bucket.getTokenCount()) - - consumed = bucket.consumeTokens(18) - assert.Equal(t, int64(16), bucket.getTokenCount()) - assert.False(t, consumed) - - consumed = bucket.consumeTokens(16) - assert.Equal(t, int64(0), bucket.getTokenCount()) - assert.True(t, consumed) -} - -func TestNewThrottler(t *testing.T) { - bucket, err := NewBucket(16, 8, 6, 100*time.Millisecond) - assert.NoError(t, err) - - throttler, err := NewThrottler(bucket) - assert.NoError(t, err) - assert.NotNil(t, throttler) - - throttler, err = NewThrottler(nil) - assert.Error(t, err) - assert.Nil(t, throttler) -} - -func TestNewThrottler_start_stop(t *testing.T) { - bucket, err := NewBucket(16, 8, 6, 100*time.Millisecond) - assert.NoError(t, err) - - throttler, err := NewThrottler(bucket) - assert.NoError(t, err) - - assert.False(t, throttler.running) - - throttler.start() - assert.True(t, throttler.running) - - <-time.Tick(2 * throttler.b.refillInterval) - assert.LessOrEqual(t, int64(14), throttler.b.getTokenCount()) - assert.True(t, throttler.running) - - throttler.start() - assert.True(t, throttler.running) - <-time.Tick(2 * throttler.b.refillInterval) - assert.Equal(t, int64(16), throttler.b.getTokenCount()) - assert.True(t, throttler.running) - - throttler.stop() - assert.False(t, throttler.running) - - throttler.stop() - assert.False(t, throttler.running) - - throttler.start() - assert.True(t, throttler.running) - - throttler.stop() - assert.False(t, throttler.running) -} - -type ByteBufferWriter interface { - Write(p []byte) (n int, err error) - Bytes() []byte -} - -type FixedSizeBufferWriter struct { - buf []byte -} - -func (w *FixedSizeBufferWriter) Write(p []byte) (n int, err error) { - n = copy(w.buf, p) - return -} - -func (w *FixedSizeBufferWriter) Bytes() []byte { - return w.buf -} - -func TestNewThrottler_bandwidthLimitingWrite(t *testing.T) { - var size10mb int64 = 10 * 1024 * 1024 - - type testCase struct { - capacity int64 - initialTokenCount int64 - writer ByteBufferWriter - inputBuffer []byte - expectedN int - expectedError error - } - testCases := []testCase{ - { - capacity: 16, - initialTokenCount: 8, - writer: bytes.NewBuffer(make([]byte, 0, 14)), - inputBuffer: []byte(strings.Repeat("a", 12)), - expectedN: 12, - expectedError: nil, - }, - { - capacity: 16, - initialTokenCount: 8, - writer: bytes.NewBuffer(make([]byte, 0, 12)), - inputBuffer: []byte(strings.Repeat("a", 14)), - expectedN: 14, - expectedError: nil, - }, - { - capacity: size10mb, - initialTokenCount: size10mb, - writer: bytes.NewBuffer(make([]byte, 0, size10mb)), - inputBuffer: []byte(strings.Repeat("a", int(size10mb))), - expectedN: int(size10mb), - expectedError: nil, - }, - { - capacity: 16, - initialTokenCount: 8, - writer: bytes.NewBuffer(make([]byte, 0, 18)), - inputBuffer: []byte(strings.Repeat("a", 18)), - expectedN: 0, - expectedError: ErrBufferSizeTooLarge, - }, - { - capacity: 16, - initialTokenCount: 8, - writer: &FixedSizeBufferWriter{buf: make([]byte, 12)}, - inputBuffer: []byte(strings.Repeat("a", 14)), - expectedN: 12, - expectedError: nil, - }, - } - - for _, test := range testCases { - bucket, err := NewBucket(test.capacity, test.initialTokenCount, 2, 100*time.Millisecond) - assert.NoError(t, err) - - throttler, err := NewThrottler(bucket) - assert.NoError(t, err) - - writer := test.writer - throttler.start() - n, err := throttler.bandwidthLimitingWrite(writer, test.inputBuffer) - assert.Equal(t, test.expectedN, n) - assert.Equal(t, test.expectedError, err) - - if test.expectedError == nil { - assert.Equal(t, test.inputBuffer[:n], test.writer.Bytes()) - } else { - assert.Equal(t, []byte{}, test.writer.Bytes()) - } - throttler.stop() - } -} diff --git a/lambda/core/bandwidthlimiter/util.go b/lambda/core/bandwidthlimiter/util.go deleted file mode 100644 index 7078d5d..0000000 --- a/lambda/core/bandwidthlimiter/util.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package bandwidthlimiter - -func min(a, b int) int { - if a < b { - return a - } - return b -} - -func min64(a, b int64) int64 { - if a < b { - return a - } - return b -} - -func NewChunkIterator(buf []byte, chunkSize int) *ChunkIterator { - if buf == nil { - return nil - } - return &ChunkIterator{ - buf: buf, - chunkSize: chunkSize, - offset: 0, - } -} - -type ChunkIterator struct { - buf []byte - chunkSize int - offset int -} - -func (i *ChunkIterator) Next() []byte { - begin := i.offset - end := min(i.offset+i.chunkSize, len(i.buf)) - i.offset = end - - if begin == end { - return nil - } - return i.buf[begin:end] -} diff --git a/lambda/core/bandwidthlimiter/util_test.go b/lambda/core/bandwidthlimiter/util_test.go deleted file mode 100644 index ed93c77..0000000 --- a/lambda/core/bandwidthlimiter/util_test.go +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package bandwidthlimiter - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestNewChunkIterator(t *testing.T) { - buf := []byte("abcdefghijk") - - type testCase struct { - buf []byte - chunkSize int - expectedResult [][]byte - } - testCases := []testCase{ - {buf: nil, chunkSize: 0, expectedResult: [][]byte{}}, - {buf: nil, chunkSize: 1, expectedResult: [][]byte{}}, - {buf: buf, chunkSize: 0, expectedResult: [][]byte{}}, - {buf: buf, chunkSize: 1, expectedResult: [][]byte{ - []byte("a"), []byte("b"), []byte("c"), []byte("d"), []byte("e"), []byte("f"), []byte("g"), []byte("h"), - []byte("i"), []byte("j"), []byte("k"), - }}, - {buf: buf, chunkSize: 4, expectedResult: [][]byte{[]byte("abcd"), []byte("efgh"), []byte("ijk")}}, - {buf: buf, chunkSize: 5, expectedResult: [][]byte{[]byte("abcde"), []byte("fghij"), []byte("k")}}, - {buf: buf, chunkSize: 11, expectedResult: [][]byte{[]byte("abcdefghijk")}}, - {buf: buf, chunkSize: 12, expectedResult: [][]byte{[]byte("abcdefghijk")}}, - } - - for _, test := range testCases { - iterator := NewChunkIterator(test.buf, test.chunkSize) - if test.buf == nil { - assert.Nil(t, iterator) - } else { - for _, expectedChunk := range test.expectedResult { - assert.Equal(t, expectedChunk, iterator.Next()) - } - assert.Nil(t, iterator.Next()) - } - } -} diff --git a/lambda/core/credentials.go b/lambda/core/credentials.go deleted file mode 100644 index ad152d0..0000000 --- a/lambda/core/credentials.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -import ( - "fmt" - "sync" - "time" -) - -const ( - UNBLOCKED = iota - BLOCKED -) - -var ErrCredentialsNotFound = fmt.Errorf("credentials not found for the provided token") - -type Credentials struct { - AwsKey string `json:"AccessKeyId"` - AwsSecret string `json:"SecretAccessKey"` - AwsSession string `json:"Token"` - Expiration time.Time `json:"Expiration"` -} - -type CredentialsService interface { - SetCredentials(token, awsKey, awsSecret, awsSession string, expiration time.Time) - GetCredentials(token string) (*Credentials, error) - UpdateCredentials(awsKey, awsSecret, awsSession string, expiration time.Time) error -} - -type credentialsServiceImpl struct { - credentials map[string]Credentials - contentMutex *sync.Mutex - serviceMutex *sync.Mutex - currentState int -} - -func NewCredentialsService() CredentialsService { - credentialsService := &credentialsServiceImpl{ - credentials: make(map[string]Credentials), - contentMutex: &sync.Mutex{}, - serviceMutex: &sync.Mutex{}, - currentState: UNBLOCKED, - } - - return credentialsService -} - -func (c *credentialsServiceImpl) SetCredentials(token, awsKey, awsSecret, awsSession string, expiration time.Time) { - c.contentMutex.Lock() - defer c.contentMutex.Unlock() - - c.credentials[token] = Credentials{ - AwsKey: awsKey, - AwsSecret: awsSecret, - AwsSession: awsSession, - Expiration: expiration, - } -} - -func (c *credentialsServiceImpl) GetCredentials(token string) (*Credentials, error) { - c.serviceMutex.Lock() - defer c.serviceMutex.Unlock() - - c.contentMutex.Lock() - defer c.contentMutex.Unlock() - - if credentials, ok := c.credentials[token]; ok { - return &credentials, nil - } - - return nil, ErrCredentialsNotFound -} - -func (c *credentialsServiceImpl) UpdateCredentials(awsKey, awsSecret, awsSession string, expiration time.Time) error { - mapSize := len(c.credentials) - if mapSize != 1 { - return fmt.Errorf("there are %d set of credentials", mapSize) - } - - var token string - for key := range c.credentials { - token = key - } - - c.SetCredentials(token, awsKey, awsSecret, awsSession, expiration) - return nil -} diff --git a/lambda/core/credentials_test.go b/lambda/core/credentials_test.go deleted file mode 100644 index 625ab8e..0000000 --- a/lambda/core/credentials_test.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -import ( - "github.com/stretchr/testify/assert" - "testing" - "time" -) - -const ( - Token string = "sampleToken" - AwsKey string = "sampleKey" - AwsSecret string = "sampleSecret" - AwsSession string = "sampleSession" -) - -func TestGetSetCredentialsHappy(t *testing.T) { - credentialsService := NewCredentialsService() - - credentialsExpiration := time.Now().Add(15 * time.Minute) - credentialsService.SetCredentials(Token, AwsKey, AwsSecret, AwsSession, credentialsExpiration) - - credentials, err := credentialsService.GetCredentials(Token) - - assert.NoError(t, err) - assert.Equal(t, AwsKey, credentials.AwsKey) - assert.Equal(t, AwsSecret, credentials.AwsSecret) - assert.Equal(t, AwsSession, credentials.AwsSession) -} - -func TestGetCredentialsFail(t *testing.T) { - credentialsService := NewCredentialsService() - - _, err := credentialsService.GetCredentials("unknownToken") - - assert.Error(t, err) -} - -func TestUpdateCredentialsHappy(t *testing.T) { - credentialsService := NewCredentialsService() - - credentialsExpiration := time.Now().Add(15 * time.Minute) - credentialsService.SetCredentials(Token, AwsKey, AwsSecret, AwsSession, credentialsExpiration) - - restoreCredentialsExpiration := time.Now().Add(10 * time.Hour) - - err := credentialsService.UpdateCredentials("sampleKey1", "sampleSecret1", "sampleSession1", restoreCredentialsExpiration) - assert.NoError(t, err) - - credentials, err := credentialsService.GetCredentials(Token) - - assert.NoError(t, err) - assert.Equal(t, "sampleKey1", credentials.AwsKey) - assert.Equal(t, "sampleSecret1", credentials.AwsSecret) - assert.Equal(t, "sampleSession1", credentials.AwsSession) - - nineHoursLater := time.Now().Add(9 * time.Hour) - - assert.True(t, nineHoursLater.Before(credentials.Expiration)) -} - -func TestUpdateCredentialsFail(t *testing.T) { - credentialsService := NewCredentialsService() - - err := credentialsService.UpdateCredentials("unknownKey", "unknownSecret", "unknownSession", time.Now()) - - assert.Error(t, err) -} diff --git a/lambda/core/directinvoke/customerheaders.go b/lambda/core/directinvoke/customerheaders.go deleted file mode 100644 index fd0e4ad..0000000 --- a/lambda/core/directinvoke/customerheaders.go +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package directinvoke - -import ( - "bytes" - "encoding/base64" - "encoding/json" -) - -type CustomerHeaders struct { - CognitoIdentityID string `json:"Cognito-Identity-Id"` - CognitoIdentityPoolID string `json:"Cognito-Identity-Pool-Id"` - ClientContext string `json:"Client-Context"` -} - -func (s CustomerHeaders) Dump() string { - if (s == CustomerHeaders{}) { - return "" - } - - custHeadersJSON, err := json.Marshal(&s) - if err != nil { - panic(err) - } - - return base64.StdEncoding.EncodeToString(custHeadersJSON) -} - -func (s *CustomerHeaders) Load(in string) error { - *s = CustomerHeaders{} - - if in == "" { - return nil - } - - base64Decoder := base64.NewDecoder(base64.StdEncoding, bytes.NewReader([]byte(in))) - - return json.NewDecoder(base64Decoder).Decode(s) -} diff --git a/lambda/core/directinvoke/customerheaders_test.go b/lambda/core/directinvoke/customerheaders_test.go deleted file mode 100644 index d81cbf4..0000000 --- a/lambda/core/directinvoke/customerheaders_test.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package directinvoke - -import ( - "github.com/stretchr/testify/require" - "testing" -) - -func TestCustomerHeadersEmpty(t *testing.T) { - in := CustomerHeaders{} - out := CustomerHeaders{} - - require.NoError(t, out.Load(in.Dump())) - require.Equal(t, in, out) -} - -func TestCustomerHeaders(t *testing.T) { - in := CustomerHeaders{CognitoIdentityID: "asd"} - out := CustomerHeaders{} - - require.NoError(t, out.Load(in.Dump())) - require.Equal(t, in, out) -} diff --git a/lambda/core/directinvoke/directinvoke.go b/lambda/core/directinvoke/directinvoke.go deleted file mode 100644 index 396bd39..0000000 --- a/lambda/core/directinvoke/directinvoke.go +++ /dev/null @@ -1,487 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 -// LOCALSTACK CHANGES 2024-02-13: casting of MaxPayloadSize - -package directinvoke - -import ( - "context" - "fmt" - "io" - "net/http" - "strconv" - "strings" - - "github.com/go-chi/chi" - "go.amzn.com/lambda/core/bandwidthlimiter" - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/metering" - - log "github.com/sirupsen/logrus" -) - -const ( - InvokeIDHeader = "Invoke-Id" - InvokedFunctionArnHeader = "Invoked-Function-Arn" - VersionIDHeader = "Invoked-Function-Version" - ReservationTokenHeader = "Reservation-Token" - CustomerHeadersHeader = "Customer-Headers" - ContentTypeHeader = "Content-Type" - MaxPayloadSizeHeader = "MaxPayloadSize" - InvokeResponseModeHeader = "InvokeResponseMode" - ResponseBandwidthRateHeader = "ResponseBandwidthRate" - ResponseBandwidthBurstSizeHeader = "ResponseBandwidthBurstSize" - FunctionResponseModeHeader = "Lambda-Runtime-Function-Response-Mode" - - ErrorTypeHeader = "Error-Type" - - EndOfResponseTrailer = "End-Of-Response" - FunctionErrorTypeTrailer = "Lambda-Runtime-Function-Error-Type" - FunctionErrorBodyTrailer = "Lambda-Runtime-Function-Error-Body" -) - -const ( - EndOfResponseComplete = "Complete" - EndOfResponseTruncated = "Truncated" - EndOfResponseOversized = "Oversized" -) - -var ResetReasonMap = map[string]fatalerror.ErrorType{ - "failure": fatalerror.SandboxFailure, - "timeout": fatalerror.SandboxTimeout, -} - -var MaxDirectResponseSize = int64(interop.MaxPayloadSize) // this is intentionally not a constant so we can configure it via CLI -var ResponseBandwidthRate int64 = interop.ResponseBandwidthRate -var ResponseBandwidthBurstSize int64 = interop.ResponseBandwidthBurstSize - -// InvokeResponseMode controls the context in which the invoke is. Since this was introduced -// in Streaming invokes, we default it to Buffered. -var InvokeResponseMode interop.InvokeResponseMode = interop.InvokeResponseModeBuffered - -func renderBadRequest(w http.ResponseWriter, r *http.Request, errorType string) { - w.Header().Set(ErrorTypeHeader, errorType) - w.WriteHeader(http.StatusBadRequest) - w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete) -} - -func renderInternalServerError(w http.ResponseWriter, errorType string) { - w.Header().Set(ErrorTypeHeader, errorType) - w.WriteHeader(http.StatusInternalServerError) - w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete) -} - -// convertToInvokeResponseMode converts the given string to a InvokeResponseMode -// It is case insensitive and if there is no match, an error is thrown. -func convertToInvokeResponseMode(value string) (interop.InvokeResponseMode, error) { - // buffered - if strings.EqualFold(value, string(interop.InvokeResponseModeBuffered)) { - return interop.InvokeResponseModeBuffered, nil - } - - // streaming - if strings.EqualFold(value, string(interop.InvokeResponseModeStreaming)) { - return interop.InvokeResponseModeStreaming, nil - } - - // unknown - allowedValues := strings.Join(interop.AllInvokeResponseModes, ", ") - log.Errorf("Unable to map %s to %s.", value, allowedValues) - return "", interop.ErrInvalidInvokeResponseMode -} - -// ReceiveDirectInvoke parses invoke and verifies it against Token message. Uses deadline provided by Token -// Renders BadRequest in case of error -func ReceiveDirectInvoke(w http.ResponseWriter, r *http.Request, token interop.Token) (*interop.Invoke, error) { - log.Infof("Received Invoke(invokeID: %s) Request", token.InvokeID) - w.Header().Set("Trailer", EndOfResponseTrailer) - - custHeaders := CustomerHeaders{} - if err := custHeaders.Load(r.Header.Get(CustomerHeadersHeader)); err != nil { - renderBadRequest(w, r, interop.ErrMalformedCustomerHeaders.Error()) - return nil, interop.ErrMalformedCustomerHeaders - } - - now := metering.Monotime() - - MaxDirectResponseSize = int64(interop.MaxPayloadSize) - if maxPayloadSize := r.Header.Get(MaxPayloadSizeHeader); maxPayloadSize != "" { - if n, err := strconv.ParseInt(maxPayloadSize, 10, 64); err == nil && n >= -1 { - MaxDirectResponseSize = n - } else { - log.Error("MaxPayloadSize header is not a valid number") - renderBadRequest(w, r, interop.ErrInvalidMaxPayloadSize.Error()) - return nil, interop.ErrInvalidMaxPayloadSize - } - } - - if valueFromHeader := r.Header.Get(InvokeResponseModeHeader); valueFromHeader != "" { - invokeResponseMode, err := convertToInvokeResponseMode(valueFromHeader) - if err != nil { - log.Errorf( - "InvokeResponseMode header is not a valid string. Was: %#v, Allowed: %#v.", - valueFromHeader, - strings.Join(interop.AllInvokeResponseModes, ", "), - ) - renderBadRequest(w, r, err.Error()) - return nil, err - } - InvokeResponseMode = invokeResponseMode - } - - // TODO: stop using `MaxDirectResponseSize` - if isStreamingInvoke(int(MaxDirectResponseSize), InvokeResponseMode) { - w.Header().Add("Trailer", FunctionErrorTypeTrailer) - w.Header().Add("Trailer", FunctionErrorBodyTrailer) - - // FIXME - // Until WorkerProxy stops sending MaxDirectResponseSize == -1 to identify streaming - // invokes, we need to override InvokeResponseMode to avoid setting InvokeResponseMode to buffered (default) for a streaming invoke (MaxDirectResponseSize == -1). - InvokeResponseMode = interop.InvokeResponseModeStreaming - - ResponseBandwidthRate = interop.ResponseBandwidthRate - if responseBandwidthRate := r.Header.Get(ResponseBandwidthRateHeader); responseBandwidthRate != "" { - if n, err := strconv.ParseInt(responseBandwidthRate, 10, 64); err == nil && - interop.MinResponseBandwidthRate <= n && n <= interop.MaxResponseBandwidthRate { - ResponseBandwidthRate = n - } else { - log.Error("ResponseBandwidthRate header is not a valid number or is out of the allowed range") - renderBadRequest(w, r, interop.ErrInvalidResponseBandwidthRate.Error()) - return nil, interop.ErrInvalidResponseBandwidthRate - } - } - - ResponseBandwidthBurstSize = interop.ResponseBandwidthBurstSize - if responseBandwidthBurstSize := r.Header.Get(ResponseBandwidthBurstSizeHeader); responseBandwidthBurstSize != "" { - if n, err := strconv.ParseInt(responseBandwidthBurstSize, 10, 64); err == nil && - interop.MinResponseBandwidthBurstSize <= n && n <= interop.MaxResponseBandwidthBurstSize { - ResponseBandwidthBurstSize = n - } else { - log.Error("ResponseBandwidthBurstSize header is not a valid number or is out of the allowed range") - renderBadRequest(w, r, interop.ErrInvalidResponseBandwidthBurstSize.Error()) - return nil, interop.ErrInvalidResponseBandwidthBurstSize - } - } - } - - inv := &interop.Invoke{ - ID: r.Header.Get(InvokeIDHeader), - ReservationToken: chi.URLParam(r, "reservationtoken"), - InvokedFunctionArn: r.Header.Get(InvokedFunctionArnHeader), - VersionID: r.Header.Get(VersionIDHeader), - ContentType: r.Header.Get(ContentTypeHeader), - CognitoIdentityID: custHeaders.CognitoIdentityID, - CognitoIdentityPoolID: custHeaders.CognitoIdentityPoolID, - TraceID: token.TraceID, - LambdaSegmentID: token.LambdaSegmentID, - ClientContext: custHeaders.ClientContext, - Payload: r.Body, - DeadlineNs: fmt.Sprintf("%d", now+token.FunctionTimeout.Nanoseconds()), - NeedDebugLogs: token.NeedDebugLogs, - InvokeReceivedTime: now, - InvokeResponseMode: InvokeResponseMode, - RestoreDurationNs: token.RestoreDurationNs, - RestoreStartTimeMonotime: token.RestoreStartTimeMonotime, - } - - if inv.ID != token.InvokeID { - renderBadRequest(w, r, interop.ErrInvalidInvokeID.Error()) - return nil, interop.ErrInvalidInvokeID - } - - if inv.ReservationToken != token.ReservationToken { - renderBadRequest(w, r, interop.ErrInvalidReservationToken.Error()) - return nil, interop.ErrInvalidReservationToken - } - - if inv.VersionID != token.VersionID { - renderBadRequest(w, r, interop.ErrInvalidFunctionVersion.Error()) - return nil, interop.ErrInvalidFunctionVersion - } - - if now > token.InvackDeadlineNs { - renderBadRequest(w, r, interop.ErrReservationExpired.Error()) - return nil, interop.ErrReservationExpired - } - - w.Header().Set(VersionIDHeader, token.VersionID) - w.Header().Set(ReservationTokenHeader, token.ReservationToken) - w.Header().Set(InvokeIDHeader, token.InvokeID) - - return inv, nil -} - -type CopyDoneResult struct { - Metrics *interop.InvokeResponseMetrics - Error error -} - -func getErrorTypeFromResetReason(resetReason string) fatalerror.ErrorType { - errorTypeTrailer, ok := ResetReasonMap[resetReason] - if !ok { - errorTypeTrailer = fatalerror.SandboxFailure - } - return errorTypeTrailer -} - -func isErrorResponse(additionalHeaders map[string]string) (isErrorResponse bool) { - _, isErrorResponse = additionalHeaders[ErrorTypeHeader] - return -} - -// isStreamingInvoke checks whether the invoke mode is streaming or not. -// `maxDirectResponseSize == -1` is used as it was the first check we did when we released -// streaming invokes. -func isStreamingInvoke(maxDirectResponseSize int, invokeResponseMode interop.InvokeResponseMode) bool { - return maxDirectResponseSize == -1 || invokeResponseMode == interop.InvokeResponseModeStreaming -} - -func asyncPayloadCopy(w http.ResponseWriter, payload io.Reader) (copyDone chan CopyDoneResult, cancel context.CancelFunc, err error) { - copyDone = make(chan CopyDoneResult) - streamedResponseWriter, cancel, err := NewStreamedResponseWriter(w) - if err != nil { - return nil, nil, &interop.ErrInternalPlatformError{} - } - - go func() { // copy payload in a separate go routine - // -1 size indicates the payload size is unlimited. - isPayloadsSizeRestricted := MaxDirectResponseSize != -1 - - if isPayloadsSizeRestricted { - // Setting the limit to MaxDirectResponseSize + 1 so we can do - // readBytes > MaxDirectResponseSize to check if the response is oversized. - // As the response is allowed to be of the size MaxDirectResponseSize but not larger than it. - payload = io.LimitReader(payload, MaxDirectResponseSize+1) - } - - // FIXME: inject bandwidthlimiter as a dependency, so that we can mock it in tests - copiedBytes, copyError := bandwidthlimiter.BandwidthLimitingCopy(streamedResponseWriter, payload) - - isPayloadsSizeOversized := copiedBytes > MaxDirectResponseSize - - if copyError != nil { - w.Header().Set(EndOfResponseTrailer, EndOfResponseTruncated) - copyError = &interop.ErrTruncatedResponse{} - } else if isPayloadsSizeRestricted && isPayloadsSizeOversized { - w.Header().Set(EndOfResponseTrailer, EndOfResponseOversized) - copyError = &interop.ErrorResponseTooLargeDI{ - ErrorResponseTooLarge: interop.ErrorResponseTooLarge{ - ResponseSize: int(copiedBytes), - MaxResponseSize: int(MaxDirectResponseSize), - }, - } - } else { - w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete) - } - copyDoneResult := CopyDoneResult{ - Metrics: streamedResponseWriter.GetMetrics(), - Error: copyError, - } - copyDone <- copyDoneResult - cancel() // free resources - }() - return -} - -func sendStreamingInvokeResponse(payload io.Reader, trailers http.Header, w http.ResponseWriter, - interruptedResponseChan chan *interop.Reset, sendResponseChan chan *interop.InvokeResponseMetrics, - request *interop.CancellableRequest, runtimeCalledResponse bool) (err error) { - /* In case of /response, we copy the payload and, once copied, we attach: - * 1) 'Lambda-Runtime-Function-Error-Type' - * 2) 'Lambda-Runtime-Function-Error-Body' - * trailers. */ - copyDone, cancel, err := asyncPayloadCopy(w, payload) - if err != nil { - renderInternalServerError(w, err.Error()) - return err - } - - var errorTypeTrailer string - var errorBodyTrailer string - var copyDoneResult CopyDoneResult - select { - case copyDoneResult = <-copyDone: // copy finished - errorTypeTrailer = trailers.Get(FunctionErrorTypeTrailer) - errorBodyTrailer = trailers.Get(FunctionErrorBodyTrailer) - if copyDoneResult.Error != nil && errorTypeTrailer == "" { - errorTypeTrailer = string(mapCopyDoneResultErrorToErrorType(copyDoneResult.Error)) - } - case reset := <-interruptedResponseChan: // reset initiated - cancel() - if request != nil { - // In case of reset: - // * to interrupt copying when runtime called /response (a potential stuck on Body.Read() operation), - // we close the underlying connection using .Close() method on the request object - // * for /error case, the whole body is already read in /error handler, so we don't need additional handling - // when sending streaming invoke error response - connErr := request.Cancel() - if connErr != nil { - log.Warnf("Failed to close underlying connection: %s", connErr) - } - } else { - log.Warn("Cannot close underlying connection. Request object is nil") - } - copyDoneResult = <-copyDone - reset.InvokeResponseMetrics = copyDoneResult.Metrics - reset.InvokeResponseMode = InvokeResponseMode - interruptedResponseChan <- nil - errorTypeTrailer = string(getErrorTypeFromResetReason(reset.Reason)) - } - w.Header().Set(FunctionErrorTypeTrailer, errorTypeTrailer) - w.Header().Set(FunctionErrorBodyTrailer, errorBodyTrailer) - - copyDoneResult.Metrics.RuntimeCalledResponse = runtimeCalledResponse - sendResponseChan <- copyDoneResult.Metrics - - if copyDoneResult.Error != nil { - log.Errorf("Error while streaming response payload: %s", copyDoneResult.Error) - err = copyDoneResult.Error - } - return -} - -// mapCopyDoneResultErrorToErrorType map a copyDoneResult error into a fatalerror -func mapCopyDoneResultErrorToErrorType(err interface{}) fatalerror.ErrorType { - switch err.(type) { - case *interop.ErrTruncatedResponse: - return fatalerror.TruncatedResponse - case *interop.ErrorResponseTooLargeDI: - return fatalerror.FunctionOversizedResponse - default: - return fatalerror.SandboxFailure - } -} - -func sendStreamingInvokeErrorResponse(payload io.Reader, w http.ResponseWriter, - interruptedResponseChan chan *interop.Reset, sendResponseChan chan *interop.InvokeResponseMetrics, runtimeCalledResponse bool) (err error) { - - copyDone, cancel, err := asyncPayloadCopy(w, payload) - if err != nil { - renderInternalServerError(w, err.Error()) - return err - } - - var copyDoneResult CopyDoneResult - select { - case copyDoneResult = <-copyDone: // copy finished - case reset := <-interruptedResponseChan: // reset initiated - cancel() - copyDoneResult = <-copyDone - reset.InvokeResponseMetrics = copyDoneResult.Metrics - reset.InvokeResponseMode = InvokeResponseMode - interruptedResponseChan <- nil - } - - copyDoneResult.Metrics.RuntimeCalledResponse = runtimeCalledResponse - sendResponseChan <- copyDoneResult.Metrics - - if copyDoneResult.Error != nil { - log.Errorf("Error while streaming error response payload: %s", copyDoneResult.Error) - err = copyDoneResult.Error - } - - return -} - -// parseFunctionResponseMode fetches the mode from the header -// If the header is absent, it returns buffered mode. -func parseFunctionResponseMode(w http.ResponseWriter) (interop.FunctionResponseMode, error) { - headerValue := w.Header().Get(FunctionResponseModeHeader) - // the header is optional, so it's ok to be absent - if headerValue == "" { - return interop.FunctionResponseModeBuffered, nil - } - - return interop.ConvertToFunctionResponseMode(headerValue) -} - -func sendPayloadLimitedResponse(payload io.Reader, trailers http.Header, w http.ResponseWriter, sendResponseChan chan *interop.InvokeResponseMetrics, runtimeCalledResponse bool) (err error) { - functionResponseMode, err := parseFunctionResponseMode(w) - if err != nil { - return err - } - - // non-streaming invoke request but runtime is streaming: predefine Trailer headers - if functionResponseMode == interop.FunctionResponseModeStreaming { - w.Header().Add("Trailer", FunctionErrorTypeTrailer) - w.Header().Add("Trailer", FunctionErrorBodyTrailer) - } - - startReadingResponseMonoTimeMs := metering.Monotime() - // Setting the limit to MaxDirectResponseSize + 1 so we can do - // readBytes > MaxDirectResponseSize to check if the response is oversized. - // As the response is allowed to be of the size MaxDirectResponseSize but not larger than it. - written, err := io.Copy(w, io.LimitReader(payload, MaxDirectResponseSize+1)) - - // non-streaming invoke request but runtime is streaming: set response trailers - if functionResponseMode == interop.FunctionResponseModeStreaming { - w.Header().Set(FunctionErrorTypeTrailer, trailers.Get(FunctionErrorTypeTrailer)) - w.Header().Set(FunctionErrorBodyTrailer, trailers.Get(FunctionErrorBodyTrailer)) - } - - isNotStreamingInvoke := InvokeResponseMode != interop.InvokeResponseModeStreaming - - if err != nil { - w.Header().Set(EndOfResponseTrailer, EndOfResponseTruncated) - err = &interop.ErrTruncatedResponse{} - } else if isNotStreamingInvoke && written == MaxDirectResponseSize+1 { - w.Header().Set(EndOfResponseTrailer, EndOfResponseOversized) - err = &interop.ErrorResponseTooLargeDI{ - ErrorResponseTooLarge: interop.ErrorResponseTooLarge{ - ResponseSize: int(written), - MaxResponseSize: int(MaxDirectResponseSize), - }, - } - } else { - w.Header().Set(EndOfResponseTrailer, EndOfResponseComplete) - } - - sendResponseChan <- &interop.InvokeResponseMetrics{ - ProducedBytes: int64(written), - StartReadingResponseMonoTimeMs: startReadingResponseMonoTimeMs, - FinishReadingResponseMonoTimeMs: metering.Monotime(), - TimeShapedNs: int64(-1), - OutboundThroughputBps: int64(-1), - // FIXME: - // We should use InvokeResponseMode here, because only when it's streaming we're interested - // on it. If the invoke is buffered, we don't generate streaming metrics, even if the - // function response mode is streaming. - FunctionResponseMode: interop.FunctionResponseModeBuffered, - RuntimeCalledResponse: runtimeCalledResponse, - } - return -} - -func SendDirectInvokeResponse(additionalHeaders map[string]string, payload io.Reader, trailers http.Header, - w http.ResponseWriter, interruptedResponseChan chan *interop.Reset, - sendResponseChan chan *interop.InvokeResponseMetrics, request *interop.CancellableRequest, runtimeCalledResponse bool, invokeID string) error { - - for k, v := range additionalHeaders { - w.Header().Add(k, v) - } - - var err error - log.Infof("Started sending response (mode: %s, requestID: %s)", InvokeResponseMode, invokeID) - if InvokeResponseMode == interop.InvokeResponseModeStreaming { - // send streamed error response when runtime called /error - if isErrorResponse(additionalHeaders) { - err = sendStreamingInvokeErrorResponse(payload, w, interruptedResponseChan, sendResponseChan, runtimeCalledResponse) - if err != nil { - log.Infof("Error in sending error response (mode: %s, requestID: %s, error: %v)", InvokeResponseMode, invokeID, err) - } - return err - } - // send streamed response when runtime called /response - err = sendStreamingInvokeResponse(payload, trailers, w, interruptedResponseChan, sendResponseChan, request, runtimeCalledResponse) - } else { - err = sendPayloadLimitedResponse(payload, trailers, w, sendResponseChan, runtimeCalledResponse) - } - - if err != nil { - log.Infof("Error in sending response (mode: %s, requestID: %s, error: %v)", InvokeResponseMode, invokeID, err) - } else { - log.Infof("Completed sending response (mode: %s, requestID: %s)", InvokeResponseMode, invokeID) - } - return err -} diff --git a/lambda/core/directinvoke/directinvoke_test.go b/lambda/core/directinvoke/directinvoke_test.go deleted file mode 100644 index 94b6323..0000000 --- a/lambda/core/directinvoke/directinvoke_test.go +++ /dev/null @@ -1,736 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package directinvoke - -import ( - "bytes" - "context" - "errors" - "fmt" - "io" - "math" - "net/http" - "net/http/httptest" - "strconv" - "strings" - "testing" - "time" - - "github.com/go-chi/chi" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/metering" -) - -func NewResponseWriterWithoutFlushMethod() *ResponseWriterWithoutFlushMethod { - return &ResponseWriterWithoutFlushMethod{} -} - -type ResponseWriterWithoutFlushMethod struct{} - -func (*ResponseWriterWithoutFlushMethod) Header() http.Header { return http.Header{} } -func (*ResponseWriterWithoutFlushMethod) Write([]byte) (n int, err error) { return } -func (*ResponseWriterWithoutFlushMethod) WriteHeader(_ int) {} - -func NewSimpleResponseWriter() *SimpleResponseWriter { - return &SimpleResponseWriter{ - buffer: bytes.NewBuffer(nil), - trailers: make(http.Header), - } -} - -type SimpleResponseWriter struct { - buffer *bytes.Buffer - trailers http.Header -} - -func (w *SimpleResponseWriter) Header() http.Header { return w.trailers } -func (w *SimpleResponseWriter) Write(p []byte) (n int, err error) { return w.buffer.Write(p) } -func (*SimpleResponseWriter) WriteHeader(_ int) {} -func (*SimpleResponseWriter) Flush() {} - -func NewInterruptableResponseWriter(interruptAfter int) (*InterruptableResponseWriter, chan struct{}) { - interruptedTestWriterChan := make(chan struct{}) - return &InterruptableResponseWriter{ - buffer: bytes.NewBuffer(nil), - trailers: make(http.Header), - interruptAfter: interruptAfter, - interruptedTestWriterChan: interruptedTestWriterChan, - }, interruptedTestWriterChan -} - -type InterruptableResponseWriter struct { - buffer *bytes.Buffer - trailers http.Header - interruptAfter int // expect Writer to be interrupted after 'interruptAfter' number of writes - interruptedTestWriterChan chan struct{} -} - -func (w *InterruptableResponseWriter) Header() http.Header { return w.trailers } -func (w *InterruptableResponseWriter) Write(p []byte) (n int, err error) { - if w.interruptAfter >= 1 { - w.interruptAfter-- - } else if w.interruptAfter == 0 { - w.interruptedTestWriterChan <- struct{}{} // ready to be interrupted - <-w.interruptedTestWriterChan // wait until interrupted - } - n, err = w.buffer.Write(p) - return -} -func (*InterruptableResponseWriter) WriteHeader(_ int) {} -func (*InterruptableResponseWriter) Flush() {} - -// This is a simple reader implementing io.Reader interface. It's based on strings.Reader, but it doesn't have extra -// methods that allow faster copying such as .WriteTo() method. -func NewReader(s string) *Reader { return &Reader{s, 0, -1} } - -type Reader struct { - s string - i int64 // current reading index - prevRune int // index of previous rune; or < 0 -} - -func (r *Reader) Read(b []byte) (n int, err error) { - if r.i >= int64(len(r.s)) { - return 0, io.EOF - } - r.prevRune = -1 - n = copy(b, r.s[r.i:]) - r.i += int64(n) - return -} - -func TestAsyncPayloadCopyWhenPayloadSizeBelowMaxAllowed(t *testing.T) { - MaxDirectResponseSize = 2 - payloadSize := int(MaxDirectResponseSize - 1) - payloadString := strings.Repeat("a", payloadSize) - writer := NewSimpleResponseWriter() - - copyDone, _, err := asyncPayloadCopy(writer, NewReader(payloadString)) - require.Nil(t, err) - - copyDoneResult := <-copyDone - require.Nil(t, copyDoneResult.Error) - - require.Equal(t, payloadString, writer.buffer.String()) - require.Equal(t, EndOfResponseComplete, writer.Header().Get(EndOfResponseTrailer)) - - // reset it to its original value - MaxDirectResponseSize = interop.MaxPayloadSize -} - -func TestAsyncPayloadCopyWhenPayloadSizeEqualMaxAllowed(t *testing.T) { - MaxDirectResponseSize = 2 - payloadSize := int(MaxDirectResponseSize) - payloadString := strings.Repeat("a", payloadSize) - writer := NewSimpleResponseWriter() - - copyDone, _, err := asyncPayloadCopy(writer, NewReader(payloadString)) - require.Nil(t, err) - - copyDoneResult := <-copyDone - require.Nil(t, copyDoneResult.Error) - - require.Equal(t, payloadString, writer.buffer.String()) - require.Equal(t, EndOfResponseComplete, writer.Header().Get(EndOfResponseTrailer)) - - // reset it to its original value - MaxDirectResponseSize = interop.MaxPayloadSize -} - -func TestAsyncPayloadCopyWhenPayloadSizeAboveMaxAllowed(t *testing.T) { - MaxDirectResponseSize = 2 - payloadSize := int(MaxDirectResponseSize) + 1 - payloadString := strings.Repeat("a", payloadSize) - writer := NewSimpleResponseWriter() - expectedCopyDoneResultError := &interop.ErrorResponseTooLargeDI{ - ErrorResponseTooLarge: interop.ErrorResponseTooLarge{ - ResponseSize: payloadSize, - MaxResponseSize: int(MaxDirectResponseSize), - }, - } - - copyDone, _, err := asyncPayloadCopy(writer, NewReader(payloadString)) - require.Nil(t, err) - - copyDoneResult := <-copyDone - require.Equal(t, expectedCopyDoneResultError, copyDoneResult.Error) - - require.Equal(t, payloadString, writer.buffer.String()) - require.Equal(t, EndOfResponseOversized, writer.Header().Get(EndOfResponseTrailer)) - - // reset it to its original value - MaxDirectResponseSize = interop.MaxPayloadSize -} - -// This is only allowed in streaming mode, currently. -func TestAsyncPayloadCopyWhenUnlimitedPayloadSizeAllowed(t *testing.T) { - MaxDirectResponseSize = -1 - payloadSize := int(interop.MaxPayloadSize + 1) - payloadString := strings.Repeat("a", payloadSize) - writer := NewSimpleResponseWriter() - - copyDone, _, err := asyncPayloadCopy(writer, NewReader(payloadString)) - require.Nil(t, err) - - copyDoneResult := <-copyDone - require.Nil(t, copyDoneResult.Error) - - require.Equal(t, payloadString, writer.buffer.String()) - require.Equal(t, EndOfResponseComplete, writer.Header().Get(EndOfResponseTrailer)) - - // reset it to its original value - MaxDirectResponseSize = interop.MaxPayloadSize -} - -// We use an interruptable response writer which informs on a channel that it's ready to be interrupted after -// 'interruptAfter' number of writes, then it waits for interruption completion to resume the current write operation. -// For this test, after initiating copying, we wait for one chunk of 32 KiB to be copied. Then, we use cancel() to -// interrupt copying. At this point, only ongoing .Write() operations can be performed. We inform the writer about -// interruption completion, and the writer resumes the current .Write() operation, which gives us another 32 KiB chunk -// that is copied. After that, copying returns, and we receive a signal on <-copyDone channel. -func TestAsyncPayloadCopySuccessAfterCancel(t *testing.T) { - payloadString := strings.Repeat("a", 10*1024*1024) // 10 MiB - writer, interruptedTestWriterChan := NewInterruptableResponseWriter(1) - - expectedPayloadString := strings.Repeat("a", 64*1024) // 64 KiB (2 chunks) - - copyDone, cancel, err := asyncPayloadCopy(writer, NewReader(payloadString)) - require.Nil(t, err) - - <-interruptedTestWriterChan // wait for writing 'interruptAfter' number of chunks - cancel() // interrupt copying - interruptedTestWriterChan <- struct{}{} // inform test writer about interruption - - <-copyDone - require.Equal(t, expectedPayloadString, writer.buffer.String()) -} -func TestAsyncPayloadCopyWithIncompatibleResponseWriter(t *testing.T) { - copyDone, cancel, err := asyncPayloadCopy(&ResponseWriterWithoutFlushMethod{}, nil) - require.Nil(t, copyDone) - require.Nil(t, cancel) - require.Error(t, err) - require.Equal(t, "ErrInternalPlatformError", err.Error()) -} - -// TODO: in order to implement this test we need bandwidthlimiter to be received by asyncPayloadCopy -// as an argument. Otherwise, this test will need to know how to force bandwidthlimiter to fail, -// which isn't a good practice. -func TestAsyncPayloadCopyWhenResponseIsTruncated(t *testing.T) { - t.Skip("Pending injection of bandwidthlimiter as a dependency of asyncPayloadCopy.") -} - -func TestSendStreamingInvokeResponseSuccess(t *testing.T) { - payloadString := strings.Repeat("a", 128*1024) // 128 KiB - payload := NewReader(payloadString) - trailers := http.Header{} - writer := NewSimpleResponseWriter() - interruptedResponseChan := make(chan *interop.Reset) - sendResponseChan := make(chan *interop.InvokeResponseMetrics) - testFinished := make(chan struct{}) - - expectedPayloadString := payloadString - - go func() { - err := sendStreamingInvokeResponse(payload, trailers, writer, interruptedResponseChan, sendResponseChan, nil, false) - require.Nil(t, err) - testFinished <- struct{}{} - }() - - <-sendResponseChan - require.Equal(t, expectedPayloadString, writer.buffer.String()) - require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) - require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) - require.Equal(t, "Complete", writer.Header().Get("End-Of-Response")) - <-testFinished -} - -func TestSendPayloadLimitedResponseWithinThresholdWithStreamingFunction(t *testing.T) { - payloadSize := 1 - payloadString := strings.Repeat("a", payloadSize) - payload := NewReader(payloadString) - trailers := http.Header{} - writer := NewSimpleResponseWriter() - writer.Header().Set("Lambda-Runtime-Function-Response-Mode", "streaming") - sendResponseChan := make(chan *interop.InvokeResponseMetrics) - testFinished := make(chan struct{}) - - MaxDirectResponseSize = int64(payloadSize + 1) - - go func() { - err := sendPayloadLimitedResponse(payload, trailers, writer, sendResponseChan, true) - require.Nil(t, err) - testFinished <- struct{}{} - }() - - metrics := <-sendResponseChan - require.Equal(t, interop.FunctionResponseModeBuffered, metrics.FunctionResponseMode) - require.Equal(t, len(payloadString), len(writer.buffer.String())) - require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) - require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) - require.Equal(t, "Complete", writer.Header().Get("End-Of-Response")) - <-testFinished - - // Reset to its default value, just in case other tests use them - MaxDirectResponseSize = interop.MaxPayloadSize -} - -func TestSendPayloadLimitedResponseAboveThresholdWithStreamingFunction(t *testing.T) { - payloadSize := 2 - payloadString := strings.Repeat("a", payloadSize) - payload := NewReader(payloadString) - trailers := http.Header{} - writer := NewSimpleResponseWriter() - writer.Header().Set("Lambda-Runtime-Function-Response-Mode", "streaming") - sendResponseChan := make(chan *interop.InvokeResponseMetrics) - testFinished := make(chan struct{}) - MaxDirectResponseSize = int64(payloadSize - 1) - expectedError := &interop.ErrorResponseTooLargeDI{ - ErrorResponseTooLarge: interop.ErrorResponseTooLarge{ - MaxResponseSize: int(MaxDirectResponseSize), - ResponseSize: payloadSize, - }, - } - - go func() { - err := sendPayloadLimitedResponse(payload, trailers, writer, sendResponseChan, true) - require.Equal(t, expectedError, err) - testFinished <- struct{}{} - }() - - metrics := <-sendResponseChan - require.Equal(t, interop.FunctionResponseModeBuffered, metrics.FunctionResponseMode) - require.Equal(t, len(payloadString), len(writer.buffer.String())) - require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) - require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) - require.Equal(t, "Oversized", writer.Header().Get("End-Of-Response")) - <-testFinished - - // Reset to its default value, just in case other tests use them - MaxDirectResponseSize = interop.MaxPayloadSize -} - -func TestSendStreamingInvokeResponseSuccessWithTrailers(t *testing.T) { - payloadString := strings.Repeat("a", 128*1024) // 128 KiB - payload := NewReader(payloadString) - trailers := http.Header{ - "Lambda-Runtime-Function-Error-Type": []string{"ErrorType"}, - "Lambda-Runtime-Function-Error-Body": []string{"ErrorBody"}, - } - writer := NewSimpleResponseWriter() - interruptedResponseChan := make(chan *interop.Reset) - sendResponseChan := make(chan *interop.InvokeResponseMetrics) - testFinished := make(chan struct{}) - - expectedPayloadString := payloadString - - go func() { - err := sendStreamingInvokeResponse(payload, trailers, writer, interruptedResponseChan, sendResponseChan, nil, false) - require.Nil(t, err) - testFinished <- struct{}{} - }() - - <-sendResponseChan - require.Equal(t, expectedPayloadString, writer.buffer.String()) - require.Equal(t, "ErrorType", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) - require.Equal(t, "ErrorBody", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) - require.Equal(t, "Complete", writer.Header().Get("End-Of-Response")) - <-testFinished -} - -func TestSendStreamingInvokeResponseReset(t *testing.T) { // Reset initiated after writing two chunks of 32 KiB - payloadString := strings.Repeat("a", 128*1024) // 128 KiB - payload := NewReader(payloadString) - trailers := http.Header{} - writer, interruptedTestWriterChan := NewInterruptableResponseWriter(1) - interruptedResponseChan := make(chan *interop.Reset) - sendResponseChan := make(chan *interop.InvokeResponseMetrics) - testFinished := make(chan struct{}) - - expectedPayloadString := strings.Repeat("a", 64*1024) // 64 KiB - - go func() { - err := sendStreamingInvokeResponse(payload, trailers, writer, interruptedResponseChan, sendResponseChan, nil, true) - require.Error(t, err) - require.Equal(t, "ErrTruncatedResponse", err.Error()) - testFinished <- struct{}{} - }() - - reset := &interop.Reset{Reason: "timeout"} - require.Nil(t, reset.InvokeResponseMetrics) - - <-interruptedTestWriterChan // wait for writing 'interruptAfter' number of chunks - interruptedResponseChan <- reset // send reset - time.Sleep(10 * time.Millisecond) // wait for cancel() being called (first instruction after getting reset) - interruptedTestWriterChan <- struct{}{} // inform test writer about interruption - <-interruptedResponseChan // wait for copy done after interruption - require.NotNil(t, reset.InvokeResponseMetrics) - require.Equal(t, interop.InvokeResponseMode("Buffered"), reset.InvokeResponseMode) - - <-sendResponseChan - require.Equal(t, expectedPayloadString, writer.buffer.String()) - require.Equal(t, "Sandbox.Timeout", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) - require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) - require.Equal(t, "Truncated", writer.Header().Get("End-Of-Response")) - <-testFinished -} - -// TODO: mock asyncPayloadCopy and force it to return Oversized in copyDone -func TestSendStreamingInvokeResponseOversizedRuntimesWithTrailers(t *testing.T) { - oversizedPayloadString := strings.Repeat("a", int(MaxDirectResponseSize)+1) - payload := NewReader(oversizedPayloadString) - trailers := http.Header{ - FunctionErrorTypeTrailer: []string{"RuntimesErrorType"}, - FunctionErrorBodyTrailer: []string{"RuntimesBody"}, - } - writer := NewSimpleResponseWriter() - interruptedResponseChan := make(chan *interop.Reset) - sendResponseChan := make(chan *interop.InvokeResponseMetrics) - testFinished := make(chan struct{}) - - go func() { - err := sendStreamingInvokeResponse(payload, trailers, writer, interruptedResponseChan, sendResponseChan, nil, false) - require.Error(t, err) - require.IsType(t, &interop.ErrorResponseTooLargeDI{}, err) - testFinished <- struct{}{} - }() - - <-sendResponseChan - require.Equal(t, trailers.Get(FunctionErrorTypeTrailer), writer.Header().Get(FunctionErrorTypeTrailer)) - require.Equal(t, trailers.Get(FunctionErrorBodyTrailer), writer.Header().Get(FunctionErrorBodyTrailer)) - require.Equal(t, EndOfResponseOversized, writer.Header().Get(EndOfResponseTrailer)) - <-testFinished -} - -// TODO: mock asyncPayloadCopy and force it to return Oversized in copyDone -func TestSendStreamingInvokeResponseOversizedRuntimesWithoutErrorTypeTrailer(t *testing.T) { - oversizedPayloadString := strings.Repeat("a", int(MaxDirectResponseSize)+1) - payload := NewReader(oversizedPayloadString) - trailers := http.Header{ - FunctionErrorTypeTrailer: []string{""}, - FunctionErrorBodyTrailer: []string{"RuntimesErrorBody"}, - } - writer := NewSimpleResponseWriter() - interruptedResponseChan := make(chan *interop.Reset) - sendResponseChan := make(chan *interop.InvokeResponseMetrics) - testFinished := make(chan struct{}) - - go func() { - err := sendStreamingInvokeResponse(payload, trailers, writer, interruptedResponseChan, sendResponseChan, nil, false) - require.Error(t, err) - require.IsType(t, &interop.ErrorResponseTooLargeDI{}, err) - testFinished <- struct{}{} - }() - - <-sendResponseChan - require.Equal(t, "Function.ResponseSizeTooLarge", writer.Header().Get(FunctionErrorTypeTrailer)) - require.Equal(t, trailers.Get(FunctionErrorBodyTrailer), writer.Header().Get(FunctionErrorBodyTrailer)) - require.Equal(t, EndOfResponseOversized, writer.Header().Get(EndOfResponseTrailer)) - <-testFinished -} - -func TestSendStreamingInvokeErrorResponseSuccess(t *testing.T) { - payloadString := strings.Repeat("a", 128*1024) // 128 KiB - payload := NewReader(payloadString) - writer := NewSimpleResponseWriter() - interruptedResponseChan := make(chan *interop.Reset) - sendResponseChan := make(chan *interop.InvokeResponseMetrics) - testFinished := make(chan struct{}) - - expectedPayloadString := payloadString - - go func() { - err := sendStreamingInvokeErrorResponse(payload, writer, interruptedResponseChan, sendResponseChan, false) - require.Nil(t, err) - testFinished <- struct{}{} - }() - - <-sendResponseChan - require.Equal(t, expectedPayloadString, writer.buffer.String()) - require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) - require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) - require.Equal(t, "Complete", writer.Header().Get("End-Of-Response")) - <-testFinished -} - -func TestSendStreamingInvokeErrorResponseReset(t *testing.T) { // Reset initiated after writing two chunks of 32 KiB - payloadString := strings.Repeat("a", 128*1024) // 128 KiB - payload := NewReader(payloadString) - writer, interruptedTestWriterChan := NewInterruptableResponseWriter(1) - interruptedResponseChan := make(chan *interop.Reset) - sendResponseChan := make(chan *interop.InvokeResponseMetrics) - testFinished := make(chan struct{}) - - expectedPayloadString := strings.Repeat("a", 64*1024) // 64 KiB - - go func() { - err := sendStreamingInvokeErrorResponse(payload, writer, interruptedResponseChan, sendResponseChan, true) - require.Error(t, err) - require.Equal(t, "ErrTruncatedResponse", err.Error()) - testFinished <- struct{}{} - }() - - reset := &interop.Reset{Reason: "timeout"} - require.Nil(t, reset.InvokeResponseMetrics) - - <-interruptedTestWriterChan // wait for writing 'interruptAfter' number of chunks - interruptedResponseChan <- reset // send reset - time.Sleep(10 * time.Millisecond) // wait for cancel() being called (first instruction after getting reset) - interruptedTestWriterChan <- struct{}{} // inform test writer about interruption - <-interruptedResponseChan // wait for copy done after interruption - require.NotNil(t, reset.InvokeResponseMetrics) - - <-sendResponseChan - require.Equal(t, expectedPayloadString, writer.buffer.String()) - require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Type")) - require.Equal(t, "", writer.Header().Get("Lambda-Runtime-Function-Error-Body")) - require.Equal(t, "Truncated", writer.Header().Get("End-Of-Response")) - <-testFinished -} - -func TestIsStreamingInvokeTrue(t *testing.T) { - fallbackFlag := -1 - reponseForFallback := isStreamingInvoke(fallbackFlag, interop.InvokeResponseModeBuffered) - - require.True(t, reponseForFallback) - - nonFallbackFlag := 1 - reponseForResponseMode := isStreamingInvoke(nonFallbackFlag, interop.InvokeResponseModeStreaming) - - require.True(t, reponseForResponseMode) -} - -func TestIsStreamingInvokeFalse(t *testing.T) { - nonFallbackFlag := 1 - response := isStreamingInvoke(nonFallbackFlag, interop.InvokeResponseModeBuffered) - - require.False(t, response) -} - -func TestMapCopyDoneResultErrorToErrorType(t *testing.T) { - require.Equal(t, fatalerror.TruncatedResponse, mapCopyDoneResultErrorToErrorType(&interop.ErrTruncatedResponse{})) - require.Equal(t, fatalerror.FunctionOversizedResponse, mapCopyDoneResultErrorToErrorType(&interop.ErrorResponseTooLargeDI{})) - require.Equal(t, fatalerror.SandboxFailure, mapCopyDoneResultErrorToErrorType(errors.New(""))) -} - -func TestConvertToInvokeResponseMode(t *testing.T) { - response, err := convertToInvokeResponseMode("buffered") - require.Equal(t, interop.InvokeResponseModeBuffered, response) - require.Nil(t, err) - - response, err = convertToInvokeResponseMode("streaming") - require.Equal(t, interop.InvokeResponseModeStreaming, response) - require.Nil(t, err) - - response, err = convertToInvokeResponseMode("foo-bar") - require.Equal(t, interop.InvokeResponseMode(""), response) - require.Equal(t, interop.ErrInvalidInvokeResponseMode, err) -} - -func FuzzReceiveDirectInvoke(f *testing.F) { - testCustHeaders := CustomerHeaders{ - CognitoIdentityID: "id1", - CognitoIdentityPoolID: "id2", - ClientContext: "clientcontext1", - } - custHeadersJSON := testCustHeaders.Dump() - - f.Add([]byte{'a'}, "res-token", "invokeid", "functionarn", "versionid", "contenttype", - custHeadersJSON, "1000", - "Streaming", fmt.Sprint(interop.MinResponseBandwidthRate), fmt.Sprint(interop.MinResponseBandwidthBurstSize)) - f.Add([]byte{'b'}, "res-token", "invokeid", "functionarn", "versionid", "contenttype", - custHeadersJSON, "2000", "Buffered", - "0", "0") - f.Add([]byte{'0'}, "0", "0", "0", "0", "0", - "", "", "0", - "0", "0") - - f.Fuzz(func( - t *testing.T, - payload []byte, - reservationToken string, - invokeID string, - invokedFunctionArn string, - versionID string, - contentType string, - custHeadersStr string, - maxPayloadSizeStr string, - invokeResponseModeStr string, - responseBandwidthRateStr string, - responseBandwidthBurstSizeStr string, - ) { - request := makeDirectInvokeRequest(payload, reservationToken, invokeID, - invokedFunctionArn, versionID, contentType, custHeadersStr, maxPayloadSizeStr, - invokeResponseModeStr, responseBandwidthRateStr, responseBandwidthBurstSizeStr) - - token := createDummyToken() - responseRecorder := httptest.NewRecorder() - - receivedInvoke, err := ReceiveDirectInvoke(responseRecorder, request, token) - - // default values used if header values are empty - responseMode := interop.InvokeResponseModeBuffered - maxDirectResponseSize := interop.MaxPayloadSize - - custHeaders := CustomerHeaders{} - - if err != nil { - if err = custHeaders.Load(custHeadersStr); err != nil { - assertBadRequestErrorType(t, responseRecorder, interop.ErrMalformedCustomerHeaders) - return - } - - if !isValidMaxPayloadSize(maxPayloadSizeStr) { - assertBadRequestErrorType(t, responseRecorder, interop.ErrInvalidMaxPayloadSize) - return - } - - n, _ := strconv.ParseInt(maxPayloadSizeStr, 10, 64) - maxDirectResponseSize = int(n) - - if invokeResponseModeStr != "" { - if responseMode, err = convertToInvokeResponseMode(invokeResponseModeStr); err != nil { - assertBadRequestErrorType(t, responseRecorder, interop.ErrInvalidInvokeResponseMode) - return - } - } - - if isStreamingInvoke(maxDirectResponseSize, responseMode) { - if !isValidResponseBandwidthRate(responseBandwidthRateStr) { - assertBadRequestErrorType(t, responseRecorder, interop.ErrInvalidResponseBandwidthRate) - return - } - - if !isValidResponseBandwidthBurstSize(responseBandwidthBurstSizeStr) { - assertBadRequestErrorType(t, responseRecorder, interop.ErrInvalidResponseBandwidthBurstSize) - return - } - } - - } else { - if isStreamingInvoke(maxDirectResponseSize, responseMode) { - // FIXME - // Until WorkerProxy stops sending MaxDirectResponseSize == -1 to identify streaming - // invokes, the ReceiveDirectInvoke() implementation overrides InvokeResponseMode - // to avoid setting InvokeResponseMode to buffered (default) for a streaming invoke (MaxDirectResponseSize == -1). - responseMode = interop.InvokeResponseModeStreaming - - assert.Equal(t, responseRecorder.Header().Values("Trailer"), []string{FunctionErrorTypeTrailer, FunctionErrorBodyTrailer}) - } - - if receivedInvoke.ID != token.InvokeID { - assertBadRequestErrorType(t, responseRecorder, interop.ErrInvalidInvokeID) - return - } - - if receivedInvoke.ReservationToken != token.ReservationToken { - assertBadRequestErrorType(t, responseRecorder, interop.ErrInvalidReservationToken) - return - } - - if receivedInvoke.VersionID != token.VersionID { - assertBadRequestErrorType(t, responseRecorder, interop.ErrInvalidFunctionVersion) - return - } - - if now := metering.Monotime(); now > token.InvackDeadlineNs { - assertBadRequestErrorType(t, responseRecorder, interop.ErrReservationExpired) - return - } - - assert.Equal(t, responseRecorder.Header().Get(VersionIDHeader), token.VersionID) - assert.Equal(t, responseRecorder.Header().Get(ReservationTokenHeader), token.ReservationToken) - assert.Equal(t, responseRecorder.Header().Get(InvokeIDHeader), token.InvokeID) - - expectedInvoke := &interop.Invoke{ - ID: invokeID, - ReservationToken: reservationToken, - InvokedFunctionArn: invokedFunctionArn, - VersionID: versionID, - ContentType: contentType, - CognitoIdentityID: custHeaders.CognitoIdentityID, - CognitoIdentityPoolID: custHeaders.CognitoIdentityPoolID, - TraceID: token.TraceID, - LambdaSegmentID: token.LambdaSegmentID, - ClientContext: custHeaders.ClientContext, - Payload: request.Body, - DeadlineNs: receivedInvoke.DeadlineNs, - NeedDebugLogs: token.NeedDebugLogs, - InvokeReceivedTime: receivedInvoke.InvokeReceivedTime, - InvokeResponseMode: responseMode, - RestoreDurationNs: token.RestoreDurationNs, - RestoreStartTimeMonotime: token.RestoreStartTimeMonotime, - } - - assert.Equal(t, expectedInvoke, receivedInvoke) - } - }) -} - -func createDummyToken() interop.Token { - return interop.Token{ - ReservationToken: "reservation_token", - TraceID: "trace_id", - InvokeID: "invoke_id", - InvackDeadlineNs: math.MaxInt64, - VersionID: "version_id", - } -} - -func assertBadRequestErrorType(t *testing.T, responseRecorder *httptest.ResponseRecorder, expectedErrType error) { - assert.Equal(t, http.StatusBadRequest, responseRecorder.Code) - - assert.Equal(t, expectedErrType.Error(), responseRecorder.Header().Get(ErrorTypeHeader)) - assert.Equal(t, EndOfResponseComplete, responseRecorder.Header().Get(EndOfResponseTrailer)) -} - -func isValidResponseBandwidthBurstSize(sizeStr string) bool { - size, err := strconv.ParseInt(sizeStr, 10, 64) - return err == nil && - interop.MinResponseBandwidthBurstSize <= size && size <= interop.MaxResponseBandwidthBurstSize -} - -func isValidResponseBandwidthRate(rateStr string) bool { - rate, err := strconv.ParseInt(rateStr, 10, 64) - return err == nil && - interop.MinResponseBandwidthRate <= rate && rate <= interop.MaxResponseBandwidthRate -} - -func isValidMaxPayloadSize(maxPayloadSizeStr string) bool { - if maxPayloadSizeStr != "" { - maxPayloadSize, err := strconv.ParseInt(maxPayloadSizeStr, 10, 64) - return err == nil && maxPayloadSize >= -1 - } - - return true -} - -func makeDirectInvokeRequest( - payload []byte, reservationToken string, invokeID string, invokedFunctionArn string, - versionID string, contentType string, custHeadersStr string, maxPayloadSize string, - invokeResponseModeStr string, responseBandwidthRate string, responseBandwidthBurstSize string, -) *http.Request { - request := httptest.NewRequest("POST", "http://example.com/", bytes.NewReader(payload)) - request = addReservationToken(request, reservationToken) - - request.Header.Set(InvokeIDHeader, invokeID) - request.Header.Set(InvokedFunctionArnHeader, invokedFunctionArn) - request.Header.Set(VersionIDHeader, versionID) - request.Header.Set(ContentTypeHeader, contentType) - request.Header.Set(CustomerHeadersHeader, custHeadersStr) - request.Header.Set(MaxPayloadSizeHeader, maxPayloadSize) - request.Header.Set(InvokeResponseModeHeader, invokeResponseModeStr) - request.Header.Set(ResponseBandwidthRateHeader, responseBandwidthRate) - request.Header.Set(ResponseBandwidthBurstSizeHeader, responseBandwidthBurstSize) - - return request -} - -func addReservationToken(r *http.Request, reservationToken string) *http.Request { - rctx := chi.NewRouteContext() - rctx.URLParams.Add("reservationtoken", reservationToken) - return r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx)) -} diff --git a/lambda/core/directinvoke/util.go b/lambda/core/directinvoke/util.go deleted file mode 100644 index 511d656..0000000 --- a/lambda/core/directinvoke/util.go +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package directinvoke - -import ( - "context" - "errors" - "go.amzn.com/lambda/core/bandwidthlimiter" - "io" - "net/http" - "time" - - log "github.com/sirupsen/logrus" -) - -const DefaultRefillIntervalMs = 125 // default refill interval in milliseconds - -func NewStreamedResponseWriter(w http.ResponseWriter) (*bandwidthlimiter.BandwidthLimitingWriter, context.CancelFunc, error) { - flushingWriter, err := NewFlushingWriter(w) // after writing a chunk we have to flush it to avoid additional buffering by ResponseWriter - if err != nil { - return nil, nil, err - } - cancellableWriter, cancel := NewCancellableWriter(flushingWriter) // cancelling prevents next calls to Write() from happening - - refillNumber := ResponseBandwidthRate * DefaultRefillIntervalMs / 1000 // refillNumber is calculated based on 'ResponseBandwidthRate' and bucket refill interval - refillInterval := DefaultRefillIntervalMs * time.Millisecond - - // Initial bucket for token bucket algorithm allows for a burst of up to 6 MiB, and an average transmission rate of 2 MiB/s - bucket, err := bandwidthlimiter.NewBucket(ResponseBandwidthBurstSize, ResponseBandwidthBurstSize, refillNumber, refillInterval) - if err != nil { - cancel() // free resources - return nil, nil, err - } - - bandwidthLimitingWriter, err := bandwidthlimiter.NewBandwidthLimitingWriter(cancellableWriter, bucket) - if err != nil { - cancel() // free resources - return nil, nil, err - } - - return bandwidthLimitingWriter, cancel, nil -} - -func NewFlushingWriter(w io.Writer) (*FlushingWriter, error) { - flusher, ok := w.(http.Flusher) - if !ok { - errorMsg := "expected http.ResponseWriter to be an http.Flusher" - log.Error(errorMsg) - return nil, errors.New(errorMsg) - } - return &FlushingWriter{ - w: w, - flusher: flusher, - }, nil -} - -type FlushingWriter struct { - w io.Writer - flusher http.Flusher -} - -func (w *FlushingWriter) Write(p []byte) (n int, err error) { - n, err = w.w.Write(p) - w.flusher.Flush() - return -} - -func NewCancellableWriter(w io.Writer) (*CancellableWriter, context.CancelFunc) { - ctx, cancel := context.WithCancel(context.Background()) - return &CancellableWriter{w: w, ctx: ctx}, cancel -} - -type CancellableWriter struct { - w io.Writer - ctx context.Context -} - -func (w *CancellableWriter) Write(p []byte) (int, error) { - if err := w.ctx.Err(); err != nil { - return 0, err - } - return w.w.Write(p) -} diff --git a/lambda/core/doc.go b/lambda/core/doc.go deleted file mode 100644 index 4a7157f..0000000 --- a/lambda/core/doc.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -/* -Package core provides state objects and synchronization primitives for -managing data flow in the system. - -# States - -Runtime and Agent implement state object design pattern. - -Runtime state interface: - - type RuntimeState interface { - InitError() error - Ready() error - InvocationResponse() error - InvocationErrorResponse() error - } - -# Gates - -Gates provide synchornization primitives for managing data flow in the system. - -Gate is a synchronization aid that allows one or more threads to wait until a -set of operations being performed in other threads completes. - -To better understand gates, consider two examples below: - -Example 1: main thread is awaiting registered threads to walk through the gate, - - and after the last registered thread walked through the gate, gate - condition will be satisfied and main thread will proceed: - -[main] // register threads with the gate and start threads ... -[main] g.AwaitGateCondition() -[main] // blocked until gate condition is satisfied - -[thread] g.WalkThrough() -[thread] // not blocked - -Example 2: main thread is awaiting registered threads to arrive at the gate, - - and after the last registered thread arrives at the gate, gate - condition will be satisfied and main thread, along with registered - threads will proceed: - -[main] // register threads with the gate and start threads ... -[main] g.AwaitGateCondition() -[main] // blocked until gate condition is satisfied - -# Flow - -Flow wraps a set of specific gates required to implement specific data flow in the system. - -Example flows would be INIT, INVOKE and RESET. - -# Registrations - -Registration service manages registrations, it maintains the mapping between registered -parties are events they are registered. Parties not registered in the system will not -be issued events. -*/ -package core diff --git a/lambda/core/externalagent.go b/lambda/core/externalagent.go deleted file mode 100644 index cd367d2..0000000 --- a/lambda/core/externalagent.go +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -import ( - "fmt" - "time" - - "go.amzn.com/lambda/core/statejson" - - "github.com/google/uuid" -) - -// ExternalAgent represents external agent -type ExternalAgent struct { - Name string - ID uuid.UUID - events map[Event]struct{} - - ManagedThread Suspendable - - currentState ExternalAgentState - stateLastModified time.Time - - StartedState ExternalAgentState - RegisteredState ExternalAgentState - ReadyState ExternalAgentState - RunningState ExternalAgentState - InitErrorState ExternalAgentState - ExitErrorState ExternalAgentState - ShutdownFailedState ExternalAgentState - ExitedState ExternalAgentState - LaunchErrorState ExternalAgentState - - errorType string -} - -// NewExternalAgent returns new instance of a named agent -func NewExternalAgent(name string, initFlow InitFlowSynchronization, invokeFlow InvokeFlowSynchronization) *ExternalAgent { - agent := &ExternalAgent{ - Name: name, - ID: uuid.New(), - ManagedThread: NewManagedThread(), - events: make(map[Event]struct{}), - } - - agent.StartedState = &ExternalAgentStartedState{agent: agent, initFlow: initFlow} - agent.RegisteredState = &ExternalAgentRegisteredState{agent: agent, initFlow: initFlow} - agent.ReadyState = &ExternalAgentReadyState{agent: agent} - agent.RunningState = &ExternalAgentRunningState{agent: agent, invokeFlow: invokeFlow} - agent.InitErrorState = &ExternalAgentInitErrorState{} - agent.ExitErrorState = &ExternalAgentExitErrorState{} - agent.ShutdownFailedState = &ExternalAgentShutdownFailedState{} - agent.ExitedState = &ExternalAgentExitedState{} - agent.LaunchErrorState = &ExternalAgentLaunchErrorState{} - - agent.setStateUnsafe(agent.StartedState) - - return agent -} - -func (s *ExternalAgent) String() string { - return fmt.Sprintf("%s (%s)", s.Name, s.ID) -} - -// SuspendUnsafe the current running thread -func (s *ExternalAgent) SuspendUnsafe() { - s.ManagedThread.SuspendUnsafe() -} - -// Release will resume a suspended thread -func (s *ExternalAgent) Release() { - s.ManagedThread.Release() -} - -// SetState using the lock -func (s *ExternalAgent) SetState(state ExternalAgentState) { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - s.setStateUnsafe(state) -} - -func (s *ExternalAgent) setStateUnsafe(state ExternalAgentState) { - s.currentState = state - s.stateLastModified = time.Now() -} - -func ValidateExternalAgentEvent(e Event) error { - switch e { - case InvokeEvent: - return nil - case ShutdownEvent: - return nil - } - return errInvalidEventType -} - -func (s *ExternalAgent) subscribeUnsafe(e Event) error { - if err := ValidateExternalAgentEvent(e); err != nil { - return err - } - s.events[e] = struct{}{} - return nil -} - -// IsSubscribed checks whether agent is subscribed the Event -func (s *ExternalAgent) IsSubscribed(e Event) bool { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - _, found := s.events[e] - return found -} - -// SubscribedEvents returns events to which the agent is subscribed -func (s *ExternalAgent) SubscribedEvents() []string { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - - events := []string{} - for event := range s.events { - events = append(events, string(event)) - } - return events -} - -// GetState returns agent's current state -func (s *ExternalAgent) GetState() ExternalAgentState { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return s.currentState -} - -// Register an agent with the platform -func (s *ExternalAgent) Register(events []Event) error { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return s.currentState.Register(events) -} - -// Ready - mark an agent as ready -func (s *ExternalAgent) Ready() error { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return s.currentState.Ready() -} - -// InitError - agent registered but failed to initialize -func (s *ExternalAgent) InitError(errorType string) error { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return s.currentState.InitError(errorType) -} - -// ExitError - agent reported unrecoverable error -func (s *ExternalAgent) ExitError(errorType string) error { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return s.currentState.ExitError(errorType) -} - -// ShutdownFailed - terminal state, agent didn't exit gracefully -func (s *ExternalAgent) ShutdownFailed() error { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return s.currentState.ShutdownFailed() -} - -// Exited - agent shut down successfully -func (s *ExternalAgent) Exited() error { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return s.currentState.Exited() -} - -// ErrorType returns error type reported during init or exit -func (s *ExternalAgent) ErrorType() string { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return s.errorType -} - -// Exited - agent shut down successfully -func (s *ExternalAgent) LaunchError(err error) error { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return s.currentState.LaunchError(err) -} - -// GetAgentDescription returns agent description object for debugging purposes -func (s *ExternalAgent) GetAgentDescription() statejson.ExtensionDescription { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return statejson.ExtensionDescription{ - Name: s.Name, - ID: s.ID.String(), - State: statejson.StateDescription{ - Name: s.currentState.Name(), - LastModified: s.stateLastModified.UnixNano() / int64(time.Millisecond), - }, - ErrorType: s.errorType, - } -} diff --git a/lambda/core/externalagent_states.go b/lambda/core/externalagent_states.go deleted file mode 100644 index 23de333..0000000 --- a/lambda/core/externalagent_states.go +++ /dev/null @@ -1,209 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -// ExternalAgentState is external agent state interface -type ExternalAgentState interface { - Register([]Event) error - Ready() error - InitError(errorType string) error - ExitError(errorType string) error - ShutdownFailed() error - Exited() error - LaunchError(error) error - Name() string -} - -// ExternalAgentStartedState is the initial state of an external agent -type ExternalAgentStartedState struct { - disallowEverything - agent *ExternalAgent - initFlow InitFlowSynchronization -} - -// Register an agent with the platform when agent is in started state -func (s *ExternalAgentStartedState) Register(events []Event) error { - for _, e := range events { - if err := s.agent.subscribeUnsafe(e); err != nil { - return err - } - } - s.agent.setStateUnsafe(s.agent.RegisteredState) - s.initFlow.ExternalAgentRegistered() - return nil -} - -// LaunchError signals that agent could not launch (non-exec/permission denied) -func (s *ExternalAgentStartedState) LaunchError(err error) error { - s.agent.setStateUnsafe(s.agent.LaunchErrorState) - s.agent.errorType = string(MapErrorToAgentInfoErrorType(err)) - return nil -} - -// Name return state's human friendly name -func (s *ExternalAgentStartedState) Name() string { - return AgentStartedStateName -} - -// ExternalAgentRegisteredState is the state of an agent that registered with the platform but has not reported ready (next) -type ExternalAgentRegisteredState struct { - disallowEverything - agent *ExternalAgent - initFlow InitFlowSynchronization -} - -// Ready - agent has called next and is now successfully initialized -func (s *ExternalAgentRegisteredState) Ready() error { - s.agent.setStateUnsafe(s.agent.ReadyState) - s.initFlow.AgentReady() - s.agent.ManagedThread.SuspendUnsafe() - - if s.agent.currentState != s.agent.ReadyState { - return ErrConcurrentStateModification - } - s.agent.setStateUnsafe(s.agent.RunningState) - - return nil -} - -// InitError - agent can transitions to InitErrorState if it failed to initialize -func (s *ExternalAgentRegisteredState) InitError(errorType string) error { - s.agent.setStateUnsafe(s.agent.InitErrorState) - s.agent.errorType = errorType - return nil -} - -// ExitError - agent called /exit/error -func (s *ExternalAgentRegisteredState) ExitError(errorType string) error { - s.agent.setStateUnsafe(s.agent.ExitErrorState) - s.agent.errorType = errorType - return nil -} - -// Name return state's human friendly name -func (s *ExternalAgentRegisteredState) Name() string { - return AgentRegisteredStateName -} - -// ExternalAgentReadyState is the state of an agent that reported ready to the platform -type ExternalAgentReadyState struct { - disallowEverything - agent *ExternalAgent -} - -// ExitError signals that agent provided unrecoverable error description -func (s *ExternalAgentReadyState) ExitError(errorType string) error { - s.agent.setStateUnsafe(s.agent.ExitErrorState) - s.agent.errorType = errorType - return nil -} - -// Name return state's human friendly name -func (s *ExternalAgentReadyState) Name() string { - return AgentReadyStateName -} - -// ExternalAgentRunningState is the state of an agent that has received an invoke event and is currently processing it -type ExternalAgentRunningState struct { - disallowEverything - agent *ExternalAgent - invokeFlow InvokeFlowSynchronization -} - -// Ready - agent transitions to Ready and the calling thread gets suspended. Upon release the agent transitions to Running -func (s *ExternalAgentRunningState) Ready() error { - s.agent.setStateUnsafe(s.agent.ReadyState) - s.invokeFlow.AgentReady() - s.agent.ManagedThread.SuspendUnsafe() - - if s.agent.currentState != s.agent.ReadyState { - return ErrConcurrentStateModification - } - s.agent.setStateUnsafe(s.agent.RunningState) - - return nil -} - -// ExitError signals that agent provided unrecoverable error description -func (s *ExternalAgentRunningState) ExitError(errorType string) error { - s.agent.setStateUnsafe(s.agent.ExitErrorState) - s.agent.errorType = errorType - return nil -} - -// ShutdownFailed transitions agent into the ShutdownFailed terminal state -func (s *ExternalAgentRunningState) ShutdownFailed() error { - s.agent.setStateUnsafe(s.agent.ShutdownFailedState) - return nil -} - -// Exited - agent process has exited -func (s *ExternalAgentRunningState) Exited() error { - s.agent.setStateUnsafe(s.agent.ExitedState) - return nil -} - -// Name return state's human friendly name -func (s *ExternalAgentRunningState) Name() string { - return AgentRunningStateName -} - -// ExternalAgentInitErrorState is a terminal state where agent has reported /init/error -type ExternalAgentInitErrorState struct { - disallowEverything -} - -// Name return state's human friendly name -func (s *ExternalAgentInitErrorState) Name() string { - return AgentInitErrorStateName -} - -// InitError - multiple calls are allowed, but only the first submitted error is accepted -func (s *ExternalAgentInitErrorState) InitError(errorType string) error { - // no-op - return nil -} - -// ExternalAgentExitErrorState is a terminal state where agent has reported /exit/error -type ExternalAgentExitErrorState struct { - disallowEverything -} - -// Name return state's human friendly name -func (s *ExternalAgentExitErrorState) Name() string { - return AgentExitErrorStateName -} - -// ExitError - multiple calls are allowed, but only the first submitted error is accepted -func (s *ExternalAgentExitErrorState) ExitError(errorType string) error { - // no-op - return nil -} - -type ExternalAgentShutdownFailedState struct { - disallowEverything -} - -// Name return state's human friendly name -func (s *ExternalAgentShutdownFailedState) Name() string { - return AgentShutdownFailedStateName -} - -type ExternalAgentExitedState struct { - disallowEverything -} - -// Name return state's human friendly name -func (s *ExternalAgentExitedState) Name() string { - return AgentExitedStateName -} - -type ExternalAgentLaunchErrorState struct { - disallowEverything -} - -// Name return state's human friendly name -func (s *ExternalAgentLaunchErrorState) Name() string { - return AgentLaunchErrorName -} diff --git a/lambda/core/externalagent_states_test.go b/lambda/core/externalagent_states_test.go deleted file mode 100644 index 5d38c80..0000000 --- a/lambda/core/externalagent_states_test.go +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -import ( - "errors" - "github.com/stretchr/testify/require" - "go.amzn.com/lambda/testdata/mockthread" - "testing" -) - -func TestExternalAgentStateUnknownEventType(t *testing.T) { - agent := NewExternalAgent("name", &mockInitFlowSynchronization{}, &mockInvokeFlowSynchronization{}) - agent.ManagedThread = &mockthread.MockManagedThread{} - require.Equal(t, agent.StartedState, agent.GetState()) - require.Equal(t, errInvalidEventType, agent.Register([]Event{"foo"})) - require.Equal(t, agent.StartedState, agent.GetState()) -} - -func TestExternalAgentStateTransitionsFromStartedState(t *testing.T) { - agent := NewExternalAgent("name", &mockInitFlowSynchronization{}, &mockInvokeFlowSynchronization{}) - agent.ManagedThread = &mockthread.MockManagedThread{} - // Initial agent state is Start - require.Equal(t, agent.StartedState, agent.GetState()) - - require.NoError(t, agent.Register([]Event{})) - require.Equal(t, agent.RegisteredState, agent.GetState()) - agent.SetState(agent.StartedState) - - require.NoError(t, agent.LaunchError(errors.New("someerror"))) - require.Equal(t, agent.LaunchErrorState, agent.GetState()) - agent.SetState(agent.StartedState) - - require.Equal(t, ErrNotAllowed, agent.Ready()) - require.Equal(t, agent.StartedState, agent.GetState()) - - require.Equal(t, ErrNotAllowed, agent.InitError("Extension.TestError")) - require.Equal(t, agent.StartedState, agent.GetState()) - - require.Equal(t, ErrNotAllowed, agent.ExitError("Extension.TestError")) - require.Equal(t, agent.StartedState, agent.GetState()) - - require.Equal(t, ErrNotAllowed, agent.ShutdownFailed()) - require.Equal(t, agent.StartedState, agent.GetState()) - - require.Equal(t, ErrNotAllowed, agent.Exited()) - require.Equal(t, agent.StartedState, agent.GetState()) -} - -func TestExternalAgentStateTransitionsFromRegisteredState(t *testing.T) { - agent := NewExternalAgent("name", &mockInitFlowSynchronization{}, &mockInvokeFlowSynchronization{}) - agent.ManagedThread = &mockthread.MockManagedThread{} - agent.SetState(agent.RegisteredState) - - require.Equal(t, ErrNotAllowed, agent.Register([]Event{})) - require.Equal(t, agent.RegisteredState, agent.GetState()) - - require.NoError(t, agent.Ready()) - require.Equal(t, agent.RunningState, agent.GetState()) - - agent.SetState(agent.RegisteredState) - require.NoError(t, agent.InitError("Extension.TestError")) - require.Equal(t, agent.InitErrorState, agent.GetState()) - require.Equal(t, "Extension.TestError", agent.errorType) - - agent.SetState(agent.RegisteredState) - require.NoError(t, agent.ExitError("Extension.TestError")) - require.Equal(t, agent.ExitErrorState, agent.GetState()) - require.Equal(t, "Extension.TestError", agent.errorType) -} - -func TestExternalAgentStateTransitionsFromReadyState(t *testing.T) { - agent := NewExternalAgent("name", &mockInitFlowSynchronization{}, &mockInvokeFlowSynchronization{}) - agent.ManagedThread = &mockthread.MockManagedThread{} - agent.SetState(agent.ReadyState) - - require.Equal(t, ErrNotAllowed, agent.Register([]Event{})) - require.Equal(t, agent.ReadyState, agent.GetState()) - - require.Equal(t, ErrNotAllowed, agent.Ready()) - require.Equal(t, agent.ReadyState, agent.GetState()) - - require.Equal(t, ErrNotAllowed, agent.InitError("Extension.TestError")) - require.Equal(t, agent.ReadyState, agent.GetState()) - - agent.SetState(agent.ReadyState) - require.NoError(t, agent.ExitError("Extension.TestError")) - require.Equal(t, agent.ExitErrorState, agent.GetState()) - require.Equal(t, "Extension.TestError", agent.errorType) - - agent.SetState(agent.ReadyState) - require.Equal(t, ErrNotAllowed, agent.Exited()) - require.Equal(t, agent.ReadyState, agent.GetState()) - - require.Equal(t, ErrNotAllowed, agent.ShutdownFailed()) - require.Equal(t, agent.ReadyState, agent.GetState()) -} - -func assertAgentIsInFinalState(t *testing.T, agent *ExternalAgent) { - initialState := agent.GetState() - require.Equal(t, ErrNotAllowed, agent.Register([]Event{})) - require.Equal(t, initialState, agent.GetState()) - require.Equal(t, ErrNotAllowed, agent.Ready()) - require.Equal(t, initialState, agent.GetState()) - require.Equal(t, ErrNotAllowed, agent.ShutdownFailed()) - require.Equal(t, initialState, agent.GetState()) - require.Equal(t, ErrNotAllowed, agent.Exited()) - require.Equal(t, initialState, agent.GetState()) - require.Equal(t, ErrNotAllowed, agent.LaunchError(errors.New("someerror"))) - require.Equal(t, initialState, agent.GetState()) - - // InitError state can be re-entered from InitError state - if agent.InitErrorState == initialState { - require.Equal(t, nil, agent.InitError("Extension.TestError")) - } else { - require.Equal(t, ErrNotAllowed, agent.InitError("Extension.TestError")) - } - - require.Equal(t, initialState, agent.GetState()) - - // ExitError state can be re-entered from ExitError state - if agent.ExitErrorState == initialState { - require.Equal(t, nil, agent.ExitError("Extension.TestError")) - } else { - require.Equal(t, ErrNotAllowed, agent.ExitError("Extension.TestError")) - } - - require.Equal(t, initialState, agent.GetState()) -} - -func TestExternalAgentStateTransitionsFromInitErrorState(t *testing.T) { - agent := NewExternalAgent("name", &mockInitFlowSynchronization{}, &mockInvokeFlowSynchronization{}) - agent.ManagedThread = &mockthread.MockManagedThread{} - agent.SetState(agent.InitErrorState) - assertAgentIsInFinalState(t, agent) -} - -func TestExternalAgentStateTransitionsFromExitErrorState(t *testing.T) { - agent := NewExternalAgent("name", &mockInitFlowSynchronization{}, &mockInvokeFlowSynchronization{}) - agent.ManagedThread = &mockthread.MockManagedThread{} - agent.SetState(agent.ExitErrorState) - assertAgentIsInFinalState(t, agent) -} - -func TestExternalAgentStateTransitionsFromShutdownFailedState(t *testing.T) { - agent := NewExternalAgent("name", &mockInitFlowSynchronization{}, &mockInvokeFlowSynchronization{}) - agent.ManagedThread = &mockthread.MockManagedThread{} - agent.SetState(agent.ShutdownFailedState) - assertAgentIsInFinalState(t, agent) -} - -func TestExternalAgentStateTransitionsFromExitedState(t *testing.T) { - agent := NewExternalAgent("name", &mockInitFlowSynchronization{}, &mockInvokeFlowSynchronization{}) - agent.ManagedThread = &mockthread.MockManagedThread{} - agent.SetState(agent.ExitedState) - assertAgentIsInFinalState(t, agent) -} - -func TestExternalAgentStateTransitionsFromRunningState(t *testing.T) { - agent := NewExternalAgent("name", &mockInitFlowSynchronization{}, &mockInvokeFlowSynchronization{}) - agent.ManagedThread = &mockthread.MockManagedThread{} - agent.SetState(agent.RunningState) - require.Equal(t, agent.RunningState, agent.GetState()) - - require.Equal(t, ErrNotAllowed, agent.Register([]Event{})) - require.Equal(t, agent.RunningState, agent.GetState()) - - require.Equal(t, ErrNotAllowed, agent.InitError("Extension.TestError")) - require.Equal(t, agent.RunningState, agent.GetState()) - - require.NoError(t, agent.ShutdownFailed()) - require.Equal(t, agent.ShutdownFailedState, agent.GetState()) - - agent.SetState(agent.RunningState) - require.NoError(t, agent.Exited()) - require.Equal(t, agent.ExitedState, agent.GetState()) - - agent.SetState(agent.RunningState) - require.NoError(t, agent.Ready()) - require.Equal(t, agent.RunningState, agent.GetState()) -} - -func TestExternalAgentStateTransitionsFromLaunchErrorState(t *testing.T) { - agent := NewExternalAgent("name", &mockInitFlowSynchronization{}, &mockInvokeFlowSynchronization{}) - agent.ManagedThread = &mockthread.MockManagedThread{} - agent.SetState(agent.LaunchErrorState) - assertAgentIsInFinalState(t, agent) -} diff --git a/lambda/core/flow.go b/lambda/core/flow.go deleted file mode 100644 index 08d5e4b..0000000 --- a/lambda/core/flow.go +++ /dev/null @@ -1,223 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -import ( - "context" - - "go.amzn.com/lambda/interop" -) - -// InitFlowSynchronization wraps init flow barriers. -type InitFlowSynchronization interface { - SetExternalAgentsRegisterCount(uint16) error - SetAgentsReadyCount(uint16) error - - ExternalAgentRegistered() error - AwaitExternalAgentsRegistered() error - - RuntimeReady() error - AwaitRuntimeReady() error - AwaitRuntimeReadyWithDeadline(context.Context) error - - AgentReady() error - AwaitAgentsReady() error - - CancelWithError(error) - - RuntimeRestoreReady() error - AwaitRuntimeRestoreReady() error - - Clear() -} - -type initFlowSynchronizationImpl struct { - externalAgentsRegisteredGate Gate - runtimeReadyGate Gate - agentReadyGate Gate - runtimeRestoreReadyGate Gate -} - -// SetExternalAgentsRegisterCount notifies init flow that N /extension/register calls should be done in future by external agents -func (s *initFlowSynchronizationImpl) SetExternalAgentsRegisterCount(externalAgentsNumber uint16) error { - return s.externalAgentsRegisteredGate.SetCount(externalAgentsNumber) -} - -// SetAgentsReadyCount sets the number of agents we expect at the gate once we know it -func (s *initFlowSynchronizationImpl) SetAgentsReadyCount(agentCount uint16) error { - return s.agentReadyGate.SetCount(agentCount) -} - -// AwaitRuntimeReady awaits runtime ready state -func (s *initFlowSynchronizationImpl) AwaitRuntimeReady() error { - return s.runtimeReadyGate.AwaitGateCondition() -} - -func (s *initFlowSynchronizationImpl) AwaitRuntimeReadyWithDeadline(ctx context.Context) error { - var err error - errorChan := make(chan error) - - go func() { - errorChan <- s.runtimeReadyGate.AwaitGateCondition() - }() - - select { - case err = <-errorChan: - break - case <-ctx.Done(): - err = interop.ErrRestoreHookTimeout - s.CancelWithError(err) - break - } - - return err -} - -// AwaitRuntimeRestoreReady awaits runtime restore ready state (/restore/next is called by runtime) -func (s *initFlowSynchronizationImpl) AwaitRuntimeRestoreReady() error { - return s.runtimeRestoreReadyGate.AwaitGateCondition() -} - -// AwaitExternalAgentsRegistered awaits for all subscribed agents to report registered -func (s *initFlowSynchronizationImpl) AwaitExternalAgentsRegistered() error { - return s.externalAgentsRegisteredGate.AwaitGateCondition() -} - -// AwaitAgentReady awaits for registered extensions to report ready -func (s *initFlowSynchronizationImpl) AwaitAgentsReady() error { - return s.agentReadyGate.AwaitGateCondition() -} - -// Ready called by runtime when initialized -func (s *initFlowSynchronizationImpl) RuntimeReady() error { - return s.runtimeReadyGate.WalkThrough() -} - -// Ready called by runtime when restore is completed (i.e. /next is called after /restore/next) -func (s *initFlowSynchronizationImpl) RuntimeRestoreReady() error { - return s.runtimeRestoreReadyGate.WalkThrough() -} - -// Ready called by agent when initialized -func (s *initFlowSynchronizationImpl) AgentReady() error { - return s.agentReadyGate.WalkThrough() -} - -// ExternalAgentRegistered called by agent as part of /register request -func (s *initFlowSynchronizationImpl) ExternalAgentRegistered() error { - return s.externalAgentsRegisteredGate.WalkThrough() -} - -// Cancel cancels gates with error. -func (s *initFlowSynchronizationImpl) CancelWithError(err error) { - s.externalAgentsRegisteredGate.CancelWithError(err) - s.runtimeReadyGate.CancelWithError(err) - s.agentReadyGate.CancelWithError(err) - s.runtimeRestoreReadyGate.CancelWithError(err) -} - -// Clear gates state -func (s *initFlowSynchronizationImpl) Clear() { - s.externalAgentsRegisteredGate.Clear() - s.runtimeReadyGate.Clear() - s.agentReadyGate.Clear() - s.runtimeRestoreReadyGate.Clear() -} - -// NewInitFlowSynchronization returns new InitFlowSynchronization instance. -func NewInitFlowSynchronization() InitFlowSynchronization { - initFlow := &initFlowSynchronizationImpl{ - runtimeReadyGate: NewGate(1), - externalAgentsRegisteredGate: NewGate(0), - agentReadyGate: NewGate(maxAgentsLimit), - runtimeRestoreReadyGate: NewGate(1), - } - return initFlow -} - -// InvokeFlowSynchronization wraps invoke flow barriers. -type InvokeFlowSynchronization interface { - InitializeBarriers() error - AwaitRuntimeResponse() error - AwaitRuntimeReady() error - RuntimeResponse(runtime *Runtime) error - RuntimeReady(runtime *Runtime) error - SetAgentsReadyCount(agentCount uint16) error - AgentReady() error - AwaitAgentsReady() error - CancelWithError(error) - Clear() -} - -type invokeFlowSynchronizationImpl struct { - runtimeReadyGate Gate - runtimeResponseGate Gate - agentReadyGate Gate -} - -// InitializeBarriers ... -func (s *invokeFlowSynchronizationImpl) InitializeBarriers() error { - s.runtimeReadyGate.Reset() - s.runtimeResponseGate.Reset() - s.agentReadyGate.Reset() - return nil -} - -// Clear gates state -func (s *invokeFlowSynchronizationImpl) Clear() { - s.runtimeReadyGate.Clear() - s.runtimeResponseGate.Clear() - s.agentReadyGate.Clear() -} - -// AwaitRuntimeResponse awaits runtime to send response to the platform. -func (s *invokeFlowSynchronizationImpl) AwaitRuntimeResponse() error { - return s.runtimeResponseGate.AwaitGateCondition() -} - -// AwaitRuntimeReady awaits runtime ready state. -func (s *invokeFlowSynchronizationImpl) AwaitRuntimeReady() error { - return s.runtimeReadyGate.AwaitGateCondition() -} - -// RuntimeResponse called by runtime when runtime response is made available to the platform. -func (s *invokeFlowSynchronizationImpl) RuntimeResponse(a *Runtime) error { - return s.runtimeResponseGate.WalkThrough() -} - -// RuntimeReady called by runtime when runtime ready. -func (s *invokeFlowSynchronizationImpl) RuntimeReady(a *Runtime) error { - return s.runtimeReadyGate.WalkThrough() -} - -// Cancel cancels gates. -func (s *invokeFlowSynchronizationImpl) CancelWithError(err error) { - s.runtimeResponseGate.CancelWithError(err) - s.runtimeReadyGate.CancelWithError(err) - s.agentReadyGate.CancelWithError(err) -} - -// SetAgentsReadyCount sets the number of agents we expect at the gate once we know it -func (s *invokeFlowSynchronizationImpl) SetAgentsReadyCount(agentCount uint16) error { - return s.agentReadyGate.SetCount(agentCount) -} - -// Ready called by agent when initialized -func (s *invokeFlowSynchronizationImpl) AgentReady() error { - return s.agentReadyGate.WalkThrough() -} - -// AwaitAgentReady awaits for registered extensions to report ready -func (s *invokeFlowSynchronizationImpl) AwaitAgentsReady() error { - return s.agentReadyGate.AwaitGateCondition() -} - -// NewInvokeFlowSynchronization returns new InvokeFlowSynchronization instance. -func NewInvokeFlowSynchronization() InvokeFlowSynchronization { - return &invokeFlowSynchronizationImpl{ - runtimeReadyGate: NewGate(1), - runtimeResponseGate: NewGate(1), - agentReadyGate: NewGate(maxAgentsLimit), - } -} diff --git a/lambda/core/gates.go b/lambda/core/gates.go deleted file mode 100644 index 3b5a9f1..0000000 --- a/lambda/core/gates.go +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -import ( - "errors" - "math" - "sync" -) - -const maxAgentsLimit uint16 = math.MaxUint16 - -// Gate ... -type Gate interface { - Register(count uint16) - Reset() - SetCount(uint16) error - WalkThrough() error - AwaitGateCondition() error - CancelWithError(error) - Clear() -} - -type gateImpl struct { - count uint16 - arrived uint16 - gateCondition *sync.Cond - canceled bool - err error -} - -func (g *gateImpl) Register(count uint16) { - g.gateCondition.L.Lock() - defer g.gateCondition.L.Unlock() - g.count += count -} - -// SetCount sets the expected number of arrivals on the gate -func (g *gateImpl) SetCount(count uint16) error { - g.gateCondition.L.Lock() - defer g.gateCondition.L.Unlock() - // you can't set count larger than limit if limit is max uint but leaving it here for correctness in case limit changes - if count > maxAgentsLimit || count < g.arrived { - return ErrGateIntegrity - } - g.count = count - return nil -} - -func (g *gateImpl) Reset() { - g.gateCondition.L.Lock() - defer g.gateCondition.L.Unlock() - if !g.canceled { - g.arrived = 0 - } -} - -// ErrGateIntegrity ... -var ErrGateIntegrity = errors.New("ErrGateIntegrity") - -// ErrGateCanceled ... -var ErrGateCanceled = errors.New("ErrGateCanceled") - -// WalkThrough walks through this gate without awaiting others. -func (g *gateImpl) WalkThrough() error { - g.gateCondition.L.Lock() - defer g.gateCondition.L.Unlock() - - if g.arrived == g.count { - return ErrGateIntegrity - } - - g.arrived++ - - if g.arrived == g.count { - g.gateCondition.Broadcast() - } - - return nil -} - -// AwaitGateCondition suspends thread execution until gate condition -// is met or await is canceled via Cancel method. -func (g *gateImpl) AwaitGateCondition() error { - g.gateCondition.L.Lock() - defer g.gateCondition.L.Unlock() - - for g.arrived != g.count && !g.canceled { - g.gateCondition.Wait() - } - - if g.canceled { - if g.err != nil { - return g.err - } - return ErrGateCanceled - } - - return nil -} - -// CancelWithError cancels gate condition with error and awakes suspended threads. -func (g *gateImpl) CancelWithError(err error) { - g.gateCondition.L.Lock() - defer g.gateCondition.L.Unlock() - g.canceled = true - g.err = err - g.gateCondition.Broadcast() -} - -// Clear gate state -func (g *gateImpl) Clear() { - g.gateCondition.L.Lock() - defer g.gateCondition.L.Unlock() - - g.canceled = false - g.arrived = 0 - g.err = nil -} - -// NewGate returns new gate instance. -func NewGate(count uint16) Gate { - return &gateImpl{ - count: count, - gateCondition: sync.NewCond(&sync.Mutex{}), - } -} diff --git a/lambda/core/gates_test.go b/lambda/core/gates_test.go deleted file mode 100644 index 156e085..0000000 --- a/lambda/core/gates_test.go +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -import ( - "errors" - "github.com/stretchr/testify/assert" - "golang.org/x/sync/errgroup" - "testing" -) - -func TestWalkThrough(t *testing.T) { - g := NewGate(1) - assert.NoError(t, g.WalkThrough()) -} - -func TestWalkThroughTwice(t *testing.T) { - g := NewGate(1) - assert.NoError(t, g.WalkThrough()) - assert.Equal(t, ErrGateIntegrity, g.WalkThrough()) -} - -func TestSetCount(t *testing.T) { - g := NewGate(2) - assert.NoError(t, g.WalkThrough()) - assert.NoError(t, g.WalkThrough()) // arrived is now 2 - assert.Equal(t, ErrGateIntegrity, g.SetCount(1)) - assert.NoError(t, g.SetCount(2)) // set to 2 - assert.Equal(t, ErrGateIntegrity, g.WalkThrough()) // can't go to 3 - assert.NoError(t, g.SetCount(3)) // set to 3 - assert.NoError(t, g.WalkThrough()) -} - -func TestReset(t *testing.T) { - g := NewGate(1) - assert.NoError(t, g.WalkThrough()) - g.Reset() - assert.NoError(t, g.WalkThrough()) -} - -func TestCancel(t *testing.T) { - g := NewGate(1) - - var errg errgroup.Group - errg.Go(g.AwaitGateCondition) - g.CancelWithError(nil) - - assert.Equal(t, ErrGateCanceled, errg.Wait()) -} - -func TestCancelWithError(t *testing.T) { - g := NewGate(1) - - var errg errgroup.Group - errg.Go(g.AwaitGateCondition) - - err := errors.New("MyErr") - g.CancelWithError(err) - - assert.Equal(t, err, errg.Wait()) -} - -func TestUseAfterCancel(t *testing.T) { - g := NewGate(1) - err := errors.New("MyErr") - g.CancelWithError(err) - assert.Equal(t, err, g.AwaitGateCondition()) - g.Reset() - assert.Equal(t, err, g.AwaitGateCondition()) -} - -func BenchmarkAwaitGateCondition(b *testing.B) { - g := NewGate(1) - - for n := 0; n < b.N; n++ { - go func() { g.WalkThrough() }() - if err := g.AwaitGateCondition(); err != nil { - panic(err) - } - g.Reset() - } -} - -// go test -run=XXX -bench=. -benchtime 10000000x -cpu 1 -blockprofile /tmp/pprof/block3.out src/go.amzn.com/lambda/core/* -// goos: linux -// goarch: amd64 -// BenchmarkAwaitGateCondition 10000000 1834 ns/op -// PASS -// ok command-line-arguments 18.449s diff --git a/lambda/core/internalagent.go b/lambda/core/internalagent.go deleted file mode 100644 index e3a83db..0000000 --- a/lambda/core/internalagent.go +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -import ( - "fmt" - "time" - - "go.amzn.com/lambda/core/statejson" - - "github.com/google/uuid" -) - -// InternalAgent represents internal agent -type InternalAgent struct { - Name string - ID uuid.UUID - events map[Event]struct{} - - ManagedThread Suspendable - - currentState InternalAgentState - stateLastModified time.Time - - StartedState InternalAgentState - RegisteredState InternalAgentState - RunningState InternalAgentState - ReadyState InternalAgentState - InitErrorState InternalAgentState - ExitErrorState InternalAgentState - - errorType string -} - -// NewInternalAgent returns new instance of a named agent -func NewInternalAgent(name string, initFlow InitFlowSynchronization, invokeFlow InvokeFlowSynchronization) *InternalAgent { - agent := &InternalAgent{ - Name: name, - ID: uuid.New(), - ManagedThread: NewManagedThread(), - events: make(map[Event]struct{}), - } - - agent.StartedState = &InternalAgentStartedState{agent: agent} - agent.RegisteredState = &InternalAgentRegisteredState{agent: agent, initFlow: initFlow} - agent.RunningState = &InternalAgentRunningState{agent: agent, invokeFlow: invokeFlow} - agent.ReadyState = &InternalAgentReadyState{agent: agent} - agent.InitErrorState = &InternalAgentInitErrorState{} - agent.ExitErrorState = &InternalAgentExitErrorState{} - - agent.setStateUnsafe(agent.StartedState) - - return agent -} - -func (s *InternalAgent) String() string { - return fmt.Sprintf("%s (%s)", s.Name, s.ID) -} - -// SuspendUnsafe the current running thread -func (s *InternalAgent) SuspendUnsafe() { - s.ManagedThread.SuspendUnsafe() -} - -// Release will resume a suspended thread -func (s *InternalAgent) Release() { - s.ManagedThread.Release() -} - -// SetState using the lock -func (s *InternalAgent) SetState(state InternalAgentState) { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - s.setStateUnsafe(state) -} - -func (s *InternalAgent) setStateUnsafe(state InternalAgentState) { - s.currentState = state - s.stateLastModified = time.Now() -} - -func ValidateInternalAgentEvent(e Event) error { - switch e { - case InvokeEvent: - return nil - case ShutdownEvent: - return errEventNotSupportedForInternalAgent - } - return errInvalidEventType -} - -func (s *InternalAgent) subscribeUnsafe(e Event) error { - if err := ValidateInternalAgentEvent(e); err != nil { - return err - } - s.events[e] = struct{}{} - return nil -} - -// IsSubscribed checks whether agent is subscribed the Event -func (s *InternalAgent) IsSubscribed(e Event) bool { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - _, found := s.events[e] - return found -} - -// SubscribedEvents returns events to which the agent is subscribed -func (s *InternalAgent) SubscribedEvents() []string { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - - events := []string{} - for event := range s.events { - events = append(events, string(event)) - } - return events -} - -// GetState returns agent's current state -func (s *InternalAgent) GetState() InternalAgentState { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return s.currentState -} - -// Register an agent with the platform -func (s *InternalAgent) Register(events []Event) error { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return s.currentState.Register(events) -} - -// Ready - mark an agent as ready -func (s *InternalAgent) Ready() error { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return s.currentState.Ready() -} - -// InitError - agent registered but failed to initialize -func (s *InternalAgent) InitError(errorType string) error { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return s.currentState.InitError(errorType) -} - -// ExitError - agent registered but failed to initialize -func (s *InternalAgent) ExitError(errorType string) error { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return s.currentState.ExitError(errorType) -} - -// ErrorType returns error type reported during init or exit -func (s *InternalAgent) ErrorType() string { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return s.errorType -} - -// GetAgentDescription returns agent description object for debugging purposes -func (s *InternalAgent) GetAgentDescription() statejson.ExtensionDescription { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return statejson.ExtensionDescription{ - Name: s.Name, - ID: s.ID.String(), - State: statejson.StateDescription{ - Name: s.currentState.Name(), - LastModified: s.stateLastModified.UnixNano() / int64(time.Millisecond), - }, - ErrorType: s.errorType, - } -} diff --git a/lambda/core/internalagent_states.go b/lambda/core/internalagent_states.go deleted file mode 100644 index da7fb80..0000000 --- a/lambda/core/internalagent_states.go +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -// InternalAgentState is internal agent state interface -type InternalAgentState interface { - Register([]Event) error - Ready() error - InitError(errorType string) error - ExitError(errorType string) error - Name() string -} - -// InternalAgentStartedState is the initial state of an internal agent -type InternalAgentStartedState struct { - disallowEverything - agent *InternalAgent -} - -// Register an agent with the platform when agent is in started state -func (s *InternalAgentStartedState) Register(events []Event) error { - for _, e := range events { - if err := s.agent.subscribeUnsafe(e); err != nil { - return err - } - } - s.agent.setStateUnsafe(s.agent.RegisteredState) - return nil -} - -// Name return state's human friendly name -func (s *InternalAgentStartedState) Name() string { - return AgentStartedStateName -} - -// InternalAgentRegisteredState is the state of an agent that registered with the platform but has not reported ready (next) -type InternalAgentRegisteredState struct { - disallowEverything - agent *InternalAgent - initFlow InitFlowSynchronization -} - -// Ready - agent has called next and is now successfully initialized -func (s *InternalAgentRegisteredState) Ready() error { - s.agent.setStateUnsafe(s.agent.ReadyState) - s.initFlow.AgentReady() - s.agent.ManagedThread.SuspendUnsafe() - - if s.agent.currentState != s.agent.ReadyState { - return ErrConcurrentStateModification - } - s.agent.setStateUnsafe(s.agent.RunningState) - - return nil -} - -// InitError - agent can transitions to InitErrorState if it failed to initialize -func (s *InternalAgentRegisteredState) InitError(errorType string) error { - s.agent.setStateUnsafe(s.agent.InitErrorState) - s.agent.errorType = errorType - return nil -} - -// ExitError - agent called /exit/error -func (s *InternalAgentRegisteredState) ExitError(errorType string) error { - s.agent.setStateUnsafe(s.agent.ExitErrorState) - s.agent.errorType = errorType - return nil -} - -// Name return state's human friendly name -func (s *InternalAgentRegisteredState) Name() string { - return AgentRegisteredStateName -} - -// InternalAgentReadyState is the state of an agent that reported ready to the platform -type InternalAgentReadyState struct { - disallowEverything - agent *InternalAgent -} - -// ExitError - agent called /exit/error -func (s *InternalAgentReadyState) ExitError(errorType string) error { - s.agent.setStateUnsafe(s.agent.ExitErrorState) - s.agent.errorType = errorType - return nil -} - -// Name return state's human friendly name -func (s *InternalAgentReadyState) Name() string { - return AgentReadyStateName -} - -// InternalAgentRunningState is the state of an agent that is currently processing an event -type InternalAgentRunningState struct { - disallowEverything - agent *InternalAgent - invokeFlow InvokeFlowSynchronization -} - -// Ready - agent can transition from ready to ready -func (s *InternalAgentRunningState) Ready() error { - s.agent.setStateUnsafe(s.agent.ReadyState) - s.invokeFlow.AgentReady() - s.agent.ManagedThread.SuspendUnsafe() - - if s.agent.currentState != s.agent.ReadyState { - return ErrConcurrentStateModification - } - s.agent.setStateUnsafe(s.agent.RunningState) - - return nil -} - -// ExitError - agent called /exit/error -func (s *InternalAgentRunningState) ExitError(errorType string) error { - s.agent.setStateUnsafe(s.agent.ExitErrorState) - s.agent.errorType = errorType - return nil -} - -// Name return state's human friendly name -func (s *InternalAgentRunningState) Name() string { - return AgentRunningStateName -} - -// InternalAgentInitErrorState is a terminal state where agent has reported /init/error -type InternalAgentInitErrorState struct { - disallowEverything -} - -// Name return state's human friendly name -func (s *InternalAgentInitErrorState) Name() string { - return AgentInitErrorStateName -} - -// InitError - multiple calls are allowed, but only the first submitted error is accepted -func (s *InternalAgentInitErrorState) InitError(errorType string) error { - // no-op - return nil -} - -// InternalAgentExitErrorState is a terminal state where agent has reported /exit/error -type InternalAgentExitErrorState struct { - disallowEverything -} - -// Name return state's human friendly name -func (s *InternalAgentExitErrorState) Name() string { - return AgentExitErrorStateName -} - -// ExitError - multiple calls are allowed, but only the first submitted error is accepted -func (s *InternalAgentExitErrorState) ExitError(errorType string) error { - // no-op - return nil -} diff --git a/lambda/core/internalagent_states_test.go b/lambda/core/internalagent_states_test.go deleted file mode 100644 index 9cf1f0e..0000000 --- a/lambda/core/internalagent_states_test.go +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -import ( - "github.com/stretchr/testify/require" - "go.amzn.com/lambda/testdata/mockthread" - "testing" -) - -func TestInternalAgentStateUnknownEventType(t *testing.T) { - agent := NewInternalAgent("name", &mockInitFlowSynchronization{}, &mockInvokeFlowSynchronization{}) - agent.ManagedThread = &mockthread.MockManagedThread{} - require.Equal(t, agent.StartedState, agent.GetState()) - require.Equal(t, errInvalidEventType, agent.Register([]Event{"foo"})) - require.Equal(t, agent.StartedState, agent.GetState()) -} - -func TestInternalAgentStateInvalidEventType(t *testing.T) { - agent := NewInternalAgent("name", &mockInitFlowSynchronization{}, &mockInvokeFlowSynchronization{}) - agent.ManagedThread = &mockthread.MockManagedThread{} - require.Equal(t, agent.StartedState, agent.GetState()) - require.Equal(t, errEventNotSupportedForInternalAgent, agent.Register([]Event{ShutdownEvent})) - require.Equal(t, agent.StartedState, agent.GetState()) -} - -func TestInternalAgentStateTransitionsFromStartedState(t *testing.T) { - agent := NewInternalAgent("name", &mockInitFlowSynchronization{}, &mockInvokeFlowSynchronization{}) - agent.ManagedThread = &mockthread.MockManagedThread{} - // Initial agent state is Start - require.Equal(t, agent.StartedState, agent.GetState()) - require.NoError(t, agent.Register([]Event{})) - require.Equal(t, agent.RegisteredState, agent.GetState()) - - agent.SetState(agent.StartedState) - require.Equal(t, ErrNotAllowed, agent.Ready()) - require.Equal(t, agent.StartedState, agent.GetState()) - - require.Equal(t, ErrNotAllowed, agent.InitError("Extension.TestError")) - require.Equal(t, agent.StartedState, agent.GetState()) - - require.Equal(t, ErrNotAllowed, agent.ExitError("Extension.TestError")) - require.Equal(t, agent.StartedState, agent.GetState()) -} - -func TestInternalAgentStateTransitionsFromRegisteredState(t *testing.T) { - agent := NewInternalAgent("name", &mockInitFlowSynchronization{}, &mockInvokeFlowSynchronization{}) - agent.ManagedThread = &mockthread.MockManagedThread{} - agent.SetState(agent.RegisteredState) - - require.Equal(t, ErrNotAllowed, agent.Register([]Event{})) - require.Equal(t, agent.RegisteredState, agent.GetState()) - - require.NoError(t, agent.Ready()) - require.Equal(t, agent.RunningState, agent.GetState()) - - agent.SetState(agent.RegisteredState) - require.NoError(t, agent.InitError("Extension.TestError")) - require.Equal(t, agent.InitErrorState, agent.GetState()) - require.Equal(t, "Extension.TestError", agent.errorType) - - agent.SetState(agent.RegisteredState) - require.NoError(t, agent.ExitError("Extension.TestError")) - require.Equal(t, agent.ExitErrorState, agent.GetState()) - require.Equal(t, "Extension.TestError", agent.errorType) -} - -func TestInternalAgentStateTransitionsFromReadyState(t *testing.T) { - agent := NewInternalAgent("name", &mockInitFlowSynchronization{}, &mockInvokeFlowSynchronization{}) - agent.ManagedThread = &mockthread.MockManagedThread{} - agent.SetState(agent.ReadyState) - - require.Equal(t, ErrNotAllowed, agent.Register([]Event{})) - require.Equal(t, agent.ReadyState, agent.GetState()) - - require.Equal(t, ErrNotAllowed, agent.InitError("Extension.TestError")) - require.Equal(t, agent.ReadyState, agent.GetState()) - - agent.SetState(agent.ReadyState) - require.NoError(t, agent.ExitError("Extension.TestError")) - require.Equal(t, agent.ExitErrorState, agent.GetState()) - require.Equal(t, "Extension.TestError", agent.errorType) - - agent.SetState(agent.ReadyState) - require.Equal(t, ErrNotAllowed, agent.Ready()) - require.Equal(t, agent.ReadyState, agent.GetState()) -} - -func TestInternalAgentStateTransitionsFromInitErrorState(t *testing.T) { - agent := NewInternalAgent("name", &mockInitFlowSynchronization{}, &mockInvokeFlowSynchronization{}) - agent.ManagedThread = &mockthread.MockManagedThread{} - agent.SetState(agent.InitErrorState) - - require.Equal(t, ErrNotAllowed, agent.Register([]Event{})) - require.Equal(t, agent.InitErrorState, agent.GetState()) - require.Equal(t, nil, agent.InitError("Extension.TestError")) // InitError -> InitError reentrancy - require.Equal(t, agent.InitErrorState, agent.GetState()) - require.Equal(t, ErrNotAllowed, agent.ExitError("Extension.TestError")) - require.Equal(t, agent.InitErrorState, agent.GetState()) - require.Equal(t, ErrNotAllowed, agent.Ready()) - require.Equal(t, agent.InitErrorState, agent.GetState()) -} - -func TestInternalAgentStateTransitionsFromExitErrorState(t *testing.T) { - agent := NewInternalAgent("name", &mockInitFlowSynchronization{}, &mockInvokeFlowSynchronization{}) - agent.ManagedThread = &mockthread.MockManagedThread{} - agent.SetState(agent.ExitErrorState) - - require.Equal(t, ErrNotAllowed, agent.Register([]Event{})) - require.Equal(t, agent.ExitErrorState, agent.GetState()) - require.Equal(t, nil, agent.ExitError("Extension.TestError")) // ExitError -> ExitError reentrancy - require.Equal(t, agent.ExitErrorState, agent.GetState()) - require.Equal(t, ErrNotAllowed, agent.InitError("Extension.TestError")) - require.Equal(t, agent.ExitErrorState, agent.GetState()) - require.Equal(t, ErrNotAllowed, agent.Ready()) - require.Equal(t, agent.ExitErrorState, agent.GetState()) -} - -func TestInternalAgentStateTransitionsFromRunningState(t *testing.T) { - agent := NewInternalAgent("name", &mockInitFlowSynchronization{}, &mockInvokeFlowSynchronization{}) - agent.ManagedThread = &mockthread.MockManagedThread{} - agent.SetState(agent.RunningState) - require.Equal(t, agent.RunningState, agent.GetState()) - - require.Equal(t, ErrNotAllowed, agent.Register([]Event{})) - require.Equal(t, agent.RunningState, agent.GetState()) - - agent.SetState(agent.RunningState) - require.NoError(t, agent.Ready()) - require.Equal(t, agent.RunningState, agent.GetState()) -} diff --git a/lambda/core/registrations.go b/lambda/core/registrations.go deleted file mode 100644 index 26f6f2f..0000000 --- a/lambda/core/registrations.go +++ /dev/null @@ -1,409 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -import ( - "errors" - "os" - "sync" - - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/core/statejson" - "go.amzn.com/lambda/interop" - - "github.com/google/uuid" - - log "github.com/sirupsen/logrus" -) - -type registrationServiceState int - -const ( - registrationServiceOn registrationServiceState = iota - registrationServiceOff -) - -const MaxAgentsAllowed = 10 - -// Event represents a platform event which agent can subscribe to -type Event string - -const ( - // InvokeEvent is dispatched when INVOKE happens - InvokeEvent Event = "INVOKE" - // ShutdownEvent is dispatched when SHUTDOWN or RESET happen - ShutdownEvent Event = "SHUTDOWN" -) - -// ErrRegistrationServiceOff returned on attempt to register after registration service has been turned off. -var ErrRegistrationServiceOff = errors.New("ErrRegistrationServiceOff") - -// ErrTooManyExtensions means MaxAgentsAllowed limit is exceeded -var ErrTooManyExtensions = errors.New("ErrTooManyExtensions") - -type AgentInfoErrorType string - -const ( - PermissionDenied AgentInfoErrorType = "PermissionDenied" - TooManyExtensions AgentInfoErrorType = "TooManyExtensions" - UnknownError AgentInfoErrorType = "UnknownError" -) - -func MapErrorToAgentInfoErrorType(err error) AgentInfoErrorType { - if os.IsPermission(err) { - return PermissionDenied - } - if err == ErrTooManyExtensions { - return TooManyExtensions - } - return UnknownError -} - -// AgentInfo holds information about an agent renderable in customer logs -type AgentInfo struct { - Name string - State string - Subscriptions []string - ErrorType string -} - -// FunctionMetadata holds static information regarding the function (Name, Version, Handler) -type FunctionMetadata struct { - AccountID string - FunctionName string - FunctionVersion string - InstanceMaxMemory uint64 - Handler string - RuntimeInfo interop.RuntimeInfo -} - -// RegistrationService keeps track of registered parties, including external agents, threads, and runtime. -type RegistrationService interface { - CreateExternalAgent(agentName string) (*ExternalAgent, error) - CreateInternalAgent(agentName string) (*InternalAgent, error) - PreregisterRuntime(r *Runtime) error - SetFunctionMetadata(metadata FunctionMetadata) - GetFunctionMetadata() FunctionMetadata - GetRuntime() *Runtime - GetRegisteredAgentsSize() uint16 - FindExternalAgentByName(agentName string) (*ExternalAgent, bool) - FindInternalAgentByName(agentName string) (*InternalAgent, bool) - FindExternalAgentByID(agentID uuid.UUID) (*ExternalAgent, bool) - FindInternalAgentByID(agentID uuid.UUID) (*InternalAgent, bool) - TurnOff() - InitFlow() InitFlowSynchronization - GetInternalStateDescriptor(appCtx appctx.ApplicationContext) func() statejson.InternalStateDescription - GetExternalAgents() []*ExternalAgent - GetSubscribedExternalAgents(eventType Event) []*ExternalAgent - GetSubscribedInternalAgents(eventType Event) []*InternalAgent - CountAgents() int - Clear() - AgentsInfo() []AgentInfo - CancelFlows(err error) -} - -type registrationServiceImpl struct { - runtime *Runtime - internalAgents InternalAgentsMap - externalAgents ExternalAgentsMap - state registrationServiceState - mutex *sync.Mutex - initFlow InitFlowSynchronization - invokeFlow InvokeFlowSynchronization - functionMetadata FunctionMetadata - cancelOnce sync.Once -} - -func (s *registrationServiceImpl) Clear() { - s.mutex.Lock() - defer s.mutex.Unlock() - - s.runtime = nil - s.internalAgents.Clear() - s.externalAgents.Clear() - s.state = registrationServiceOn - s.cancelOnce = sync.Once{} -} - -func (s *registrationServiceImpl) InitFlow() InitFlowSynchronization { - return s.initFlow -} - -// GetInternalStateDescriptor returns function that returns internal state description for debugging purposes -func (s *registrationServiceImpl) GetInternalStateDescriptor(appCtx appctx.ApplicationContext) func() statejson.InternalStateDescription { - return func() statejson.InternalStateDescription { - return s.getInternalStateDescription(appCtx) - } -} - -func (s *registrationServiceImpl) getInternalStateDescription(appCtx appctx.ApplicationContext) statejson.InternalStateDescription { - isd := statejson.InternalStateDescription{ - Extensions: []statejson.ExtensionDescription{}, - } - - if s.runtime != nil { - // we use pointer here so that 'runtime' json field is nil if runtime is not set (as opposed to filled with default values) - rtdesc := s.runtime.GetRuntimeDescription() - isd.Runtime = &rtdesc - } - - s.mutex.Lock() - defer s.mutex.Unlock() - - s.internalAgents.Visit(func(agent *InternalAgent) { - isd.Extensions = append(isd.Extensions, agent.GetAgentDescription()) - }) - - s.externalAgents.Visit(func(agent *ExternalAgent) { - isd.Extensions = append(isd.Extensions, agent.GetAgentDescription()) - }) - - if fatalerror, found := appctx.LoadFirstFatalError(appCtx); found { - isd.FirstFatalError = string(fatalerror) - } - - return isd -} - -func (s *registrationServiceImpl) CountAgents() int { - s.mutex.Lock() - defer s.mutex.Unlock() - - return s.countAgentsUnsafe() -} - -func (s *registrationServiceImpl) countAgentsUnsafe() int { - res := 0 - s.externalAgents.Visit(func(a *ExternalAgent) { - res++ - }) - s.internalAgents.Visit(func(a *InternalAgent) { - res++ - }) - return res -} - -func (s *registrationServiceImpl) GetExternalAgents() []*ExternalAgent { - agents := []*ExternalAgent{} - s.externalAgents.Visit(func(a *ExternalAgent) { - agents = append(agents, a) - }) - return agents -} - -func (s *registrationServiceImpl) GetInternalAgents() []*InternalAgent { - agents := []*InternalAgent{} - s.internalAgents.Visit(func(a *InternalAgent) { - agents = append(agents, a) - }) - return agents -} - -func (s *registrationServiceImpl) GetSubscribedExternalAgents(eventType Event) []*ExternalAgent { - agents := []*ExternalAgent{} - s.externalAgents.Visit(func(a *ExternalAgent) { - if a.IsSubscribed(eventType) { - agents = append(agents, a) - } - }) - return agents -} - -func (s *registrationServiceImpl) GetSubscribedInternalAgents(eventType Event) []*InternalAgent { - agents := []*InternalAgent{} - s.internalAgents.Visit(func(a *InternalAgent) { - if a.IsSubscribed(eventType) { - agents = append(agents, a) - } - }) - return agents -} - -// CreateExternalAgent creates agent in agent collection -func (s *registrationServiceImpl) CreateExternalAgent(agentName string) (*ExternalAgent, error) { - agent := NewExternalAgent(agentName, s.initFlow, s.invokeFlow) - - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state != registrationServiceOn { - return nil, ErrRegistrationServiceOff - } - - if _, found := s.internalAgents.FindByName(agentName); found { - return nil, ErrAgentNameCollision - } - - if err := s.externalAgents.Insert(agent); err != nil { - return nil, err - } - - return agent, nil -} - -// CreateInternalAgent creates agent in agent collection -func (s *registrationServiceImpl) CreateInternalAgent(agentName string) (*InternalAgent, error) { - agent := NewInternalAgent(agentName, s.initFlow, s.invokeFlow) - - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state != registrationServiceOn { - return nil, ErrRegistrationServiceOff - } - - if s.countAgentsUnsafe() >= MaxAgentsAllowed { - return nil, ErrTooManyExtensions - } - - if _, found := s.externalAgents.FindByName(agentName); found { - return nil, ErrAgentNameCollision - } - - if err := s.internalAgents.Insert(agent); err != nil { - return nil, err - } - - return agent, nil -} - -// PreregisterRuntime allows to preregister a runtime. -func (s *registrationServiceImpl) PreregisterRuntime(r *Runtime) error { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.state != registrationServiceOn { - return ErrRegistrationServiceOff - } - - s.runtime = r - - // RUNTIME IS NOT PART OF SUBSCRIPTIONS - - return nil -} - -// SetFunctionMetadata sets the static function metadata object -func (s *registrationServiceImpl) SetFunctionMetadata(metadata FunctionMetadata) { - s.functionMetadata = metadata -} - -// GetFunctionMetadata returns the static function metadata object -func (s *registrationServiceImpl) GetFunctionMetadata() FunctionMetadata { - return s.functionMetadata -} - -// GetRuntime retrieves runtime. -func (s *registrationServiceImpl) GetRuntime() *Runtime { - s.mutex.Lock() - defer s.mutex.Unlock() - return s.runtime -} - -// GetRegisteredAgentsSize retrieves the number of agents registered with the platform. -func (s *registrationServiceImpl) GetRegisteredAgentsSize() uint16 { - s.mutex.Lock() - defer s.mutex.Unlock() - return uint16(s.externalAgents.Size()) + uint16(s.internalAgents.Size()) -} - -// FindExternalAgentByName -func (s *registrationServiceImpl) FindExternalAgentByName(name string) (agent *ExternalAgent, found bool) { - s.mutex.Lock() - defer s.mutex.Unlock() - if agent, found = s.externalAgents.FindByName(name); found { - return - } - return -} - -// FindInternalAgentByName -func (s *registrationServiceImpl) FindInternalAgentByName(name string) (agent *InternalAgent, found bool) { - s.mutex.Lock() - defer s.mutex.Unlock() - if agent, found = s.internalAgents.FindByName(name); found { - return - } - return -} - -// FindExternalAgentByID -func (s *registrationServiceImpl) FindExternalAgentByID(agentID uuid.UUID) (agent *ExternalAgent, found bool) { - s.mutex.Lock() - defer s.mutex.Unlock() - if agent, found = s.externalAgents.FindByID(agentID); found { - return - } - return -} - -// FindInternalAgentByID -func (s *registrationServiceImpl) FindInternalAgentByID(agentID uuid.UUID) (agent *InternalAgent, found bool) { - s.mutex.Lock() - defer s.mutex.Unlock() - if agent, found = s.internalAgents.FindByID(agentID); found { - return - } - return -} - -// ReportAgentsInfo returns information about all agents -func (s *registrationServiceImpl) AgentsInfo() []AgentInfo { - s.mutex.Lock() - defer s.mutex.Unlock() - - agentsInfo := []AgentInfo{} - for _, agent := range s.GetExternalAgents() { - agentsInfo = append(agentsInfo, AgentInfo{ - agent.Name, - agent.GetState().Name(), - agent.SubscribedEvents(), - agent.ErrorType(), - }) - } - - for _, agent := range s.GetInternalAgents() { - agentsInfo = append(agentsInfo, AgentInfo{ - agent.Name, - agent.GetState().Name(), - agent.SubscribedEvents(), - agent.ErrorType(), - }) - } - - return agentsInfo -} - -// TurnOff turns off registration service. -func (s *registrationServiceImpl) TurnOff() { - s.mutex.Lock() - defer s.mutex.Unlock() - s.state = registrationServiceOff -} - -// CancelFlows cancels init and invoke flows with error. -func (s *registrationServiceImpl) CancelFlows(err error) { - s.mutex.Lock() - defer s.mutex.Unlock() - // The following block protects us from overwriting the error - // which was first used to cancel flows. - s.cancelOnce.Do(func() { - log.Debugf("Canceling flows: %s", err) - s.initFlow.CancelWithError(err) - s.invokeFlow.CancelWithError(err) - }) -} - -// NewRegistrationService returns new RegistrationService instance. -func NewRegistrationService(initFlow InitFlowSynchronization, invokeFlow InvokeFlowSynchronization) RegistrationService { - return ®istrationServiceImpl{ - mutex: &sync.Mutex{}, - state: registrationServiceOn, - internalAgents: NewInternalAgentsMap(), - externalAgents: NewExternalAgentsMap(), - initFlow: initFlow, - invokeFlow: invokeFlow, - cancelOnce: sync.Once{}, - } -} diff --git a/lambda/core/registrations_test.go b/lambda/core/registrations_test.go deleted file mode 100644 index 5956ac3..0000000 --- a/lambda/core/registrations_test.go +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestRegistrationServiceHappyPathDuringInit(t *testing.T) { - // Setup - initFlow, invokeFlow := NewInitFlowSynchronization(), NewInvokeFlowSynchronization() - registrationService := NewRegistrationService(initFlow, invokeFlow) - - registrationService.SetFunctionMetadata(FunctionMetadata{ - FunctionName: "AWS_LAMBDA_FUNCTION_NAME", - FunctionVersion: "AWS_LAMBDA_FUNCTION_VERSION", - Handler: "_HANDLER", - }) - - // Extension INIT (external) - extAgentNames := []string{"agentName1", "agentName2"} - assert.NoError(t, initFlow.SetExternalAgentsRegisterCount(uint16(len(extAgentNames)))) - - extAgent1, err := registrationService.CreateExternalAgent(extAgentNames[0]) - assert.NoError(t, err) - - extAgent2, err := registrationService.CreateExternalAgent(extAgentNames[1]) - assert.NoError(t, err) - - go func() { - for _, agentName := range extAgentNames { - agent, found := registrationService.FindExternalAgentByName(agentName) - assert.True(t, found) - - assert.NoError(t, agent.Register([]Event{InvokeEvent, ShutdownEvent})) - } - }() - - assert.NoError(t, initFlow.AwaitExternalAgentsRegistered()) - - // Runtime INIT (+ internal extensions) - runtime := NewRuntime(initFlow, invokeFlow) - assert.NoError(t, registrationService.PreregisterRuntime(runtime)) - - intAgentNames := []string{"intAgentName1", "intAgentName2"} - - intAgent1, err := registrationService.CreateInternalAgent(intAgentNames[0]) - assert.NoError(t, err) - - intAgent2, err := registrationService.CreateInternalAgent(intAgentNames[1]) - assert.NoError(t, err) - - go func() { - for _, agentName := range intAgentNames { - agent, found := registrationService.FindInternalAgentByName(agentName) - assert.True(t, found) - - assert.NoError(t, agent.Register([]Event{InvokeEvent})) - } - assert.NoError(t, runtime.Ready()) - }() - - assert.NoError(t, initFlow.AwaitRuntimeRestoreReady()) - registrationService.TurnOff() - - // Agents Ready - - assert.NoError(t, initFlow.SetAgentsReadyCount(registrationService.GetRegisteredAgentsSize())) - go func() { - for _, agentName := range intAgentNames { - agent, found := registrationService.FindInternalAgentByName(agentName) - assert.True(t, found) - go func() { assert.NoError(t, agent.Ready()) }() - } - - for _, agentName := range extAgentNames { - agent, found := registrationService.FindExternalAgentByName(agentName) - assert.True(t, found) - go func() { assert.NoError(t, agent.Ready()) }() - } - }() - - assert.NoError(t, initFlow.AwaitAgentsReady()) - - // Assertions - expectedAgents := []AgentInfo{ - AgentInfo{extAgent1.Name, "Ready", []string{"INVOKE", "SHUTDOWN"}, ""}, - AgentInfo{extAgent2.Name, "Ready", []string{"INVOKE", "SHUTDOWN"}, ""}, - AgentInfo{intAgent1.Name, "Ready", []string{"INVOKE"}, ""}, - AgentInfo{intAgent2.Name, "Ready", []string{"INVOKE"}, ""}, - } - - assert.Len(t, registrationService.AgentsInfo(), len(expectedAgents)) - - actualAgents := map[string]AgentInfo{} - for _, agentInfo := range registrationService.AgentsInfo() { - actualAgents[agentInfo.Name] = agentInfo - } - - for _, agentInfo := range expectedAgents { - assert.Contains(t, actualAgents, agentInfo.Name) - assert.Equal(t, actualAgents[agentInfo.Name].Name, agentInfo.Name) - assert.Equal(t, actualAgents[agentInfo.Name].State, agentInfo.State) - for _, event := range agentInfo.Subscriptions { - assert.Contains(t, actualAgents[agentInfo.Name].Subscriptions, event) - } - } -} diff --git a/lambda/core/runtime_state_names.go b/lambda/core/runtime_state_names.go deleted file mode 100644 index 4a2184d..0000000 --- a/lambda/core/runtime_state_names.go +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -// String values of possibles runtime states -const ( - RuntimeStartedStateName = "Started" - RuntimeInitErrorStateName = "InitError" - RuntimeReadyStateName = "Ready" - RuntimeRunningStateName = "Running" - // RuntimeStartedState -> RuntimeRestoreReadyState - RuntimeRestoreReadyStateName = "RestoreReady" - // RuntimeRestoreReadyState -> RuntimeRestoringState - RuntimeRestoringStateName = "Restoring" - RuntimeInvocationResponseStateName = "InvocationResponse" - RuntimeInvocationErrorResponseStateName = "InvocationErrorResponse" - RuntimeResponseSentStateName = "RuntimeResponseSentState" - RuntimeRestoreErrorStateName = "RuntimeRestoreErrorState" -) diff --git a/lambda/core/statejson/description.go b/lambda/core/statejson/description.go deleted file mode 100644 index a614d20..0000000 --- a/lambda/core/statejson/description.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package statejson - -import ( - "encoding/json" - - log "github.com/sirupsen/logrus" -) - -// ResponseMode are top-level constants used in combination with the various types of -// modes we have for responses, such as invoke's response mode and function's response mode. -// In the future we might have invoke's request mode or similar, so these help set the ground -// for consistency. -type ResponseMode string - -const ResponseModeBuffered = "Buffered" -const ResponseModeStreaming = "Streaming" - -type InvokeResponseMode string - -const InvokeResponseModeBuffered InvokeResponseMode = ResponseModeBuffered -const InvokeResponseModeStreaming InvokeResponseMode = ResponseModeStreaming - -// StateDescription ... -type StateDescription struct { - Name string `json:"name"` - LastModified int64 `json:"lastModified"` - ResponseTimeNs int64 `json:"responseTimeNs"` -} - -// RuntimeDescription ... -type RuntimeDescription struct { - State StateDescription `json:"state"` -} - -// ExtensionDescription ... -type ExtensionDescription struct { - Name string `json:"name"` - ID string - State StateDescription `json:"state"` - ErrorType string `json:"errorType"` -} - -// InternalStateDescription describes internal state of runtime and extensions for debugging purposes -type InternalStateDescription struct { - Runtime *RuntimeDescription `json:"runtime"` - Extensions []ExtensionDescription `json:"extensions"` - FirstFatalError string `json:"firstFatalError"` -} - -type ResponseMetricsDimensions struct { - InvokeResponseMode InvokeResponseMode `json:"invokeResponseMode"` -} - -type ResponseMetrics struct { - RuntimeResponseLatencyMs float64 `json:"runtimeResponseLatencyMs"` - Dimensions ResponseMetricsDimensions `json:"dimensions"` -} - -type ReleaseResponse struct { - *InternalStateDescription - ResponseMetrics ResponseMetrics `json:"responseMetrics"` -} - -// ResetDescription describes fields of the response to an INVOKE API request -type ResetDescription struct { - ExtensionsResetMs int64 `json:"extensionsResetMs"` - ResponseMetrics ResponseMetrics `json:"responseMetrics"` -} - -func (s *InternalStateDescription) AsJSON() []byte { - bytes, err := json.Marshal(s) - if err != nil { - log.Panicf("Failed to marshall internal states: %s", err) - } - return bytes -} - -func (s *ResetDescription) AsJSON() []byte { - bytes, err := json.Marshal(s) - if err != nil { - log.Panicf("Failed to marshall reset description: %s", err) - } - return bytes -} - -func (s *ReleaseResponse) AsJSON() []byte { - bytes, err := json.Marshal(s) - if err != nil { - log.Panicf("Failed to marshall release response: %s", err) - } - return bytes -} diff --git a/lambda/core/states.go b/lambda/core/states.go deleted file mode 100644 index 0de88ec..0000000 --- a/lambda/core/states.go +++ /dev/null @@ -1,461 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -import ( - "errors" - "sync" - "time" - - "go.amzn.com/lambda/core/statejson" - "go.amzn.com/lambda/interop" -) - -// Suspendable on operator condition. -type Suspendable interface { - SuspendUnsafe() - Release() - Lock() - Unlock() -} - -// ManagedThread is suspendable on operator condition. -type ManagedThread struct { - operatorCondition *sync.Cond - operatorConditionValue bool -} - -// SuspendUnsafe suspends ManagedThread on operator condition. This allows thread -// to be suspended and then resumed from the main thread. -// It's marked Unsafe because ManagedThread should be locked before SuspendUnsafe is called -func (s *ManagedThread) SuspendUnsafe() { - for !s.operatorConditionValue { - s.operatorCondition.Wait() - } - s.operatorConditionValue = false // reset back to false -} - -// Release releases operator condition. This allows thread -// to be suspended and then resumed from the main thread. -func (s *ManagedThread) Release() { - s.operatorCondition.L.Lock() - defer s.operatorCondition.L.Unlock() - s.operatorConditionValue = true - s.operatorCondition.Signal() -} - -// Lock ManagedThread condvar mutex -func (s *ManagedThread) Lock() { - s.operatorCondition.L.Lock() -} - -// Unlock ManagedThread condvar mutex -func (s *ManagedThread) Unlock() { - s.operatorCondition.L.Unlock() -} - -// NewManagedThread returns new ManagedThread instance. -func NewManagedThread() *ManagedThread { - return &ManagedThread{ - operatorCondition: sync.NewCond(&sync.Mutex{}), - operatorConditionValue: false, - } -} - -// ErrNotAllowed returned on illegal state transition -var ErrNotAllowed = errors.New("State transition is not allowed") - -// ErrConcurrentStateModification returned when we've detected an invalid state transision caused by concurrent modification -var ErrConcurrentStateModification = errors.New("Concurrent state modification") - -// RuntimeState is runtime state machine interface. -type RuntimeState interface { - InitError() error - Ready() error - RestoreReady() error - InvocationResponse() error - InvocationErrorResponse() error - ResponseSent() error - RestoreError(interop.FunctionError) error - Name() string -} - -type disallowEveryTransitionByDefault struct{} - -func (s *disallowEveryTransitionByDefault) InitError() error { return ErrNotAllowed } -func (s *disallowEveryTransitionByDefault) Ready() error { return ErrNotAllowed } -func (s *disallowEveryTransitionByDefault) RestoreReady() error { return ErrNotAllowed } -func (s *disallowEveryTransitionByDefault) InvocationResponse() error { return ErrNotAllowed } -func (s *disallowEveryTransitionByDefault) InvocationErrorResponse() error { return ErrNotAllowed } -func (s *disallowEveryTransitionByDefault) ResponseSent() error { return ErrNotAllowed } -func (s *disallowEveryTransitionByDefault) RestoreError(interop.FunctionError) error { - return ErrNotAllowed -} - -// Runtime is runtime object. -type Runtime struct { - ManagedThread Suspendable - - currentState RuntimeState - stateLastModified time.Time - responseTime time.Time - - RuntimeStartedState RuntimeState - RuntimeInitErrorState RuntimeState - RuntimeReadyState RuntimeState - RuntimeRunningState RuntimeState - RuntimeRestoreReadyState RuntimeState - RuntimeRestoringState RuntimeState - RuntimeInvocationResponseState RuntimeState - RuntimeInvocationErrorResponseState RuntimeState - RuntimeResponseSentState RuntimeState - RuntimeRestoreErrorState RuntimeState -} - -// Release ... -func (s *Runtime) Release() { - s.ManagedThread.Release() -} - -// SetState ... -func (s *Runtime) SetState(state RuntimeState) { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - s.setStateUnsafe(state) -} - -func (s *Runtime) setStateUnsafe(state RuntimeState) { - s.currentState = state - s.stateLastModified = time.Now() -} - -// GetState ... -func (s *Runtime) GetState() RuntimeState { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return s.currentState -} - -// Ready delegates to state implementation. -func (s *Runtime) Ready() error { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return s.currentState.Ready() -} - -func (s *Runtime) RestoreReady() error { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return s.currentState.RestoreReady() -} - -// InvocationResponse delegates to state implementation. -func (s *Runtime) InvocationResponse() error { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return s.currentState.InvocationResponse() -} - -// InvocationErrorResponse delegates to state implementation. -func (s *Runtime) InvocationErrorResponse() error { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return s.currentState.InvocationErrorResponse() -} - -// InitError delegates to state implementation. -func (s *Runtime) InitError() error { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return s.currentState.InitError() -} - -// ResponseSent delegates to state implementation. -func (s *Runtime) ResponseSent() error { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - err := s.currentState.ResponseSent() - if err == nil { - s.responseTime = time.Now() - } - return err -} - -func (s *Runtime) RestoreError(UserError interop.FunctionError) error { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - return s.currentState.RestoreError(UserError) -} - -// GetRuntimeDescription returns runtime description object for debugging purposes -func (s *Runtime) GetRuntimeDescription() statejson.RuntimeDescription { - s.ManagedThread.Lock() - defer s.ManagedThread.Unlock() - res := statejson.RuntimeDescription{ - State: statejson.StateDescription{ - Name: s.currentState.Name(), - LastModified: s.stateLastModified.UnixNano() / int64(time.Millisecond), - }, - } - if !s.responseTime.IsZero() { - res.State.ResponseTimeNs = s.responseTime.UnixNano() - } - return res -} - -// NewRuntime returns new Runtime instance. -func NewRuntime(initFlow InitFlowSynchronization, invokeFlow InvokeFlowSynchronization) *Runtime { - runtime := &Runtime{ - ManagedThread: NewManagedThread(), - } - - runtime.RuntimeStartedState = &RuntimeStartedState{runtime: runtime, initFlow: initFlow} - runtime.RuntimeInitErrorState = &RuntimeInitErrorState{runtime: runtime, initFlow: initFlow} - runtime.RuntimeReadyState = &RuntimeReadyState{runtime: runtime} - runtime.RuntimeRunningState = &RuntimeRunningState{runtime: runtime, invokeFlow: invokeFlow} - runtime.RuntimeInvocationResponseState = &RuntimeInvocationResponseState{runtime: runtime, invokeFlow: invokeFlow} - runtime.RuntimeInvocationErrorResponseState = &RuntimeInvocationErrorResponseState{runtime: runtime, invokeFlow: invokeFlow} - runtime.RuntimeResponseSentState = &RuntimeResponseSentState{runtime: runtime, invokeFlow: invokeFlow} - runtime.RuntimeRestoreReadyState = &RuntimeRestoreReadyState{} - runtime.RuntimeRestoringState = &RuntimeRestoringState{runtime: runtime, initFlow: initFlow} - runtime.RuntimeRestoreErrorState = &RuntimeRestoreErrorState{runtime: runtime, initFlow: initFlow} - - runtime.setStateUnsafe(runtime.RuntimeStartedState) - return runtime -} - -// RuntimeStartedState runtime started state. -type RuntimeStartedState struct { - disallowEveryTransitionByDefault - runtime *Runtime - initFlow InitFlowSynchronization -} - -// Ready call when runtime init done. -func (s *RuntimeStartedState) Ready() error { - s.runtime.setStateUnsafe(s.runtime.RuntimeReadyState) - // runtime called /next without calling /restore/next - // that means it's not interested in restore phase - err := s.initFlow.RuntimeRestoreReady() - if err != nil { - return err - } - - err = s.initFlow.RuntimeReady() - if err != nil { - return err - } - - s.runtime.ManagedThread.SuspendUnsafe() - if s.runtime.currentState != s.runtime.RuntimeReadyState && s.runtime.currentState != s.runtime.RuntimeRunningState { - return ErrConcurrentStateModification - } - - s.runtime.setStateUnsafe(s.runtime.RuntimeRunningState) - return nil -} - -func (s *RuntimeStartedState) RestoreReady() error { - s.runtime.setStateUnsafe(s.runtime.RuntimeRestoreReadyState) - err := s.initFlow.RuntimeRestoreReady() - if err != nil { - return err - } - - s.runtime.ManagedThread.SuspendUnsafe() - if s.runtime.currentState != s.runtime.RuntimeRestoreReadyState && s.runtime.currentState != s.runtime.RuntimeRestoringState { - return ErrConcurrentStateModification - } - - s.runtime.setStateUnsafe(s.runtime.RuntimeRestoringState) - return nil -} - -// InitError move runtime to init error state. -func (s *RuntimeStartedState) InitError() error { - s.runtime.setStateUnsafe(s.runtime.RuntimeInitErrorState) - return nil -} - -// Name ... -func (s *RuntimeStartedState) Name() string { - return RuntimeStartedStateName -} - -type RuntimeRestoringState struct { - disallowEveryTransitionByDefault - runtime *Runtime - initFlow InitFlowSynchronization -} - -// Runtime is healthy after restore and called /next -func (s *RuntimeRestoringState) Ready() error { - s.runtime.setStateUnsafe(s.runtime.RuntimeReadyState) - err := s.initFlow.RuntimeReady() - if err != nil { - return err - } - s.runtime.ManagedThread.SuspendUnsafe() - if s.runtime.currentState != s.runtime.RuntimeReadyState && s.runtime.currentState != s.runtime.RuntimeRunningState { - return ErrConcurrentStateModification - } - - s.runtime.setStateUnsafe(s.runtime.RuntimeRunningState) - return nil -} - -func (s *RuntimeRestoringState) RestoreError(userError interop.FunctionError) error { - s.runtime.setStateUnsafe(s.runtime.RuntimeRestoreErrorState) - s.initFlow.CancelWithError(interop.ErrRestoreHookUserError{UserError: userError}) - return nil -} - -func (s *RuntimeRestoringState) Name() string { - return RuntimeRestoringStateName -} - -// RuntimeInitErrorState runtime started state. -type RuntimeInitErrorState struct { - disallowEveryTransitionByDefault - runtime *Runtime - initFlow InitFlowSynchronization -} - -// Name ... -func (s *RuntimeInitErrorState) Name() string { - return RuntimeInitErrorStateName -} - -// RuntimeReadyState runtime ready state. -type RuntimeReadyState struct { - disallowEveryTransitionByDefault - runtime *Runtime -} - -func (s *RuntimeReadyState) Ready() error { - s.runtime.ManagedThread.SuspendUnsafe() - if s.runtime.currentState != s.runtime.RuntimeReadyState && s.runtime.currentState != s.runtime.RuntimeRunningState { - return ErrConcurrentStateModification - } - - s.runtime.setStateUnsafe(s.runtime.RuntimeRunningState) - return nil -} - -// Name ... -func (s *RuntimeReadyState) Name() string { - return RuntimeReadyStateName -} - -// RuntimeRunningState runtime ready state. -type RuntimeRunningState struct { - disallowEveryTransitionByDefault - runtime *Runtime - invokeFlow InvokeFlowSynchronization -} - -func (s *RuntimeRunningState) Ready() error { - return nil -} - -// InvocationResponse call when runtime response is available. -func (s *RuntimeRunningState) InvocationResponse() error { - s.runtime.setStateUnsafe(s.runtime.RuntimeInvocationResponseState) - return nil -} - -// InvocationErrorResponse call when runtime error response is available. -func (s *RuntimeRunningState) InvocationErrorResponse() error { - s.runtime.setStateUnsafe(s.runtime.RuntimeInvocationErrorResponseState) - return nil -} - -// Name ... -func (s *RuntimeRunningState) Name() string { - return RuntimeRunningStateName -} - -type RuntimeRestoreReadyState struct { - disallowEveryTransitionByDefault -} - -func (s *RuntimeRestoreReadyState) Name() string { - return RuntimeRestoreReadyStateName -} - -// RuntimeInvocationResponseState runtime response is available. -// Start state for runtime response submission. -type RuntimeInvocationResponseState struct { - disallowEveryTransitionByDefault - runtime *Runtime - invokeFlow InvokeFlowSynchronization -} - -// ResponseSent completes RuntimeInvocationResponseState. -func (s *RuntimeInvocationResponseState) ResponseSent() error { - s.runtime.setStateUnsafe(s.runtime.RuntimeResponseSentState) - return s.invokeFlow.RuntimeResponse(s.runtime) -} - -// Name ... -func (s *RuntimeInvocationResponseState) Name() string { - return RuntimeInvocationResponseStateName -} - -// RuntimeInvocationErrorResponseState runtime response is available. -// Start state for runtime error response submission. -type RuntimeInvocationErrorResponseState struct { - disallowEveryTransitionByDefault - runtime *Runtime - invokeFlow InvokeFlowSynchronization -} - -// ResponseSent completes RuntimeInvocationErrorResponseState. -func (s *RuntimeInvocationErrorResponseState) ResponseSent() error { - s.runtime.setStateUnsafe(s.runtime.RuntimeResponseSentState) - return s.invokeFlow.RuntimeResponse(s.runtime) -} - -// Name ... -func (s *RuntimeInvocationErrorResponseState) Name() string { - return RuntimeInvocationErrorResponseStateName -} - -// RuntimeResponseSentState ends started runtime response or runtime error response submission. -type RuntimeResponseSentState struct { - disallowEveryTransitionByDefault - runtime *Runtime - invokeFlow InvokeFlowSynchronization -} - -// Ready call when runtime ready. -func (s *RuntimeResponseSentState) Ready() error { - s.runtime.setStateUnsafe(s.runtime.RuntimeReadyState) - if err := s.invokeFlow.RuntimeReady(s.runtime); err != nil { - return err - } - - s.runtime.ManagedThread.SuspendUnsafe() - if s.runtime.currentState != s.runtime.RuntimeReadyState && s.runtime.currentState != s.runtime.RuntimeRunningState { - return ErrConcurrentStateModification - } - - s.runtime.setStateUnsafe(s.runtime.RuntimeRunningState) - return nil -} - -// Name ... -func (s *RuntimeResponseSentState) Name() string { - return RuntimeResponseSentStateName -} - -type RuntimeRestoreErrorState struct { - disallowEveryTransitionByDefault - runtime *Runtime - initFlow InitFlowSynchronization -} - -func (s *RuntimeRestoreErrorState) Name() string { - return RuntimeRestoreErrorStateName -} diff --git a/lambda/core/states_test.go b/lambda/core/states_test.go deleted file mode 100644 index b6d2955..0000000 --- a/lambda/core/states_test.go +++ /dev/null @@ -1,428 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package core - -import ( - "context" - "sync" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/testdata/mockthread" -) - -func TestRuntimeInitErrorAfterReady(t *testing.T) { - initFlow := &mockInitFlowSynchronization{} - initFlow.ReadyCond = sync.NewCond(&sync.Mutex{}) - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - - readyChan := make(chan struct{}) - runtime.SetState(runtime.RuntimeStartedState) - go func() { - assert.NoError(t, runtime.Ready()) - readyChan <- struct{}{} - }() - - initFlow.ReadyCond.L.Lock() - for !initFlow.ReadyCalled { - initFlow.ReadyCond.Wait() - } - initFlow.ReadyCond.L.Unlock() - assert.Equal(t, runtime.RuntimeReadyState, runtime.GetState()) - - assert.Equal(t, ErrNotAllowed, runtime.InitError()) - runtime.Release() - <-readyChan - assert.Equal(t, ErrNotAllowed, runtime.InitError()) - assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) -} - -func TestRuntimeStateTransitionsFromStartedState(t *testing.T) { - runtime := newRuntime() - // Started - assert.Equal(t, runtime.RuntimeStartedState, runtime.GetState()) - // Started -> InitError - runtime.SetState(runtime.RuntimeStartedState) - assert.NoError(t, runtime.InitError()) - assert.Equal(t, runtime.RuntimeInitErrorState, runtime.GetState()) - // Started -> Ready - runtime.SetState(runtime.RuntimeStartedState) - assert.NoError(t, runtime.Ready()) - assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) - // Started -> RestoreReady - runtime.SetState(runtime.RuntimeStartedState) - assert.NoError(t, runtime.RestoreReady()) - assert.Equal(t, runtime.RuntimeRestoringState, runtime.GetState()) - // Started -> ResponseSent - runtime.SetState(runtime.RuntimeStartedState) - assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) - assert.Equal(t, runtime.RuntimeStartedState, runtime.GetState()) - // Started -> InvocationResponse - runtime.SetState(runtime.RuntimeStartedState) - assert.Equal(t, ErrNotAllowed, runtime.InvocationResponse()) - assert.Equal(t, runtime.RuntimeStartedState, runtime.GetState()) - // Started -> InvocationErrorResponse - runtime.SetState(runtime.RuntimeStartedState) - assert.Equal(t, ErrNotAllowed, runtime.InvocationErrorResponse()) - assert.Equal(t, runtime.RuntimeStartedState, runtime.GetState()) -} - -func TestRuntimeStateTransitionsFromInitErrorState(t *testing.T) { - runtime := newRuntime() - // InitError -> InitError - runtime.SetState(runtime.RuntimeInitErrorState) - assert.Equal(t, ErrNotAllowed, runtime.InitError()) - assert.Equal(t, runtime.RuntimeInitErrorState, runtime.GetState()) - // InitError -> Ready - runtime.SetState(runtime.RuntimeInitErrorState) - assert.Equal(t, ErrNotAllowed, runtime.Ready()) - assert.Equal(t, runtime.RuntimeInitErrorState, runtime.GetState()) - // InitError -> RestoreReady - runtime.SetState(runtime.RuntimeInitErrorState) - assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) - assert.Equal(t, runtime.RuntimeInitErrorState, runtime.GetState()) - // InitError -> ResponseSent - runtime.SetState(runtime.RuntimeInitErrorState) - assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) - assert.Equal(t, runtime.RuntimeInitErrorState, runtime.GetState()) - // InitError -> InvocationResponse - runtime.SetState(runtime.RuntimeInitErrorState) - assert.Equal(t, ErrNotAllowed, runtime.InvocationResponse()) - assert.Equal(t, runtime.RuntimeInitErrorState, runtime.GetState()) - // InitError -> InvocationErrorResponse - runtime.SetState(runtime.RuntimeInitErrorState) - assert.Equal(t, ErrNotAllowed, runtime.InvocationErrorResponse()) - assert.Equal(t, runtime.RuntimeInitErrorState, runtime.GetState()) -} - -func TestRuntimeStateTransitionsFromRestoreErrorState(t *testing.T) { - runtime := newRuntime() - // RestoreError -> InitError - runtime.SetState(runtime.RuntimeRestoreErrorState) - assert.Equal(t, ErrNotAllowed, runtime.InitError()) - assert.Equal(t, runtime.RuntimeRestoreErrorState, runtime.GetState()) - // RestoreError -> Ready - runtime.SetState(runtime.RuntimeRestoreErrorState) - assert.Equal(t, ErrNotAllowed, runtime.Ready()) - assert.Equal(t, runtime.RuntimeRestoreErrorState, runtime.GetState()) - // RestoreError -> RestoreReady - runtime.SetState(runtime.RuntimeRestoreErrorState) - assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) - assert.Equal(t, runtime.RuntimeRestoreErrorState, runtime.GetState()) - // RestoreError -> ResponseSent - runtime.SetState(runtime.RuntimeRestoreErrorState) - assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) - assert.Equal(t, runtime.RuntimeRestoreErrorState, runtime.GetState()) - // RestoreError -> InvocationResponse - runtime.SetState(runtime.RuntimeRestoreErrorState) - assert.Equal(t, ErrNotAllowed, runtime.InvocationResponse()) - assert.Equal(t, runtime.RuntimeRestoreErrorState, runtime.GetState()) - // RestoreError -> InvocationErrorResponse - runtime.SetState(runtime.RuntimeRestoreErrorState) - assert.Equal(t, ErrNotAllowed, runtime.InvocationErrorResponse()) - assert.Equal(t, runtime.RuntimeRestoreErrorState, runtime.GetState()) -} - -func TestRuntimeStateTransitionsFromReadyState(t *testing.T) { - runtime := newRuntime() - // Ready -> InitError - runtime.SetState(runtime.RuntimeReadyState) - assert.Equal(t, ErrNotAllowed, runtime.InitError()) - assert.Equal(t, runtime.RuntimeReadyState, runtime.GetState()) - // Ready -> Ready - runtime.SetState(runtime.RuntimeReadyState) - assert.NoError(t, runtime.Ready()) - assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) - // Ready -> RestoreReady - runtime.SetState(runtime.RuntimeReadyState) - assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) - assert.Equal(t, runtime.RuntimeReadyState, runtime.GetState()) - // Ready -> ResponseSent - runtime.SetState(runtime.RuntimeReadyState) - assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) - assert.Equal(t, runtime.RuntimeReadyState, runtime.GetState()) - // Ready -> InvocationResponse - runtime.SetState(runtime.RuntimeReadyState) - assert.Equal(t, ErrNotAllowed, runtime.InvocationResponse()) - assert.Equal(t, runtime.RuntimeReadyState, runtime.GetState()) - // Ready -> InvocationErrorResponse - runtime.SetState(runtime.RuntimeReadyState) - assert.Equal(t, ErrNotAllowed, runtime.InvocationErrorResponse()) - assert.Equal(t, runtime.RuntimeReadyState, runtime.GetState()) -} - -func TestRuntimeStateTransitionsFromRunningState(t *testing.T) { - runtime := newRuntime() - // Running -> InitError - runtime.SetState(runtime.RuntimeRunningState) - assert.Equal(t, ErrNotAllowed, runtime.InitError()) - assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) - // Running -> Ready - runtime.SetState(runtime.RuntimeRunningState) - assert.NoError(t, runtime.Ready()) - assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) - // Running -> RestoreReady - runtime.SetState(runtime.RuntimeRunningState) - assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) - assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) - // Running -> ResponseSent - runtime.SetState(runtime.RuntimeRunningState) - assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) - assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) - // Running -> InvocationResponse - runtime.SetState(runtime.RuntimeRunningState) - assert.NoError(t, runtime.InvocationResponse()) - assert.Equal(t, runtime.RuntimeInvocationResponseState, runtime.GetState()) - // Running -> InvocationErrorResponse - runtime.SetState(runtime.RuntimeRunningState) - assert.NoError(t, runtime.InvocationErrorResponse()) - assert.Equal(t, runtime.RuntimeInvocationErrorResponseState, runtime.GetState()) -} - -func TestRuntimeStateTransitionsFromInvocationResponseState(t *testing.T) { - runtime := newRuntime() - // InvocationResponse -> InitError - runtime.SetState(runtime.RuntimeInvocationResponseState) - assert.Equal(t, ErrNotAllowed, runtime.InitError()) - assert.Equal(t, runtime.RuntimeInvocationResponseState, runtime.GetState()) - // InvocationResponse -> Ready - runtime.SetState(runtime.RuntimeInvocationResponseState) - assert.Equal(t, ErrNotAllowed, runtime.Ready()) - assert.Equal(t, runtime.RuntimeInvocationResponseState, runtime.GetState()) - // InvocationResponse -> RestoreReady - runtime.SetState(runtime.RuntimeInvocationResponseState) - assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) - assert.Equal(t, runtime.RuntimeInvocationResponseState, runtime.GetState()) - // InvocationResponse -> ResponseSent - runtime.SetState(runtime.RuntimeInvocationResponseState) - assert.NoError(t, runtime.ResponseSent()) - assert.Equal(t, runtime.RuntimeResponseSentState, runtime.GetState()) - assert.NotEqual(t, 0, runtime.GetRuntimeDescription().State.ResponseTimeNs) - // InvocationResponse-> InvocationResponse - runtime.SetState(runtime.RuntimeInvocationResponseState) - assert.Equal(t, ErrNotAllowed, runtime.InvocationResponse()) - assert.Equal(t, runtime.RuntimeInvocationResponseState, runtime.GetState()) - // InvocationResponse -> InvocationErrorResponse - runtime.SetState(runtime.RuntimeInvocationResponseState) - assert.Equal(t, ErrNotAllowed, runtime.InvocationErrorResponse()) - assert.Equal(t, runtime.RuntimeInvocationResponseState, runtime.GetState()) -} - -func TestRuntimeStateTransitionsFromInvocationErrorResponseState(t *testing.T) { - runtime := newRuntime() - // InvocationErrorResponse -> InitError - runtime.SetState(runtime.RuntimeInvocationErrorResponseState) - assert.Equal(t, ErrNotAllowed, runtime.InitError()) - assert.Equal(t, runtime.RuntimeInvocationErrorResponseState, runtime.GetState()) - // InvocationErrorResponse -> Ready - runtime.SetState(runtime.RuntimeInvocationErrorResponseState) - assert.Equal(t, ErrNotAllowed, runtime.Ready()) - assert.Equal(t, runtime.RuntimeInvocationErrorResponseState, runtime.GetState()) - // InvocationErrorResponse -> RestoreReady - runtime.SetState(runtime.RuntimeInvocationErrorResponseState) - assert.Equal(t, ErrNotAllowed, runtime.Ready()) - assert.Equal(t, runtime.RuntimeInvocationErrorResponseState, runtime.GetState()) - // InvocationErrorResponse -> ResponseSent - runtime.SetState(runtime.RuntimeInvocationErrorResponseState) - assert.NoError(t, runtime.ResponseSent()) - assert.Equal(t, runtime.RuntimeResponseSentState, runtime.GetState()) - // InvocationErrorResponse -> InvocationResponse - runtime.SetState(runtime.RuntimeInvocationErrorResponseState) - assert.Equal(t, ErrNotAllowed, runtime.InvocationResponse()) - assert.Equal(t, runtime.RuntimeInvocationErrorResponseState, runtime.GetState()) - // InvocationErrorResponse -> InvocationErrorResponse - runtime.SetState(runtime.RuntimeInvocationErrorResponseState) - assert.Equal(t, ErrNotAllowed, runtime.InvocationErrorResponse()) - assert.Equal(t, runtime.RuntimeInvocationErrorResponseState, runtime.GetState()) -} - -func TestRuntimeStateTransitionsFromResponseSentState(t *testing.T) { - runtime := newRuntime() - // ResponseSent -> InitError - runtime.SetState(runtime.RuntimeResponseSentState) - assert.Equal(t, ErrNotAllowed, runtime.InitError()) - assert.Equal(t, runtime.RuntimeResponseSentState, runtime.GetState()) - // ResponseSent -> Ready - runtime.SetState(runtime.RuntimeResponseSentState) - assert.NoError(t, runtime.Ready()) - assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) - // ResponseSent -> RestoreReady - runtime.SetState(runtime.RuntimeResponseSentState) - assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) - assert.Equal(t, runtime.RuntimeResponseSentState, runtime.GetState()) - // ResponseSent -> ResponseSent - runtime.SetState(runtime.RuntimeResponseSentState) - assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) - assert.Equal(t, runtime.RuntimeResponseSentState, runtime.GetState()) - // ResponseSent -> InvocationResponse - runtime.SetState(runtime.RuntimeResponseSentState) - assert.Equal(t, ErrNotAllowed, runtime.InvocationResponse()) - assert.Equal(t, runtime.RuntimeResponseSentState, runtime.GetState()) - // ResponseSent -> InvocationErrorResponse - runtime.SetState(runtime.RuntimeResponseSentState) - assert.Equal(t, ErrNotAllowed, runtime.InvocationErrorResponse()) - assert.Equal(t, runtime.RuntimeResponseSentState, runtime.GetState()) -} - -func TestRuntimeStateTransitionsFromRestoreReadyState(t *testing.T) { - runtime := newRuntime() - // RestoreReady -> InitError - runtime.SetState(runtime.RuntimeRestoreReadyState) - assert.Equal(t, ErrNotAllowed, runtime.InitError()) - assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) - // RestoreReady -> Ready - runtime.SetState(runtime.RuntimeRestoreReadyState) - assert.Equal(t, ErrNotAllowed, runtime.Ready()) - assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) - // RestoreReady -> RestoreReady() - runtime.SetState(runtime.RuntimeRestoreReadyState) - assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) - assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) - // RestoreReady -> ResponseSent - runtime.SetState(runtime.RuntimeRestoreReadyState) - assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) - assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) - // RestoreReady -> InvocationResponse - runtime.SetState(runtime.RuntimeRestoreReadyState) - assert.Equal(t, ErrNotAllowed, runtime.InvocationResponse()) - assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) - // RestoreReady -> InvocationErrorResponse - runtime.SetState(runtime.RuntimeRestoreReadyState) - assert.Equal(t, ErrNotAllowed, runtime.InvocationErrorResponse()) - assert.Equal(t, runtime.RuntimeRestoreReadyState, runtime.GetState()) -} - -func TestRuntimeStateTransitionsFromRestoringState(t *testing.T) { - runtime, mockInitFlow, _ := newRuntimeGetMockFlows() - runtime.SetState(runtime.RuntimeRestoringState) - mockInitFlow.On("CancelWithError", interop.ErrRestoreHookUserError{UserError: interop.FunctionError{}}).Return() - // RestoreRunning -> Ready - runtime.SetState(runtime.RuntimeRestoringState) - assert.NoError(t, runtime.Ready()) - assert.Equal(t, runtime.RuntimeRunningState, runtime.GetState()) - // RestoreRunning -> RestoreReady - runtime.SetState(runtime.RuntimeRestoringState) - assert.Equal(t, ErrNotAllowed, runtime.RestoreReady()) - assert.Equal(t, runtime.RuntimeRestoringState, runtime.GetState()) - // RestoreRunning -> ResponseSent - runtime.SetState(runtime.RuntimeRestoringState) - assert.Equal(t, ErrNotAllowed, runtime.ResponseSent()) - assert.Equal(t, runtime.RuntimeRestoringState, runtime.GetState()) - // RestoreRunning -> InvocationResponse - runtime.SetState(runtime.RuntimeRestoringState) - assert.Equal(t, ErrNotAllowed, runtime.InvocationResponse()) - assert.Equal(t, runtime.RuntimeRestoringState, runtime.GetState()) - // RestoreRunning -> InvocationErrorResponse - runtime.SetState(runtime.RuntimeRestoringState) - assert.Equal(t, ErrNotAllowed, runtime.InvocationErrorResponse()) - assert.Equal(t, runtime.RuntimeRestoringState, runtime.GetState()) - // RestoreRunning -> RestoreError - runtime.SetState(runtime.RuntimeRestoringState) - assert.NoError(t, runtime.RestoreError(interop.FunctionError{})) - assert.Equal(t, runtime.RuntimeRestoreErrorState, runtime.GetState()) -} - -func newRuntime() *Runtime { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} - - return runtime -} - -func newRuntimeGetMockFlows() (*Runtime, *mockInitFlowSynchronization, *mockInvokeFlowSynchronization) { - initFlow := &mockInitFlowSynchronization{} - invokeFlow := &mockInvokeFlowSynchronization{} - runtime := NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} - - return runtime, initFlow, invokeFlow -} - -type mockInitFlowSynchronization struct { - mock.Mock - ReadyCond *sync.Cond - ReadyCalled bool -} - -func (s *mockInitFlowSynchronization) SetExternalAgentsRegisterCount(agentCount uint16) error { - return nil -} - -func (s *mockInitFlowSynchronization) SetAgentsReadyCount(agentCount uint16) error { - return nil -} - -func (s *mockInitFlowSynchronization) AwaitExternalAgentsRegistered() error { - return nil -} -func (s *mockInitFlowSynchronization) ExternalAgentRegistered() error { - return nil -} -func (s *mockInitFlowSynchronization) AwaitRuntimeReady() error { - return nil -} -func (s *mockInitFlowSynchronization) AwaitRuntimeReadyWithDeadline(ctx context.Context) error { - return nil -} -func (s *mockInitFlowSynchronization) AwaitAgentsReady() error { - return nil -} -func (s *mockInitFlowSynchronization) RuntimeReady() error { - if s.ReadyCond != nil { - s.ReadyCond.L.Lock() - defer s.ReadyCond.L.Unlock() - s.ReadyCalled = true - s.ReadyCond.Signal() - } - return nil -} -func (s *mockInitFlowSynchronization) AgentReady() error { - return nil -} -func (s *mockInitFlowSynchronization) CancelWithError(err error) { - s.Called(err) -} -func (s *mockInitFlowSynchronization) Clear() {} -func (s *mockInitFlowSynchronization) RuntimeRestoreReady() error { - return nil -} -func (s *mockInitFlowSynchronization) AwaitRuntimeRestoreReady() error { - return nil -} - -type mockInvokeFlowSynchronization struct{ mock.Mock } - -func (s *mockInvokeFlowSynchronization) InitializeBarriers() error { - return nil -} -func (s *mockInvokeFlowSynchronization) AwaitRuntimeResponse() error { - return nil -} -func (s *mockInvokeFlowSynchronization) AwaitRuntimeReady() error { - return nil -} -func (s *mockInvokeFlowSynchronization) RuntimeResponse(runtime *Runtime) error { - return nil -} -func (s *mockInvokeFlowSynchronization) RuntimeReady(runtime *Runtime) error { - return nil -} -func (s *mockInvokeFlowSynchronization) SetAgentsReadyCount(agentCount uint16) error { - return nil -} -func (s *mockInvokeFlowSynchronization) AwaitAgentsReady() error { - return nil -} -func (s *mockInvokeFlowSynchronization) AgentReady() error { - return nil -} -func (s *mockInvokeFlowSynchronization) CancelWithError(err error) { - s.Called(err) -} -func (s *mockInvokeFlowSynchronization) Clear() {} diff --git a/lambda/extensions/extensions.go b/lambda/extensions/extensions.go deleted file mode 100644 index abe0c87..0000000 --- a/lambda/extensions/extensions.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package extensions - -import ( - "os" - "sync/atomic" - - log "github.com/sirupsen/logrus" -) - -const ( - disableExtensionsFile = "/opt/disable-extensions-jwigqn8j" -) - -var enabled atomic.Value - -// Enable or disable extensions -func Enable() { - enabled.Store(true) -} - -func Disable() { - enabled.Store(false) -} - -// AreEnabled returns true if extensions are enabled, false otherwise -// If it was never set defaults to false -func AreEnabled() bool { - val := enabled.Load() - if nil == val { - return false - } - return val.(bool) -} - -func DisableViaMagicLayer() { - _, err := os.Stat(disableExtensionsFile) - if err == nil { - log.Infof("Extensions disabled by attached layer (%s)", disableExtensionsFile) - Disable() - } -} diff --git a/lambda/fatalerror/fatalerror.go b/lambda/fatalerror/fatalerror.go deleted file mode 100644 index 665627d..0000000 --- a/lambda/fatalerror/fatalerror.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package fatalerror - -import ( - "regexp" - "strings" -) - -// This package defines constant error types returned to slicer with DONE(failure), and also sandbox errors -// Separate package for namespacing - -// ErrorType is returned to slicer inside DONE -type ErrorType string - -// TODO: Find another name than "fatalerror" -// TODO: Rename all const so that they always begin with Agent/Runtime/Sandbox/Function -// TODO: Add filtering for extensions as well -const ( - // Extension errors - AgentInitError ErrorType = "Extension.InitError" // agent exited after calling /extension/init/error - AgentExitError ErrorType = "Extension.ExitError" // agent exited after calling /extension/exit/error - AgentCrash ErrorType = "Extension.Crash" // agent crashed unexpectedly - AgentLaunchError ErrorType = "Extension.LaunchError" // agent could not be launched - - // Runtime errors - RuntimeExit ErrorType = "Runtime.ExitError" - InvalidEntrypoint ErrorType = "Runtime.InvalidEntrypoint" - InvalidWorkingDir ErrorType = "Runtime.InvalidWorkingDir" - InvalidTaskConfig ErrorType = "Runtime.InvalidTaskConfig" - TruncatedResponse ErrorType = "Runtime.TruncatedResponse" - RuntimeInvalidResponseModeHeader ErrorType = "Runtime.InvalidResponseModeHeader" - RuntimeUnknown ErrorType = "Runtime.Unknown" - - // Function errors - FunctionOversizedResponse ErrorType = "Function.ResponseSizeTooLarge" - FunctionUnknown ErrorType = "Function.Unknown" - - // Sandbox errors - SandboxFailure ErrorType = "Sandbox.Failure" - SandboxTimeout ErrorType = "Sandbox.Timeout" -) - -var validRuntimeAndFunctionErrors = map[ErrorType]struct{}{ - // Runtime errors - RuntimeExit: {}, - InvalidEntrypoint: {}, - InvalidWorkingDir: {}, - InvalidTaskConfig: {}, - TruncatedResponse: {}, - RuntimeInvalidResponseModeHeader: {}, - RuntimeUnknown: {}, - - // Function errors - FunctionOversizedResponse: {}, - FunctionUnknown: {}, -} - -func GetValidRuntimeOrFunctionErrorType(errorType string) ErrorType { - match, _ := regexp.MatchString("(Runtime|Function)\\.[A-Z][a-zA-Z]+", errorType) - if match { - return ErrorType(errorType) - } - - if strings.HasPrefix(errorType, "Function.") { - return FunctionUnknown - } - - return RuntimeUnknown -} diff --git a/lambda/fatalerror/fatalerror_test.go b/lambda/fatalerror/fatalerror_test.go deleted file mode 100644 index 72c34aa..0000000 --- a/lambda/fatalerror/fatalerror_test.go +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package fatalerror - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestValidRuntimeAndFunctionErrors(t *testing.T) { - type test struct { - input string - expected ErrorType - } - - var tests = []test{} - for validError := range validRuntimeAndFunctionErrors { - tests = append(tests, test{input: string(validError), expected: validError}) - } - - for _, tt := range tests { - t.Run(tt.input, func(t *testing.T) { - assert.Equal(t, GetValidRuntimeOrFunctionErrorType(tt.input), tt.expected) - }) - } -} - -func TestGetValidRuntimeOrFunctionErrorType(t *testing.T) { - type test struct { - input string - expected ErrorType - } - - var tests = []test{ - {"", RuntimeUnknown}, - {"MyCustomError", RuntimeUnknown}, - {"MyCustomError.Error", RuntimeUnknown}, - {"Runtime.MyCustomErrorTypeHere", ErrorType("Runtime.MyCustomErrorTypeHere")}, - {"Function.MyCustomErrorTypeHere", ErrorType("Function.MyCustomErrorTypeHere")}, - } - - for _, tt := range tests { - testname := fmt.Sprintf("TestGetValidRuntimeOrFunctionErrorType with %s", tt.input) - t.Run(testname, func(t *testing.T) { - assert.Equal(t, GetValidRuntimeOrFunctionErrorType(tt.input), tt.expected) - }) - } -} diff --git a/lambda/interop/bootstrap.go b/lambda/interop/bootstrap.go deleted file mode 100644 index d3f4500..0000000 --- a/lambda/interop/bootstrap.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package interop - -import ( - "os" - - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/rapidcore/env" -) - -type Bootstrap interface { - Cmd() ([]string, error) // returns the args of bootstrap, where args[0] is the path to executable - Env(e *env.Environment) map[string]string // returns the environment variables to be passed to the bootstrapped process - Cwd() (string, error) // returns the working directory of the bootstrap process - ExtraFiles() []*os.File // returns the extra file descriptors apart from 1 & 2 to be passed to runtime - CachedFatalError(err error) (fatalerror.ErrorType, string, bool) -} diff --git a/lambda/interop/cancellable_request.go b/lambda/interop/cancellable_request.go deleted file mode 100644 index 7e8fca5..0000000 --- a/lambda/interop/cancellable_request.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package interop - -import ( - "net" - "net/http" -) - -type key int - -const ( - HTTPConnKey key = iota -) - -func GetConn(r *http.Request) net.Conn { - return r.Context().Value(HTTPConnKey).(net.Conn) -} - -type CancellableRequest struct { - Request *http.Request -} - -func (c *CancellableRequest) Cancel() error { - return GetConn(c.Request).Close() -} diff --git a/lambda/interop/events_api.go b/lambda/interop/events_api.go deleted file mode 100644 index a0e9967..0000000 --- a/lambda/interop/events_api.go +++ /dev/null @@ -1,193 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package interop - -import ( - "fmt" - - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/rapi/model" -) - -type InitPhase string - -// InitializationType describes possible types of INIT phase -type InitType string - -type InitStartData struct { - InitializationType InitType `json:"initializationType"` - RuntimeVersion string `json:"runtimeVersion"` - RuntimeVersionArn string `json:"runtimeVersionArn"` - FunctionName string `json:"functionName"` - FunctionArn string `json:"functionArn"` - FunctionVersion string `json:"functionVersion"` - InstanceID string `json:"instanceId"` - InstanceMaxMemory uint64 `json:"instanceMaxMemory"` - Phase InitPhase `json:"phase"` - Tracing *TracingCtx `json:"tracing,omitempty"` -} - -func (d *InitStartData) String() string { - return fmt.Sprintf("INIT START(type: %s, phase: %s)", d.InitializationType, d.Phase) -} - -type InitRuntimeDoneData struct { - InitializationType InitType `json:"initializationType"` - Status string `json:"status"` - Phase InitPhase `json:"phase"` - ErrorType *string `json:"errorType,omitempty"` - Tracing *TracingCtx `json:"tracing,omitempty"` -} - -func (d *InitRuntimeDoneData) String() string { - return fmt.Sprintf("INIT RTDONE(status: %s)", d.Status) -} - -type InitReportMetrics struct { - DurationMs float64 `json:"durationMs"` -} - -type InitReportData struct { - InitializationType InitType `json:"initializationType"` - Metrics InitReportMetrics `json:"metrics"` - Phase InitPhase `json:"phase"` - Tracing *TracingCtx `json:"tracing,omitempty"` -} - -func (d *InitReportData) String() string { - return fmt.Sprintf("INIT REPORT(durationMs: %f)", d.Metrics.DurationMs) -} - -type RestoreRuntimeDoneData struct { - Status string `json:"status"` - ErrorType *string `json:"errorType,omitempty"` - Tracing *TracingCtx `json:"tracing,omitempty"` -} - -func (d *RestoreRuntimeDoneData) String() string { - return fmt.Sprintf("RESTORE RTDONE(status: %s)", d.Status) -} - -type TracingCtx struct { - SpanID string `json:"spanId,omitempty"` - Type model.TracingType `json:"type"` - Value string `json:"value"` -} - -type InvokeStartData struct { - RequestID string `json:"requestId"` - Version string `json:"version,omitempty"` - Tracing *TracingCtx `json:"tracing,omitempty"` -} - -func (d *InvokeStartData) String() string { - return fmt.Sprintf("INVOKE START(requestId: %s)", d.RequestID) -} - -type RuntimeDoneInvokeMetrics struct { - ProducedBytes int64 `json:"producedBytes"` - DurationMs float64 `json:"durationMs"` -} - -type Span struct { - Name string `json:"name"` - Start string `json:"start"` - DurationMs float64 `json:"durationMs"` -} - -func (s *Span) String() string { - return fmt.Sprintf("SPAN(name: %s)", s.Name) -} - -type InvokeRuntimeDoneData struct { - RequestID RequestID `json:"requestId"` - Status string `json:"status"` - Metrics *RuntimeDoneInvokeMetrics `json:"metrics,omitempty"` - Tracing *TracingCtx `json:"tracing,omitempty"` - Spans []Span `json:"spans,omitempty"` - ErrorType *string `json:"errorType,omitempty"` - InternalMetrics *InvokeResponseMetrics `json:"-"` -} - -func (d *InvokeRuntimeDoneData) String() string { - return fmt.Sprintf("INVOKE RTDONE(status: %s, produced bytes: %d, duration: %fms)", d.Status, d.Metrics.ProducedBytes, d.Metrics.DurationMs) -} - -type ExtensionInitData struct { - AgentName string `json:"name"` - State string `json:"state"` - Subscriptions []string `json:"events"` - ErrorType string `json:"errorType,omitempty"` -} - -func (d *ExtensionInitData) String() string { - return fmt.Sprintf("EXTENSION INIT(agent name: %s, state: %s, error type: %s)", d.AgentName, d.State, d.ErrorType) -} - -type ReportMetrics struct { - DurationMs float64 `json:"durationMs"` - BilledDurationMs float64 `json:"billedDurationMs"` - MemorySizeMB uint64 `json:"memorySizeMB"` - MaxMemoryUsedMB uint64 `json:"maxMemoryUsedMB"` - InitDurationMs float64 `json:"initDurationMs,omitempty"` -} - -type ReportData struct { - RequestID RequestID `json:"requestId"` - Status string `json:"status"` - Metrics ReportMetrics `json:"metrics"` - Tracing *TracingCtx `json:"tracing,omitempty"` - Spans []Span `json:"spans,omitempty"` - ErrorType *string `json:"errorType,omitempty"` -} - -func (d *ReportData) String() string { - return fmt.Sprintf("REPORT(status: %s, durationMs: %f)", d.Status, d.Metrics.DurationMs) -} - -type EndData struct { - RequestID RequestID `json:"requestId"` -} - -func (d *EndData) String() string { - return "END" -} - -type RequestID string - -type FaultData struct { - RequestID RequestID - ErrorMessage error - ErrorType fatalerror.ErrorType -} - -func (d *FaultData) String() string { - return fmt.Sprintf("RequestId: %s Error: %s\n%s\n", d.RequestID, d.ErrorMessage, d.ErrorType) -} - -type ImageErrorLogData string - -type EventsAPI interface { - SetCurrentRequestID(RequestID) - SendInitStart(InitStartData) error - SendInitRuntimeDone(InitRuntimeDoneData) error - SendInitReport(InitReportData) error - SendRestoreRuntimeDone(RestoreRuntimeDoneData) error - SendInvokeStart(InvokeStartData) error - SendInvokeRuntimeDone(InvokeRuntimeDoneData) error - SendExtensionInit(ExtensionInitData) error - SendReportSpan(Span) error - SendReport(ReportData) error - SendEnd(EndData) error - SendFault(FaultData) error - SendImageErrorLog(ImageErrorLogData) - - FetchTailLogs(string) (string, error) - GetRuntimeDoneSpans( - runtimeStartedTime int64, - invokeResponseMetrics *InvokeResponseMetrics, - runtimeOverheadStartedTime int64, - runtimeReadyTime int64, - ) []Span -} diff --git a/lambda/interop/events_api_test.go b/lambda/interop/events_api_test.go deleted file mode 100644 index d3a7dc1..0000000 --- a/lambda/interop/events_api_test.go +++ /dev/null @@ -1,656 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package interop - -import ( - "encoding/json" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.amzn.com/lambda/rapi/model" -) - -const requestID RequestID = "REQUEST_ID" - -func TestJsonMarshalInvokeRuntimeDone(t *testing.T) { - data := InvokeRuntimeDoneData{ - RequestID: requestID, - Status: "success", - Metrics: &RuntimeDoneInvokeMetrics{ - ProducedBytes: int64(100), - DurationMs: float64(52.56), - }, - Spans: []Span{ - { - Name: "responseLatency", - Start: "2022-04-11T15:01:28.543Z", - DurationMs: float64(23.02), - }, - { - Name: "responseDuration", - Start: "2022-04-11T15:00:00.000Z", - DurationMs: float64(20), - }, - }, - Tracing: &TracingCtx{ - SpanID: "spanid", - Type: model.XRayTracingType, - Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", - }, - } - - expected := ` - { - "requestId": "REQUEST_ID", - "status": "success", - "tracing": { - "spanId": "spanid", - "type": "X-Amzn-Trace-Id", - "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" - }, - "spans": [ - { - "name": "responseLatency", - "start": "2022-04-11T15:01:28.543Z", - "durationMs": 23.02 - }, - { - "name": "responseDuration", - "start": "2022-04-11T15:00:00.000Z", - "durationMs": 20 - } - ], - "metrics": { - "producedBytes": 100, - "durationMs": 52.56 - } - } - ` - - actual, err := json.Marshal(data) - assert.NoError(t, err) - assert.JSONEq(t, expected, string(actual)) -} - -func TestJsonMarshalInvokeRuntimeDoneNoTracing(t *testing.T) { - data := InvokeRuntimeDoneData{ - RequestID: requestID, - Status: "success", - Metrics: &RuntimeDoneInvokeMetrics{ - ProducedBytes: int64(100), - DurationMs: float64(52.56), - }, - Spans: []Span{ - { - Name: "responseLatency", - Start: "2022-04-11T15:01:28.543Z", - DurationMs: float64(23.02), - }, - { - Name: "responseDuration", - Start: "2022-04-11T15:00:00.000Z", - DurationMs: float64(20), - }, - }, - } - - expected := ` - { - "requestId": "REQUEST_ID", - "status": "success", - "spans": [ - { - "name": "responseLatency", - "start": "2022-04-11T15:01:28.543Z", - "durationMs": 23.02 - }, - { - "name": "responseDuration", - "start": "2022-04-11T15:00:00.000Z", - "durationMs": 20 - } - ], - "metrics": { - "producedBytes": 100, - "durationMs": 52.56 - } - } - ` - - actual, err := json.Marshal(data) - assert.NoError(t, err) - assert.JSONEq(t, expected, string(actual)) -} - -func TestJsonMarshalInvokeRuntimeDoneNoMetrics(t *testing.T) { - data := InvokeRuntimeDoneData{ - RequestID: requestID, - Status: "success", - Spans: []Span{ - { - Name: "responseLatency", - Start: "2022-04-11T15:01:28.543Z", - DurationMs: float64(23.02), - }, - { - Name: "responseDuration", - Start: "2022-04-11T15:00:00.000Z", - DurationMs: float64(20), - }, - }, - Tracing: &TracingCtx{ - SpanID: "spanid", - Type: model.XRayTracingType, - Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", - }, - } - - expected := ` - { - "requestId": "REQUEST_ID", - "status": "success", - "tracing": { - "spanId": "spanid", - "type": "X-Amzn-Trace-Id", - "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" - }, - "spans": [ - { - "name": "responseLatency", - "start": "2022-04-11T15:01:28.543Z", - "durationMs": 23.02 - }, - { - "name": "responseDuration", - "start": "2022-04-11T15:00:00.000Z", - "durationMs": 20 - } - ] - } - ` - - actual, err := json.Marshal(data) - assert.NoError(t, err) - assert.JSONEq(t, expected, string(actual)) -} - -func TestJsonMarshalInvokeRuntimeDoneWithProducedBytesEqualToZero(t *testing.T) { - data := InvokeRuntimeDoneData{ - RequestID: requestID, - Status: "success", - Metrics: &RuntimeDoneInvokeMetrics{ - ProducedBytes: int64(0), - DurationMs: float64(52.56), - }, - Spans: []Span{ - { - Name: "responseLatency", - Start: "2022-04-11T15:01:28.543Z", - DurationMs: float64(23.02), - }, - { - Name: "responseDuration", - Start: "2022-04-11T15:00:00.000Z", - DurationMs: float64(20), - }, - }, - Tracing: &TracingCtx{ - SpanID: "spanid", - Type: model.XRayTracingType, - Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", - }, - } - - expected := ` - { - "requestId": "REQUEST_ID", - "status": "success", - "tracing": { - "spanId": "spanid", - "type": "X-Amzn-Trace-Id", - "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" - }, - "spans": [ - { - "name": "responseLatency", - "start": "2022-04-11T15:01:28.543Z", - "durationMs": 23.02 - }, - { - "name": "responseDuration", - "start": "2022-04-11T15:00:00.000Z", - "durationMs": 20 - } - ], - "metrics": { - "producedBytes": 0, - "durationMs": 52.56 - } - } - ` - - actual, err := json.Marshal(data) - assert.NoError(t, err) - assert.JSONEq(t, expected, string(actual)) -} - -func TestJsonMarshalInvokeRuntimeDoneWithNoSpans(t *testing.T) { - data := InvokeRuntimeDoneData{ - RequestID: requestID, - Status: "success", - Metrics: &RuntimeDoneInvokeMetrics{ - ProducedBytes: int64(100), - DurationMs: float64(52.56), - }, - Spans: []Span{}, - Tracing: &TracingCtx{ - SpanID: "spanid", - Type: model.XRayTracingType, - Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", - }, - } - - expected := ` - { - "requestId": "REQUEST_ID", - "status": "success", - "tracing": { - "spanId": "spanid", - "type": "X-Amzn-Trace-Id", - "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" - }, - "metrics": { - "producedBytes": 100, - "durationMs": 52.56 - } - } - ` - - actual, err := json.Marshal(data) - assert.NoError(t, err) - assert.JSONEq(t, expected, string(actual)) -} - -func TestJsonMarshalInvokeRuntimeDoneTimeout(t *testing.T) { - data := InvokeRuntimeDoneData{ - RequestID: requestID, - Status: "timeout", - Metrics: &RuntimeDoneInvokeMetrics{ - DurationMs: float64(52.56), - }, - Spans: []Span{}, - Tracing: &TracingCtx{ - SpanID: "spanid", - Type: model.XRayTracingType, - Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", - }, - } - - expected := ` - { - "requestId": "REQUEST_ID", - "status": "timeout", - "tracing": { - "spanId": "spanid", - "type": "X-Amzn-Trace-Id", - "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" - }, - "metrics": { - "producedBytes": 0, - "durationMs": 52.56 - } - } - ` - - actual, err := json.Marshal(data) - assert.NoError(t, err) - assert.JSONEq(t, expected, string(actual)) -} - -func TestJsonMarshalInvokeRuntimeDoneFailure(t *testing.T) { - errorType := "Runtime.ExitError" - data := InvokeRuntimeDoneData{ - RequestID: requestID, - Status: "failure", - ErrorType: &errorType, - Metrics: &RuntimeDoneInvokeMetrics{ - DurationMs: float64(52.56), - }, - Spans: []Span{}, - Tracing: &TracingCtx{ - SpanID: "spanid", - Type: model.XRayTracingType, - Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", - }, - } - - expected := ` - { - "requestId": "REQUEST_ID", - "status": "failure", - "tracing": { - "spanId": "spanid", - "type": "X-Amzn-Trace-Id", - "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" - }, - "metrics": { - "producedBytes": 0, - "durationMs": 52.56 - }, - "errorType": "Runtime.ExitError" - } - ` - - actual, err := json.Marshal(data) - assert.NoError(t, err) - assert.JSONEq(t, expected, string(actual)) -} - -func TestJsonMarshalInvokeRuntimeDoneWithEmptyErrorType(t *testing.T) { - errorType := "" - data := InvokeRuntimeDoneData{ - RequestID: requestID, - Status: "failure", - ErrorType: &errorType, - Metrics: &RuntimeDoneInvokeMetrics{ - DurationMs: float64(52.56), - }, - Spans: []Span{}, - Tracing: &TracingCtx{ - SpanID: "spanid", - Type: model.XRayTracingType, - Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", - }, - } - - expected := ` - { - "requestId": "REQUEST_ID", - "status": "failure", - "tracing": { - "spanId": "spanid", - "type": "X-Amzn-Trace-Id", - "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" - }, - "metrics": { - "producedBytes": 0, - "durationMs": 52.56 - }, - "errorType": "" - } - ` - - actual, err := json.Marshal(data) - assert.NoError(t, err) - assert.JSONEq(t, expected, string(actual)) -} - -func TestJsonMarshalInitRuntimeDoneSuccess(t *testing.T) { - var errorType *string - data := InitRuntimeDoneData{ - InitializationType: "snap-start", - Phase: "init", - Status: "success", - ErrorType: errorType, - } - - expected := ` - { - "initializationType": "snap-start", - "phase": "init", - "status": "success" - } - ` - - actual, err := json.Marshal(data) - assert.NoError(t, err) - assert.JSONEq(t, expected, string(actual)) -} - -func TestJsonMarshalInitRuntimeDoneError(t *testing.T) { - errorType := "Runtime.ExitError" - data := InitRuntimeDoneData{ - InitializationType: "snap-start", - Phase: "init", - Status: "error", - ErrorType: &errorType, - } - - expected := ` - { - "initializationType": "snap-start", - "phase": "init", - "status": "error", - "errorType": "Runtime.ExitError" - } - ` - - actual, err := json.Marshal(data) - assert.NoError(t, err) - assert.JSONEq(t, expected, string(actual)) -} - -func TestJsonMarshalInitRuntimeDoneFailureWithEmptyErrorType(t *testing.T) { - errorType := "" - data := InitRuntimeDoneData{ - InitializationType: "snap-start", - Phase: "init", - Status: "error", - ErrorType: &errorType, - } - - expected := ` - { - "initializationType": "snap-start", - "phase": "init", - "status": "error", - "errorType": "" - } - ` - - actual, err := json.Marshal(data) - assert.NoError(t, err) - assert.JSONEq(t, expected, string(actual)) -} - -func TestJsonMarshalRestoreRuntimeDoneSuccess(t *testing.T) { - var errorType *string - data := RestoreRuntimeDoneData{ - Status: "success", - ErrorType: errorType, - } - - expected := ` - { - "status": "success" - } - ` - - actual, err := json.Marshal(data) - assert.NoError(t, err) - assert.JSONEq(t, expected, string(actual)) -} - -func TestJsonMarshalRestoreRuntimeDoneError(t *testing.T) { - errorType := "Runtime.ExitError" - data := RestoreRuntimeDoneData{ - Status: "error", - ErrorType: &errorType, - } - - expected := ` - { - "status": "error", - "errorType": "Runtime.ExitError" - } - ` - - actual, err := json.Marshal(data) - assert.NoError(t, err) - assert.JSONEq(t, expected, string(actual)) -} - -func TestJsonMarshalRestoreRuntimeDoneErrorWithEmptyErrorType(t *testing.T) { - errorType := "" - data := RestoreRuntimeDoneData{ - Status: "error", - ErrorType: &errorType, - } - - expected := ` - { - "status": "error", - "errorType": "" - } - ` - - actual, err := json.Marshal(data) - assert.NoError(t, err) - assert.JSONEq(t, expected, string(actual)) -} - -func TestJsonMarshalExtensionInit(t *testing.T) { - data := ExtensionInitData{ - AgentName: "agentName", - State: "Registered", - ErrorType: "", - Subscriptions: []string{"INVOKE", "SHUTDOWN"}, - } - - actual, err := json.Marshal(data) - assert.NoError(t, err) - assert.JSONEq(t, `{"name":"agentName","state":"Registered","events":["INVOKE","SHUTDOWN"]}`, string(actual)) -} - -func TestJsonMarshalExtensionInitWithError(t *testing.T) { - data := ExtensionInitData{ - AgentName: "agentName", - State: "Registered", - ErrorType: "Extension.FooBar", - Subscriptions: []string{"INVOKE", "SHUTDOWN"}, - } - - actual, err := json.Marshal(data) - assert.NoError(t, err) - assert.JSONEq(t, `{"name":"agentName","state":"Registered","events":["INVOKE","SHUTDOWN"],"errorType":"Extension.FooBar"}`, string(actual)) -} - -func TestJsonMarshalExtensionInitEmptyEvents(t *testing.T) { - data := ExtensionInitData{ - AgentName: "agentName", - State: "Registered", - ErrorType: "Extension.FooBar", - Subscriptions: []string{}, - } - - actual, err := json.Marshal(data) - require.NoError(t, err) - require.JSONEq(t, `{"name":"agentName","state":"Registered","events":[],"errorType":"Extension.FooBar"}`, string(actual)) -} - -func TestJsonMarshalReportWithTracing(t *testing.T) { - errorType := "Runtime.ExitError" - data := ReportData{ - RequestID: requestID, - Status: "error", - ErrorType: &errorType, - Metrics: ReportMetrics{ - DurationMs: float64(52.56), - BilledDurationMs: float64(52.40), - MemorySizeMB: uint64(1024), - MaxMemoryUsedMB: uint64(512), - }, - Tracing: &TracingCtx{ - SpanID: "spanid", - Type: model.XRayTracingType, - Value: "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", - }, - } - - expected := ` - { - "requestId": "REQUEST_ID", - "status": "error", - "errorType": "Runtime.ExitError", - "metrics": { - "durationMs": 52.56, - "billedDurationMs": 52.40, - "memorySizeMB": 1024, - "maxMemoryUsedMB": 512 - }, - "tracing": { - "spanId": "spanid", - "type": "X-Amzn-Trace-Id", - "value": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" - } - } - ` - - actual, err := json.Marshal(data) - require.NoError(t, err) - require.JSONEq(t, expected, string(actual)) -} - -func TestJsonMarshalReportWithoutErrorSpansAndTracing(t *testing.T) { - data := ReportData{ - RequestID: requestID, - Status: "timeout", - Metrics: ReportMetrics{ - DurationMs: float64(52.56), - BilledDurationMs: float64(52.40), - MemorySizeMB: uint64(1024), - MaxMemoryUsedMB: uint64(512), - }, - } - - expected := ` - { - "requestId": "REQUEST_ID", - "status": "timeout", - "metrics": { - "durationMs": 52.56, - "billedDurationMs": 52.40, - "memorySizeMB": 1024, - "maxMemoryUsedMB": 512 - } - } - ` - - actual, err := json.Marshal(data) - require.NoError(t, err) - require.JSONEq(t, expected, string(actual)) -} - -func TestJsonMarshalReportWithInit(t *testing.T) { - data := ReportData{ - RequestID: requestID, - Status: "success", - Metrics: ReportMetrics{ - DurationMs: float64(52.56), - BilledDurationMs: float64(52.40), - MemorySizeMB: uint64(1024), - MaxMemoryUsedMB: uint64(512), - InitDurationMs: float64(3.15), - }, - } - - expected := ` - { - "requestId": "REQUEST_ID", - "status": "success", - "metrics": { - "durationMs": 52.56, - "billedDurationMs": 52.40, - "memorySizeMB": 1024, - "maxMemoryUsedMB": 512, - "initDurationMs": 3.15 - } - } - ` - - actual, err := json.Marshal(data) - require.NoError(t, err) - require.JSONEq(t, expected, string(actual)) -} diff --git a/lambda/interop/messages.go b/lambda/interop/messages.go deleted file mode 100644 index ee1c783..0000000 --- a/lambda/interop/messages.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package interop - -// conversion from internal data structure into well defined messages - -func DoneFromInvokeSuccess(successMsg InvokeSuccess) *Done { - return &Done{ - Meta: DoneMetadata{ - RuntimeRelease: successMsg.RuntimeRelease, - NumActiveExtensions: successMsg.NumActiveExtensions, - ExtensionNames: successMsg.ExtensionNames, - InvokeRequestReadTimeNs: successMsg.InvokeMetrics.InvokeRequestReadTimeNs, - InvokeRequestSizeBytes: successMsg.InvokeMetrics.InvokeRequestSizeBytes, - RuntimeReadyTime: successMsg.InvokeMetrics.RuntimeReadyTime, - - InvokeCompletionTimeNs: successMsg.InvokeCompletionTimeNs, - InvokeReceivedTime: successMsg.InvokeReceivedTime, - RuntimeResponseLatencyMs: successMsg.ResponseMetrics.RuntimeResponseLatencyMs, - RuntimeTimeThrottledMs: successMsg.ResponseMetrics.RuntimeTimeThrottledMs, - RuntimeProducedBytes: successMsg.ResponseMetrics.RuntimeProducedBytes, - RuntimeOutboundThroughputBps: successMsg.ResponseMetrics.RuntimeOutboundThroughputBps, - LogsAPIMetrics: successMsg.LogsAPIMetrics, - MetricsDimensions: DoneMetadataMetricsDimensions{ - InvokeResponseMode: successMsg.InvokeResponseMode, - }, - }, - } -} - -func DoneFailFromInvokeFailure(failureMsg *InvokeFailure) *DoneFail { - return &DoneFail{ - ErrorType: failureMsg.ErrorType, - Meta: DoneMetadata{ - RuntimeRelease: failureMsg.RuntimeRelease, - NumActiveExtensions: failureMsg.NumActiveExtensions, - InvokeReceivedTime: failureMsg.InvokeReceivedTime, - - RuntimeResponseLatencyMs: failureMsg.ResponseMetrics.RuntimeResponseLatencyMs, - RuntimeTimeThrottledMs: failureMsg.ResponseMetrics.RuntimeTimeThrottledMs, - RuntimeProducedBytes: failureMsg.ResponseMetrics.RuntimeProducedBytes, - RuntimeOutboundThroughputBps: failureMsg.ResponseMetrics.RuntimeOutboundThroughputBps, - - InvokeRequestReadTimeNs: failureMsg.InvokeMetrics.InvokeRequestReadTimeNs, - InvokeRequestSizeBytes: failureMsg.InvokeMetrics.InvokeRequestSizeBytes, - RuntimeReadyTime: failureMsg.InvokeMetrics.RuntimeReadyTime, - - ExtensionNames: failureMsg.ExtensionNames, - LogsAPIMetrics: failureMsg.LogsAPIMetrics, - - MetricsDimensions: DoneMetadataMetricsDimensions{ - InvokeResponseMode: failureMsg.InvokeResponseMode, - }, - }, - } -} - -func DoneFailFromInitFailure(initFailure *InitFailure) *DoneFail { - return &DoneFail{ - ErrorType: initFailure.ErrorType, - Meta: DoneMetadata{ - RuntimeRelease: initFailure.RuntimeRelease, - NumActiveExtensions: initFailure.NumActiveExtensions, - LogsAPIMetrics: initFailure.LogsAPIMetrics, - }, - } -} diff --git a/lambda/interop/model.go b/lambda/interop/model.go deleted file mode 100644 index ee7bb2a..0000000 --- a/lambda/interop/model.go +++ /dev/null @@ -1,431 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 -// LOCALSTACK CHANGES 2024-02-13: adjust error message for ErrorResponseTooLarge to be in parity with what AWS returns; make MaxPayloadSize adjustable - -package interop - -import ( - "encoding/json" - "errors" - "fmt" - "io" - "strings" - "time" - - "go.amzn.com/lambda/core/statejson" - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/supervisor/model" - - log "github.com/sirupsen/logrus" -) - -var MaxPayloadSize int = 6*1024*1024 + 100 // 6 MiB + 100 bytes - -// MaxPayloadSize max event body size declared as LAMBDA_EVENT_BODY_SIZE -const ( - ResponseBandwidthRate = 2 * 1024 * 1024 // default average rate of 2 MiB/s - ResponseBandwidthBurstSize = 6 * 1024 * 1024 // default burst size of 6 MiB - - MinResponseBandwidthRate = 32 * 1024 // 32 KiB/s - MaxResponseBandwidthRate = 64 * 1024 * 1024 // 64 MiB/s - - MinResponseBandwidthBurstSize = 32 * 1024 // 32 KiB - MaxResponseBandwidthBurstSize = 64 * 1024 * 1024 // 64 MiB -) - -// ResponseMode are top-level constants used in combination with the various types of -// modes we have for responses, such as invoke's response mode and function's response mode. -// In the future we might have invoke's request mode or similar, so these help set the ground -// for consistency. -type ResponseMode string - -const ResponseModeBuffered = "Buffered" -const ResponseModeStreaming = "Streaming" - -type InvokeResponseMode string - -const InvokeResponseModeBuffered InvokeResponseMode = ResponseModeBuffered -const InvokeResponseModeStreaming InvokeResponseMode = ResponseModeStreaming - -var AllInvokeResponseModes = []string{ - string(InvokeResponseModeBuffered), string(InvokeResponseModeStreaming), -} - -// FunctionResponseMode is passed by Runtime to tell whether the response should be -// streamed or not. -type FunctionResponseMode string - -const FunctionResponseModeBuffered FunctionResponseMode = ResponseModeBuffered -const FunctionResponseModeStreaming FunctionResponseMode = ResponseModeStreaming - -var AllFunctionResponseModes = []string{ - string(FunctionResponseModeBuffered), string(FunctionResponseModeStreaming), -} - -// TODO: move to directinvoke.go as we're trying to deprecate interop.* package -// ConvertToFunctionResponseMode converts the given string to a FunctionResponseMode -// It is case insensitive and if there is no match, an error is thrown. -func ConvertToFunctionResponseMode(value string) (FunctionResponseMode, error) { - // buffered - if strings.EqualFold(value, string(FunctionResponseModeBuffered)) { - return FunctionResponseModeBuffered, nil - } - - // streaming - if strings.EqualFold(value, string(FunctionResponseModeStreaming)) { - return FunctionResponseModeStreaming, nil - } - - // unknown - allowedValues := strings.Join(AllFunctionResponseModes, ", ") - log.Errorf("Unlable to map %s to %s.", value, allowedValues) - return "", ErrInvalidFunctionResponseMode -} - -// Message is a generic interop message. -type Message interface{} - -// Invoke is an invocation request received from the slicer. -type Invoke struct { - // Tracing header. - // https://docs.aws.amazon.com/xray/latest/devguide/xray-concepts.html#xray-concepts-tracingheader - TraceID string - LambdaSegmentID string - ID string - InvokedFunctionArn string - CognitoIdentityID string - CognitoIdentityPoolID string - DeadlineNs string - ClientContext string - ContentType string - Payload io.Reader - NeedDebugLogs bool - ReservationToken string - VersionID string - InvokeReceivedTime int64 - InvokeResponseMetrics *InvokeResponseMetrics - InvokeResponseMode InvokeResponseMode - RestoreDurationNs int64 // equals 0 for non-snapstart functions - RestoreStartTimeMonotime int64 // equals 0 for non-snapstart functions -} - -type Token struct { - ReservationToken string - InvokeID string - VersionID string - FunctionTimeout time.Duration - InvackDeadlineNs int64 - TraceID string - LambdaSegmentID string - InvokeMetadata string - NeedDebugLogs bool - RestoreDurationNs int64 - RestoreStartTimeMonotime int64 -} - -// InvokeErrorTraceData is used by the tracer to mark segments as being invocation error -type InvokeErrorTraceData struct { - // Attached to invoke segment - ErrorCause json.RawMessage `json:"ErrorCause,omitempty"` -} - -func GetErrorResponseWithFormattedErrorMessage(errorType fatalerror.ErrorType, err error, invokeRequestID string) *ErrorInvokeResponse { - var errorMessage string - if invokeRequestID != "" { - errorMessage = fmt.Sprintf("RequestId: %s Error: %v", invokeRequestID, err) - } else { - errorMessage = fmt.Sprintf("Error: %v", err) - } - - jsonPayload, err := json.Marshal(FunctionError{ - Type: errorType, - Message: errorMessage, - }) - - if err != nil { - return &ErrorInvokeResponse{ - Headers: InvokeResponseHeaders{}, - FunctionError: FunctionError{ - Type: fatalerror.SandboxFailure, - Message: errorMessage, - }, - Payload: []byte{}, - } - } - - headers := InvokeResponseHeaders{} - functionError := FunctionError{ - Type: errorType, - Message: errorMessage, - } - - return &ErrorInvokeResponse{Headers: headers, FunctionError: functionError, Payload: jsonPayload} -} - -// SandboxType identifies sandbox type (PreWarmed vs Classic) -type SandboxType string - -const SandboxPreWarmed SandboxType = "PreWarmed" -const SandboxClassic SandboxType = "Classic" - -// RuntimeInfo contains metadata about the runtime used by the Sandbox -type RuntimeInfo struct { - ImageJSON string // image config, e.g {\"layers\":[]} - Arn string // runtime ARN, e.g. arn:awstest:lambda:us-west-2::runtime:python3.8::alpha - Version string // human-readable runtime arn equivalent, e.g. python3.8.v999 -} - -// Captures configuration of the operator and runtime domain -// that are only known after INIT is received -type DynamicDomainConfig struct { - // extra hooks to execute at domain start. Currently used for filesystem and network hooks. - // It can be empty. - AdditionalStartHooks []model.Hook - Mounts []model.Mount - //TODO: other dynamic configurations for the domain go here -} - -// Reset message is sent to rapid to initiate reset sequence -type Reset struct { - Reason string - DeadlineNs int64 - InvokeResponseMetrics *InvokeResponseMetrics - TraceID string - LambdaSegmentID string - InvokeResponseMode InvokeResponseMode -} - -// Restore message is sent to rapid to restore runtime to make it ready for consecutive invokes -type Restore struct { - AwsKey string - AwsSecret string - AwsSession string - CredentialsExpiry time.Time - RestoreHookTimeoutMs int64 - LogStreamName string -} - -type Resync struct { -} - -// Shutdown message is sent to rapid to initiate graceful shutdown -type Shutdown struct { - DeadlineNs int64 -} - -// Metrics for response status of LogsAPI/TelemetryAPI `/subscribe` calls -type TelemetrySubscriptionMetrics map[string]int - -func MergeSubscriptionMetrics(logsAPIMetrics TelemetrySubscriptionMetrics, telemetryAPIMetrics TelemetrySubscriptionMetrics) TelemetrySubscriptionMetrics { - metrics := make(map[string]int) - for metric, value := range logsAPIMetrics { - metrics[metric] = value - } - - for metric, value := range telemetryAPIMetrics { - metrics[metric] += value - } - return metrics -} - -// InvokeResponseMetrics are produced while sending streaming invoke response to WP -type InvokeResponseMetrics struct { - // FIXME: this assumes a value in nanoseconds, let's rename it - // to StartReadingResponseMonoTimeNs - StartReadingResponseMonoTimeMs int64 - // Same as the one above - FinishReadingResponseMonoTimeMs int64 - TimeShapedNs int64 - ProducedBytes int64 - OutboundThroughputBps int64 // in bytes per second - FunctionResponseMode FunctionResponseMode - RuntimeCalledResponse bool -} - -func IsResponseStreamingMetrics(metrics *InvokeResponseMetrics) bool { - if metrics == nil { - return false - } - return metrics.FunctionResponseMode == FunctionResponseModeStreaming -} - -type DoneMetadataMetricsDimensions struct { - InvokeResponseMode InvokeResponseMode -} - -func (dimensions DoneMetadataMetricsDimensions) String() string { - var stringDimensions []string - - if dimensions.InvokeResponseMode != "" { - dimension := string("invoke_response_mode=" + dimensions.InvokeResponseMode) - stringDimensions = append(stringDimensions, dimension) - } - return strings.ToLower( - strings.Join(stringDimensions, ","), - ) -} - -type DoneMetadata struct { - NumActiveExtensions int - ExtensionsResetMs int64 - ExtensionNames string - RuntimeRelease string - // Metrics for response status of LogsAPI `/subscribe` calls - LogsAPIMetrics TelemetrySubscriptionMetrics - InvokeRequestReadTimeNs int64 - InvokeRequestSizeBytes int64 - InvokeCompletionTimeNs int64 - InvokeReceivedTime int64 - RuntimeReadyTime int64 - RuntimeResponseLatencyMs float64 - RuntimeTimeThrottledMs int64 - RuntimeProducedBytes int64 - RuntimeOutboundThroughputBps int64 - MetricsDimensions DoneMetadataMetricsDimensions -} - -type Done struct { - WaitForExit bool - ErrorType fatalerror.ErrorType - Meta DoneMetadata -} - -type DoneFail struct { - ErrorType fatalerror.ErrorType - Meta DoneMetadata -} - -// ErrInvalidInvokeID is returned when invokeID provided in Invoke2 does not match one provided in Token -var ErrInvalidInvokeID = fmt.Errorf("ErrInvalidInvokeID") - -// ErrInvalidReservationToken is returned when reservationToken provided in Invoke2 does not match one provided in Token -var ErrInvalidReservationToken = fmt.Errorf("ErrInvalidReservationToken") - -// ErrInvalidFunctionVersion is returned when functionVersion provided in Invoke2 does not match one provided in Token -var ErrInvalidFunctionVersion = fmt.Errorf("ErrInvalidFunctionVersion") - -// ErrInvalidFunctionResponseMode is returned when the value sent by runtime during Invoke2 -// is not a constant of type interop.FunctionResponseMode -var ErrInvalidFunctionResponseMode = fmt.Errorf("ErrInvalidFunctionResponseMode") - -// ErrInvalidInvokeResponseMode is returned when optional InvokeResponseMode header provided in Invoke2 is not a constant of type interop.InvokeResponseMode -var ErrInvalidInvokeResponseMode = fmt.Errorf("ErrInvalidInvokeResponseMode") - -// ErrInvalidMaxPayloadSize is returned when optional MaxPayloadSize header provided in Invoke2 is invalid -var ErrInvalidMaxPayloadSize = fmt.Errorf("ErrInvalidMaxPayloadSize") - -// ErrInvalidResponseBandwidthRate is returned when optional ResponseBandwidthRate header provided in Invoke2 is invalid -var ErrInvalidResponseBandwidthRate = fmt.Errorf("ErrInvalidResponseBandwidthRate") - -// ErrInvalidResponseBandwidthBurstSize is returned when optional ResponseBandwidthBurstSize header provided in Invoke2 is invalid -var ErrInvalidResponseBandwidthBurstSize = fmt.Errorf("ErrInvalidResponseBandwidthBurstSize") - -// ErrMalformedCustomerHeaders is returned when customer headers format is invalid -var ErrMalformedCustomerHeaders = fmt.Errorf("ErrMalformedCustomerHeaders") - -// ErrResponseSent is returned when response with given invokeID was already sent. -var ErrResponseSent = fmt.Errorf("ErrResponseSent") - -// ErrReservationExpired is returned when invoke arrived after InvackDeadline -var ErrReservationExpired = fmt.Errorf("ErrReservationExpired") - -// ErrInternalPlatformError is returned when internal platform error occurred -type ErrInternalPlatformError struct{} - -func (s *ErrInternalPlatformError) Error() string { - return "ErrInternalPlatformError" -} - -// ErrTruncatedResponse is returned when response is truncated -type ErrTruncatedResponse struct{} - -func (s *ErrTruncatedResponse) Error() string { - return "ErrTruncatedResponse" -} - -// ErrorResponseTooLarge is returned when response Payload exceeds shared memory buffer size -type ErrorResponseTooLarge struct { - MaxResponseSize int - ResponseSize int -} - -// ErrorResponseTooLargeDI is used to reproduce ErrorResponseTooLarge behavior for Direct Invoke mode -type ErrorResponseTooLargeDI struct { - ErrorResponseTooLarge -} - -// ErrorResponseTooLarge is returned when response provided by Runtime does not fit into shared memory buffer -func (s *ErrorResponseTooLarge) Error() string { - return fmt.Sprintf("Response payload size exceeded maximum allowed payload size (%d bytes).", s.MaxResponseSize) -} - -// AsErrorResponse generates ErrorInvokeResponse from ErrorResponseTooLarge -func (s *ErrorResponseTooLarge) AsErrorResponse() *ErrorInvokeResponse { - functionError := FunctionError{ - Type: fatalerror.FunctionOversizedResponse, - Message: s.Error(), - } - jsonPayload, err := json.Marshal(functionError) - if err != nil { - panic("Failed to marshal interop.FunctionError") - } - headers := InvokeResponseHeaders{ContentType: "application/json"} - return &ErrorInvokeResponse{Headers: headers, FunctionError: functionError, Payload: jsonPayload} -} - -// Server used for sending messages and sharing data between the Runtime API handlers and the -// internal platform facing servers. For example, -// -// responseCtx.SendResponse(...) -// -// will send the response payload and metadata provided by the runtime to the platform, through the internal -// protocol used by the specific implementation -// TODO: rename this to InvokeResponseContext, used to send responses from handlers to platform-facing server -type Server interface { - // GetCurrentInvokeID returns current invokeID. - // NOTE, in case of INIT, when invokeID is not known in advance (e.g. provisioned concurrency), - // returned invokeID will contain empty value. - GetCurrentInvokeID() string - - // SendRuntimeReady sends a message indicating the runtime has called /invocation/next. - // The checkpoint allows us to compute the overhead due to Extensions by substracting it - // from the time when all extensions have called /next. - // TODO: this method is a lifecycle event used only for metrics, and doesn't belong here - SendRuntimeReady() error - - // SendInitErrorResponse does two separate things when init/error is called: - // a) sends the init error response if called during invoke, and - // b) notifies platform of a user fault if called, during both init or invoke - // TODO: - // separate the two concerns & unify with SendErrorResponse in response sender - SendInitErrorResponse(response *ErrorInvokeResponse) error -} - -type InternalStateGetter func() statejson.InternalStateDescription - -// ErrRestoreHookTimeout is returned as a response to `RESTORE` message -// when function's restore hook takes more time to execute thatn -// the timeout value. -var ErrRestoreHookTimeout = errors.New("Runtime.RestoreHookUserTimeout") - -// ErrRestoreHookUserError is returned as a response to `RESTORE` message -// when function's restore hook faces with an error on throws an exception. -// UserError contains the error type that the runtime encountered. -type ErrRestoreHookUserError struct { - UserError FunctionError -} - -func (err ErrRestoreHookUserError) Error() string { - return "errRestoreHookUserError" -} - -// ErrRestoreUpdateCredentials is returned as a response to `RESTORE` message -// if RAPID cannot update the credentials served by credentials API -// during the RESTORE phase. -var ErrRestoreUpdateCredentials = errors.New("errRestoreUpdateCredentials") - -var ErrCannotParseCredentialsExpiry = errors.New("errCannotParseCredentialsExpiry") - -var ErrCannotParseRestoreHookTimeoutMs = errors.New("errCannotParseRestoreHookTimeoutMs") - -var ErrMissingRestoreCredentials = errors.New("errMissingRestoreCredentials") diff --git a/lambda/interop/model_test.go b/lambda/interop/model_test.go deleted file mode 100644 index d9ba36a..0000000 --- a/lambda/interop/model_test.go +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package interop - -import ( - "fmt" - "testing" - - "go.amzn.com/lambda/fatalerror" - - "github.com/stretchr/testify/assert" -) - -func TestMergeSubscriptionMetrics(t *testing.T) { - logsAPIMetrics := map[string]int{ - "server_error": 1, - "client_error": 2, - } - - telemetryAPIMetrics := map[string]int{ - "server_error": 1, - "success": 5, - } - - metrics := MergeSubscriptionMetrics(logsAPIMetrics, telemetryAPIMetrics) - assert.Equal(t, 5, metrics["success"]) - assert.Equal(t, 2, metrics["server_error"]) - assert.Equal(t, 2, metrics["client_error"]) -} - -func TestGetErrorResponseWithFormattedErrorMessageWithoutInvokeRequestId(t *testing.T) { - errorType := fatalerror.RuntimeExit - errorMessage := fmt.Errorf("Divided by 0") - expectedMsg := fmt.Sprintf(`Error: %s`, errorMessage) - expectedJSON := fmt.Sprintf(`{"errorType": "%s", "errorMessage": "%s"}`, string(errorType), expectedMsg) - - actual := GetErrorResponseWithFormattedErrorMessage(errorType, errorMessage, "") - assert.Equal(t, errorType, actual.FunctionError.Type) - assert.Equal(t, expectedMsg, actual.FunctionError.Message) - assert.JSONEq(t, expectedJSON, string(actual.Payload)) -} - -func TestGetErrorResponseWithFormattedErrorMessageWithInvokeRequestId(t *testing.T) { - errorType := fatalerror.RuntimeExit - errorMessage := fmt.Errorf("Divided by 0") - invokeID := "invoke-id" - expectedMsg := fmt.Sprintf(`RequestId: %s Error: %s`, invokeID, errorMessage) - expectedJSON := fmt.Sprintf(`{"errorType": "%s", "errorMessage": "%s"}`, string(errorType), expectedMsg) - - actual := GetErrorResponseWithFormattedErrorMessage(errorType, errorMessage, invokeID) - assert.Equal(t, errorType, actual.FunctionError.Type) - assert.Equal(t, expectedMsg, actual.FunctionError.Message) - assert.JSONEq(t, expectedJSON, string(actual.Payload)) -} - -func TestDoneMetadataMetricsDimensionsStringWhenInvokeResponseModeIsPresent(t *testing.T) { - dimensions := DoneMetadataMetricsDimensions{ - InvokeResponseMode: InvokeResponseModeStreaming, - } - assert.Equal(t, "invoke_response_mode=streaming", dimensions.String()) -} -func TestDoneMetadataMetricsDimensionsStringWhenEmpty(t *testing.T) { - dimensions := DoneMetadataMetricsDimensions{} - assert.Equal(t, "", dimensions.String()) -} diff --git a/lambda/interop/sandbox_model.go b/lambda/interop/sandbox_model.go deleted file mode 100644 index 3011c48..0000000 --- a/lambda/interop/sandbox_model.go +++ /dev/null @@ -1,249 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package interop - -import ( - "bytes" - "io" - "net/http" - "time" - - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/rapidcore/env" -) - -// Init represents an init message -// In Rapid Shim, this is a START GirD message -// In Rapid Daemon, this is an INIT GirP message -type Init struct { - InvokeID string - Handler string - AccountID string - AwsKey string - AwsSecret string - AwsSession string - CredentialsExpiry time.Time - SuppressInit bool - InvokeTimeoutMs int64 // timeout duration of whole invoke - InitTimeoutMs int64 // timeout duration for init only - XRayDaemonAddress string // only in standalone - FunctionName string // only in standalone - FunctionVersion string // only in standalone - // In standalone mode, these env vars come from test/init but from environment otherwise. - CustomerEnvironmentVariables map[string]string - SandboxType SandboxType - LogStreamName string - InstanceMaxMemory uint64 - OperatorDomainExtraConfig DynamicDomainConfig - RuntimeDomainExtraConfig DynamicDomainConfig - RuntimeInfo RuntimeInfo - Bootstrap Bootstrap - EnvironmentVariables *env.Environment // contains env vars for agents and runtime procs -} - -// InitSuccess indicates that runtime/extensions initialization completed successfully -// In Rapid Shim, this translates to a DONE GirD message to Slicer -// In Rapid Daemon, this is followed by a DONEDONE GirP message to MM -type InitSuccess struct { - NumActiveExtensions int // indicates number of active extensions - ExtensionNames string // file names of extensions in /opt/extensions - RuntimeRelease string - LogsAPIMetrics TelemetrySubscriptionMetrics // used if telemetry API enabled - Ack chan struct{} // used by the sending goroutine to wait until ipc message has been sent -} - -// InitFailure indicates that runtime/extensions initialization failed due to process exit or /error calls -// In Rapid Shim, this translates to either a DONE or a DONEFAIL GirD message to Slicer (depending on extensions mode) -// However, even on failure, the next invoke is expected to work with a suppressed init - i.e. we init again as aprt of the invoke -type InitFailure struct { - ResetReceived bool // indicates if failure happened due to a reset received - RequestReset bool // Indicates whether reset should be requested on init failure - ErrorType fatalerror.ErrorType - ErrorMessage error - NumActiveExtensions int - RuntimeRelease string // value of the User Agent HTTP header provided by runtime - LogsAPIMetrics TelemetrySubscriptionMetrics - Ack chan struct{} // used by the sending goroutine to wait until ipc message has been sent -} - -// ErrorInvokeResponse represents a buffered response received via Runtime API -// for error responses. When body (Payload) is not provided, e.g. -// not retrievable, error type and error message headers will be -// used by the platform to construct a response json, e.g: -// -// default error response produced by the Slicer: -// '{"errorMessage":"Unknown application error occurred"}', -// -// when error type is provided, error response becomes: -// '{"errorMessage":"Unknown application error occurred","errorType":"ErrorType"}' -type ErrorInvokeResponse struct { - Headers InvokeResponseHeaders - Payload []byte - FunctionError FunctionError -} - -// StreamableInvokeResponse represents a response received via Runtime API that can be streamed -type StreamableInvokeResponse struct { - Headers map[string]string - Payload io.Reader - Trailers http.Header - Request *CancellableRequest // streaming request may need to gracefully terminate request streams -} - -// InvokeResponseHeaders contains the headers received via Runtime API /invocation/response -type InvokeResponseHeaders struct { - ContentType string - FunctionResponseMode string -} - -// FunctionError represents information about function errors or 'user errors' -// These are not platform errors and hence are returned as 200 by Lambda -// In the absence of a response payload, the Function Error is serialized and sent -type FunctionError struct { - // Type of error is derived from the Lambda-Runtime-Function-Error-Type set by the Runtime - // This is customer data, so RAPID scrubs this error type to contain only allowlisted values - Type fatalerror.ErrorType `json:"errorType,omitempty"` - // ErrorMessage is generated by RAPID and can never be specified by runtime - Message string `json:"errorMessage,omitempty"` -} - -type InvokeResponseSender interface { - // SendResponse sends invocation response received from Runtime to platform - // This is response may be streamed based on function and invoke response mode - SendResponse(invokeID string, response *StreamableInvokeResponse) error - // SendErrorResponse sends error response in the case of function errors, which are always buffered - SendErrorResponse(invokeID string, response *ErrorInvokeResponse) error -} - -// ResponseMetrics groups metrics related to the response stream -type ResponseMetrics struct { - RuntimeOutboundThroughputBps int64 - RuntimeProducedBytes int64 - RuntimeResponseLatencyMs float64 - RuntimeTimeThrottledMs int64 -} - -// InvokeMetrics groups metrics related to the invoke phase -type InvokeMetrics struct { - InvokeRequestReadTimeNs int64 - InvokeRequestSizeBytes int64 - RuntimeReadyTime int64 -} - -// InvokeSuccess is the success response to invoke phase end -type InvokeSuccess struct { - RuntimeRelease string // value of the User Agent HTTP header provided by runtime - NumActiveExtensions int - ExtensionNames string - InvokeCompletionTimeNs int64 - InvokeReceivedTime int64 - LogsAPIMetrics TelemetrySubscriptionMetrics - ResponseMetrics ResponseMetrics - InvokeMetrics InvokeMetrics - InvokeResponseMode InvokeResponseMode -} - -// InvokeFailure is the failure response to invoke phase end -type InvokeFailure struct { - ResetReceived bool // indicates if failure happened due to a reset received - RequestReset bool // indicates if reset must be requested after the failure - ErrorType fatalerror.ErrorType - ErrorMessage error - RuntimeRelease string // value of the User Agent HTTP header provided by runtime - NumActiveExtensions int - InvokeReceivedTime int64 - LogsAPIMetrics TelemetrySubscriptionMetrics - ResponseMetrics ResponseMetrics - InvokeMetrics InvokeMetrics - ExtensionNames string - DefaultErrorResponse *ErrorInvokeResponse // error resp constructed by platform during fn errors - InvokeResponseMode InvokeResponseMode -} - -// ResetSuccess is the success response to reset request -type ResetSuccess struct { - ExtensionsResetMs int64 - ErrorType fatalerror.ErrorType - ResponseMetrics ResponseMetrics - InvokeResponseMode InvokeResponseMode -} - -// ResetFailure is the failure response to reset request -type ResetFailure struct { - ExtensionsResetMs int64 - ErrorType fatalerror.ErrorType - ResponseMetrics ResponseMetrics - InvokeResponseMode InvokeResponseMode -} - -// ShutdownSuccess is the response to a shutdown request -type ShutdownSuccess struct { - ErrorType fatalerror.ErrorType -} - -// SandboxInfoFromInit captures data from init request that -// is required during invoke (e.g. for suppressed init) -type SandboxInfoFromInit struct { - EnvironmentVariables *env.Environment // contains agent env vars (creds, customer, platform) - SandboxType SandboxType // indicating Pre-Warmed, On-Demand etc - RuntimeBootstrap Bootstrap // contains the runtime bootstrap binary path, Cwd, Args, Env, Cmd -} - -// RestoreResult represents the result of `HandleRestore` function -// in RapidCore -type RestoreResult struct { - RestoreMs int64 -} - -// RapidContext expose methods for functionality of the Rapid Core library -type RapidContext interface { - HandleInit(i *Init, success chan<- InitSuccess, failure chan<- InitFailure) - HandleInvoke(i *Invoke, sbMetadata SandboxInfoFromInit, requestBuf *bytes.Buffer, responseSender InvokeResponseSender) (InvokeSuccess, *InvokeFailure) - HandleReset(reset *Reset) (ResetSuccess, *ResetFailure) - HandleShutdown(shutdown *Shutdown) ShutdownSuccess - HandleRestore(restore *Restore) (RestoreResult, error) - Clear() - - SetRuntimeStartedTime(runtimeStartedTime int64) - SetInvokeResponseMetrics(metrics *InvokeResponseMetrics) - - SetEventsAPI(eventsAPI EventsAPI) -} - -// SandboxContext represents the sandbox lifecycle context -type SandboxContext interface { - Init(i *Init, timeoutMs int64) InitContext - Reset(reset *Reset) (ResetSuccess, *ResetFailure) - Shutdown(shutdown *Shutdown) ShutdownSuccess - Restore(restore *Restore) (RestoreResult, error) - - // TODO: refactor this - // runtimeStartedTime and InvokeResponseMetrics are needed to compute the runtimeDone metrics - // in case of a Reset during an invoke (reset.reason=failure or reset.reason=timeout). - // Ideally: - // - the InvokeContext will have a Reset method to deal with Reset during an invoke and will hold runtimeStartedTime and InvokeResponseMetrics - // - the SandboxContext will have its own Reset/Spindown method - SetRuntimeStartedTime(invokeReceivedTime int64) - SetInvokeResponseMetrics(metrics *InvokeResponseMetrics) -} - -// InitContext represents the lifecycle of a sandbox initialization -type InitContext interface { - Wait() (InitSuccess, *InitFailure) - Reserve() InvokeContext -} - -// InvokeContext represents the lifecycle of a sandbox reservation -type InvokeContext interface { - SendRequest(i *Invoke, r InvokeResponseSender) - Wait() (InvokeSuccess, *InvokeFailure) -} - -// LifecyclePhase represents enum for possible Sandbox lifecycle phases, like init, invoke, etc. -type LifecyclePhase int - -const ( - LifecyclePhaseInit LifecyclePhase = iota + 1 - LifecyclePhaseInvoke -) diff --git a/lambda/logging/doc.go b/lambda/logging/doc.go deleted file mode 100644 index a1f7e95..0000000 --- a/lambda/logging/doc.go +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -/* -RAPID emits or proxies the following sources of logging: - - 1. Internal logs: RAPID's own application logs into stderr for operational use, visible only internally - 2. Function stream-based logs: Runtime's stdout and stderr, read as newline separated lines - 3. Function message-based logs: Stock runtimes communicate using a custom TLV protocol over a Unix pipe - 4. Extension stream-based logs: Extension's stdout and stderr, read as newline separated lines - 5. Platform logs: Logs that RAPID generates, but is visible either in customer's logs or via Logs API - (e.g. EXTENSION, RUNTIME, RUNTIMEDONE, IMAGE) -*/ -package logging diff --git a/lambda/logging/internal_log.go b/lambda/logging/internal_log.go deleted file mode 100644 index 018b2c7..0000000 --- a/lambda/logging/internal_log.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package logging - -import ( - "bytes" - "fmt" - "github.com/sirupsen/logrus" - "io" - "log" - "strings" -) - -// SetOutput configures logging output for standard loggers. -func SetOutput(w io.Writer) { - log.SetOutput(w) - logrus.SetOutput(w) -} - -type InternalFormatter struct{} - -// format RAPID's internal log like the rest of the sandbox log -func (f *InternalFormatter) Format(entry *logrus.Entry) ([]byte, error) { - b := &bytes.Buffer{} - - // time with comma separator for fraction of second - time := entry.Time.Format("02 Jan 2006 15:04:05.000") - time = strings.Replace(time, ".", ",", 1) - fmt.Fprint(b, time) - - // level - level := strings.ToUpper(entry.Level.String()) - fmt.Fprintf(b, " [%s]", level) - - // label - fmt.Fprint(b, " (rapid)") - - // message - fmt.Fprintf(b, " %s", entry.Message) - - // from WithField and WithError - for field, value := range entry.Data { - fmt.Fprintf(b, " %s=%s", field, value) - } - - fmt.Fprintf(b, "\n") - return b.Bytes(), nil -} diff --git a/lambda/logging/internal_log_test.go b/lambda/logging/internal_log_test.go deleted file mode 100644 index 3ec537f..0000000 --- a/lambda/logging/internal_log_test.go +++ /dev/null @@ -1,167 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package logging - -import ( - "bytes" - "fmt" - "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" - "io" - "log" - "testing" -) - -func TestLogPrint(t *testing.T) { - buf := new(bytes.Buffer) - SetOutput(buf) - log.Print("hello log") - assert.Contains(t, buf.String(), "hello log") -} - -func TestLogrusPrint(t *testing.T) { - buf := new(bytes.Buffer) - SetOutput(buf) - logrus.Print("hello logrus") - assert.Contains(t, buf.String(), "hello logrus") -} - -func TestInternalFormatter(t *testing.T) { - pattern := `^([0-9]{2}\s[A-Za-z]{3}\s[0-9]{4}\s[0-9]{2}:[0-9]{2}:[0-9]{2}(?:,[0-9]{3})?)\s(?:\s\{sandbox:([0-9]+)\}\s)?\[([A-Za-z]+)\]\s(\(([^\)]+)\)(?:\s\[Logging Metrics\]\sSBLOG:([a-zA-Z:]+) ([0-9]+))?\s?.*)` - - buf := new(bytes.Buffer) - SetOutput(buf) - logrus.SetFormatter(&InternalFormatter{}) - - logrus.Print("hello logrus") - assert.Regexp(t, pattern, buf.String()) - - buf.Reset() - err := fmt.Errorf("error message") - logrus.WithError(err).Warning("hello logrus") - assert.Regexp(t, pattern, buf.String()) - - buf.Reset() - logrus.WithFields(logrus.Fields{ - "field1": "val1", - "field2": "val2", - "field3": "val3", - }).Info("hello logrus") - assert.Regexp(t, pattern, buf.String()) - - // no caller logged - buf.Reset() - logrus.WithFields(logrus.Fields{ - "field1": "val1", - "field2": "val2", - "field3": "val3", - }).Info("hello logrus") - assert.Regexp(t, pattern, buf.String()) - - // invalid format without InternalFormatter - buf.Reset() - logrus.SetFormatter(&logrus.TextFormatter{}) - logrus.Print("hello logrus") - assert.NotRegexp(t, pattern, buf.String()) -} - -func BenchmarkLogPrint(b *testing.B) { - SetOutput(io.Discard) - for n := 0; n < b.N; n++ { - log.Print(1, "two", true) - } -} - -func BenchmarkLogrusPrint(b *testing.B) { - SetOutput(io.Discard) - for n := 0; n < b.N; n++ { - logrus.Print(1, "two", true) - } -} - -func BenchmarkLogrusPrintInternalFormatter(b *testing.B) { - var l = logrus.New() - l.SetFormatter(&InternalFormatter{}) - l.SetOutput(io.Discard) - for n := 0; n < b.N; n++ { - l.Print(1, "two", true) - } -} - -func BenchmarkLogPrintf(b *testing.B) { - SetOutput(io.Discard) - for n := 0; n < b.N; n++ { - log.Printf("field:%v,field:%v,field:%v", 1, "two", true) - } -} - -func BenchmarkLogrusPrintf(b *testing.B) { - SetOutput(io.Discard) - for n := 0; n < b.N; n++ { - logrus.Printf("field:%v,field:%v,field:%v", 1, "two", true) - } -} - -func BenchmarkLogrusPrintfInternalFormatter(b *testing.B) { - var l = logrus.New() - l.SetFormatter(&InternalFormatter{}) - l.SetOutput(io.Discard) - for n := 0; n < b.N; n++ { - l.Printf("field:%v,field:%v,field:%v", 1, "two", true) - } -} - -func BenchmarkLogrusDebugLogLevelDisabled(b *testing.B) { - SetOutput(io.Discard) - logrus.SetLevel(logrus.InfoLevel) - for n := 0; n < b.N; n++ { - logrus.Debug(1, "two", true) - } -} - -func BenchmarkLogrusDebugLogLevelDisabledInternalFormatter(b *testing.B) { - var l = logrus.New() - l.SetOutput(io.Discard) - l.SetLevel(logrus.InfoLevel) - for n := 0; n < b.N; n++ { - l.Debug(1, "two", true) - } -} - -func BenchmarkLogrusDebugLogLevelEnabled(b *testing.B) { - SetOutput(io.Discard) - logrus.SetLevel(logrus.DebugLevel) - for n := 0; n < b.N; n++ { - logrus.Debug(1, "two", true) - } -} - -func BenchmarkLogrusDebugLogLevelEnabledInternalFormatter(b *testing.B) { - var l = logrus.New() - l.SetFormatter(&InternalFormatter{}) - l.SetOutput(io.Discard) - l.SetLevel(logrus.DebugLevel) - for n := 0; n < b.N; n++ { - l.Debug(1, "two", true) - } -} - -func BenchmarkLogrusDebugWithFieldLogLevelDisabled(b *testing.B) { - SetOutput(io.Discard) - logrus.SetLevel(logrus.InfoLevel) - for n := 0; n < b.N; n++ { - logrus.WithField("field", "value").Debug(1, "two", true) - } -} - -func BenchmarkLogrusDebugWithFieldLogLevelDisabledInternalFormatter(b *testing.B) { - var l = logrus.New() - l.SetFormatter(&InternalFormatter{}) - l.SetOutput(io.Discard) - l.SetLevel(logrus.InfoLevel) - for n := 0; n < b.N; n++ { - l.WithField("field", "value").Debug(1, "two", true) - } -} - diff --git a/lambda/metering/time.go b/lambda/metering/time.go deleted file mode 100644 index 9e0fa01..0000000 --- a/lambda/metering/time.go +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package metering - -import ( - _ "runtime" //for nanotime() and walltime() - "time" - _ "unsafe" //for go:linkname -) - -//go:linkname Monotime runtime.nanotime -func Monotime() int64 - -// MonoToEpoch converts monotonic time nanos to unix epoch time nanos. -func MonoToEpoch(t int64) int64 { - monoNsec := Monotime() - wallNsec := time.Now().UnixNano() - clockOffset := wallNsec - monoNsec - return t + clockOffset -} - -func TimeToMono(t time.Time) int64 { - durNs := time.Since(t).Nanoseconds() - return Monotime() - durNs -} - -type ExtensionsResetDurationProfiler struct { - NumAgentsRegisteredForShutdown int - AvailableNs int64 - extensionsResetStartTimeNs int64 - extensionsResetEndTimeNs int64 -} - -func (p *ExtensionsResetDurationProfiler) Start() { - p.extensionsResetStartTimeNs = Monotime() -} - -func (p *ExtensionsResetDurationProfiler) Stop() { - p.extensionsResetEndTimeNs = Monotime() -} - -func (p *ExtensionsResetDurationProfiler) CalculateExtensionsResetMs() (int64, bool) { - var extensionsResetDurationNs = p.extensionsResetEndTimeNs - p.extensionsResetStartTimeNs - var extensionsResetMs int64 - timedOut := false - - if p.NumAgentsRegisteredForShutdown == 0 || p.AvailableNs < 0 || extensionsResetDurationNs < 0 { - extensionsResetMs = 0 - } else if extensionsResetDurationNs > p.AvailableNs { - extensionsResetMs = p.AvailableNs / time.Millisecond.Nanoseconds() - timedOut = true - } else { - extensionsResetMs = extensionsResetDurationNs / time.Millisecond.Nanoseconds() - } - - return extensionsResetMs, timedOut -} diff --git a/lambda/metering/time_test.go b/lambda/metering/time_test.go deleted file mode 100644 index 5c37a87..0000000 --- a/lambda/metering/time_test.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package metering - -import ( - "math" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestMonoToEpochPrecision(t *testing.T) { - a := time.Now().UnixNano() - b := MonoToEpoch(Monotime()) - - // Conversion error is less than a millisecond. - assert.True(t, math.Abs(float64(a-b)) < float64(time.Millisecond)) -} - -func TestEpochToMonoPrecision(t *testing.T) { - a := Monotime() - b := TimeToMono(time.Now()) - - // Conversion error is less than a millisecond. - assert.Less(t, math.Abs(float64(b-a)), float64(1*time.Millisecond)) -} - -func TestExtensionsResetDurationProfilerForExtensionsResetWithNoExtensions(t *testing.T) { - mono := Monotime() - profiler := ExtensionsResetDurationProfiler{} - - profiler.extensionsResetStartTimeNs = mono - profiler.extensionsResetEndTimeNs = mono + time.Second.Nanoseconds() - profiler.AvailableNs = 3 * time.Second.Nanoseconds() - profiler.NumAgentsRegisteredForShutdown = 0 - extensionsResetMs, resetTimeout := profiler.CalculateExtensionsResetMs() - - assert.Equal(t, int64(0), extensionsResetMs) - assert.Equal(t, false, resetTimeout) -} - -func TestExtensionsResetDurationProfilerForExtensionsResetWithinDeadline(t *testing.T) { - mono := Monotime() - profiler := ExtensionsResetDurationProfiler{} - - profiler.extensionsResetStartTimeNs = mono - profiler.extensionsResetEndTimeNs = mono + time.Second.Nanoseconds() - profiler.AvailableNs = 3 * time.Second.Nanoseconds() - profiler.NumAgentsRegisteredForShutdown = 1 - extensionsResetMs, resetTimeout := profiler.CalculateExtensionsResetMs() - - assert.Equal(t, time.Second.Milliseconds(), extensionsResetMs) - assert.Equal(t, false, resetTimeout) -} - -func TestExtensionsResetDurationProfilerForExtensionsResetTimeout(t *testing.T) { - mono := Monotime() - profiler := ExtensionsResetDurationProfiler{} - - profiler.extensionsResetStartTimeNs = mono - profiler.extensionsResetEndTimeNs = mono + 3*time.Second.Nanoseconds() - profiler.AvailableNs = time.Second.Nanoseconds() - profiler.NumAgentsRegisteredForShutdown = 1 - extensionsResetMs, resetTimeout := profiler.CalculateExtensionsResetMs() - - assert.Equal(t, time.Second.Milliseconds(), extensionsResetMs) - assert.Equal(t, true, resetTimeout) -} - -func TestExtensionsResetDurationProfilerEndToEnd(t *testing.T) { - profiler := ExtensionsResetDurationProfiler{} - - profiler.Start() - time.Sleep(time.Second) - profiler.Stop() - - profiler.AvailableNs = 2 * time.Second.Nanoseconds() - profiler.NumAgentsRegisteredForShutdown = 1 - extensionsResetMs, _ := profiler.CalculateExtensionsResetMs() - - assert.GreaterOrEqual(t, 2*time.Second.Milliseconds(), extensionsResetMs) - assert.LessOrEqual(t, time.Second.Milliseconds(), extensionsResetMs) -} diff --git a/lambda/rapi/extensions_fuzz_test.go b/lambda/rapi/extensions_fuzz_test.go deleted file mode 100644 index c223859..0000000 --- a/lambda/rapi/extensions_fuzz_test.go +++ /dev/null @@ -1,344 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapi - -import ( - "bytes" - "context" - "encoding/json" - "io" - "log" - "net/http" - "net/http/httptest" - "testing" - - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/extensions" - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/rapi/handler" - "go.amzn.com/lambda/rapi/model" - "go.amzn.com/lambda/rapi/rendering" - "go.amzn.com/lambda/telemetry" - "go.amzn.com/lambda/testdata" -) - -func FuzzAgentRegisterHandler(f *testing.F) { - extensions.Enable() - defer extensions.Disable() - - registerReq := handler.RegisterRequest{ - Events: []core.Event{core.InvokeEvent, core.ShutdownEvent}, - } - regReqBytes, err := json.Marshal(®isterReq) - if err != nil { - f.Errorf("failed to marshal register request: %v", err) - } - f.Add("agent", "accountId", true, regReqBytes) - f.Add("agent", "accountId", false, regReqBytes) - - f.Fuzz(func(t *testing.T, - agentName string, - featuresHeader string, - external bool, - payload []byte, - ) { - flowTest := testdata.NewFlowTest() - - if external { - flowTest.RegistrationService.CreateExternalAgent(agentName) - } - - functionMetadata := createDummyFunctionMetadata() - flowTest.RegistrationService.SetFunctionMetadata(functionMetadata) - - rapiServer := makeRapiServer(flowTest) - - target := makeTargetURL("/extension/register", version20200101) - request := httptest.NewRequest("POST", target, bytes.NewReader(payload)) - request.Header.Add(handler.LambdaAgentName, agentName) - request.Header.Add("Lambda-Extension-Accept-Feature", featuresHeader) - - responseRecorder := serveTestRequest(rapiServer, request) - - if agentName == "" { - assertForbiddenErrorType(t, responseRecorder, "Extension.InvalidExtensionName") - return - } - - regReqStruct := struct { - handler.RegisterRequest - ConfigurationKeys []string `json:"configurationKeys"` - }{} - if err := json.Unmarshal(payload, ®ReqStruct); err != nil { - assertForbiddenErrorType(t, responseRecorder, "InvalidRequestFormat") - return - } - - if containsInvalidEvent(external, regReqStruct.Events) { - assertForbiddenErrorType(t, responseRecorder, "Extension.InvalidEventType") - return - } - - assert.Equal(t, http.StatusOK, responseRecorder.Code) - - respBody, err := io.ReadAll(responseRecorder.Body) - assert.NoError(t, err) - - expectedResponse := map[string]interface{}{ - "functionName": functionMetadata.FunctionName, - "functionVersion": functionMetadata.FunctionVersion, - "handler": functionMetadata.Handler, - } - if featuresHeader == "accountId" && functionMetadata.AccountID != "" { - expectedResponse["accountId"] = functionMetadata.AccountID - } - - expectedRespBytes, err := json.Marshal(expectedResponse) - assert.NoError(t, err) - assert.JSONEq(t, string(expectedRespBytes), string(respBody)) - - if external { - agent, found := flowTest.RegistrationService.FindExternalAgentByName(agentName) - assert.True(t, found) - assert.Equal(t, agent.RegisteredState, agent.GetState()) - } else { - agent, found := flowTest.RegistrationService.FindInternalAgentByName(agentName) - assert.True(t, found) - assert.Equal(t, agent.RegisteredState, agent.GetState()) - } - }) -} - -func FuzzAgentNextHandler(f *testing.F) { - extensions.Enable() - defer extensions.Disable() - - regService := core.NewRegistrationService(core.NewInitFlowSynchronization(), core.NewInvokeFlowSynchronization()) - testAgent := makeExternalAgent(regService) - f.Add(testAgent.ID.String(), true, true) - f.Add(testAgent.ID.String(), true, false) - - f.Fuzz(func(t *testing.T, - agentIdentifierHeader string, - registered bool, - isInvokeEvent bool, - ) { - flowTest := testdata.NewFlowTest() - agent := makeExternalAgent(flowTest.RegistrationService) - - if registered { - agent.SetState(agent.RegisteredState) - agent.Release() - } - - configureRendererForEvent(flowTest, isInvokeEvent) - - rapiServer := makeRapiServer(flowTest) - - target := makeTargetURL("/extension/event/next", version20200101) - request := httptest.NewRequest("GET", target, nil) - request.Header.Set(handler.LambdaAgentIdentifier, agentIdentifierHeader) - - responseRecorder := serveTestRequest(rapiServer, request) - - if agentIdentifierHeader == "" { - assertForbiddenErrorType(t, responseRecorder, handler.ErrAgentIdentifierMissing) - return - } - if _, err := uuid.Parse(agentIdentifierHeader); err != nil { - assertForbiddenErrorType(t, responseRecorder, handler.ErrAgentIdentifierInvalid) - return - } - if agentIdentifierHeader != agent.ID.String() { - assertForbiddenErrorType(t, responseRecorder, "Extension.UnknownExtensionIdentifier") - return - } - if !registered { - assertForbiddenErrorType(t, responseRecorder, "Extension.InvalidExtensionState") - return - } - - assert.Equal(t, http.StatusOK, responseRecorder.Code) - - assertResponseEventType(t, isInvokeEvent, responseRecorder) - - assert.Equal(t, agent.RunningState, agent.GetState()) - }) -} - -func FuzzAgentInitErrorHandler(f *testing.F) { - fuzzErrorHandler(f, "/extension/init/error", fatalerror.AgentInitError) -} - -func FuzzAgentExitErrorHandler(f *testing.F) { - fuzzErrorHandler(f, "/extension/exit/error", fatalerror.AgentExitError) -} - -func fuzzErrorHandler(f *testing.F, handlerPath string, fatalErrorType fatalerror.ErrorType) { - extensions.Enable() - defer extensions.Disable() - - regService := core.NewRegistrationService(core.NewInitFlowSynchronization(), core.NewInvokeFlowSynchronization()) - testAgent := makeExternalAgent(regService) - f.Add(true, testAgent.ID.String(), "Extension.SomeError") - f.Add(false, testAgent.ID.String(), "Extension.SomeError") - - f.Fuzz(func(t *testing.T, - agentRegistered bool, - agentIdentifierHeader string, - errorType string, - ) { - flowTest := testdata.NewFlowTest() - - agent := makeExternalAgent(flowTest.RegistrationService) - - if agentRegistered { - agent.SetState(agent.RegisteredState) - } - - rapiServer := makeRapiServer(flowTest) - - target := makeTargetURL(handlerPath, version20200101) - - request := httptest.NewRequest("POST", target, nil) - request = appctx.RequestWithAppCtx(request, flowTest.AppCtx) - request.Header.Set(handler.LambdaAgentIdentifier, agentIdentifierHeader) - request.Header.Set(handler.LambdaAgentFunctionErrorType, errorType) - - responseRecorder := serveTestRequest(rapiServer, request) - - if agentIdentifierHeader == "" { - assertForbiddenErrorType(t, responseRecorder, handler.ErrAgentIdentifierMissing) - return - } - - if _, e := uuid.Parse(agentIdentifierHeader); e != nil { - assertForbiddenErrorType(t, responseRecorder, handler.ErrAgentIdentifierInvalid) - return - } - - if errorType == "" { - assertForbiddenErrorType(t, responseRecorder, "Extension.MissingHeader") - return - } - if agentIdentifierHeader != agent.ID.String() { - assertForbiddenErrorType(t, responseRecorder, "Extension.UnknownExtensionIdentifier") - return - } - if !agentRegistered { - assertForbiddenErrorType(t, responseRecorder, "Extension.InvalidExtensionState") - } else { - assertErrorAgentRegistered(t, responseRecorder, flowTest, fatalErrorType) - } - }) -} - -func assertErrorAgentRegistered(t *testing.T, responseRecorder *httptest.ResponseRecorder, flowTest *testdata.FlowTest, expectedErrType fatalerror.ErrorType) { - var response model.StatusResponse - - respBody, _ := io.ReadAll(responseRecorder.Body) - err := json.Unmarshal(respBody, &response) - assert.NoError(t, err) - - assert.Equal(t, http.StatusAccepted, responseRecorder.Code) - assert.Equal(t, "OK", response.Status) - - v, found := appctx.LoadFirstFatalError(flowTest.AppCtx) - assert.True(t, found) - assert.Equal(t, expectedErrType, v) -} - -func assertForbiddenErrorType(t *testing.T, responseRecorder *httptest.ResponseRecorder, errType string) { - assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - - var errorResponse model.ErrorResponse - - respBody, _ := io.ReadAll(responseRecorder.Body) - err := json.Unmarshal(respBody, &errorResponse) - assert.NoError(t, err) - - assert.Equal(t, errType, errorResponse.ErrorType) -} - -func createDummyFunctionMetadata() core.FunctionMetadata { - return core.FunctionMetadata{ - AccountID: "accID", - FunctionName: "myFunc", - FunctionVersion: "1.0", - Handler: "myHandler", - } -} - -func makeExternalAgent(registrationService core.RegistrationService) *core.ExternalAgent { - agent, err := registrationService.CreateExternalAgent("agent") - if err != nil { - log.Fatalf("failed to create external agent: %v", err) - return nil - } - - return agent -} - -func configureRendererForEvent(flowTest *testdata.FlowTest, isInvokeEvent bool) { - if isInvokeEvent { - invoke := createDummyInvoke() - - var buf bytes.Buffer - flowTest.RenderingService.SetRenderer( - rendering.NewInvokeRenderer( - context.Background(), - invoke, - &buf, - telemetry.NewNoOpTracer().BuildTracingHeader(), - )) - } else { - flowTest.RenderingService.SetRenderer( - &rendering.ShutdownRenderer{ - AgentEvent: model.AgentShutdownEvent{ - AgentEvent: &model.AgentEvent{ - EventType: "SHUTDOWN", - DeadlineMs: int64(10000), - }, - ShutdownReason: "spindown", - }, - }) - } -} - -func assertResponseEventType(t *testing.T, isInvokeEvent bool, responseRecorder *httptest.ResponseRecorder) { - if isInvokeEvent { - var response model.AgentInvokeEvent - - respBody, _ := io.ReadAll(responseRecorder.Body) - err := json.Unmarshal(respBody, &response) - assert.NoError(t, err) - - assert.Equal(t, "INVOKE", response.AgentEvent.EventType) - } else { - var response model.AgentShutdownEvent - - respBody, _ := io.ReadAll(responseRecorder.Body) - err := json.Unmarshal(respBody, &response) - assert.NoError(t, err) - - assert.Equal(t, "SHUTDOWN", response.AgentEvent.EventType) - } -} - -func containsInvalidEvent(external bool, events []core.Event) bool { - for _, e := range events { - if external { - if err := core.ValidateExternalAgentEvent(e); err != nil { - return true - } - } else if err := core.ValidateInternalAgentEvent(e); err != nil { - return true - } - } - - return false -} diff --git a/lambda/rapi/handler/agentexiterror.go b/lambda/rapi/handler/agentexiterror.go deleted file mode 100644 index 245e155..0000000 --- a/lambda/rapi/handler/agentexiterror.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "net/http" - - "github.com/google/uuid" - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/rapi/rendering" - - log "github.com/sirupsen/logrus" -) - -type agentExitErrorHandler struct { - registrationService core.RegistrationService -} - -func (h *agentExitErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - agentID, ok := request.Context().Value(AgentIDCtxKey).(uuid.UUID) - if !ok { - rendering.RenderInternalServerError(writer, request) - return - } - - var errorType string - if errorType = request.Header.Get(LambdaAgentFunctionErrorType); errorType == "" { - log.Warnf("Invalid /extension/exit/error: missing %s header, agentID: %s", LambdaAgentFunctionErrorType, agentID) - rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentMissingHeader, "%s not found", LambdaAgentFunctionErrorType) - return - } - - if externalAgent, found := h.registrationService.FindExternalAgentByID(agentID); found { - if err := externalAgent.ExitError(errorType); err != nil { - log.Warnf("Failed to transition agent %s to ExitError state: %s, current state %s", externalAgent.String(), err, externalAgent.GetState().Name()) - rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentInvalidState, StateTransitionFailedForExtensionMessageFormat, - externalAgent.GetState().Name(), core.AgentExitedStateName, agentID.String(), err) - return - } - } else if internalAgent, found := h.registrationService.FindInternalAgentByID(agentID); found { - if err := internalAgent.ExitError(errorType); err != nil { - log.Warnf("Failed to transition agent %s to ExitError state: %s, current state %s", internalAgent.String(), err, internalAgent.GetState().Name()) - rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentInvalidState, StateTransitionFailedForExtensionMessageFormat, - internalAgent.GetState().Name(), core.AgentExitedStateName, agentID.String(), err) - return - } - } else { - log.Warnf("Unknown agent %s tried to call /extension/exit/error", agentID) - rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentIdentifierUnknown, "Unknown "+LambdaAgentIdentifier) - return - } - - appctx.StoreFirstFatalError(appctx.FromRequest(request), fatalerror.AgentExitError) - rendering.RenderAccepted(writer, request) -} - -// NewAgentExitErrorHandler returns a new instance of http handler for serving /extension/exit/error -func NewAgentExitErrorHandler(registrationService core.RegistrationService) http.Handler { - return &agentExitErrorHandler{ - registrationService: registrationService, - } -} diff --git a/lambda/rapi/handler/agentiniterror.go b/lambda/rapi/handler/agentiniterror.go deleted file mode 100644 index 0c18622..0000000 --- a/lambda/rapi/handler/agentiniterror.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "net/http" - - "github.com/google/uuid" - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/rapi/rendering" - - log "github.com/sirupsen/logrus" -) - -type agentInitErrorHandler struct { - registrationService core.RegistrationService -} - -func (h *agentInitErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - agentID, ok := request.Context().Value(AgentIDCtxKey).(uuid.UUID) - if !ok { - rendering.RenderInternalServerError(writer, request) - return - } - - var errorType string - if errorType = request.Header.Get(LambdaAgentFunctionErrorType); errorType == "" { - log.Warnf("Invalid /extension/init/error: missing %s header, agentID: %s", LambdaAgentFunctionErrorType, agentID) - rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentMissingHeader, "%s not found", LambdaAgentFunctionErrorType) - return - } - - if externalAgent, found := h.registrationService.FindExternalAgentByID(agentID); found { - if err := externalAgent.InitError(errorType); err != nil { - log.Warnf("InitError() failed for %s: %s, state is %s", externalAgent.String(), err, externalAgent.GetState().Name()) - rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentInvalidState, StateTransitionFailedForExtensionMessageFormat, - externalAgent.GetState().Name(), core.AgentInitErrorStateName, agentID.String(), err) - return - } - } else if internalAgent, found := h.registrationService.FindInternalAgentByID(agentID); found { - if err := internalAgent.InitError(errorType); err != nil { - log.Warnf("InitError() failed for %s: %s, state is %s", internalAgent.String(), err, internalAgent.GetState().Name()) - rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentInvalidState, StateTransitionFailedForExtensionMessageFormat, - internalAgent.GetState().Name(), core.AgentInitErrorStateName, agentID.String(), err) - return - } - } else { - log.Warnf("Unknown agent %s tried to call /extension/init/error", LambdaAgentIdentifier) - rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentIdentifierUnknown, "Unknown "+LambdaAgentIdentifier) - return - } - - appctx.StoreFirstFatalError(appctx.FromRequest(request), fatalerror.AgentInitError) - rendering.RenderAccepted(writer, request) -} - -// NewAgentInitErrorHandler returns a new instance of http handler for serving /extension/init/error -func NewAgentInitErrorHandler(registrationService core.RegistrationService) http.Handler { - return &agentInitErrorHandler{ - registrationService: registrationService, - } -} diff --git a/lambda/rapi/handler/agentiniterror_test.go b/lambda/rapi/handler/agentiniterror_test.go deleted file mode 100644 index 50b9143..0000000 --- a/lambda/rapi/handler/agentiniterror_test.go +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "context" - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "testing" - - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/rapi/model" -) - -func newRequest(appCtx appctx.ApplicationContext, agentID uuid.UUID) *http.Request { - request := httptest.NewRequest("POST", "/", nil) - request = request.WithContext(context.WithValue(context.Background(), AgentIDCtxKey, agentID)) - request = appctx.RequestWithAppCtx(request, appCtx) - request.Header.Set(LambdaAgentFunctionErrorType, "Extension.TestError") - return request -} - -func TestAgentInitErrorInternalError(t *testing.T) { - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - handler := NewAgentInitErrorHandler(registrationService) - request := httptest.NewRequest("POST", "/", nil) - // request is missing agent's UUID context. This should not happen since the middleware validation should have failed - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, request) - assert.Equal(t, http.StatusInternalServerError, responseRecorder.Code) -} - -func TestAgentInitErrorMissingErrorHeader(t *testing.T) { - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - - appCtx := appctx.NewApplicationContext() - agent, err := registrationService.CreateExternalAgent("dummyName") - agent.SetState(agent.RegisteredState) - assert.NoError(t, err) - handler := NewAgentInitErrorHandler(registrationService) - responseRecorder := httptest.NewRecorder() - - req := newRequest(appCtx, uuid.New()) - req.Header.Del(LambdaAgentFunctionErrorType) - handler.ServeHTTP(responseRecorder, req) - assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - - var errorResponse model.ErrorResponse - respBody, _ := io.ReadAll(responseRecorder.Body) - json.Unmarshal(respBody, &errorResponse) - assert.Equal(t, errAgentMissingHeader, errorResponse.ErrorType) -} - -func TestAgentInitErrorUnknownAgent(t *testing.T) { - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - handler := NewAgentInitErrorHandler(registrationService) - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, newRequest(appctx.NewApplicationContext(), uuid.New())) - assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - - var errorResponse model.ErrorResponse - respBody, _ := io.ReadAll(responseRecorder.Body) - json.Unmarshal(respBody, &errorResponse) - assert.Equal(t, errAgentIdentifierUnknown, errorResponse.ErrorType) -} - -func TestAgentInitErrorAgentInvalidState(t *testing.T) { - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - // agent is in started state, it is not allowed to transition to init error - agent, err := registrationService.CreateExternalAgent("dummyName") - assert.NoError(t, err) - handler := NewAgentInitErrorHandler(registrationService) - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, newRequest(appctx.NewApplicationContext(), agent.ID)) - assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - - var errorResponse model.ErrorResponse - respBody, _ := io.ReadAll(responseRecorder.Body) - json.Unmarshal(respBody, &errorResponse) - assert.Equal(t, errAgentInvalidState, errorResponse.ErrorType) -} - -func TestAgentInitErrorRequestAccepted(t *testing.T) { - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - appCtx := appctx.NewApplicationContext() - agent, err := registrationService.CreateExternalAgent("dummyName") - agent.SetState(agent.RegisteredState) - assert.NoError(t, err) - handler := NewAgentInitErrorHandler(registrationService) - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, newRequest(appCtx, agent.ID)) - assert.Equal(t, http.StatusAccepted, responseRecorder.Code) - - var response model.StatusResponse - respBody, _ := io.ReadAll(responseRecorder.Body) - json.Unmarshal(respBody, &response) - assert.Equal(t, "OK", response.Status) - - v, found := appctx.LoadFirstFatalError(appCtx) - assert.True(t, found) - assert.Equal(t, fatalerror.AgentInitError, v) -} diff --git a/lambda/rapi/handler/agentnext.go b/lambda/rapi/handler/agentnext.go deleted file mode 100644 index ffdd61d..0000000 --- a/lambda/rapi/handler/agentnext.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "net/http" - - "github.com/google/uuid" - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/rapi/rendering" -) - -// A CtxKey type is used as a key for storing values in the request context. -type CtxKey int - -// AgentIDCtxKey is the context key for fetching agent's UUID -const ( - AgentIDCtxKey CtxKey = iota -) - -type agentNextHandler struct { - registrationService core.RegistrationService - renderingService *rendering.EventRenderingService -} - -func (h *agentNextHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - agentID, ok := request.Context().Value(AgentIDCtxKey).(uuid.UUID) - if !ok { - rendering.RenderInternalServerError(writer, request) - return - } - - if externalAgent, found := h.registrationService.FindExternalAgentByID(agentID); found { - if err := externalAgent.Ready(); err != nil { - log.Warnf("Ready() failed for %s: %s, state is %s", externalAgent.String(), err, externalAgent.GetState().Name()) - rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentInvalidState, StateTransitionFailedForExtensionMessageFormat, - externalAgent.GetState().Name(), core.AgentReadyStateName, agentID.String(), err) - return - } - } else if internalAgent, found := h.registrationService.FindInternalAgentByID(agentID); found { - if err := internalAgent.Ready(); err != nil { - log.Warnf("Ready() failed for %s: %s, state is %s", internalAgent.String(), err, internalAgent.GetState().Name()) - rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentInvalidState, StateTransitionFailedForExtensionMessageFormat, - internalAgent.GetState().Name(), core.AgentReadyStateName, agentID.String(), err) - return - } - } else { - log.Warnf("Unknown agent %s tried to call /next", agentID.String()) - rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentIdentifierUnknown, "Unknown extension %s", agentID.String()) - return - } - - if err := h.renderingService.RenderAgentEvent(writer, request); err != nil { - log.Error(err) - rendering.RenderInternalServerError(writer, request) - return - } -} - -// NewAgentNextHandler returns a new instance of http handler for serving /extension/event/next -func NewAgentNextHandler(registrationService core.RegistrationService, renderingService *rendering.EventRenderingService) http.Handler { - return &agentNextHandler{ - registrationService: registrationService, - renderingService: renderingService, - } -} diff --git a/lambda/rapi/handler/agentnext_test.go b/lambda/rapi/handler/agentnext_test.go deleted file mode 100644 index 417633e..0000000 --- a/lambda/rapi/handler/agentnext_test.go +++ /dev/null @@ -1,308 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/metering" - "go.amzn.com/lambda/rapi/model" - "go.amzn.com/lambda/rapi/rendering" - "go.amzn.com/lambda/telemetry" -) - -func TestRenderAgentInternalError(t *testing.T) { - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - handler := NewAgentNextHandler(registrationService, rendering.NewRenderingService()) - request := httptest.NewRequest("GET", "/", nil) - // request is missing agent's UUID context. This should not happen since the middleware validation should have failed - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, request) - assert.Equal(t, http.StatusInternalServerError, responseRecorder.Code) -} - -func TestRenderAgentInvokeUnknownAgent(t *testing.T) { - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - request := httptest.NewRequest("GET", "/", nil) - request = request.WithContext(context.WithValue(context.Background(), AgentIDCtxKey, uuid.New())) - responseRecorder := httptest.NewRecorder() - - handler := NewAgentNextHandler(registrationService, rendering.NewRenderingService()) - handler.ServeHTTP(responseRecorder, request) - assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - - var errorResponse model.ErrorResponse - respBody, _ := io.ReadAll(responseRecorder.Body) - json.Unmarshal(respBody, &errorResponse) - assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - assert.Equal(t, errAgentIdentifierUnknown, errorResponse.ErrorType) -} - -func TestRenderAgentInvokeInvalidAgentState(t *testing.T) { - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - - agent, err := registrationService.CreateExternalAgent("dummyName") - assert.NoError(t, err) - handler := NewAgentNextHandler(registrationService, rendering.NewRenderingService()) - request := httptest.NewRequest("GET", "/", nil) - request = request.WithContext(context.WithValue(context.Background(), AgentIDCtxKey, agent.ID)) - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, request) - assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - - var errorResponse model.ErrorResponse - respBody, _ := io.ReadAll(responseRecorder.Body) - json.Unmarshal(respBody, &errorResponse) - assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - assert.Equal(t, errAgentInvalidState, errorResponse.ErrorType) -} - -func TestRenderAgentInvokeNextHappy(t *testing.T) { - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - agent, err := registrationService.CreateExternalAgent("dummyName") - assert.NoError(t, err) - - agent.SetState(agent.RegisteredState) - agent.Release() // sets operator condition to true so that the thread doesn't suspend waiting for invoke request - - deadlineNs := metering.Monotime() + int64(100*time.Millisecond) - requestID, functionArn := "ID", "InvokedFunctionArn" - traceID := "Root=RootID;Parent=LambdaFrontend;Sampled=1" - invoke := &interop.Invoke{ - TraceID: traceID, - ID: requestID, - InvokedFunctionArn: functionArn, - CognitoIdentityID: "CognitoIdentityId1", - CognitoIdentityPoolID: "CognitoIdentityPoolId1", - ClientContext: "ClientContext", - DeadlineNs: fmt.Sprintf("%d", deadlineNs), - ContentType: "image/png", - Payload: strings.NewReader("Payload"), - } - - renderingService := rendering.NewRenderingService() - var buf bytes.Buffer - renderingService.SetRenderer(rendering.NewInvokeRenderer(context.Background(), invoke, &buf, telemetry.NewNoOpTracer().BuildTracingHeader())) - - handler := NewAgentNextHandler(registrationService, renderingService) - request := httptest.NewRequest("GET", "/", nil) - request = request.WithContext(context.WithValue(context.Background(), AgentIDCtxKey, agent.ID)) - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, request) - assert.Equal(t, http.StatusOK, responseRecorder.Code) - var response model.AgentInvokeEvent - respBody, _ := io.ReadAll(responseRecorder.Body) - json.Unmarshal(respBody, &response) - - assert.Equal(t, agent.RunningState, agent.GetState()) - assert.Equal(t, "INVOKE", response.AgentEvent.EventType) - assert.InDelta(t, time.Now().Add(100*time.Millisecond).UnixNano()/int64(time.Millisecond), response.AgentEvent.DeadlineMs, 5) - assert.Equal(t, requestID, response.RequestID) - assert.Equal(t, functionArn, response.InvokedFunctionArn) - assert.Equal(t, model.XRayTracingType, response.Tracing.Type) - assert.Equal(t, traceID, response.Tracing.Value) -} - -func TestRenderAgentInternalInvokeNextHappy(t *testing.T) { - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - agent, err := registrationService.CreateInternalAgent("dummyName") - assert.NoError(t, err) - - agent.SetState(agent.RegisteredState) - agent.Release() // sets operator condition to true so that the thread doesn't suspend waiting for invoke request - - deadlineNs := metering.Monotime() + int64(100*time.Millisecond) - requestID, functionArn := "ID", "InvokedFunctionArn" - traceID := "Root=RootID;Parent=LambdaFrontend;Sampled=1" - invoke := &interop.Invoke{ - TraceID: traceID, - ID: requestID, - InvokedFunctionArn: functionArn, - CognitoIdentityID: "CognitoIdentityId1", - CognitoIdentityPoolID: "CognitoIdentityPoolId1", - ClientContext: "ClientContext", - DeadlineNs: fmt.Sprintf("%d", deadlineNs), - ContentType: "image/png", - Payload: strings.NewReader("Payload"), - } - - renderingService := rendering.NewRenderingService() - var buf bytes.Buffer - renderingService.SetRenderer(rendering.NewInvokeRenderer(context.Background(), invoke, &buf, telemetry.NewNoOpTracer().BuildTracingHeader())) - - handler := NewAgentNextHandler(registrationService, renderingService) - request := httptest.NewRequest("GET", "/", nil) - request = request.WithContext(context.WithValue(context.Background(), AgentIDCtxKey, agent.ID)) - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, request) - assert.Equal(t, http.StatusOK, responseRecorder.Code) - var response model.AgentInvokeEvent - respBody, _ := io.ReadAll(responseRecorder.Body) - json.Unmarshal(respBody, &response) - - assert.Equal(t, agent.RunningState, agent.GetState()) - assert.Equal(t, "INVOKE", response.AgentEvent.EventType) - assert.InDelta(t, time.Now().Add(100*time.Millisecond).UnixNano()/int64(time.Millisecond), response.AgentEvent.DeadlineMs, 5) - assert.Equal(t, requestID, response.RequestID) - assert.Equal(t, functionArn, response.InvokedFunctionArn) - assert.Equal(t, model.XRayTracingType, response.Tracing.Type) - assert.Equal(t, traceID, response.Tracing.Value) -} - -func TestRenderAgentInternalShutdownEvent(t *testing.T) { - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - agent, err := registrationService.CreateInternalAgent("dummyName") - assert.NoError(t, err) - - agent.SetState(agent.RegisteredState) - agent.Release() - - renderingService := rendering.NewRenderingService() - deadlineMs := time.Now().UnixNano() / (1000 * 1000) - shutdownReason := "spindown" - renderingService.SetRenderer( - &rendering.ShutdownRenderer{ - AgentEvent: model.AgentShutdownEvent{ - AgentEvent: &model.AgentEvent{ - EventType: "SHUTDOWN", - DeadlineMs: int64(deadlineMs), - }, - ShutdownReason: shutdownReason, - }, - }) - - handler := NewAgentNextHandler(registrationService, renderingService) - request := httptest.NewRequest("GET", "/", nil) - request = request.WithContext(context.WithValue(context.Background(), AgentIDCtxKey, agent.ID)) - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, request) - assert.Equal(t, http.StatusOK, responseRecorder.Code) - var response model.AgentShutdownEvent - respBody, _ := io.ReadAll(responseRecorder.Body) - json.Unmarshal(respBody, &response) - - assert.Equal(t, agent.RunningState, agent.GetState()) - assert.Equal(t, "SHUTDOWN", response.AgentEvent.EventType) - assert.Equal(t, int64(deadlineMs), response.AgentEvent.DeadlineMs) - assert.Equal(t, shutdownReason, response.ShutdownReason) -} - -func TestRenderAgentExternalShutdownEvent(t *testing.T) { - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - agent, err := registrationService.CreateExternalAgent("dummyName") - assert.NoError(t, err) - - agent.SetState(agent.RegisteredState) - agent.Release() - - renderingService := rendering.NewRenderingService() - deadlineMs := time.Now().UnixNano() / (1000 * 1000) - shutdownReason := "spindown" - renderingService.SetRenderer( - &rendering.ShutdownRenderer{ - AgentEvent: model.AgentShutdownEvent{ - AgentEvent: &model.AgentEvent{ - EventType: "SHUTDOWN", - DeadlineMs: int64(deadlineMs), - }, - ShutdownReason: shutdownReason, - }, - }) - - handler := NewAgentNextHandler(registrationService, renderingService) - request := httptest.NewRequest("GET", "/", nil) - request = request.WithContext(context.WithValue(context.Background(), AgentIDCtxKey, agent.ID)) - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, request) - assert.Equal(t, http.StatusOK, responseRecorder.Code) - var response model.AgentShutdownEvent - respBody, _ := io.ReadAll(responseRecorder.Body) - json.Unmarshal(respBody, &response) - - assert.Equal(t, agent.RunningState, agent.GetState()) - assert.Equal(t, "SHUTDOWN", response.AgentEvent.EventType) - assert.Equal(t, int64(deadlineMs), response.AgentEvent.DeadlineMs) - assert.Equal(t, shutdownReason, response.ShutdownReason) -} - -func TestRenderAgentInvokeNextHappyEmptyTraceID(t *testing.T) { - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - agent, err := registrationService.CreateExternalAgent("dummyName") - assert.NoError(t, err) - - agent.SetState(agent.RegisteredState) - agent.Release() // sets operator condition to true so that the thread doesn't suspend waiting for invoke request - - deadlineNs := metering.Monotime() + int64(100*time.Millisecond) - requestID, functionArn := "ID", "InvokedFunctionArn" - traceID := "" - invoke := &interop.Invoke{ - TraceID: traceID, - ID: requestID, - InvokedFunctionArn: functionArn, - DeadlineNs: fmt.Sprintf("%d", deadlineNs), - ContentType: "image/png", - Payload: strings.NewReader("Payload"), - } - - renderingService := rendering.NewRenderingService() - var buf bytes.Buffer - renderingService.SetRenderer(rendering.NewInvokeRenderer(context.Background(), invoke, &buf, telemetry.NewNoOpTracer().BuildTracingHeader())) - - handler := NewAgentNextHandler(registrationService, renderingService) - request := httptest.NewRequest("GET", "/", nil) - request = request.WithContext(context.WithValue(context.Background(), AgentIDCtxKey, agent.ID)) - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, request) - assert.Equal(t, http.StatusOK, responseRecorder.Code) - var response model.AgentInvokeEvent - respBody, _ := io.ReadAll(responseRecorder.Body) - json.Unmarshal(respBody, &response) - - assert.Nil(t, response.Tracing) -} diff --git a/lambda/rapi/handler/agentregister.go b/lambda/rapi/handler/agentregister.go deleted file mode 100644 index 867ad9d..0000000 --- a/lambda/rapi/handler/agentregister.go +++ /dev/null @@ -1,215 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "encoding/json" - "errors" - "io" - "net/http" - "strings" - - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/rapi/model" - "go.amzn.com/lambda/rapi/rendering" -) - -type agentRegisterHandler struct { - registrationService core.RegistrationService -} - -// RegisterRequest represent /extension/register JSON body -type RegisterRequest struct { - Events []core.Event `json:"events"` -} - -const featuresHeader = "Lambda-Extension-Accept-Feature" - -type registrationFeature int - -const ( - accountFeature registrationFeature = iota + 1 -) - -var allowedFeatures = map[string]registrationFeature{ - "accountId": accountFeature, -} - -type responseModifier func(*model.ExtensionRegisterResponse) - -func parseRegister(request *http.Request) (*RegisterRequest, error) { - body, err := io.ReadAll(request.Body) - if err != nil { - return nil, err - } - - req := struct { - RegisterRequest - ConfigurationKeys []string `json:"configurationKeys"` - }{} - - if err := json.Unmarshal(body, &req); err != nil { - return nil, err - } - - if len(req.ConfigurationKeys) != 0 { - return nil, errors.New("configurationKeys are deprecated; use environment variables instead") - } - - return &req.RegisterRequest, nil -} - -func (h *agentRegisterHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - agentName := request.Header.Get(LambdaAgentName) - if agentName == "" { - rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentNameInvalid, "Empty extension name") - return - } - - var responseModifiers []responseModifier - for _, f := range parseRegistrationFeatures(request) { - if f == accountFeature { - responseModifiers = append(responseModifiers, h.respondWithAccountID()) - } - } - - registerRequest, err := parseRegister(request) - if err != nil { - rendering.RenderForbiddenWithTypeMsg(writer, request, errInvalidRequestFormat, "%s", err.Error()) - return - } - - agent, found := h.registrationService.FindExternalAgentByName(agentName) - if found { - h.registerExternalAgent(agent, registerRequest, writer, request, responseModifiers...) - } else { - h.registerInternalAgent(agentName, registerRequest, writer, request, responseModifiers...) - } -} - -func (h *agentRegisterHandler) respondWithAccountID() responseModifier { - return func(resp *model.ExtensionRegisterResponse) { - resp.AccountID = h.registrationService.GetFunctionMetadata().AccountID - } -} - -func parseRegistrationFeatures(request *http.Request) []registrationFeature { - rawFeatures := strings.Split(request.Header.Get(featuresHeader), ",") - - var features []registrationFeature - for _, feature := range rawFeatures { - feature = strings.TrimSpace(feature) - if v, found := allowedFeatures[feature]; found { - features = append(features, v) - } - } - - return features -} - -func (h *agentRegisterHandler) renderResponse( - agentID string, - writer http.ResponseWriter, - request *http.Request, - respModifiers ...responseModifier, -) { - writer.Header().Set(LambdaAgentIdentifier, agentID) - - metadata := h.registrationService.GetFunctionMetadata() - resp := &model.ExtensionRegisterResponse{ - FunctionVersion: metadata.FunctionVersion, - FunctionName: metadata.FunctionName, - Handler: metadata.Handler, - } - - for _, mod := range respModifiers { - mod(resp) - } - - if err := rendering.RenderJSON(http.StatusOK, writer, request, resp); err != nil { - log.WithError(err).Warn("Error while rendering response") - http.Error(writer, err.Error(), http.StatusInternalServerError) - } -} - -func (h *agentRegisterHandler) registerExternalAgent( - agent *core.ExternalAgent, - registerRequest *RegisterRequest, - writer http.ResponseWriter, - request *http.Request, - respModifiers ...responseModifier, -) { - for _, e := range registerRequest.Events { - if err := core.ValidateExternalAgentEvent(e); err != nil { - log.Warnf("Failed to register %s: event %s: %s", agent.Name, e, err) - rendering.RenderForbiddenWithTypeMsg(writer, request, errInvalidEventType, "%s: %s", e, err) - return - } - } - - if err := agent.Register(registerRequest.Events); err != nil { - log.Warnf("Failed to register %s: %s", agent.String(), err) - rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentInvalidState, StateTransitionFailedForExtensionMessageFormat, - agent.GetState().Name(), core.AgentRegisteredStateName, agent.Name, err) - return - } - - h.renderResponse(agent.ID.String(), writer, request, respModifiers...) - log.Infof("External agent %s registered, subscribed to %v", agent.String(), registerRequest.Events) -} - -func (h *agentRegisterHandler) registerInternalAgent( - agentName string, - registerRequest *RegisterRequest, - writer http.ResponseWriter, - request *http.Request, - respModifiers ...responseModifier, -) { - for _, e := range registerRequest.Events { - if err := core.ValidateInternalAgentEvent(e); err != nil { - log.Warnf("Failed to register %s: event %s: %s", agentName, e, err) - rendering.RenderForbiddenWithTypeMsg(writer, request, errInvalidEventType, "%s: %s", e, err) - return - } - } - - agent, err := h.registrationService.CreateInternalAgent(agentName) - if err != nil { - log.Warnf("Failed to create internal agent %s: %s", agentName, err) - - switch err { - case core.ErrRegistrationServiceOff: - rendering.RenderForbiddenWithTypeMsg(writer, request, - errAgentRegistrationClosed, "Extension registration closed already") - case core.ErrAgentNameCollision: - rendering.RenderForbiddenWithTypeMsg(writer, request, - errAgentInvalidState, "Extension with this name already registered") - case core.ErrTooManyExtensions: - rendering.RenderForbiddenWithTypeMsg(writer, request, - errTooManyExtensions, "Extension limit (%d) reached", core.MaxAgentsAllowed) - default: - rendering.RenderInternalServerError(writer, request) - } - - return - } - - if err := agent.Register(registerRequest.Events); err != nil { - log.Warnf("Failed to register %s: %s", agent.String(), err) - rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentInvalidState, StateTransitionFailedForExtensionMessageFormat, - agent.GetState().Name(), core.AgentRegisteredStateName, agent.Name, err) - return - } - - h.renderResponse(agent.ID.String(), writer, request, respModifiers...) - log.Infof("Internal agent %s registered, subscribed to %v", agent.String(), registerRequest.Events) -} - -// NewAgentRegisterHandler returns a new instance of http handler for serving /extension/register -func NewAgentRegisterHandler(registrationService core.RegistrationService) http.Handler { - return &agentRegisterHandler{ - registrationService: registrationService, - } -} diff --git a/lambda/rapi/handler/agentregister_test.go b/lambda/rapi/handler/agentregister_test.go deleted file mode 100644 index 7370c42..0000000 --- a/lambda/rapi/handler/agentregister_test.go +++ /dev/null @@ -1,397 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "bytes" - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/stretchr/testify/require" - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/rapi/model" -) - -func registerRequestReader(req RegisterRequest) io.Reader { - body, err := json.Marshal(req) - if err != nil { - panic(err) - } - return bytes.NewReader(body) -} - -func TestRenderAgentRegisterInvalidAgentName(t *testing.T) { - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - - handler := NewAgentRegisterHandler(registrationService) - request := httptest.NewRequest("POST", "/extension/register", registerRequestReader(RegisterRequest{})) - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, request) - require.Equal(t, http.StatusForbidden, responseRecorder.Code) - - var errorResponse model.ErrorResponse - respBody, _ := io.ReadAll(responseRecorder.Body) - json.Unmarshal(respBody, &errorResponse) - require.Equal(t, http.StatusForbidden, responseRecorder.Code) - require.Equal(t, errAgentNameInvalid, errorResponse.ErrorType) -} - -func TestRenderAgentRegisterRegistrationClosed(t *testing.T) { - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - registrationService.TurnOff() - - handler := NewAgentRegisterHandler(registrationService) - request := httptest.NewRequest("POST", "/extension/register", registerRequestReader(RegisterRequest{})) - request.Header.Add(LambdaAgentName, "dummyName") - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, request) - require.Equal(t, http.StatusForbidden, responseRecorder.Code) - - var errorResponse model.ErrorResponse - respBody, _ := io.ReadAll(responseRecorder.Body) - json.Unmarshal(respBody, &errorResponse) - require.Equal(t, http.StatusForbidden, responseRecorder.Code) - require.Equal(t, errAgentRegistrationClosed, errorResponse.ErrorType) -} - -func TestRenderAgentRegisterInvalidAgentState(t *testing.T) { - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - - agent, err := registrationService.CreateExternalAgent("dummyName") - require.NoError(t, err) - agent.SetState(agent.RegisteredState) - - handler := NewAgentRegisterHandler(registrationService) - request := httptest.NewRequest("POST", "/extension/register", registerRequestReader(RegisterRequest{})) - request.Header.Add(LambdaAgentName, "dummyName") - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, request) - require.Equal(t, http.StatusForbidden, responseRecorder.Code) - - var errorResponse model.ErrorResponse - respBody, _ := io.ReadAll(responseRecorder.Body) - json.Unmarshal(respBody, &errorResponse) - require.Equal(t, http.StatusForbidden, responseRecorder.Code) - require.Equal(t, errAgentInvalidState, errorResponse.ErrorType) -} - -func registerAgent(t *testing.T, agentName string, events []core.Event, registerHandler http.Handler) { - request := httptest.NewRequest("POST", "/extension/register", registerRequestReader(RegisterRequest{Events: events})) - request.Header.Add(LambdaAgentName, agentName) - responseRecorder := httptest.NewRecorder() - registerHandler.ServeHTTP(responseRecorder, request) - require.Equal(t, http.StatusOK, responseRecorder.Code) -} - -func TestGetSubscribedExternalAgents(t *testing.T) { - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - registrationService.CreateExternalAgent("externalInvokeAgent") - registrationService.CreateExternalAgent("externalShutdownAgent") - - handler := NewAgentRegisterHandler(registrationService) - - registerAgent(t, "externalInvokeAgent", []core.Event{core.InvokeEvent}, handler) - registerAgent(t, "externalShutdownAgent", []core.Event{core.ShutdownEvent}, handler) - registerAgent(t, "internalInvokeAgent", []core.Event{core.InvokeEvent}, handler) - - subscribers := registrationService.GetSubscribedExternalAgents(core.InvokeEvent) - require.Equal(t, 1, len(subscribers)) - require.Equal(t, "externalInvokeAgent", subscribers[0].Name) -} - -func TestInternalAgentShutdownSubscription(t *testing.T) { - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - - request := httptest.NewRequest("POST", "/extension/register", registerRequestReader(RegisterRequest{Events: []core.Event{core.ShutdownEvent}})) - agentName := "internalShutdownAgent" - request.Header.Add(LambdaAgentName, agentName) - - responseRecorder := httptest.NewRecorder() - NewAgentRegisterHandler(registrationService).ServeHTTP(responseRecorder, request) - require.Equal(t, http.StatusForbidden, responseRecorder.Code) - - response := model.ErrorResponse{} - json.Unmarshal(responseRecorder.Body.Bytes(), &response) - require.Equal(t, errInvalidEventType, response.ErrorType) - require.Contains(t, response.ErrorMessage, string(core.ShutdownEvent)) - - _, found := registrationService.FindInternalAgentByName(agentName) - require.False(t, found) - - require.Equal(t, 0, registrationService.CountAgents()) -} - -func TestInternalAgentInvalidEventType(t *testing.T) { - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - - for i := 0; i < 2; i++ { // make the request twice to make sure invalid /register request doesn't change agent state - request := httptest.NewRequest("POST", "/extension/register", registerRequestReader(RegisterRequest{Events: []core.Event{"abcdef"}})) - agentName := "internalShutdownAgent" - request.Header.Add(LambdaAgentName, agentName) - - responseRecorder := httptest.NewRecorder() - NewAgentRegisterHandler(registrationService).ServeHTTP(responseRecorder, request) - require.Equal(t, http.StatusForbidden, responseRecorder.Code) - - response := model.ErrorResponse{} - json.Unmarshal(responseRecorder.Body.Bytes(), &response) - require.Equal(t, errInvalidEventType, response.ErrorType) - require.Contains(t, response.ErrorMessage, "abcdef") - - _, found := registrationService.FindInternalAgentByName(agentName) - require.False(t, found) - - require.Equal(t, 0, registrationService.CountAgents()) - } -} - -func TestExternalAgentInvalidEventType(t *testing.T) { - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - agentName := "ABC" - registrationService.CreateExternalAgent(agentName) - - for i := 0; i < 2; i++ { // make the request twice to make sure invalid /register request doesn't change agent state - request := httptest.NewRequest("POST", "/extension/register", registerRequestReader(RegisterRequest{Events: []core.Event{"abcdef"}})) - request.Header.Add(LambdaAgentName, agentName) - - responseRecorder := httptest.NewRecorder() - NewAgentRegisterHandler(registrationService).ServeHTTP(responseRecorder, request) - require.Equal(t, http.StatusForbidden, responseRecorder.Code) - - response := model.ErrorResponse{} - json.Unmarshal(responseRecorder.Body.Bytes(), &response) - require.Equal(t, errInvalidEventType, response.ErrorType) - require.Contains(t, response.ErrorMessage, "abcdef") - - _, found := registrationService.FindExternalAgentByName(agentName) - require.True(t, found) - - shutdownSubscribers := registrationService.GetSubscribedExternalAgents(core.ShutdownEvent) - require.Equal(t, 0, len(shutdownSubscribers)) - - invokeSubscribers := registrationService.GetSubscribedExternalAgents(core.InvokeEvent) - require.Equal(t, 0, len(invokeSubscribers)) - - require.Equal(t, 1, registrationService.CountAgents()) - } -} - -func TestGetSubscribedInternalAgents(t *testing.T) { - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - registrationService.CreateExternalAgent("externalInvokeAgent") - registrationService.CreateExternalAgent("externalShutdownAgent") - - handler := NewAgentRegisterHandler(registrationService) - - registerAgent(t, "externalInvokeAgent", []core.Event{core.InvokeEvent}, handler) - registerAgent(t, "externalShutdownAgent", []core.Event{core.ShutdownEvent}, handler) - registerAgent(t, "internalInvokeAgent", []core.Event{core.InvokeEvent}, handler) - - subscribers := registrationService.GetSubscribedInternalAgents(core.InvokeEvent) - require.Equal(t, 1, len(subscribers)) - require.Equal(t, "internalInvokeAgent", subscribers[0].Name) -} - -type ExtensionRegisterResponseWithConfig struct { - model.ExtensionRegisterResponse - Configuration map[string]string `json:"configuration"` -} - -func TestRenderAgentResponse(t *testing.T) { - defaultFunctionMetadata := core.FunctionMetadata{ - FunctionVersion: "$LATEST", - FunctionName: "my-func", - Handler: "lambda_handler", - } - - happyPathTests := map[string]struct { - agentName string - external bool - registrationRequest RegisterRequest - featuresHeader string - functionMetadata core.FunctionMetadata - expectedResponse string - }{ - "no-config-internal": { - agentName: "internal", - external: false, - functionMetadata: defaultFunctionMetadata, - registrationRequest: RegisterRequest{}, - expectedResponse: `{ - "functionName": "my-func", - "functionVersion": "$LATEST", - "handler": "lambda_handler" - }`, - }, - "no-config-external": { - agentName: "external", - external: true, - functionMetadata: defaultFunctionMetadata, - registrationRequest: RegisterRequest{}, - expectedResponse: `{ - "functionName": "my-func", - "functionVersion": "$LATEST", - "handler": "lambda_handler" - }`, - }, - "function-md-override": { - agentName: "external", - external: true, - functionMetadata: core.FunctionMetadata{FunctionName: "function-name", FunctionVersion: "1", Handler: "myHandler"}, - registrationRequest: RegisterRequest{}, - expectedResponse: `{ - "functionName": "function-name", - "functionVersion": "1", - "handler": "myHandler" - }`, - }, - "internal with account id feature": { - agentName: "internal", - external: false, - functionMetadata: core.FunctionMetadata{ - FunctionName: "function-name", - FunctionVersion: "1", - Handler: "myHandler", - AccountID: "0123", - }, - featuresHeader: "accountId", - registrationRequest: RegisterRequest{}, - expectedResponse: `{ - "functionName": "function-name", - "functionVersion": "1", - "handler": "myHandler", - "accountId": "0123" - }`, - }, - "external with account id feature": { - agentName: "external", - external: true, - functionMetadata: core.FunctionMetadata{ - FunctionName: "function-name", - FunctionVersion: "1", - Handler: "myHandler", - AccountID: "0123", - }, - featuresHeader: "accountId", - registrationRequest: RegisterRequest{}, - expectedResponse: `{ - "functionName": "function-name", - "functionVersion": "1", - "handler": "myHandler", - "accountId": "0123" - }`, - }, - "with non-existing accept feature": { - agentName: "external", - external: true, - featuresHeader: "some_non_existing_feature,", - functionMetadata: defaultFunctionMetadata, - registrationRequest: RegisterRequest{}, - expectedResponse: `{ - "functionName": "my-func", - "functionVersion": "$LATEST", - "handler": "lambda_handler" - }`, - }, - "account id feature and some non-existing feature": { - agentName: "external", - external: true, - featuresHeader: "some_non_existing_feature,accountId,", - functionMetadata: core.FunctionMetadata{ - FunctionName: "function-name", - FunctionVersion: "1", - Handler: "myHandler", - AccountID: "0123", - }, - registrationRequest: RegisterRequest{}, - expectedResponse: `{ - "functionName": "function-name", - "functionVersion": "1", - "handler": "myHandler", - "accountId": "0123" - }`, - }, - "with empty account id data": { - agentName: "external", - external: true, - featuresHeader: "accountId", - functionMetadata: defaultFunctionMetadata, - registrationRequest: RegisterRequest{}, - expectedResponse: `{ - "functionName": "my-func", - "functionVersion": "$LATEST", - "handler": "lambda_handler" - }`, - }, - } - - for name, tt := range happyPathTests { - t.Run(name, func(t *testing.T) { - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - registrationService.CreateExternalAgent("external") // external agent has to be pre-registered - registrationService.SetFunctionMetadata(tt.functionMetadata) - - handler := NewAgentRegisterHandler(registrationService) - - request := httptest.NewRequest("POST", "/extension/register", registerRequestReader(tt.registrationRequest)) - request.Header.Add(LambdaAgentName, tt.agentName) - if tt.featuresHeader != "" { - request.Header.Add(featuresHeader, tt.featuresHeader) - } - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, request) - assert.Equal(t, http.StatusOK, responseRecorder.Code) - - respBody, err := io.ReadAll(responseRecorder.Body) - require.NoError(t, err) - assert.JSONEq(t, tt.expectedResponse, string(respBody)) - - if tt.external { - agent, found := registrationService.FindExternalAgentByName(tt.agentName) - assert.True(t, found) - assert.Equal(t, agent.RegisteredState, agent.GetState()) - } else { - agent, found := registrationService.FindInternalAgentByName(tt.agentName) - assert.True(t, found) - assert.Equal(t, agent.RegisteredState, agent.GetState()) - } - }) - } -} diff --git a/lambda/rapi/handler/constants.go b/lambda/rapi/handler/constants.go deleted file mode 100644 index 5912d71..0000000 --- a/lambda/rapi/handler/constants.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -const ( - // LambdaAgentIdentifier is the header key for passing agent's id - LambdaAgentIdentifier string = "Lambda-Extension-Identifier" - LambdaAgentFunctionErrorType string = "Lambda-Extension-Function-Error-Type" - // LambdaAgentName is agent name, provided by user (internal agents) or equal to executable basename (external agents) - LambdaAgentName string = "Lambda-Extension-Name" - - ErrAgentIdentifierMissing string = "Extension.MissingExtensionIdentifier" - ErrAgentIdentifierInvalid string = "Extension.InvalidExtensionIdentifier" - - errAgentNameInvalid string = "Extension.InvalidExtensionName" - errAgentRegistrationClosed string = "Extension.RegistrationClosed" - errAgentIdentifierUnknown string = "Extension.UnknownExtensionIdentifier" - errAgentInvalidState string = "Extension.InvalidExtensionState" - errAgentMissingHeader string = "Extension.MissingHeader" - errTooManyExtensions string = "Extension.TooManyExtensions" - errInvalidEventType string = "Extension.InvalidEventType" - errInvalidRequestFormat string = "InvalidRequestFormat" - - StateTransitionFailedForExtensionMessageFormat string = "State transition from %s to %s failed for extension %s. Error: %s" - StateTransitionFailedForRuntimeMessageFormat string = "State transition from %s to %s failed for runtime. Error: %s" -) diff --git a/lambda/rapi/handler/credentials.go b/lambda/rapi/handler/credentials.go deleted file mode 100644 index f1536c4..0000000 --- a/lambda/rapi/handler/credentials.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "encoding/json" - "fmt" - "net/http" - - log "github.com/sirupsen/logrus" - - "go.amzn.com/lambda/core" -) - -type credentialsHandler struct { - credentialsService core.CredentialsService -} - -func (h *credentialsHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - token := request.Header.Get("Authorization") - - credentials, err := h.credentialsService.GetCredentials(token) - - if err != nil { - errorMsg := "cannot get credentials for the provided token" - log.WithError(err).Error(errorMsg) - http.Error(writer, errorMsg, http.StatusNotFound) - return - } - - jsonResponse, _ := json.Marshal(*credentials) - fmt.Fprint(writer, string(jsonResponse)) -} - -func NewCredentialsHandler(credentialsService core.CredentialsService) http.Handler { - return &credentialsHandler{ - credentialsService: credentialsService, - } -} diff --git a/lambda/rapi/handler/credentials_test.go b/lambda/rapi/handler/credentials_test.go deleted file mode 100644 index d5a1090..0000000 --- a/lambda/rapi/handler/credentials_test.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "encoding/json" - "log" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/testdata" -) - -const InitCachingToken = "sampleInitCachingToken" -const InitCachingAwsKey = "sampleAwsKey" -const InitCachingAwsSecret = "sampleAwsSecret" -const InitCachingAwsSessionToken = "sampleAwsSessionToken" - -func getRequestContext() (http.Handler, *http.Request, *httptest.ResponseRecorder) { - flowTest := testdata.NewFlowTest() - - flowTest.ConfigureForInitCaching(InitCachingToken, InitCachingAwsKey, InitCachingAwsSecret, InitCachingAwsSessionToken) - - handler := NewCredentialsHandler(flowTest.CredentialsService) - responseRecorder := httptest.NewRecorder() - appCtx := flowTest.AppCtx - - request := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) - - return handler, request, responseRecorder -} - -func TestEmptyAuthorizationHeader(t *testing.T) { - handler, request, responseRecorder := getRequestContext() - - handler.ServeHTTP(responseRecorder, request) - assert.Equal(t, http.StatusNotFound, responseRecorder.Code) -} - -func TestArbitraryAuthorizationHeader(t *testing.T) { - handler, request, responseRecorder := getRequestContext() - request.Header.Set("Authorization", "randomAuthToken") - - handler.ServeHTTP(responseRecorder, request) - assert.Equal(t, http.StatusNotFound, responseRecorder.Code) -} - -func TestSuccessfulGet(t *testing.T) { - handler, request, responseRecorder := getRequestContext() - request.Header.Set("Authorization", InitCachingToken) - - handler.ServeHTTP(responseRecorder, request) - - var responseMap map[string]string - json.Unmarshal(responseRecorder.Body.Bytes(), &responseMap) - assert.Equal(t, InitCachingAwsKey, responseMap["AccessKeyId"]) - assert.Equal(t, InitCachingAwsSecret, responseMap["SecretAccessKey"]) - assert.Equal(t, InitCachingAwsSessionToken, responseMap["Token"]) - - expirationTime, err := time.Parse(time.RFC3339, responseMap["Expiration"]) - assert.NoError(t, err) - durationUntilExpiration := time.Until(expirationTime) - assert.True(t, durationUntilExpiration.Minutes() <= 30 && durationUntilExpiration.Minutes() > 29 && durationUntilExpiration.Hours() < 1) - log.Println(responseRecorder.Body.String()) -} diff --git a/lambda/rapi/handler/initerror.go b/lambda/rapi/handler/initerror.go deleted file mode 100644 index 79daa1f..0000000 --- a/lambda/rapi/handler/initerror.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "encoding/json" - "io" - "net/http" - - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/interop" - - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/rapi/rendering" - - log "github.com/sirupsen/logrus" -) - -type initErrorHandler struct { - registrationService core.RegistrationService -} - -func (h *initErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - appCtx := appctx.FromRequest(request) - interopServer := appctx.LoadInteropServer(appCtx) - if interopServer == nil { - log.Panic("Invalid state, cannot access interop server") - } - - errorType := fatalerror.GetValidRuntimeOrFunctionErrorType(request.Header.Get("Lambda-Runtime-Function-Error-Type")) - fnError := interop.FunctionError{Type: errorType} - errorBody, err := io.ReadAll(request.Body) - if err != nil { - log.WithError(err).Warn("Failed to read error body") - } - headers := interop.InvokeResponseHeaders{ContentType: determineJSONContentType(errorBody)} - response := &interop.ErrorInvokeResponse{Headers: headers, FunctionError: fnError, Payload: errorBody} - - runtime := h.registrationService.GetRuntime() - - // remove once Languages team change the endpoint to /restore/error - // when an exception is throw while executing the restore hooks - if runtime.GetState() == runtime.RuntimeRestoringState { - if err := runtime.RestoreError(fnError); err != nil { - log.Warn(err) - rendering.RenderForbiddenWithTypeMsg(writer, request, rendering.ErrorTypeInvalidStateTransition, StateTransitionFailedForRuntimeMessageFormat, - runtime.GetState().Name(), core.RuntimeRestoreErrorStateName, err) - return - } - - appctx.StoreInvokeErrorTraceData(appCtx, &interop.InvokeErrorTraceData{}) - rendering.RenderAccepted(writer, request) - return - } - - if err := runtime.InitError(); err != nil { - log.Warn(err) - rendering.RenderForbiddenWithTypeMsg(writer, request, rendering.ErrorTypeInvalidStateTransition, StateTransitionFailedForRuntimeMessageFormat, - runtime.GetState().Name(), core.RuntimeInitErrorStateName, err) - return - } - - if err := interopServer.SendInitErrorResponse(response); err != nil { - rendering.RenderInteropError(writer, request, err) - return - } - - appctx.StoreInvokeErrorTraceData(appCtx, &interop.InvokeErrorTraceData{}) - rendering.RenderAccepted(writer, request) -} - -// NewInitErrorHandler returns a new instance of http handler -// for serving /runtime/init/error. -func NewInitErrorHandler(registrationService core.RegistrationService) http.Handler { - return &initErrorHandler{registrationService: registrationService} -} - -func determineJSONContentType(body []byte) string { - if json.Valid(body) { - return "application/json" - } - return "application/octet-stream" -} diff --git a/lambda/rapi/handler/initerror_test.go b/lambda/rapi/handler/initerror_test.go deleted file mode 100644 index a9c4b94..0000000 --- a/lambda/rapi/handler/initerror_test.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "bytes" - "fmt" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/require" - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/testdata" -) - -// TestInitErrorHandler tests that API handler for -// initialization-time errors receives and passes -// information through to the Slicer unmodified. -func TestInitErrorHandler(t *testing.T) { - t.Run("GA", func(t *testing.T) { runTestInitErrorHandler(t) }) -} - -func runTestInitErrorHandler(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - - handler := NewInitErrorHandler(flowTest.RegistrationService) - responseRecorder := httptest.NewRecorder() - appCtx := flowTest.AppCtx - - // Error request, as submitted by custom runtime. - errorBody := []byte("My byte array is yours") - errorType := "ErrorType" - errorContentType := "application/MyBinaryType" - - request := appctx.RequestWithAppCtx(httptest.NewRequest("POST", "/", bytes.NewReader(errorBody)), appCtx) - request.Header.Set("Content-Type", errorContentType) - request.Header.Set("Lambda-runtime-functioN-erroR-typE", errorType) // Headers are case-insensitive anyway ! - - // Submit ! - handler.ServeHTTP(responseRecorder, request) - - // Assertions - - // Validate response sent to the runtime. - require.Equal(t, http.StatusAccepted, responseRecorder.Code, "Handler returned wrong status code: got %v expected %v", - responseRecorder.Code, http.StatusAccepted) - require.JSONEq(t, fmt.Sprintf("{\"status\":\"%s\"}\n", "OK"), responseRecorder.Body.String()) - require.Equal(t, "application/json", responseRecorder.Header().Get("Content-Type")) - - // Validate init error persisted in the application context. - errorResponse := flowTest.InteropServer.ErrorResponse - require.NotNil(t, errorResponse) - require.Nil(t, flowTest.InteropServer.Response) - - // Slicer falls back to using ErrorMessage when error - // payload is not provided. This fallback is not part - // of the RAPID API spec and is not available to - // customers. - require.Equal(t, "", errorResponse.FunctionError.Message) - - // Slicer falls back to using ErrorType when error - // payload is not provided. Customers can set error - // type via header to use this fallback. - require.Equal(t, fatalerror.RuntimeUnknown, errorResponse.FunctionError.Type) - - // Payload is arbitrary data that customers submit - it's error response body. - require.Equal(t, errorBody, errorResponse.Payload) -} diff --git a/lambda/rapi/handler/invocationerror.go b/lambda/rapi/handler/invocationerror.go deleted file mode 100644 index d434461..0000000 --- a/lambda/rapi/handler/invocationerror.go +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapi/model" - - "go.amzn.com/lambda/appctx" - - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/rapi/rendering" - - "github.com/go-chi/chi" - log "github.com/sirupsen/logrus" -) - -const errorWithCauseContentType = "application/vnd.aws.lambda.error.cause+json" -const xrayErrorCauseHeaderName = "Lambda-Runtime-Function-XRay-Error-Cause" -const invalidErrorBodyMessage = "Invalid error body" - -const ( - contentTypeHeader = "Content-Type" - functionResponseModeHeader = "Lambda-Runtime-Function-Response-Mode" -) - -type invocationErrorHandler struct { - registrationService core.RegistrationService -} - -func (h *invocationErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - appCtx := appctx.FromRequest(request) - - server := appctx.LoadResponseSender(appCtx) - if server == nil { - log.Panic("Invalid state, cannot access interop server") - } - - runtime := h.registrationService.GetRuntime() - if err := runtime.InvocationErrorResponse(); err != nil { - log.Warn(err) - rendering.RenderForbiddenWithTypeMsg(writer, request, rendering.ErrorTypeInvalidStateTransition, StateTransitionFailedForRuntimeMessageFormat, - runtime.GetState().Name(), core.RuntimeInvocationErrorResponseStateName, err) - return - } - - errorType := fatalerror.GetValidRuntimeOrFunctionErrorType(h.getErrorType(request.Header)) - - var errorCause json.RawMessage - var errorBody []byte - var contentType string - var err error - - switch request.Header.Get(contentTypeHeader) { - case errorWithCauseContentType: - errorBody, errorCause, err = h.getErrorBodyForErrorCauseContentType(request) - contentType = "application/json" - if err != nil { - contentType = "application/octet-stream" - } - default: - errorBody, err = h.getErrorBody(request) - errorCause = h.getValidatedErrorCause(request.Header) - contentType = request.Header.Get(contentTypeHeader) - } - functionResponseMode := request.Header.Get(functionResponseModeHeader) - - if err != nil { - log.WithError(err).Warn("Failed to parse error body") - } - - headers := interop.InvokeResponseHeaders{ - ContentType: contentType, - FunctionResponseMode: functionResponseMode, - } - - response := &interop.ErrorInvokeResponse{ - Headers: headers, - FunctionError: interop.FunctionError{Type: errorType}, - Payload: errorBody, - } - - if err := server.SendErrorResponse(chi.URLParam(request, "awsrequestid"), response); err != nil { - rendering.RenderInteropError(writer, request, err) - return - } - - appctx.StoreInvokeErrorTraceData(appCtx, &interop.InvokeErrorTraceData{ErrorCause: errorCause}) - - if err := runtime.ResponseSent(); err != nil { - log.Panic(err) - } - - rendering.RenderAccepted(writer, request) -} - -func (h *invocationErrorHandler) getErrorType(headers http.Header) string { - return headers.Get("Lambda-Runtime-Function-Error-Type") -} - -func (h *invocationErrorHandler) getErrorBody(request *http.Request) ([]byte, error) { - errorBody, err := io.ReadAll(request.Body) - if err != nil { - return nil, fmt.Errorf("error reading request body: %s", err) - } - return errorBody, nil -} - -func (h *invocationErrorHandler) getValidatedErrorCause(headers http.Header) json.RawMessage { - errorCauseHeader := headers.Get(xrayErrorCauseHeaderName) - if len(errorCauseHeader) == 0 { - return nil - } - - errorCauseJSON := json.RawMessage(errorCauseHeader) - - validErrorCauseJSON, err := model.ValidatedErrorCauseJSON(errorCauseJSON) - if err != nil { - log.WithError(err).Error("errorCause validation error") - return nil - } - - return validErrorCauseJSON -} - -func (h *invocationErrorHandler) getErrorBodyForErrorCauseContentType(request *http.Request) ([]byte, json.RawMessage, error) { - errorBody, err := io.ReadAll(request.Body) - if err != nil { - return nil, nil, fmt.Errorf("error reading request body: %s", err) - } - - parsedError, err := newErrorWithCauseRequest(errorBody) - if err != nil { - errResponse, _ := json.Marshal(invalidErrorBodyMessage) - return errResponse, nil, fmt.Errorf("error parsing request body: %s, request.Body: %s", err, errorBody) - } - - filteredError, err := parsedError.getInvokeErrorResponse() - - return filteredError, parsedError.getValidatedXRayCause(), err -} - -// NewInvocationErrorHandler returns a new instance of http handler -// for serving /runtime/invocation/{awsrequestid}/error. -func NewInvocationErrorHandler(registrationService core.RegistrationService) http.Handler { - return &invocationErrorHandler{ - registrationService: registrationService, - } -} diff --git a/lambda/rapi/handler/invocationerror_test.go b/lambda/rapi/handler/invocationerror_test.go deleted file mode 100644 index 72e6719..0000000 --- a/lambda/rapi/handler/invocationerror_test.go +++ /dev/null @@ -1,425 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapi/model" - "go.amzn.com/lambda/testdata" - - "github.com/go-chi/chi" - "github.com/stretchr/testify/assert" -) - -// TestInvocationErrorHandler tests that API handler for -// invocation-time errors receives and passes information -// through to the Slicer unmodified. -func TestInvocationErrorHandler(t *testing.T) { - t.Run("GA", func(t *testing.T) { runTestInvocationErrorHandler(t) }) -} - -func addInvocationID(r *http.Request, invokeID string) *http.Request { - rctx := chi.NewRouteContext() - rctx.URLParams.Add("awsrequestid", invokeID) - return r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx)) -} - -func runTestInvocationErrorHandler(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - flowTest.Runtime.Ready() - handler := NewInvocationErrorHandler(flowTest.RegistrationService) - responseRecorder := httptest.NewRecorder() - appCtx := flowTest.AppCtx - - // Invoke that we are sending response for must be placed into appCtx. - invoke := &interop.Invoke{ - ID: "InvocationID1", - InvokedFunctionArn: "arn::dummy1", - CognitoIdentityID: "CognitoidentityID1", - CognitoIdentityPoolID: "CognitoidentityPollID1", - DeadlineNs: "deadlinens1", - ClientContext: "clientcontext1", - ContentType: "image/png", - Payload: strings.NewReader("Payload1"), - } - - flowTest.ConfigureForInvoke(context.Background(), invoke) - - // Error request, as submitted by custom runtime. - errorBody := []byte("My byte array is yours") - errorType := "ErrorType" - errorContentType := "application/MyBinaryType" - - request := appctx.RequestWithAppCtx(httptest.NewRequest("POST", "/", bytes.NewReader(errorBody)), appCtx) - request = addInvocationID(request, invoke.ID) - - request.Header.Set("Content-Type", errorContentType) - request.Header.Set("Lambda-runtime-functioN-erroR-typE", errorType) // Headers are case-insensitive anyway ! - - // Submit ! - handler.ServeHTTP(responseRecorder, request) - - // Assertions - - // Validate response sent to the runtime. - assert.Equal(t, http.StatusAccepted, responseRecorder.Code, "Handler returned wrong status code: got %v expected %v", - responseRecorder.Code, http.StatusAccepted) - assert.JSONEq(t, fmt.Sprintf("{\"status\":\"%s\"}\n", "OK"), responseRecorder.Body.String()) - assert.Equal(t, "application/json", responseRecorder.Header().Get("Content-Type")) - - errorResponse := flowTest.InteropServer.ErrorResponse - assert.NotNil(t, errorResponse) - assert.Nil(t, flowTest.InteropServer.Response) - - // Slicer falls back to using ErrorMessage when error - // payload is not provided. This fallback is not part - // of the RAPID API spec and is not available to - // customers. - assert.Equal(t, "", errorResponse.FunctionError.Message) - - // Slicer falls back to using ErrorType when error - // payload is not provided. Customers can set error - // type header to use this fallback. - assert.Equal(t, fatalerror.RuntimeUnknown, errorResponse.FunctionError.Type) - - // Payload is arbitrary data that customers submit - it's error response body. - assert.Equal(t, errorBody, errorResponse.Payload) -} - -func TestInvocationErrorHandlerRemovesErrorCauseFromResponse(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - flowTest.Runtime.Ready() - handler := NewInvocationErrorHandler(flowTest.RegistrationService) - responseRecorder := httptest.NewRecorder() - appCtx := flowTest.AppCtx - - invoke := &interop.Invoke{ID: "InvocationID1"} - - flowTest.ConfigureForInvoke(context.Background(), invoke) - - // Error request, as submitted by custom runtime. - errMsg, errType := "foo", "foo" - - errorCause := json.RawMessage(`{"paths":[],"working_directory":[],"exceptions":[]}`) - errorWithCause := errorWithCauseRequest{ - ErrorMessage: errMsg, - ErrorType: errType, - ErrorCause: errorCause, - } - - requestBody, err := json.Marshal(errorWithCause) - assert.NoError(t, err, "error while creating test request") - - errorContentType := errorWithCauseContentType - - request := appctx.RequestWithAppCtx(httptest.NewRequest("POST", "/", bytes.NewReader(requestBody)), appCtx) - request = addInvocationID(request, invoke.ID) - request.Header.Set("Content-Type", errorContentType) - - handler.ServeHTTP(responseRecorder, request) - - expectedResponsePayload := []byte(fmt.Sprintf(`{"errorMessage":"%s","errorType":"%s"}`, errMsg, errType)) - - errorResponse := flowTest.InteropServer.ErrorResponse - assert.NotNil(t, errorResponse) - assert.Nil(t, flowTest.InteropServer.Response) - - // Payload is arbitrary data that customers submit - it's error response body. - assert.JSONEq(t, string(expectedResponsePayload), string(errorResponse.Payload)) -} - -////////////////////////////////////////////// -///// Tests for error.cause Content-Type ///// -////////////////////////////////////////////// - -func TestInvocationErrorHandlerSendsErrorCauseToXRayForContentTypeErrorCause(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - flowTest.Runtime.Ready() - handler := NewInvocationErrorHandler(flowTest.RegistrationService) - responseRecorder := httptest.NewRecorder() - appCtx := flowTest.AppCtx - - errorCause := json.RawMessage(`{"paths":[],"working_directory":"/foo/bar/baz","exceptions":[]}`) - errorWithCause := errorWithCauseRequest{ - ErrorMessage: "foo", - ErrorType: "bar", - ErrorCause: errorCause, - } - - requestBody, err := json.Marshal(errorWithCause) - assert.NoError(t, err, "error while creating test request") - - invoke := &interop.Invoke{TraceID: "Root=TraceID;Parent=ParentID;Sampled=1", ID: "InvokeID"} - - request := httptest.NewRequest("POST", "/", bytes.NewReader(requestBody)) - request = addInvocationID(request, invoke.ID) - request.Header.Set("Content-Type", errorWithCauseContentType) - - // Corresponding invoke must be placed into appCtx. - flowTest.ConfigureForInvoke(context.Background(), invoke) - - // Run - handler.ServeHTTP(responseRecorder, appctx.RequestWithAppCtx(request, appCtx)) - - // Assert error response contains error cause - invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) - assert.NotNil(t, invokeErrorTraceData) - assert.Nil(t, flowTest.InteropServer.Response) - assert.JSONEq(t, string(errorCause), string(invokeErrorTraceData.ErrorCause)) -} - -func TestInvocationErrorHandlerSendsNullErrorCauseWhenErrorCauseFormatIsInvalidOrEmptyForContentTypeErrorCause(t *testing.T) { - causes := []json.RawMessage{ - json.RawMessage(`{"foobar":"baz"}`), - json.RawMessage(`""`), - } - for _, cause := range causes { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - flowTest.Runtime.Ready() - appCtx := flowTest.AppCtx - - errorWithCause := errorWithCauseRequest{ - ErrorMessage: "foo", - ErrorType: "bar", - ErrorCause: json.RawMessage(cause), - } - - requestBody, err := json.Marshal(errorWithCause) - assert.NoError(t, err, "error while creating test request") - - invoke := &interop.Invoke{TraceID: "Root=TraceID;Parent=ParentID;Sampled=1", ID: "InvokeID"} - request := httptest.NewRequest("POST", "/", bytes.NewReader(requestBody)) - request = addInvocationID(request, invoke.ID) - request.Header.Set("Content-Type", errorWithCauseContentType) - - // Corresponding invoke must be placed into appCtx. - flowTest.ConfigureForInvoke(context.Background(), invoke) - - // Run - NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - - invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) - assert.NotNil(t, invokeErrorTraceData) - assert.Nil(t, flowTest.InteropServer.Response) - assert.Equal(t, json.RawMessage(nil), invokeErrorTraceData.ErrorCause) - } -} - -func TestInvocationErrorHandlerSendsCompactedErrorCauseWhenErrorCauseIsTooLargeForContentTypeErrorCause(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - flowTest.Runtime.Ready() - appCtx := flowTest.AppCtx - - cause := json.RawMessage(`{"working_directory": "` + strings.Repeat(`a`, model.MaxErrorCauseSizeBytes+1) + `"}`) - - errorWithCause := errorWithCauseRequest{ - ErrorMessage: "foo", - ErrorType: "bar", - ErrorCause: json.RawMessage(cause), - } - - requestBody, err := json.Marshal(errorWithCause) - assert.NoError(t, err, "error while creating test request") - - invoke := &interop.Invoke{TraceID: "Root=TraceID;Parent=ParentID;Sampled=1", ID: "InvokeID"} - request := httptest.NewRequest("POST", "/", bytes.NewReader(requestBody)) - request = addInvocationID(request, invoke.ID) - request.Header.Set("Content-Type", errorWithCauseContentType) - - // Corresponding invoke must be placed into appCtx. - flowTest.ConfigureForInvoke(context.Background(), invoke) - - // Run - NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - - invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) - assert.NotNil(t, invokeErrorTraceData) - assert.Nil(t, flowTest.InteropServer.Response) - - errorCauseJSON, err := model.ValidatedErrorCauseJSON(invokeErrorTraceData.ErrorCause) - assert.NoError(t, err, "expected cause sent x-ray to be valid") - assert.True(t, len(errorCauseJSON) < model.MaxErrorCauseSizeBytes, "expected cause to be compacted to size") -} - -func TestInvocationResponsePayloadIsDefaultErrorMessageWhenRequestParsingFailsForContentTypeErrorCause(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - flowTest.Runtime.Ready() - appCtx := flowTest.AppCtx - - invalidRequestBody := json.RawMessage(`{"foo":bar}`) - - invoke := &interop.Invoke{TraceID: "Root=TraceID;Parent=ParentID;Sampled=1", ID: "InvokeID"} - request := httptest.NewRequest("POST", "/", bytes.NewReader(invalidRequestBody)) - request = addInvocationID(request, invoke.ID) - request.Header.Set(contentTypeHeader, errorWithCauseContentType) - request.Header.Set(functionResponseModeHeader, "function-response-mode") - - // Corresponding invoke must be placed into appCtx. - flowTest.ConfigureForInvoke(context.Background(), invoke) - - // Run - NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - - invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) - assert.NotNil(t, invokeErrorTraceData) - assert.Nil(t, flowTest.InteropServer.Response) - assert.Equal(t, "application/octet-stream", flowTest.InteropServer.ResponseContentType) - assert.Equal(t, "function-response-mode", flowTest.InteropServer.FunctionResponseMode) - - errorResponse := flowTest.InteropServer.ErrorResponse - invokeResponsePayload := errorResponse.Payload - - expectedResponse, _ := json.Marshal(invalidErrorBodyMessage) - assert.Equal(t, invokeResponsePayload, expectedResponse) -} - -////////////////////////////////////////////// -///// Tests for X-Ray Error-Cause header ///// -////////////////////////////////////////////// - -func TestInvocationErrorHandlerSendsErrorCauseToXRayWhenXRayErrorCauseHeaderIsSet(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - flowTest.Runtime.Ready() - appCtx := flowTest.AppCtx - - invoke := &interop.Invoke{TraceID: "Root=TraceID;Parent=ParentID;Sampled=1", ID: "InvokeID"} - request := httptest.NewRequest("POST", "/", bytes.NewReader([]byte(`foo doesn't matter`))) - request = addInvocationID(request, invoke.ID) - errorCause := json.RawMessage(`{"paths":[],"working_directory":"/foo/bar/baz","exceptions":[]}`) - request.Header.Set(xrayErrorCauseHeaderName, string(errorCause)) - - // Corresponding invoke must be placed into appCtx. - flowTest.ConfigureForInvoke(context.Background(), invoke) - - // Run - NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - - invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) - assert.NotNil(t, invokeErrorTraceData) - assert.Nil(t, flowTest.InteropServer.Response) - assert.JSONEq(t, string(errorCause), string(invokeErrorTraceData.ErrorCause)) -} - -func TestInvocationErrorHandlerSendsNilCauseToXRayWhenXRayErrorCauseHeaderContainsInvalidCause(t *testing.T) { - invalidCauses := []json.RawMessage{ - json.RawMessage(`{invalid:json}`), - json.RawMessage(``), - } - - for _, errorCause := range invalidCauses { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - flowTest.Runtime.Ready() - appCtx := flowTest.AppCtx - - invoke := &interop.Invoke{TraceID: "Root=TraceID;Parent=ParentID;Sampled=1", ID: "InvokeID"} - request := httptest.NewRequest("POST", "/", bytes.NewReader([]byte(`foo doesn't matter`))) - request = addInvocationID(request, invoke.ID) - request.Header.Set(xrayErrorCauseHeaderName, string(errorCause)) - - // Corresponding invoke must be placed into appCtx. - flowTest.ConfigureForInvoke(context.Background(), invoke) - - // Run - NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - - invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) - assert.NotNil(t, invokeErrorTraceData) - assert.Nil(t, flowTest.InteropServer.Response) - assert.Equal(t, json.RawMessage(nil), invokeErrorTraceData.ErrorCause) - } -} - -func TestInvocationErrorHandlerSendsCompactedErrorCauseToXRayWhenXRayErrorCauseInHeaderIsTooLarge(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - flowTest.Runtime.Ready() - appCtx := flowTest.AppCtx - - invoke := &interop.Invoke{TraceID: "Root=TraceID;Parent=ParentID;Sampled=1", ID: "InvokeID"} - request := httptest.NewRequest("POST", "/", bytes.NewReader([]byte(`foo doesn't matter`))) - request = addInvocationID(request, invoke.ID) - - errorCause := json.RawMessage(`{"working_directory": "` + strings.Repeat(`a`, model.MaxErrorCauseSizeBytes+1) + `"}`) - request.Header.Set(xrayErrorCauseHeaderName, string(errorCause)) - - // Corresponding invoke must be placed into appCtx. - flowTest.ConfigureForInvoke(context.Background(), invoke) - - // Run - NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - - invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) - assert.NotNil(t, invokeErrorTraceData) - assert.Nil(t, flowTest.InteropServer.Response) - - errorCauseJSON, err := model.ValidatedErrorCauseJSON(invokeErrorTraceData.ErrorCause) - assert.NoError(t, err, "expected cause sent x-ray to be valid") - assert.True(t, len(errorCauseJSON) < model.MaxErrorCauseSizeBytes, "expected cause to be compacted to size") -} - -func TestInvocationErrorHandlerSendsNilToXRayWhenXRayErrorCauseHeaderIsNotSet(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - flowTest.Runtime.Ready() - appCtx := flowTest.AppCtx - - invoke := &interop.Invoke{TraceID: "Root=TraceID;Parent=ParentID;Sampled=1", ID: "InvokeID"} - request := httptest.NewRequest("POST", "/", bytes.NewReader([]byte(`foo doesn't matter`))) - request = addInvocationID(request, invoke.ID) - - // Corresponding invoke must be placed into appCtx. - flowTest.ConfigureForInvoke(context.Background(), invoke) - - // Run - NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - - invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) - assert.NotNil(t, invokeErrorTraceData) - assert.Nil(t, flowTest.InteropServer.Response) - assert.Nil(t, invokeErrorTraceData.ErrorCause) -} - -func TestInvocationErrorHandlerSendsErrorCauseToXRayWhenXRayErrorCauseContainsUTF8Characters(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - flowTest.Runtime.Ready() - appCtx := flowTest.AppCtx - - invoke := &interop.Invoke{TraceID: "Root=TraceID;Parent=ParentID;Sampled=1", ID: "InvokeID"} - request := httptest.NewRequest("POST", "/", bytes.NewReader([]byte(`foo doesn't matter`))) - request = addInvocationID(request, invoke.ID) - - errorCause := json.RawMessage(`{"exceptions":[],"working_directory":"κόσμε","paths":[]}`) - request.Header.Set(xrayErrorCauseHeaderName, string(errorCause)) - - // Corresponding invoke must be placed into appCtx. - flowTest.ConfigureForInvoke(context.Background(), invoke) - - // Run - NewInvocationErrorHandler(flowTest.RegistrationService).ServeHTTP(httptest.NewRecorder(), appctx.RequestWithAppCtx(request, appCtx)) - - invokeErrorTraceData := appctx.LoadInvokeErrorTraceData(appCtx) - assert.NotNil(t, invokeErrorTraceData) - assert.Nil(t, flowTest.InteropServer.Response) - assert.JSONEq(t, string(errorCause), string(invokeErrorTraceData.ErrorCause)) -} diff --git a/lambda/rapi/handler/invocationnext.go b/lambda/rapi/handler/invocationnext.go deleted file mode 100644 index f8fc89c..0000000 --- a/lambda/rapi/handler/invocationnext.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "net/http" - - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/rapi/rendering" - - log "github.com/sirupsen/logrus" -) - -type invocationNextHandler struct { - registrationService core.RegistrationService - renderingService *rendering.EventRenderingService -} - -func (h *invocationNextHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - runtime := h.registrationService.GetRuntime() - err := runtime.Ready() - if err != nil { - log.Warn(err) - rendering.RenderForbiddenWithTypeMsg(writer, request, rendering.ErrorTypeInvalidStateTransition, StateTransitionFailedForRuntimeMessageFormat, - runtime.GetState().Name(), core.RuntimeReadyStateName, err) - return - } - err = h.renderingService.RenderRuntimeEvent(writer, request) - if err != nil { - log.Error(err) - rendering.RenderInternalServerError(writer, request) - return - } -} - -// NewInvocationNextHandler returns a new instance of http handler -// for serving /runtime/invocation/next. -func NewInvocationNextHandler(registrationService core.RegistrationService, renderingService *rendering.EventRenderingService) http.Handler { - return &invocationNextHandler{ - registrationService: registrationService, - renderingService: renderingService, - } -} diff --git a/lambda/rapi/handler/invocationnext_test.go b/lambda/rapi/handler/invocationnext_test.go deleted file mode 100644 index 64ae057..0000000 --- a/lambda/rapi/handler/invocationnext_test.go +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "bytes" - "context" - "errors" - "fmt" - "log" - "math" - "net/http" - "net/http/httptest" - "os" - "runtime" - "strconv" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/metering" - "go.amzn.com/lambda/telemetry" - "go.amzn.com/lambda/testdata" -) - -// TestRenderInvokeEmptyHeaders tests that headers -// are not rendered when not set. -func TestRenderInvokeEmptyHeaders(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - handler := NewInvocationNextHandler(flowTest.RegistrationService, flowTest.RenderingService) - responseRecorder := httptest.NewRecorder() - appCtx := flowTest.AppCtx - - flowTest.ConfigureForInvoke(context.Background(), &interop.Invoke{}) - request := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) - handler.ServeHTTP(responseRecorder, request) - - headers := responseRecorder.Header() - assert.Equal(t, "application/json", headers.Get("Content-Type")) - assert.Len(t, headers, 1) - assert.Equal(t, http.StatusOK, responseRecorder.Code) -} - -func TestRenderInvokeHappy(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - appCtx := flowTest.AppCtx - - deadlineNs := 12345 - invoke := &interop.Invoke{ - TraceID: "Root=RootID;Parent=LambdaFrontend;Sampled=1", - ID: "", // updated in loop - InvokedFunctionArn: "InvokedFunctionArn", - CognitoIdentityID: "CognitoIdentityId1", - CognitoIdentityPoolID: "CognitoIdentityPoolId1", - ClientContext: "ClientContext", - DeadlineNs: strconv.Itoa(deadlineNs), - ContentType: "image/png", - Payload: strings.NewReader(""), // updated in loop - } - - ctx := telemetry.NewTraceContext(context.Background(), "RootID", "InvocationSubegmentID") - var requestBuffer bytes.Buffer - for i := 0; i < 6; i++ { - handler := NewInvocationNextHandler(flowTest.RegistrationService, flowTest.RenderingService) - responseRecorder := httptest.NewRecorder() - invoke.ID = fmt.Sprintf("ID-%d", i) - invokePayload := string(bytes.Repeat([]byte("a"), (i%3)*128*1024)) // vary payload size up and down across invokes - invoke.Payload = strings.NewReader(invokePayload) - - flowTest.ConfigureForInvoke(ctx, invoke) - flowTest.ConfigureInvokeRenderer(ctx, invoke, &requestBuffer) // reuse request buffer on each invoke - request := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) - handler.ServeHTTP(responseRecorder, request) - - headers := responseRecorder.Header() - assert.Equal(t, invoke.InvokedFunctionArn, headers.Get("Lambda-Runtime-Invoked-Function-Arn")) - assert.Equal(t, invoke.ID, headers.Get("Lambda-Runtime-Aws-Request-Id")) - assert.Equal(t, invoke.ClientContext, headers.Get("Lambda-Runtime-Client-Context")) - expectedCognitoIdentityHeader := fmt.Sprintf("{\"cognitoIdentityId\":\"%s\",\"cognitoIdentityPoolId\":\"%s\"}", invoke.CognitoIdentityID, invoke.CognitoIdentityPoolID) - assert.JSONEq(t, expectedCognitoIdentityHeader, headers.Get("Lambda-Runtime-Cognito-Identity")) - assert.Equal(t, "Root=RootID;Parent=InvocationSubegmentID;Sampled=1", headers.Get("Lambda-Runtime-Trace-Id")) - - // Assert deadline precision. E.g. 1999 ns and 2001 ns having diff of 2 ns - // would result in 1ms and 2ms deadline correspondingly. - expectedDeadline := metering.MonoToEpoch(int64(deadlineNs)) / int64(time.Millisecond) - receivedDeadline, _ := strconv.ParseInt(headers.Get("Lambda-Runtime-Deadline-Ms"), 10, 64) - assert.True(t, math.Abs(float64(expectedDeadline-receivedDeadline)) <= float64(1), - fmt.Sprintf("Expected: %v, received: %v", expectedDeadline, receivedDeadline)) - - assert.Equal(t, "image/png", headers.Get("Content-Type")) - assert.Len(t, headers, 7) - responsePayload := responseRecorder.Body.String() - require.Equalf(t, len(invokePayload), len(responsePayload), "Unexpected payload for request %d", i) - assert.Equal(t, invokePayload, responsePayload) - } -} - -// Cgo calls removed due to crashes while spawning threads under memory pressure. -func TestRenderInvokeDoesNotCallCgo(t *testing.T) { - cgoCallsBefore := runtime.NumCgoCall() - TestRenderInvokeHappy(t) - cgoCallsAfter := runtime.NumCgoCall() - assert.Equal(t, cgoCallsBefore, cgoCallsAfter) -} - -func BenchmarkRenderInvoke(b *testing.B) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - handler := NewInvocationNextHandler(flowTest.RegistrationService, flowTest.RenderingService) - responseRecorder := httptest.NewRecorder() - appCtx := flowTest.AppCtx - - deadlineNs := 12345 - invoke := &interop.Invoke{ - TraceID: "Root=RootID;Parent=LambdaFrontend;Sampled=1", - ID: "ID", - InvokedFunctionArn: "InvokedFunctionArn", - CognitoIdentityID: "CognitoIdentityId1", - CognitoIdentityPoolID: "CognitoIdentityPoolId1", - ClientContext: "ClientContext", - DeadlineNs: strconv.Itoa(deadlineNs), - ContentType: "image/png", - Payload: strings.NewReader("Payload"), - } - - ctx := telemetry.NewTraceContext(context.Background(), "RootID", "InvocationSubegmentID") - flowTest.ConfigureForInvoke(ctx, invoke) - - request := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) - - for i := 0; i < b.N; i++ { - handler.ServeHTTP(responseRecorder, request) - } -} - -type mockBrokenRenderer struct{} - -// RenderAgentEvent renders shutdown event for agent. -func (s *mockBrokenRenderer) RenderAgentEvent(w http.ResponseWriter, r *http.Request) error { - return errors.New("Broken") -} - -// RenderRuntimeEvent renders shutdown event for runtime. -func (s *mockBrokenRenderer) RenderRuntimeEvent(w http.ResponseWriter, r *http.Request) error { - return errors.New("Broken") -} - -func TestRender500AndExitOnInteropFailureDuringFirstInvoke(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - handler := NewInvocationNextHandler(flowTest.RegistrationService, flowTest.RenderingService) - responseRecorder := httptest.NewRecorder() - appCtx := flowTest.AppCtx - - flowTest.InvokeFlow.InitializeBarriers() - flowTest.RenderingService.SetRenderer(&mockBrokenRenderer{}) - - request := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) - handler.ServeHTTP(responseRecorder, request) - - assert.Equal(t, http.StatusInternalServerError, responseRecorder.Code) - assert.JSONEq(t, `{"errorMessage":"Internal Server Error","errorType":"InternalServerError"}`, responseRecorder.Body.String()) -} - -func TestMain(m *testing.M) { - if err := runtime.StartTrace(); err != nil { - log.Fatalf("Failed to start Golang tracer: %s", err) - os.Exit(1) - } - defer runtime.StopTrace() - - os.Exit(m.Run()) -} diff --git a/lambda/rapi/handler/invocationresponse.go b/lambda/rapi/handler/invocationresponse.go deleted file mode 100644 index d267775..0000000 --- a/lambda/rapi/handler/invocationresponse.go +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "net/http" - - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapi/rendering" - - "github.com/go-chi/chi" - log "github.com/sirupsen/logrus" -) - -const ( - StreamingFunctionResponseMode = "streaming" -) - -type invocationResponseHandler struct { - registrationService core.RegistrationService -} - -func (h *invocationResponseHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - appCtx := appctx.FromRequest(request) - - server := appctx.LoadResponseSender(appCtx) - if server == nil { - log.Panic("Invalid state, cannot access interop server") - } - - runtime := h.registrationService.GetRuntime() - if err := runtime.InvocationResponse(); err != nil { - log.Warn(err) - rendering.RenderForbiddenWithTypeMsg(writer, request, rendering.ErrorTypeInvalidStateTransition, StateTransitionFailedForRuntimeMessageFormat, - runtime.GetState().Name(), core.RuntimeInvocationResponseStateName, err) - return - } - - invokeID := chi.URLParam(request, "awsrequestid") - - headers := map[string]string{contentTypeHeader: request.Header.Get(contentTypeHeader)} - if functionResponseMode := request.Header.Get(functionResponseModeHeader); functionResponseMode != "" { - switch functionResponseMode { - case StreamingFunctionResponseMode: - headers[functionResponseModeHeader] = functionResponseMode - default: - errHeaders := interop.InvokeResponseHeaders{ - ContentType: request.Header.Get(contentTypeHeader), - } - fnError := interop.FunctionError{Type: fatalerror.RuntimeInvalidResponseModeHeader} - response := &interop.ErrorInvokeResponse{ - Headers: errHeaders, - FunctionError: fnError, - Payload: []byte{}, - } - - _ = server.SendErrorResponse(chi.URLParam(request, "awsrequestid"), response) - rendering.RenderInvalidFunctionResponseMode(writer, request) - return - } - } - - response := &interop.StreamableInvokeResponse{ - Headers: headers, - Payload: request.Body, - Trailers: request.Trailer, - Request: &interop.CancellableRequest{Request: request}, - } - - if err := server.SendResponse(invokeID, response); err != nil { - switch err := err.(type) { - case *interop.ErrorResponseTooLarge: - if server.SendErrorResponse(invokeID, err.AsErrorResponse()) != nil { - rendering.RenderInteropError(writer, request, err) - return - } - - appctx.StoreInvokeErrorTraceData(appCtx, &interop.InvokeErrorTraceData{}) - - if err := runtime.ResponseSent(); err != nil { - log.Panic(err) - } - - rendering.RenderRequestEntityTooLarge(writer, request) - return - - case *interop.ErrorResponseTooLargeDI: - // in DirectInvoke case, the (truncated) response is already sent back to the caller - if err := runtime.ResponseSent(); err != nil { - log.Panic(err) - } - - rendering.RenderRequestEntityTooLarge(writer, request) - return - - case *interop.ErrTruncatedResponse: - if err := runtime.ResponseSent(); err != nil { - log.Panic(err) - } - - rendering.RenderTruncatedHTTPRequestError(writer, request) - return - - case *interop.ErrInternalPlatformError: - rendering.RenderInternalServerError(writer, request) - return - - default: - rendering.RenderInteropError(writer, request, err) - return - } - } - - if err := runtime.ResponseSent(); err != nil { - log.Panic(err) - } - - rendering.RenderAccepted(writer, request) -} - -// NewInvocationResponseHandler returns a new instance of http handler -// for serving /runtime/invocation/{awsrequestid}/response. -func NewInvocationResponseHandler(registrationService core.RegistrationService) http.Handler { - return &invocationResponseHandler{ - registrationService: registrationService, - } -} diff --git a/lambda/rapi/handler/invocationresponse_test.go b/lambda/rapi/handler/invocationresponse_test.go deleted file mode 100644 index dc29c10..0000000 --- a/lambda/rapi/handler/invocationresponse_test.go +++ /dev/null @@ -1,207 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/aws/aws-lambda-go/events/test" - "github.com/stretchr/testify/assert" - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/testdata" -) - -func TestResponseTooLarge(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - flowTest.Runtime.Ready() - handler := NewInvocationResponseHandler(flowTest.RegistrationService) - responseRecorder := httptest.NewRecorder() - appCtx := flowTest.AppCtx - - // Invoke that we are sending response for must be placed into appCtx. - invoke := &interop.Invoke{ - ID: "InvocationID1", - InvokedFunctionArn: "arn::dummy1", - CognitoIdentityID: "CognitoidentityID1", - CognitoIdentityPoolID: "CognitoidentityPollID1", - DeadlineNs: "deadlinens1", - ClientContext: "clientcontext1", - ContentType: "application/json", - Payload: strings.NewReader(`{"message": "hello"}`), - } - - flowTest.ConfigureForInvoke(context.Background(), invoke) - - // Invocation response submitted by runtime. - var responseBody = make([]byte, interop.MaxPayloadSize+1) - request := httptest.NewRequest("", "/", bytes.NewReader(responseBody)) - request = addInvocationID(request, invoke.ID) - handler.ServeHTTP(responseRecorder, appctx.RequestWithAppCtx(request, appCtx)) - - // Assertions - - assert.Equal(t, http.StatusRequestEntityTooLarge, responseRecorder.Code, "Handler returned wrong status code: got %v expected %v", - responseRecorder.Code, http.StatusRequestEntityTooLarge) - - expectedAPIResponse := fmt.Sprintf("{\"errorMessage\":\"Exceeded maximum allowed payload size (%d bytes).\",\"errorType\":\"RequestEntityTooLarge\"}\n", interop.MaxPayloadSize) - body, err := io.ReadAll(responseRecorder.Body) - assert.NoError(t, err) - test.AssertJsonsEqual(t, []byte(expectedAPIResponse), body) - - errorResponse := flowTest.InteropServer.ErrorResponse - assert.NotNil(t, errorResponse) - assert.Nil(t, flowTest.InteropServer.Response) - assert.Equal(t, fatalerror.FunctionOversizedResponse, errorResponse.FunctionError.Type) - assert.Equal(t, "Response payload size (6291557 bytes) exceeded maximum allowed payload size (6291556 bytes).", errorResponse.FunctionError.Message) - - var errorPayload map[string]interface{} - assert.NoError(t, json.Unmarshal(errorResponse.Payload, &errorPayload)) - assert.Equal(t, string(errorResponse.FunctionError.Type), errorPayload["errorType"]) - assert.Equal(t, errorResponse.FunctionError.Message, errorPayload["errorMessage"]) -} - -func TestResponseAccepted(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - flowTest.Runtime.Ready() - handler := NewInvocationResponseHandler(flowTest.RegistrationService) - responseRecorder := httptest.NewRecorder() - appCtx := flowTest.AppCtx - - // Invoke that we are sending response for must be placed into appCtx. - invoke := &interop.Invoke{ - ID: "InvocationID1", - InvokedFunctionArn: "arn::dummy1", - CognitoIdentityID: "CognitoidentityID1", - CognitoIdentityPoolID: "CognitoidentityPollID1", - DeadlineNs: "deadlinens1", - ClientContext: "clientcontext1", - ContentType: "application/json", - Payload: strings.NewReader(`{"message": "hello"}`), - } - - flowTest.ConfigureForInvoke(context.Background(), invoke) - - // Invocation response submitted by runtime. - var responseBody = []byte("{'foo': 'bar'}") - - request := httptest.NewRequest("", "/", bytes.NewReader(responseBody)) - request = addInvocationID(request, invoke.ID) - request.Header.Set(contentTypeHeader, "application/json") - request.Header.Set(functionResponseModeHeader, "streaming") - handler.ServeHTTP(responseRecorder, appctx.RequestWithAppCtx(request, appCtx)) - - // Assertions - assert.Equal(t, http.StatusAccepted, responseRecorder.Code, "Handler returned wrong status code: got %v expected %v", - responseRecorder.Code, http.StatusAccepted) - - expectedAPIResponse := "{\"status\":\"OK\"}\n" - body, err := io.ReadAll(responseRecorder.Body) - assert.NoError(t, err) - test.AssertJsonsEqual(t, []byte(expectedAPIResponse), body) - - response := flowTest.InteropServer.Response - assert.NotNil(t, response) - assert.Nil(t, flowTest.InteropServer.ErrorResponse) - assert.Equal(t, "application/json", flowTest.InteropServer.ResponseContentType) - assert.Equal(t, "streaming", flowTest.InteropServer.FunctionResponseMode) - assert.Equal(t, responseBody, response, - "Persisted response data in app context must match the submitted.") -} - -func TestResponseWithDifferentFunctionResponseModes(t *testing.T) { - type testCase struct { - providedFunctionResponseMode string - expectedFunctionResponseMode string - expectedAPIResponse string - expectedStatusCode int - expectedErrorResponse bool - } - testCases := []testCase{ - { - providedFunctionResponseMode: "", - expectedFunctionResponseMode: "", - expectedAPIResponse: "{\"status\":\"OK\"}\n", - expectedStatusCode: http.StatusAccepted, - expectedErrorResponse: false, - }, - { - providedFunctionResponseMode: "streaming", - expectedFunctionResponseMode: "streaming", - expectedAPIResponse: "{\"status\":\"OK\"}\n", - expectedStatusCode: http.StatusAccepted, - expectedErrorResponse: false, - }, - { - providedFunctionResponseMode: "invalid-mode", - expectedFunctionResponseMode: "", - expectedAPIResponse: "{\"errorMessage\":\"Invalid function response mode\", \"errorType\":\"InvalidFunctionResponseMode\"}\n", - expectedStatusCode: http.StatusBadRequest, - expectedErrorResponse: true, - }, - } - - for _, testCase := range testCases { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - flowTest.Runtime.Ready() - handler := NewInvocationResponseHandler(flowTest.RegistrationService) - responseRecorder := httptest.NewRecorder() - appCtx := flowTest.AppCtx - - // Invoke that we are sending response for must be placed into appCtx. - invoke := &interop.Invoke{ - ID: "InvocationID1", - InvokedFunctionArn: "arn::dummy1", - CognitoIdentityID: "CognitoidentityID1", - CognitoIdentityPoolID: "CognitoidentityPollID1", - DeadlineNs: "deadlinens1", - ClientContext: "clientcontext1", - ContentType: "application/json", - Payload: strings.NewReader(`{"message": "hello"}`), - } - - flowTest.ConfigureForInvoke(context.Background(), invoke) - - // Invocation response submitted by runtime. - var responseBody = []byte("{'foo': 'bar'}") - - request := httptest.NewRequest("", "/", bytes.NewReader(responseBody)) - request = addInvocationID(request, invoke.ID) - request.Header.Set(functionResponseModeHeader, testCase.providedFunctionResponseMode) - handler.ServeHTTP(responseRecorder, appctx.RequestWithAppCtx(request, appCtx)) - - // Assertions - assert.Equal(t, testCase.expectedStatusCode, responseRecorder.Code, "Handler returned wrong status code: got %v expected %v", - responseRecorder.Code, testCase.expectedStatusCode) - - body, err := io.ReadAll(responseRecorder.Body) - assert.NoError(t, err) - test.AssertJsonsEqual(t, []byte(testCase.expectedAPIResponse), body) - - if testCase.expectedErrorResponse { - assert.NotNil(t, flowTest.InteropServer.ErrorResponse) - assert.Nil(t, flowTest.InteropServer.Response) - assert.Equal(t, fatalerror.RuntimeInvalidResponseModeHeader, flowTest.InteropServer.ErrorResponse.FunctionError.Type) - } else { - assert.NotNil(t, flowTest.InteropServer.Response) - assert.Nil(t, flowTest.InteropServer.ErrorResponse) - assert.Equal(t, responseBody, flowTest.InteropServer.Response, - "Persisted response data in app context must match the submitted.") - } - - assert.Equal(t, testCase.expectedFunctionResponseMode, flowTest.InteropServer.FunctionResponseMode) - } -} diff --git a/lambda/rapi/handler/mime_type_error_cause_json.go b/lambda/rapi/handler/mime_type_error_cause_json.go deleted file mode 100644 index d66ab88..0000000 --- a/lambda/rapi/handler/mime_type_error_cause_json.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "encoding/json" - "fmt" - - "go.amzn.com/lambda/rapi/model" - - log "github.com/sirupsen/logrus" -) - -// Validation, serialization & deserialization for -// MIME type: application/vnd.aws.lambda.error.cause+json" -type errorWithCauseRequest struct { - ErrorMessage string `json:"errorMessage"` - ErrorType string `json:"errorType"` - StackTrace []string `json:"stackTrace"` - ErrorCause json.RawMessage `json:"errorCause"` -} - -func newErrorWithCauseRequest(requestBody []byte) (*errorWithCauseRequest, error) { - var parsedRequest errorWithCauseRequest - if err := json.Unmarshal(requestBody, &parsedRequest); err != nil { - return nil, fmt.Errorf("error unmarshalling request body with error cause: %s", err) - } - - return &parsedRequest, nil -} - -func (r *errorWithCauseRequest) getInvokeErrorResponse() ([]byte, error) { - responseBody := model.ErrorResponse{ - ErrorMessage: r.ErrorMessage, - ErrorType: r.ErrorType, - StackTrace: r.StackTrace, - } - - filteredResponseBody, err := json.Marshal(responseBody) - if err != nil { - return nil, fmt.Errorf("error marshalling invocation/error response body: %s", err) - } - - return filteredResponseBody, nil -} - -func (r *errorWithCauseRequest) getValidatedXRayCause() json.RawMessage { - if len(r.ErrorCause) == 0 { - return nil - } - - validErrorCauseJSON, err := model.ValidatedErrorCauseJSON(r.ErrorCause) - if err != nil { - log.WithError(err).Errorf("errorCause validation error, Content-Type: %s", errorWithCauseContentType) - return nil - } - - return validErrorCauseJSON -} diff --git a/lambda/rapi/handler/ping.go b/lambda/rapi/handler/ping.go deleted file mode 100644 index e93627b..0000000 --- a/lambda/rapi/handler/ping.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "net/http" - - log "github.com/sirupsen/logrus" -) - -type pingHandler struct { - // -} - -func (h *pingHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - if _, err := writer.Write([]byte("pong")); err != nil { - log.WithError(err).Fatal("Failed to write 'pong' response") - } -} - -// NewPingHandler returns a new instance of http handler -// for serving /ping. -func NewPingHandler() http.Handler { - return &pingHandler{} -} diff --git a/lambda/rapi/handler/restoreerror.go b/lambda/rapi/handler/restoreerror.go deleted file mode 100644 index eed97b2..0000000 --- a/lambda/rapi/handler/restoreerror.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "net/http" - - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapi/rendering" -) - -type restoreErrorHandler struct { - registrationService core.RegistrationService -} - -func (h *restoreErrorHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - appCtx := appctx.FromRequest(request) - server := appctx.LoadInteropServer(appCtx) - if server == nil { - log.Panic("Invalid state, cannot access interop server") - } - - errorType := fatalerror.GetValidRuntimeOrFunctionErrorType(request.Header.Get("Lambda-Runtime-Function-Error-Type")) - fnError := interop.FunctionError{Type: errorType} - - runtime := h.registrationService.GetRuntime() - - if err := runtime.RestoreError(fnError); err != nil { - log.Warn(err) - rendering.RenderForbiddenWithTypeMsg(writer, request, rendering.ErrorTypeInvalidStateTransition, StateTransitionFailedForRuntimeMessageFormat, - runtime.GetState().Name(), core.RuntimeRestoreErrorStateName, err) - return - } - - appctx.StoreInvokeErrorTraceData(appCtx, &interop.InvokeErrorTraceData{}) - - rendering.RenderAccepted(writer, request) -} - -func NewRestoreErrorHandler(registrationService core.RegistrationService) http.Handler { - return &restoreErrorHandler{registrationService: registrationService} -} diff --git a/lambda/rapi/handler/restoreerror_test.go b/lambda/rapi/handler/restoreerror_test.go deleted file mode 100644 index 57226fa..0000000 --- a/lambda/rapi/handler/restoreerror_test.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "bytes" - "fmt" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/require" - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/testdata" -) - -func TestRestoreErrorHandler(t *testing.T) { - t.Run("GA", func(t *testing.T) { runTestRestoreErrorHandler(t) }) -} - -func runTestRestoreErrorHandler(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForRestoring() - - handler := NewRestoreErrorHandler(flowTest.RegistrationService) - responseRecorder := httptest.NewRecorder() - appCtx := flowTest.AppCtx - - errorBody := []byte("My byte array is yours") - errorType := "ErrorType" - errorContentType := "application/MyBinaryType" - - request := appctx.RequestWithAppCtx(httptest.NewRequest("POST", "/", bytes.NewReader(errorBody)), appCtx) - - request.Header.Set("Content-Type", errorContentType) - request.Header.Set("Lambda-Runtime-Function-Error-Type", errorType) - - handler.ServeHTTP(responseRecorder, request) - - require.Equal(t, http.StatusAccepted, responseRecorder.Code, "Handler returned wrong status code: got %v expected %v", responseRecorder.Code, http.StatusAccepted) - require.JSONEq(t, fmt.Sprintf("{\"status\":\"%s\"}\n", "OK"), responseRecorder.Body.String()) - require.Equal(t, "application/json", responseRecorder.Header().Get("Content-Type")) -} diff --git a/lambda/rapi/handler/restorenext.go b/lambda/rapi/handler/restorenext.go deleted file mode 100644 index ecff059..0000000 --- a/lambda/rapi/handler/restorenext.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "net/http" - - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/rapi/rendering" -) - -type restoreNextHandler struct { - registrationService core.RegistrationService - renderingService *rendering.EventRenderingService -} - -func (h *restoreNextHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - runtime := h.registrationService.GetRuntime() - err := runtime.RestoreReady() - if err != nil { - log.Warn(err) - rendering.RenderForbiddenWithTypeMsg(writer, request, rendering.ErrorTypeInvalidStateTransition, StateTransitionFailedForRuntimeMessageFormat, runtime.GetState().Name(), core.RuntimeReadyStateName, err) - return - } - err = h.renderingService.RenderRuntimeEvent(writer, request) - if err != nil { - log.Error(err) - rendering.RenderInternalServerError(writer, request) - return - } -} - -func NewRestoreNextHandler(registrationService core.RegistrationService, renderingService *rendering.EventRenderingService) http.Handler { - return &restoreNextHandler{ - registrationService: registrationService, - renderingService: renderingService, - } -} diff --git a/lambda/rapi/handler/restorenext_test.go b/lambda/rapi/handler/restorenext_test.go deleted file mode 100644 index 7018d98..0000000 --- a/lambda/rapi/handler/restorenext_test.go +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "context" - "net/http" - "net/http/httptest" - "strconv" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/telemetry" - "go.amzn.com/lambda/testdata" -) - -func TestRenderRestoreNext(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - handler := NewRestoreNextHandler(flowTest.RegistrationService, flowTest.RenderingService) - responseRecorder := httptest.NewRecorder() - appCtx := flowTest.AppCtx - - flowTest.ConfigureForRestore() - request := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) - handler.ServeHTTP(responseRecorder, request) - - assert.Equal(t, http.StatusOK, responseRecorder.Code) -} - -func TestBrokenRenderer(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - handler := NewRestoreNextHandler(flowTest.RegistrationService, flowTest.RenderingService) - responseRecorder := httptest.NewRecorder() - appCtx := flowTest.AppCtx - - flowTest.ConfigureForRestore() - flowTest.RenderingService.SetRenderer(&mockBrokenRenderer{}) - request := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) - handler.ServeHTTP(responseRecorder, request) - - assert.Equal(t, http.StatusInternalServerError, responseRecorder.Code) - - assert.JSONEq(t, `{"errorMessage":"Internal Server Error","errorType":"InternalServerError"}`, responseRecorder.Body.String()) -} - -func TestRenderRestoreAfterInvoke(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - handler := NewInvocationNextHandler(flowTest.RegistrationService, flowTest.RenderingService) - responseRecorder := httptest.NewRecorder() - appCtx := flowTest.AppCtx - - deadlineNs := 12345 - invokePayload := "Payload" - invoke := &interop.Invoke{ - TraceID: "Root=RootID;Parent=LambdaFrontend;Sampled=1", - ID: "ID", - InvokedFunctionArn: "InvokedFunctionArn", - CognitoIdentityID: "CognitoIdentityId1", - CognitoIdentityPoolID: "CognitoIdentityPoolId1", - ClientContext: "ClientContext", - DeadlineNs: strconv.Itoa(deadlineNs), - ContentType: "image/png", - Payload: strings.NewReader(invokePayload), - } - - ctx := telemetry.NewTraceContext(context.Background(), "RootID", "InvocationSubegmentID") - flowTest.ConfigureForInvoke(ctx, invoke) - - request := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) - handler.ServeHTTP(responseRecorder, request) - - assert.Equal(t, http.StatusOK, responseRecorder.Code) - - restoreHandler := NewRestoreNextHandler(flowTest.RegistrationService, flowTest.RenderingService) - restoreRequest := appctx.RequestWithAppCtx(httptest.NewRequest("", "/", nil), appCtx) - responseRecorder = httptest.NewRecorder() - restoreHandler.ServeHTTP(responseRecorder, restoreRequest) - - assert.Equal(t, http.StatusForbidden, responseRecorder.Code) -} diff --git a/lambda/rapi/handler/runtimelogs.go b/lambda/rapi/handler/runtimelogs.go deleted file mode 100644 index 4fd534e..0000000 --- a/lambda/rapi/handler/runtimelogs.go +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "bytes" - "errors" - "fmt" - "io" - "net/http" - "strings" - - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/rapi/rendering" - "go.amzn.com/lambda/telemetry" - - "github.com/google/uuid" - log "github.com/sirupsen/logrus" -) - -type runtimeLogsHandler struct { - registrationService core.RegistrationService - telemetrySubscription telemetry.SubscriptionAPI -} - -func (h *runtimeLogsHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - agentName, err := h.verifyAgentID(writer, request) - if err != nil { - log.Errorf("Agent Verification Error: %s", err) - switch err := err.(type) { - case *ErrAgentIdentifierUnknown: - rendering.RenderForbiddenWithTypeMsg(writer, request, errAgentIdentifierUnknown, "Unknown extension %s", err.agentID.String()) - h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeClientErr, 1) - default: - rendering.RenderInternalServerError(writer, request) - h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeServerErr, 1) - } - return - } - - delete(request.Header, LambdaAgentIdentifier) - - body, err := h.getBody(writer, request) - if err != nil { - log.Error(err) - rendering.RenderInternalServerError(writer, request) - h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeServerErr, 1) - return - } - - respBody, status, headers, err := h.telemetrySubscription.Subscribe(agentName, bytes.NewReader(body), request.Header, request.RemoteAddr) - if err != nil { - log.Errorf("Telemetry API error: %s", err) - switch err { - case telemetry.ErrTelemetryServiceOff: - rendering.RenderForbiddenWithTypeMsg(writer, request, - h.telemetrySubscription.GetServiceClosedErrorType(), "%s", h.telemetrySubscription.GetServiceClosedErrorMessage()) - h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeClientErr, 1) - default: - rendering.RenderInternalServerError(writer, request) - h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeServerErr, 1) - } - return - } - - rendering.RenderRuntimeLogsResponse(writer, respBody, status, headers) - switch status / 100 { - case 2: // 2xx - if strings.Contains(string(respBody), "OK") { - h.telemetrySubscription.RecordCounterMetric(telemetry.NumSubscribers, 1) - } - h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeSuccess, 1) - case 4: // 4xx - h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeClientErr, 1) - case 5: // 5xx - h.telemetrySubscription.RecordCounterMetric(telemetry.SubscribeServerErr, 1) - } -} - -type ErrAgentIdentifierUnknown struct { - agentID uuid.UUID -} - -func NewErrAgentIdentifierUnknown(agentID uuid.UUID) *ErrAgentIdentifierUnknown { - return &ErrAgentIdentifierUnknown{ - agentID: agentID, - } -} - -func (e *ErrAgentIdentifierUnknown) Error() string { - return fmt.Sprintf("Unknown agent %s tried to call /runtime/logs", e.agentID.String()) -} - -func (h *runtimeLogsHandler) verifyAgentID(writer http.ResponseWriter, request *http.Request) (string, error) { - agentID, ok := request.Context().Value(AgentIDCtxKey).(uuid.UUID) - if !ok { - return "", errors.New("internal error: agent ID not set in context") - } - - agentName, found := h.getAgentName(agentID) - if !found { - return "", NewErrAgentIdentifierUnknown(agentID) - } - - return agentName, nil -} - -func (h *runtimeLogsHandler) getAgentName(agentID uuid.UUID) (string, bool) { - if agent, found := h.registrationService.FindExternalAgentByID(agentID); found { - return agent.Name, true - } else if agent, found := h.registrationService.FindInternalAgentByID(agentID); found { - return agent.Name, true - } else { - return "", false - } -} - -func (h *runtimeLogsHandler) getBody(writer http.ResponseWriter, request *http.Request) ([]byte, error) { - body, err := io.ReadAll(request.Body) - if err != nil { - return nil, fmt.Errorf("Failed to read error body: %s", err) - } - - return body, nil -} - -// NewRuntimeTelemetrySubscriptionHandler returns a new instance of http handler -// for serving /runtime/logs -func NewRuntimeTelemetrySubscriptionHandler(registrationService core.RegistrationService, telemetrySubscription telemetry.SubscriptionAPI) http.Handler { - return &runtimeLogsHandler{ - registrationService: registrationService, - telemetrySubscription: telemetrySubscription, - } -} diff --git a/lambda/rapi/handler/runtimelogs_stub.go b/lambda/rapi/handler/runtimelogs_stub.go deleted file mode 100644 index f540e9b..0000000 --- a/lambda/rapi/handler/runtimelogs_stub.go +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "net/http" - - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/rapi/model" - "go.amzn.com/lambda/rapi/rendering" -) - -const ( - logsAPIDisabledErrorType = "Logs.NotSupported" - telemetryAPIDisabledErrorType = "Telemetry.NotSupported" -) - -type runtimeLogsStubAPIHandler struct{} - -func (h *runtimeLogsStubAPIHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - if err := rendering.RenderJSON(http.StatusAccepted, writer, request, &model.ErrorResponse{ - ErrorType: logsAPIDisabledErrorType, - ErrorMessage: "Logs API is not supported", - }); err != nil { - log.WithError(err).Warn("Error while rendering response") - http.Error(writer, err.Error(), http.StatusInternalServerError) - } -} - -// NewRuntimeLogsAPIStubHandler returns a new instance of http handler -// for serving /runtime/logs when a telemetry service implementation is absent -func NewRuntimeLogsAPIStubHandler() http.Handler { - return &runtimeLogsStubAPIHandler{} -} - -type runtimeTelemetryAPIStubHandler struct{} - -func (h *runtimeTelemetryAPIStubHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - if err := rendering.RenderJSON(http.StatusAccepted, writer, request, &model.ErrorResponse{ - ErrorType: telemetryAPIDisabledErrorType, - ErrorMessage: "Telemetry API is not supported", - }); err != nil { - log.WithError(err).Warn("Error while rendering response") - http.Error(writer, err.Error(), http.StatusInternalServerError) - } -} - -// NewRuntimeTelemetryAPIStubHandler returns a new instance of http handler -// for serving /runtime/logs when a telemetry service implementation is absent -func NewRuntimeTelemetryAPIStubHandler() http.Handler { - return &runtimeTelemetryAPIStubHandler{} -} diff --git a/lambda/rapi/handler/runtimelogs_stub_test.go b/lambda/rapi/handler/runtimelogs_stub_test.go deleted file mode 100644 index 5b27983..0000000 --- a/lambda/rapi/handler/runtimelogs_stub_test.go +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "bytes" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestSuccessfulRuntimeLogsAPIStub202Response(t *testing.T) { - handler := NewRuntimeLogsAPIStubHandler() - requestBody := []byte(`foobar`) - request := httptest.NewRequest("PUT", "/logs", bytes.NewBuffer(requestBody)) - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, request) - - assert.Equal(t, http.StatusAccepted, responseRecorder.Code) - assert.JSONEq(t, `{"errorMessage":"Logs API is not supported","errorType":"Logs.NotSupported"}`, responseRecorder.Body.String()) -} - -func TestSuccessfulRuntimeTelemetryAPIStub202Response(t *testing.T) { - handler := NewRuntimeTelemetryAPIStubHandler() - requestBody := []byte(`foobar`) - request := httptest.NewRequest("PUT", "/telemetry", bytes.NewBuffer(requestBody)) - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, request) - - assert.Equal(t, http.StatusAccepted, responseRecorder.Code) - assert.JSONEq(t, `{"errorMessage":"Telemetry API is not supported","errorType":"Telemetry.NotSupported"}`, responseRecorder.Body.String()) -} diff --git a/lambda/rapi/handler/runtimelogs_test.go b/lambda/rapi/handler/runtimelogs_test.go deleted file mode 100644 index cbb8b0b..0000000 --- a/lambda/rapi/handler/runtimelogs_test.go +++ /dev/null @@ -1,364 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "bytes" - "context" - "errors" - "fmt" - "io" - "net" - "net/http" - "net/http/httptest" - "testing" - - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/telemetry" - - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -type mockSubscriptionAPI struct{ mock.Mock } - -func (s *mockSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string, remoteAddr string) ([]byte, int, map[string][]string, error) { - args := s.Called(agentName, body, headers, remoteAddr) - return args.Get(0).([]byte), args.Int(1), args.Get(2).(map[string][]string), args.Error(3) -} - -func (s *mockSubscriptionAPI) RecordCounterMetric(metricName string, count int) { - s.Called(metricName, count) -} - -func (s *mockSubscriptionAPI) FlushMetrics() interop.TelemetrySubscriptionMetrics { - args := s.Called() - return args.Get(0).(interop.TelemetrySubscriptionMetrics) -} - -func (s *mockSubscriptionAPI) Clear() { - s.Called() -} - -func (s *mockSubscriptionAPI) TurnOff() { - s.Called() -} - -func (s *mockSubscriptionAPI) GetEndpointURL() string { - args := s.Called() - return args.Get(0).(string) -} - -func (s *mockSubscriptionAPI) GetServiceClosedErrorMessage() string { - args := s.Called() - return args.Get(0).(string) -} - -func (s *mockSubscriptionAPI) GetServiceClosedErrorType() string { - args := s.Called() - return args.Get(0).(string) -} - -func validIPPort(addr string) bool { - ip, _, err := net.SplitHostPort(addr) - return err == nil && net.ParseIP(ip) != nil -} - -func TestSuccessfulRuntimeLogsResponseProxy(t *testing.T) { - agentName, reqBody, reqHeaders := "dummyName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} - respBody, respStatus, respHeaders := []byte(`barbaz`), http.StatusNotFound, map[string][]string{"K": []string{"V1", "V2"}} - clientErrMetric := telemetry.SubscribeClientErr - - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - - agent, err := registrationService.CreateExternalAgent(agentName) - assert.NoError(t, err) - - telemetrySubscription := &mockSubscriptionAPI{} - telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return(respBody, respStatus, respHeaders, nil) - telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) - - handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) - request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) - for k, vals := range reqHeaders { - for _, v := range vals { - request.Header.Add(k, v) - } - } - - request = request.WithContext(context.WithValue(context.Background(), AgentIDCtxKey, agent.ID)) - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, request) - - telemetrySubscription.AssertCalled(t, "Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)) - telemetrySubscription.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) - - recordedBody, err := io.ReadAll(responseRecorder.Body) - assert.NoError(t, err) - - assert.Equal(t, respStatus, responseRecorder.Code) - assert.Equal(t, respBody, recordedBody) - assert.Equal(t, http.Header(respHeaders), responseRecorder.Header()) -} - -func TestSuccessfulTelemetryAPIPutRequest(t *testing.T) { - agentName, reqBody, reqHeaders := "extensionName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} - respBody, respStatus, respHeaders := []byte(`"OK"`), http.StatusOK, map[string][]string{"K": []string{"V1", "V2"}} - numSubscribersMetric := telemetry.NumSubscribers - subscribeSuccessMetric := telemetry.SubscribeSuccess - - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - - agent, err := registrationService.CreateExternalAgent(agentName) - assert.NoError(t, err) - - telemetrySubscription := &mockSubscriptionAPI{} - telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return(respBody, respStatus, respHeaders, nil) - telemetrySubscription.On("RecordCounterMetric", numSubscribersMetric, 1) - telemetrySubscription.On("RecordCounterMetric", subscribeSuccessMetric, 1) - - handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) - request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) - for k, vals := range reqHeaders { - for _, v := range vals { - request.Header.Add(k, v) - } - } - - request = request.WithContext(context.WithValue(context.Background(), AgentIDCtxKey, agent.ID)) - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, request) - - telemetrySubscription.AssertCalled(t, "Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)) - telemetrySubscription.AssertCalled(t, "RecordCounterMetric", numSubscribersMetric, 1) - telemetrySubscription.AssertCalled(t, "RecordCounterMetric", subscribeSuccessMetric, 1) - - recordedBody, err := io.ReadAll(responseRecorder.Body) - assert.NoError(t, err) - - assert.Equal(t, respStatus, responseRecorder.Code) - assert.Equal(t, respBody, recordedBody) - assert.Equal(t, http.Header(respHeaders), responseRecorder.Header()) -} - -func TestNumberOfSubscribersWhenAnExtensionIsAlreadySubscribed(t *testing.T) { - agentName, reqBody, reqHeaders := "extensionName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} - respBody, respStatus, respHeaders := []byte(`"AlreadySubcribed"`), http.StatusOK, map[string][]string{"K": []string{"V1", "V2"}} - numSubscribersMetric := telemetry.NumSubscribers - subscribeSuccessMetric := telemetry.SubscribeSuccess - - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - - agent, err := registrationService.CreateExternalAgent(agentName) - assert.NoError(t, err) - - telemetrySubscription := &mockSubscriptionAPI{} - telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return(respBody, respStatus, respHeaders, nil) - telemetrySubscription.On("RecordCounterMetric", subscribeSuccessMetric, 1) - - handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) - request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) - for k, vals := range reqHeaders { - for _, v := range vals { - request.Header.Add(k, v) - } - } - - request = request.WithContext(context.WithValue(context.Background(), AgentIDCtxKey, agent.ID)) - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, request) - - telemetrySubscription.AssertCalled(t, "Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)) - telemetrySubscription.AssertCalled(t, "RecordCounterMetric", subscribeSuccessMetric, 1) - telemetrySubscription.AssertNotCalled(t, "RecordCounterMetric", numSubscribersMetric, mock.Anything) - - recordedBody, err := io.ReadAll(responseRecorder.Body) - assert.NoError(t, err) - - assert.Equal(t, respStatus, responseRecorder.Code) - assert.Equal(t, respBody, recordedBody) - assert.Equal(t, http.Header(respHeaders), responseRecorder.Header()) -} - -func TestErrorUnregisteredAgentID(t *testing.T) { - invalidAgentID := uuid.New() - reqBody, reqHeaders := []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} - clientErrMetric := telemetry.SubscribeClientErr - - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - - telemetrySubscription := &mockSubscriptionAPI{} - telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) - - handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) - request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) - for k, vals := range reqHeaders { - for _, v := range vals { - request.Header.Add(k, v) - } - } - - request = request.WithContext(context.WithValue(context.Background(), AgentIDCtxKey, invalidAgentID)) - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, request) - - recordedBody, err := io.ReadAll(responseRecorder.Body) - assert.NoError(t, err) - - expectedErrorBody := fmt.Sprintf(`{"errorMessage":"Unknown extension %s","errorType":"Extension.UnknownExtensionIdentifier"}`+"\n", invalidAgentID.String()) - expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json"}}) - - assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - assert.Equal(t, expectedErrorBody, string(recordedBody)) - assert.Equal(t, expectedHeaders, responseRecorder.Header()) - telemetrySubscription.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) -} - -func TestErrorTelemetryAPICallFailure(t *testing.T) { - agentName, reqBody, reqHeaders := "dummyName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} - apiError := errors.New("Error calling Telemetry API: connection refused") - serverErrMetric := telemetry.SubscribeServerErr - - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - - agent, err := registrationService.CreateExternalAgent(agentName) - assert.NoError(t, err) - - telemetrySubscription := &mockSubscriptionAPI{} - telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) - telemetrySubscription.On("RecordCounterMetric", serverErrMetric, 1) - - handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) - request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) - for k, vals := range reqHeaders { - for _, v := range vals { - request.Header.Add(k, v) - } - } - - request = request.WithContext(context.WithValue(context.Background(), AgentIDCtxKey, agent.ID)) - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, request) - - recordedBody, err := io.ReadAll(responseRecorder.Body) - assert.NoError(t, err) - - expectedErrorBody := `{"errorMessage":"Internal Server Error","errorType":"InternalServerError"}` + "\n" - expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json"}}) - - assert.Equal(t, http.StatusInternalServerError, responseRecorder.Code) - assert.Equal(t, expectedErrorBody, string(recordedBody)) - assert.Equal(t, expectedHeaders, responseRecorder.Header()) - telemetrySubscription.AssertCalled(t, "RecordCounterMetric", serverErrMetric, 1) -} - -func TestRenderLogsSubscriptionClosed(t *testing.T) { - agentName, reqBody, reqHeaders := "dummyName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} - apiError := telemetry.ErrTelemetryServiceOff - clientErrMetric := telemetry.SubscribeClientErr - - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - - agent, err := registrationService.CreateExternalAgent(agentName) - assert.NoError(t, err) - - telemetrySubscription := &mockSubscriptionAPI{} - telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) - telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) - telemetrySubscription.On("GetServiceClosedErrorMessage").Return("Logs API subscription is closed already") - telemetrySubscription.On("GetServiceClosedErrorType").Return("Logs.SubscriptionClosed") - - handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) - request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) - for k, vals := range reqHeaders { - for _, v := range vals { - request.Header.Add(k, v) - } - } - - request = request.WithContext(context.WithValue(context.Background(), AgentIDCtxKey, agent.ID)) - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, request) - - recordedBody, err := io.ReadAll(responseRecorder.Body) - assert.NoError(t, err) - - expectedErrorBody := `{"errorMessage":"Logs API subscription is closed already","errorType":"Logs.SubscriptionClosed"}` + "\n" - expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json"}}) - - assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - assert.Equal(t, expectedErrorBody, string(recordedBody)) - assert.Equal(t, expectedHeaders, responseRecorder.Header()) - telemetrySubscription.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) -} - -func TestRenderTelemetrySubscriptionClosed(t *testing.T) { - agentName, reqBody, reqHeaders := "dummyName", []byte(`foobar`), map[string][]string{"Key": []string{"VAL1", "VAL2"}} - apiError := telemetry.ErrTelemetryServiceOff - clientErrMetric := telemetry.SubscribeClientErr - - registrationService := core.NewRegistrationService( - core.NewInitFlowSynchronization(), - core.NewInvokeFlowSynchronization(), - ) - - agent, err := registrationService.CreateExternalAgent(agentName) - assert.NoError(t, err) - - telemetrySubscription := &mockSubscriptionAPI{} - telemetrySubscription.On("Subscribe", agentName, bytes.NewReader(reqBody), reqHeaders, mock.MatchedBy(validIPPort)).Return([]byte(``), http.StatusOK, map[string][]string{}, apiError) - telemetrySubscription.On("RecordCounterMetric", clientErrMetric, 1) - telemetrySubscription.On("GetServiceClosedErrorMessage").Return("Telemetry API subscription is closed already") - telemetrySubscription.On("GetServiceClosedErrorType").Return("Telemetry.SubscriptionClosed") - - handler := NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscription) - request := httptest.NewRequest("PUT", "/", bytes.NewBuffer(reqBody)) - for k, vals := range reqHeaders { - for _, v := range vals { - request.Header.Add(k, v) - } - } - - request = request.WithContext(context.WithValue(context.Background(), AgentIDCtxKey, agent.ID)) - responseRecorder := httptest.NewRecorder() - - handler.ServeHTTP(responseRecorder, request) - - recordedBody, err := io.ReadAll(responseRecorder.Body) - assert.NoError(t, err) - - expectedErrorBody := `{"errorMessage":"Telemetry API subscription is closed already","errorType":"Telemetry.SubscriptionClosed"}` + "\n" - expectedHeaders := http.Header(map[string][]string{"Content-Type": []string{"application/json"}}) - - assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - assert.Equal(t, expectedErrorBody, string(recordedBody)) - assert.Equal(t, expectedHeaders, responseRecorder.Header()) - telemetrySubscription.AssertCalled(t, "RecordCounterMetric", clientErrMetric, 1) -} diff --git a/lambda/rapi/middleware/middleware.go b/lambda/rapi/middleware/middleware.go deleted file mode 100644 index d45798b..0000000 --- a/lambda/rapi/middleware/middleware.go +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package middleware - -import ( - "context" - "net/http" - - "github.com/google/uuid" - "go.amzn.com/lambda/extensions" - "go.amzn.com/lambda/rapi/handler" - "go.amzn.com/lambda/rapi/rendering" - - "github.com/go-chi/chi" - "go.amzn.com/lambda/appctx" - - log "github.com/sirupsen/logrus" -) - -// AwsRequestIDValidator validates that {awsrequestid} parameter -// is present in the URL and matches to the currently active id. -func AwsRequestIDValidator(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - appCtx := appctx.FromRequest(r) - interopServer := appctx.LoadInteropServer(appCtx) - - if interopServer == nil { - log.Panic("Invalid state, cannot access interop server") - } - - invokeID := chi.URLParam(r, "awsrequestid") - if invokeID == "" || invokeID != interopServer.GetCurrentInvokeID() { - rendering.RenderInvalidRequestID(w, r) - return - } - - next.ServeHTTP(w, r) - }) -} - -// AgentUniqueIdentifierHeaderValidator validates that the request contains a valid agent unique identifier in the headers -func AgentUniqueIdentifierHeaderValidator(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - agentIdentifier := r.Header.Get(handler.LambdaAgentIdentifier) - if len(agentIdentifier) == 0 { - rendering.RenderForbiddenWithTypeMsg(w, r, handler.ErrAgentIdentifierMissing, "Missing Lambda-Extension-Identifier header") - return - } - agentID, e := uuid.Parse(agentIdentifier) - if e != nil { - rendering.RenderForbiddenWithTypeMsg(w, r, handler.ErrAgentIdentifierInvalid, "Invalid Lambda-Extension-Identifier") - return - } - - r = r.WithContext(context.WithValue(r.Context(), handler.AgentIDCtxKey, agentID)) - next.ServeHTTP(w, r) - }) -} - -// AppCtxMiddleware injects application context into request context. -func AppCtxMiddleware(appCtx appctx.ApplicationContext) func(next http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { - r = appctx.RequestWithAppCtx(r, appCtx) - next.ServeHTTP(w, r) - } - return http.HandlerFunc(fn) - } -} - -// AccessLogMiddleware writes api access log. -func AccessLogMiddleware() func(next http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { - log.Debug("API request - ", r.Method, " ", r.URL, ", Headers:", r.Header) - next.ServeHTTP(w, r) - } - return http.HandlerFunc(fn) - } -} - -func AllowIfExtensionsEnabled(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !extensions.AreEnabled() { - w.WriteHeader(http.StatusNotFound) - return - } - next.ServeHTTP(w, r) - }) -} - -// RuntimeReleaseMiddleware places runtime_release into app context. -func RuntimeReleaseMiddleware() func(next http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { - appCtx := appctx.FromRequest(r) - // Place runtime_release into app context. - appctx.UpdateAppCtxWithRuntimeRelease(r, appCtx) - next.ServeHTTP(w, r) - } - return http.HandlerFunc(fn) - } -} diff --git a/lambda/rapi/middleware/middleware_test.go b/lambda/rapi/middleware/middleware_test.go deleted file mode 100644 index a0b9134..0000000 --- a/lambda/rapi/middleware/middleware_test.go +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package middleware - -import ( - "bytes" - "context" - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "testing" - - "github.com/go-chi/chi" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/extensions" - "go.amzn.com/lambda/rapi/handler" - "go.amzn.com/lambda/rapi/model" -) - -type mockHandler struct{} - -func (h *mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {} - -func TestRuntimeReleaseMiddleware(t *testing.T) { - appCtx := appctx.NewApplicationContext() - router := chi.NewRouter() - handler := &mockHandler{} - router.Use(RuntimeReleaseMiddleware()) - router.Get("/", handler.ServeHTTP) - - userAgent := "foobar" - - responseRecorder := httptest.NewRecorder() - responseBody := make([]byte, 100) - request := httptest.NewRequest("GET", "/", bytes.NewReader(responseBody)) - request.Header.Set("User-Agent", userAgent) - router.ServeHTTP(responseRecorder, appctx.RequestWithAppCtx(request, appCtx)) - - assert.Equal(t, http.StatusOK, responseRecorder.Code) - ctxRuntimeRelease, ok := appCtx.Load(appctx.AppCtxRuntimeReleaseKey) - assert.True(t, ok) - assert.Equal(t, userAgent, ctxRuntimeRelease) -} - -func TestAgentUniqueIdentifierHeaderValidatorForbidden(t *testing.T) { - router := chi.NewRouter() - mockHandler := &mockHandler{} - router.Get("/", AgentUniqueIdentifierHeaderValidator(mockHandler).ServeHTTP) - responseBody := make([]byte, 100) - var errorResponse model.ErrorResponse - - request := httptest.NewRequest("GET", "/", bytes.NewReader(responseBody)) - - responseRecorder := httptest.NewRecorder() - router.ServeHTTP(responseRecorder, request) - assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - respBody, _ := io.ReadAll(responseRecorder.Body) - json.Unmarshal(respBody, &errorResponse) - assert.Equal(t, handler.ErrAgentIdentifierMissing, errorResponse.ErrorType) - - responseRecorder = httptest.NewRecorder() - request.Header.Set(handler.LambdaAgentIdentifier, "invalid-unique-identifier") - router.ServeHTTP(responseRecorder, request) - assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - respBody, _ = io.ReadAll(responseRecorder.Body) - json.Unmarshal(respBody, &errorResponse) - assert.Equal(t, handler.ErrAgentIdentifierInvalid, errorResponse.ErrorType) -} - -func TestAgentUniqueIdentifierHeaderValidatorSuccess(t *testing.T) { - router := chi.NewRouter() - mockHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - val, ok := r.Context().Value(handler.AgentIDCtxKey).(uuid.UUID) - if !ok { - assert.FailNow(t, "expected key not in request context") - } - assert.Equal(t, "85083764-ff1e-476f-ada1-d51f26e4f6be", val.String()) - }) - router.Get("/", AgentUniqueIdentifierHeaderValidator(mockHandler).ServeHTTP) - responseBody := make([]byte, 100) - request := httptest.NewRequest("GET", "/", bytes.NewReader(responseBody)) - ctx := context.Background() - request = request.WithContext(ctx) - - responseRecorder := httptest.NewRecorder() - responseRecorder.Code = http.StatusOK - request.Header.Set(handler.LambdaAgentIdentifier, "85083764-ff1e-476f-ada1-d51f26e4f6be") - router.ServeHTTP(responseRecorder, request) - assert.Equal(t, http.StatusOK, responseRecorder.Code) -} - -func TestAllowIfExtensionsEnabledPositive(t *testing.T) { - router := chi.NewRouter() - handler := &mockHandler{} - router.Use(AllowIfExtensionsEnabled) - router.Get("/", handler.ServeHTTP) - - responseRecorder := httptest.NewRecorder() - responseBody := make([]byte, 100) - - extensions.Enable() - defer extensions.Disable() - - router.ServeHTTP(responseRecorder, httptest.NewRequest("GET", "/", bytes.NewReader(responseBody))) - - assert.Equal(t, http.StatusOK, responseRecorder.Code) -} - -func TestAllowIfExtensionsEnabledNegative(t *testing.T) { - router := chi.NewRouter() - handler := &mockHandler{} - router.Use(AllowIfExtensionsEnabled) - router.Get("/", handler.ServeHTTP) - - responseRecorder := httptest.NewRecorder() - responseBody := make([]byte, 100) - router.ServeHTTP(responseRecorder, httptest.NewRequest("GET", "/", bytes.NewReader(responseBody))) - - assert.Equal(t, http.StatusNotFound, responseRecorder.Code) -} diff --git a/lambda/rapi/model/agentevent.go b/lambda/rapi/model/agentevent.go deleted file mode 100644 index 5c0cc73..0000000 --- a/lambda/rapi/model/agentevent.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package model - -// AgentEvent is one of INVOKE, SHUTDOWN agent events -type AgentEvent struct { - EventType string `json:"eventType"` - DeadlineMs int64 `json:"deadlineMs"` -} - -// AgentInvokeEvent is the response to agent's get next request -type AgentInvokeEvent struct { - *AgentEvent - RequestID string `json:"requestId"` - InvokedFunctionArn string `json:"invokedFunctionArn"` - Tracing *Tracing `json:"tracing,omitempty"` -} - -// AgentShutdownEvent is the response to agent's get next request -type AgentShutdownEvent struct { - *AgentEvent - ShutdownReason string `json:"shutdownReason"` -} diff --git a/lambda/rapi/model/agentregisterresponse.go b/lambda/rapi/model/agentregisterresponse.go deleted file mode 100644 index fb9cacc..0000000 --- a/lambda/rapi/model/agentregisterresponse.go +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package model - -// ExtensionRegisterResponse is a response returned by the API server on extension/register post request -type ExtensionRegisterResponse struct { - AccountID string `json:"accountId,omitempty"` - FunctionName string `json:"functionName"` - FunctionVersion string `json:"functionVersion"` - Handler string `json:"handler"` -} diff --git a/lambda/rapi/model/cognitoidentity.go b/lambda/rapi/model/cognitoidentity.go deleted file mode 100644 index 68025a7..0000000 --- a/lambda/rapi/model/cognitoidentity.go +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package model - -// CognitoIdentity is returned by the API server in a response headers, -// providing information about client's Cognito identity. -type CognitoIdentity struct { - CognitoIdentityID string `json:"cognitoIdentityId"` - CognitoIdentityPoolID string `json:"cognitoIdentityPoolId"` -} diff --git a/lambda/rapi/model/error_cause.go b/lambda/rapi/model/error_cause.go deleted file mode 100644 index a7706d5..0000000 --- a/lambda/rapi/model/error_cause.go +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package model - -import ( - "encoding/json" - "fmt" -) - -// MaxErrorCauseSizeBytes limits the size of a cause, -// since the max X-Ray document size of 64kB -const MaxErrorCauseSizeBytes = 64 << 10 - -type exceptionStackFrame struct { - Path string `json:"path,omitempty"` - Line int `json:"line,omitempty"` - Label string `json:"label,omitempty"` -} - -type exception struct { - Message string `json:"message,omitempty"` - Type string `json:"type,omitempty"` - Stack []exceptionStackFrame `json:"stack,omitempty"` -} - -// ErrorCause represents the cause of an error reported by -// the runtime, and may contain stack traces and exceptions -type ErrorCause struct { - Exceptions []exception `json:"exceptions"` - WorkingDir string `json:"working_directory"` - Paths []string `json:"paths"` - Message string `json:"message,omitempty"` -} - -// newErrorCause unmarshals JSON into an ErrorCause -func newErrorCause(errorCauseJSON []byte) (*ErrorCause, error) { - var ec ErrorCause - - if err := json.Unmarshal(errorCauseJSON, &ec); err != nil { - return nil, fmt.Errorf("failed to parse error cause JSON: %s", err) - } - - return &ec, nil -} - -// isValid validates the ErrorCause format -func (ec *ErrorCause) isValid() bool { - if len(ec.WorkingDir) == 0 && len(ec.Paths) == 0 && len(ec.Exceptions) == 0 && len(ec.Message) == 0 { - // X-Ray public docs imply WorkingDir, Paths & Exceptions are not optional, - // but we are less restrictive. - - // Message is not a valid field as per the latest X-Ray docs, but native runtimes - // use it via LambdaSandboxRuntime and the X-Ray Data Plane frontend supports it. - return false - } - - return true -} - -func (ec *ErrorCause) croppedJSON() []byte { - // Iteratively crop the error cause by a factor of its size - // until it is within the size limit - - truncationFactors := []float64{0.8, 0.6, 0.4, 0.2} - for _, factor := range truncationFactors { - compactor := newErrorCauseCompactor(*ec) - compactor.crop(factor) - validErrorCauseJSON, err := json.Marshal(compactor.cause()) - if err != nil { - break - } - - if len(validErrorCauseJSON) <= MaxErrorCauseSizeBytes { - return validErrorCauseJSON - } - } - - // If compaction failed, drop Exceptions & Path, and truncate - // Message & WorkingDir, using smallest possible factor - compactor := newErrorCauseCompactor(*ec) - compactor.crop(0) - - validErrorCauseJSON, err := json.Marshal(compactor.cause()) - if err != nil { - return nil - } - - return validErrorCauseJSON -} - -// ValidatedErrorCauseJSON returns an error if -// the ErrorCause JSON has an invalid format -func ValidatedErrorCauseJSON(errorCauseJSON []byte) ([]byte, error) { - errorCause, err := newErrorCause(errorCauseJSON) - if err != nil { - return nil, err - } - - if !errorCause.isValid() { - return nil, fmt.Errorf("error cause body has invalid format: %s", errorCauseJSON) - } - - validErrorCauseJSON, err := json.Marshal(errorCause) - if err != nil { - return nil, err - } - - if len(validErrorCauseJSON) > MaxErrorCauseSizeBytes { - return errorCause.croppedJSON(), nil - } - - return validErrorCauseJSON, nil -} diff --git a/lambda/rapi/model/error_cause_compactor.go b/lambda/rapi/model/error_cause_compactor.go deleted file mode 100644 index 8d1204a..0000000 --- a/lambda/rapi/model/error_cause_compactor.go +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package model - -import ( - "math" -) - -const paddingForFieldNames = 4096 - -type errorCauseCompactor struct { - ec ErrorCause -} - -func newErrorCauseCompactor(errorCause ErrorCause) *errorCauseCompactor { - ec := errorCause - return &errorCauseCompactor{ec} -} - -// cropStackTraces crops Exceptions & Paths of an error cause -// by a factor in [0,1]. e.g. 0 removes all elements, 1 removes -// nothing, 0.5 removes half the array elements -func (c *errorCauseCompactor) cropStackTraces(factor float64) { - if factor > 0 { - factor = math.Min(factor, 1) // number in (0,1] - exceptionsLen := float64(len(c.ec.Exceptions)) * factor - pathLen := float64(len(c.ec.Paths)) * factor - - c.ec.Exceptions = c.ec.Exceptions[:int(exceptionsLen)] - c.ec.Paths = c.ec.Paths[:int(pathLen)] - - return - } - - c.ec.Exceptions = nil - c.ec.Paths = nil -} - -// cropMessage crops Message of an error cause to be half the -// max size when the factor is 0 (worst-case) -func (c *errorCauseCompactor) cropMessage(factor float64) { - if factor > 0 { - return - } - - length := ((MaxErrorCauseSizeBytes - paddingForFieldNames) / 2) - c.ec.Message = cropString(c.ec.Message, length) -} - -// cropWorkingDir crops WorkingDir of an error cause to be half -// the max size when the factor is 0 (worst-case) -func (c *errorCauseCompactor) cropWorkingDir(factor float64) { - if factor > 0 { - return - } - - length := ((MaxErrorCauseSizeBytes - paddingForFieldNames) / 2) - c.ec.WorkingDir = cropString(c.ec.WorkingDir, length) -} - -func (c *errorCauseCompactor) crop(factor float64) { - c.cropStackTraces(factor) - c.cropMessage(factor) - c.cropWorkingDir(factor) -} - -func (c *errorCauseCompactor) cause() *ErrorCause { - return &c.ec -} - -func cropString(str string, length int) string { - if len(str) <= length { - return str - } - - truncationIndicator := `...` - length = length - len(truncationIndicator) - return str[:length] + truncationIndicator -} diff --git a/lambda/rapi/model/error_cause_compactor_test.go b/lambda/rapi/model/error_cause_compactor_test.go deleted file mode 100644 index 2322c48..0000000 --- a/lambda/rapi/model/error_cause_compactor_test.go +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package model - -import ( - "fmt" - "strings" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestErrorCauseCropMessageAndWorkingDir(t *testing.T) { - largeString := strings.Repeat("a", 4*MaxErrorCauseSizeBytes) - factorsAndExpectedLengths := map[float64]int{ - 1.5: len(largeString), - 1.0: len(largeString), - 0.5: len(largeString), - 0.0: (MaxErrorCauseSizeBytes - paddingForFieldNames) / 2, - } - - for factor, length := range factorsAndExpectedLengths { - cause := ErrorCause{ - Message: largeString, - WorkingDir: largeString, - } - - compactor := newErrorCauseCompactor(cause) - compactor.crop(factor) - - failureMsg := fmt.Sprintf("factor: %f, length: expected=%d, actual=%d", factor, length, len(compactor.ec.Message)) - assert.Len(t, compactor.ec.Message, length, "Message: "+failureMsg) - assert.Len(t, compactor.ec.WorkingDir, length, "WorkingDir: "+failureMsg) - } -} - -func TestErrorCauseCropStackTraces(t *testing.T) { - noOfElements := 3 * MaxErrorCauseSizeBytes - largeExceptions := make([]exception, noOfElements) - for i := range largeExceptions { - largeExceptions[i] = exception{Message: "a"} - } - - largePaths := make([]string, noOfElements) - for i := range largePaths { - largePaths[i] = "a" - } - - factorsAndExpectedLengths := map[float64]int{ - 1.5: noOfElements, - 1.0: noOfElements, - 0.5: int(noOfElements / 2), - 0.0: 0, - } - - for factor, length := range factorsAndExpectedLengths { - cause := ErrorCause{ - Exceptions: largeExceptions, - Paths: largePaths, - } - - compactor := newErrorCauseCompactor(cause) - compactor.crop(factor) - - failureMsg := fmt.Sprintf("factor: %f, length: expected=%d, actual=%d", factor, length, len(compactor.ec.WorkingDir)) - assert.Len(t, compactor.ec.Exceptions, length, "Exceptions: "+failureMsg) - assert.Len(t, compactor.ec.Paths, length, "Paths: "+failureMsg) - } -} - -func TestCropString(t *testing.T) { - maxLen := 5 - stringsAndExpectedCrops := map[string]string{ - "abcde": "abcde", - "abcdef": "ab...", - "": "", - } - - for str, expectedStr := range stringsAndExpectedCrops { - assert.Equal(t, expectedStr, cropString(str, maxLen)) - } -} diff --git a/lambda/rapi/model/error_cause_test.go b/lambda/rapi/model/error_cause_test.go deleted file mode 100644 index f001bb9..0000000 --- a/lambda/rapi/model/error_cause_test.go +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package model - -import ( - "fmt" - "strings" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestErrorCauseValidationWhenCauseIsValid(t *testing.T) { - validCauses := [][]byte{ - []byte(`{"paths":[],"working_directory":"/foo/bar/baz","exceptions":[]}`), - []byte(`{"paths":["foo", "bar"]}`), - []byte(`{"working_directory":"/foo/bar/baz"}`), - []byte(`{"exceptions":[{"message": "foo"}, {"message": "bar"}]}`), - []byte(`{"exceptions":[{}]}`), - []byte(`{"exceptions":[{}], "arbitrary":"field"}`), - []byte(`{"message":"foo error"}`), - } - - for _, c := range validCauses { - _, err := ValidatedErrorCauseJSON(c) - assert.Nil(t, err, "validation failed for valid cause") - } -} - -func TestWorkingDirCropping(t *testing.T) { - -} - -func TestErrorCauseMarshallingWhenCauseIsValid(t *testing.T) { - causesAndExpectations := map[string]string{ - `{"paths":[],"working_directory":"/","exceptions":[]}`: `{"paths":[],"working_directory":"/","exceptions":[]}`, - `{"paths":["f"]}`: `{"paths":["f"],"working_directory":"","exceptions":null}`, - `{"working_directory":"/foo"}`: `{"paths":null,"working_directory":"/foo","exceptions":null}`, - `{"exceptions":[{}], "arbitrary":"field"}`: `{"paths":null,"working_directory":"","exceptions":[{}]}`, - `{"message":"foo"}`: `{"paths":null,"working_directory":"","exceptions":null,"message":"foo"}`, - } - - for causeJSON, expectedJSON := range causesAndExpectations { - validCauseJSON, err := ValidatedErrorCauseJSON([]byte(causeJSON)) - assert.Nil(t, err, "validation failed for valid cause") - assert.JSONEq(t, string(expectedJSON), string(validCauseJSON)) - } -} - -func TestErrorCauseValidationWhenCauseIsInvalid(t *testing.T) { - invalidCauses := [][]byte{ - []byte(`{"paths":[],"working_directory":"","exceptions":[]}`), - []byte(`{"paths":"","working_directory":"","exceptions":[]}`), - []byte(`{"paths":"","exceptions":[]}`), - []byte(`{foo: invalid}`), - []byte(`{}`), - []byte(`{"arbitrary":"field"}`), - } - - for _, c := range invalidCauses { - causeJSON, err := ValidatedErrorCauseJSON(c) - assert.Error(t, err, "validation didn't return an error") - assert.Nil(t, causeJSON) - } -} - -func TestErrorCauseCroppedJSONForEmptyCause(t *testing.T) { - emptyCauseJSON := `{"exceptions":null, "paths":null, "working_directory":""}` - cause := ErrorCause{} - - causeJSON := cause.croppedJSON() - - assert.JSONEq(t, emptyCauseJSON, string(causeJSON)) -} - -func TestErrorCauseCroppedJSONForLargeCause(t *testing.T) { - noOfElements := MaxErrorCauseSizeBytes - largeExceptions := make([]exception, noOfElements) - for i := range largeExceptions { - largeExceptions[i] = exception{Message: "a"} - } - - largePaths := make([]string, noOfElements) - for i := range largePaths { - largePaths[i] = "a" - } - - largeCause := ErrorCause{ - Message: strings.Repeat("a", noOfElements), - WorkingDir: strings.Repeat("a", noOfElements), - Exceptions: largeExceptions, - Paths: largePaths, - } - expectedStringFieldsLen := (MaxErrorCauseSizeBytes - paddingForFieldNames) / 2 - - causeJSON := largeCause.croppedJSON() - assert.True(t, len(causeJSON) <= MaxErrorCauseSizeBytes, fmt.Sprintf("cropped JSON too long: len=%d", len(causeJSON))) - - parsedCause, err := newErrorCause(causeJSON) - assert.NoError(t, err, "failed to parse constructed JSON") - assert.Len(t, parsedCause.Message, expectedStringFieldsLen, "Message length incorrect") - assert.Len(t, parsedCause.WorkingDir, expectedStringFieldsLen, "WorkingDir length incorrect") - assert.Len(t, parsedCause.Exceptions, 0, "Exceptions length incorrect") - assert.Len(t, parsedCause.Paths, 0, "Paths length incorrect") -} - -func TestErrorCauseCroppedJSONForLargeCauseWithOnlyExceptionsAndPaths(t *testing.T) { - elementsAndExpectedLengthFactors := map[int]float64{ - 100: 0.8, - 5000: 0.6, - 8000: 0.4, - 10000: 0.2, - MaxErrorCauseSizeBytes / 4: 0.0, - } - - for noOfElements, factor := range elementsAndExpectedLengthFactors { - largeExceptions := make([]exception, noOfElements) - for i := range largeExceptions { - largeExceptions[i] = exception{Message: "a"} - } - - largePaths := make([]string, noOfElements) - for i := range largePaths { - largePaths[i] = "a" - } - - largeCause := ErrorCause{ - Exceptions: largeExceptions, - Paths: largePaths, - } - - causeJSON := largeCause.croppedJSON() - assert.True(t, len(causeJSON) <= MaxErrorCauseSizeBytes, fmt.Sprintf("cropped JSON too long: len=%d", len(causeJSON))) - - parsedCause, err := newErrorCause(causeJSON) - assert.NoError(t, err, "failed to parse constructed JSON") - assert.Len(t, parsedCause.Message, 0, "Message length incorrect") - assert.Len(t, parsedCause.WorkingDir, 0, "WorkingDir length incorrect") - assert.Len(t, parsedCause.Exceptions, int(float64(noOfElements)*factor), "Exceptions length incorrect") - assert.Len(t, parsedCause.Paths, int(float64(noOfElements)*factor), "Paths length incorrect") - } -} diff --git a/lambda/rapi/model/errorresponse.go b/lambda/rapi/model/errorresponse.go deleted file mode 100644 index 4c95e6c..0000000 --- a/lambda/rapi/model/errorresponse.go +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package model - -// ErrorResponse is a standard invoke error response, -// providing information about the error. -type ErrorResponse struct { - ErrorMessage string `json:"errorMessage"` - ErrorType string `json:"errorType"` - StackTrace []string `json:"stackTrace,omitempty"` -} diff --git a/lambda/rapi/model/statusresponse.go b/lambda/rapi/model/statusresponse.go deleted file mode 100644 index 7694e94..0000000 --- a/lambda/rapi/model/statusresponse.go +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package model - -// StatusResponse is a response returned by the API server, -// providing status information. -type StatusResponse struct { - Status string `json:"status"` -} diff --git a/lambda/rapi/model/tracing.go b/lambda/rapi/model/tracing.go deleted file mode 100644 index 83f97e8..0000000 --- a/lambda/rapi/model/tracing.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package model - -type TracingType string - -const ( - // XRayTracingType represents an X-Ray Tracing object type - XRayTracingType TracingType = "X-Amzn-Trace-Id" -) - -const ( - XRaySampled = "1" - XRayNonSampled = "0" -) - -// Tracing object returned as part of agent Invoke event -type Tracing struct { - Type TracingType `json:"type"` - XRayTracing -} - -// XRayTracing is a type of Tracing object -type XRayTracing struct { - Value string `json:"value"` -} - -// NewXRayTracing returns a new XRayTracing object with specified value -func NewXRayTracing(value string) *Tracing { - if len(value) == 0 { - return nil - } - - return &Tracing{ - XRayTracingType, - XRayTracing{value}, - } -} diff --git a/lambda/rapi/rapi_fuzz_test.go b/lambda/rapi/rapi_fuzz_test.go deleted file mode 100644 index f1df47f..0000000 --- a/lambda/rapi/rapi_fuzz_test.go +++ /dev/null @@ -1,391 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapi - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "log" - "net/http" - "net/http/httptest" - "net/url" - "os" - "regexp" - "strings" - "testing" - "unicode" - - "github.com/stretchr/testify/assert" - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/extensions" - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/telemetry" - "go.amzn.com/lambda/testdata" -) - -type runtimeFunctionErrStruct struct { - ErrorMessage string - ErrorType string - StackTrace []string -} - -func FuzzRuntimeAPIRouter(f *testing.F) { - extensions.Enable() - defer extensions.Disable() - - addSeedCorpusURLTargets(f) - - f.Fuzz(func(t *testing.T, rawPath string, payload []byte, isGetMethod bool) { - u, err := parseToURLStruct(rawPath) - if err != nil { - t.Skipf("error parsing url: %v. Skipping test.", err) - } - - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - - invoke := createDummyInvoke() - flowTest.ConfigureForInvoke(context.Background(), invoke) - - appctx.StoreInitType(flowTest.AppCtx, true) - - rapiServer := makeRapiServer(flowTest) - - method := "GET" - if !isGetMethod { - method = "POST" - } - - request := httptest.NewRequest(method, rawPath, bytes.NewReader(payload)) - responseRecorder := serveTestRequest(rapiServer, request) - - if isExpectedPath(u.Path, invoke.ID, isGetMethod) { - assertExpectedPathResponseCode(t, responseRecorder.Code, rawPath) - } else { - assertUnexpectedPathResponseCode(t, responseRecorder.Code, rawPath) - } - }) -} - -func FuzzInitErrorHandler(f *testing.F) { - addRuntimeFunctionErrorJSONCorpus(f) - - f.Fuzz(func(t *testing.T, errorBody []byte, errTypeHeader []byte) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - - rapiServer := makeRapiServer(flowTest) - - target := makeTargetURL("/runtime/init/error", version20180601) - request := httptest.NewRequest("POST", target, bytes.NewReader(errorBody)) - request = appctx.RequestWithAppCtx(request, flowTest.AppCtx) - request.Header.Set("Lambda-Runtime-Function-Error-Type", string(errTypeHeader)) - - responseRecorder := serveTestRequest(rapiServer, request) - - assert.Equal(t, http.StatusAccepted, responseRecorder.Code) - assert.JSONEq(t, "{\"status\":\"OK\"}\n", responseRecorder.Body.String()) - assert.Equal(t, "application/json", responseRecorder.Header().Get("Content-Type")) - - assertErrorResponsePersists(t, errorBody, errTypeHeader, flowTest) - }) -} - -func FuzzInvocationResponseHandler(f *testing.F) { - f.Add([]byte("SUCCESS"), []byte("application/json"), []byte("streaming")) - f.Add([]byte(strings.Repeat("a", interop.MaxPayloadSize+1)), []byte("application/json"), []byte("streaming")) - - f.Fuzz(func(t *testing.T, responseBody []byte, contentType []byte, responseMode []byte) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - flowTest.Runtime.Ready() - - invoke := createDummyInvoke() - flowTest.ConfigureForInvoke(context.Background(), invoke) - - rapiServer := makeRapiServer(flowTest) - - target := makeTargetURL(fmt.Sprintf("/runtime/invocation/%s/response", invoke.ID), version20180601) - request := httptest.NewRequest("POST", target, bytes.NewReader(responseBody)) - request.Header.Set("Content-Type", string(contentType)) - request.Header.Set("Lambda-Runtime-Function-Response-Mode", string(responseMode)) - - request = appctx.RequestWithAppCtx(request, flowTest.AppCtx) - - responseRecorder := serveTestRequest(rapiServer, request) - - if !isValidResponseMode(responseMode) { - assert.Equal(t, http.StatusBadRequest, responseRecorder.Code) - return - } - - if len(responseBody) > interop.MaxPayloadSize { - assertInvocationResponseTooLarge(t, responseRecorder, flowTest, responseBody) - } else { - assertInvocationResponseAccepted(t, responseRecorder, flowTest, responseBody, contentType) - } - }) -} - -func FuzzInvocationErrorHandler(f *testing.F) { - addRuntimeFunctionErrorJSONCorpus(f) - - f.Fuzz(func(t *testing.T, errorBody []byte, errTypeHeader []byte) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - flowTest.Runtime.Ready() - appCtx := flowTest.AppCtx - - invoke := createDummyInvoke() - flowTest.ConfigureForInvoke(context.Background(), invoke) - - rapiServer := makeRapiServer(flowTest) - - target := makeTargetURL(fmt.Sprintf("/runtime/invocation/%s/error", invoke.ID), version20180601) - request := httptest.NewRequest("POST", target, bytes.NewReader(errorBody)) - request = appctx.RequestWithAppCtx(request, appCtx) - - request.Header.Set("Lambda-Runtime-Function-Error-Type", string(errTypeHeader)) - - responseRecorder := serveTestRequest(rapiServer, request) - - assert.Equal(t, http.StatusAccepted, responseRecorder.Code) - assert.JSONEq(t, "{\"status\":\"OK\"}\n", responseRecorder.Body.String()) - assert.Equal(t, "application/json", responseRecorder.Header().Get("Content-Type")) - - assertErrorResponsePersists(t, errorBody, errTypeHeader, flowTest) - }) -} - -func FuzzRestoreErrorHandler(f *testing.F) { - f.Fuzz(func(t *testing.T, errorBody []byte, errTypeHeader []byte) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForRestoring() - - appctx.StoreInitType(flowTest.AppCtx, true) - - rapiServer := makeRapiServer(flowTest) - - target := makeTargetURL("/runtime/restore/error", version20180601) - request := httptest.NewRequest("POST", target, bytes.NewReader(errorBody)) - request = appctx.RequestWithAppCtx(request, flowTest.AppCtx) - - request.Header.Set("Lambda-Runtime-Function-Error-Type", string(errTypeHeader)) - - responseRecorder := serveTestRequest(rapiServer, request) - - assert.Equal(t, http.StatusAccepted, responseRecorder.Code) - assert.JSONEq(t, "{\"status\":\"OK\"}\n", responseRecorder.Body.String()) - assert.Equal(t, "application/json", responseRecorder.Header().Get("Content-Type")) - }) -} - -func makeRapiServer(flowTest *testdata.FlowTest) *Server { - return NewServer( - "127.0.0.1", - 0, - flowTest.AppCtx, - flowTest.RegistrationService, - flowTest.RenderingService, - true, - &telemetry.NoOpSubscriptionAPI{}, - flowTest.TelemetrySubscription, - flowTest.CredentialsService, - ) -} - -func createDummyInvoke() *interop.Invoke { - return &interop.Invoke{ - ID: "InvocationID1", - Payload: strings.NewReader("Payload1"), - } -} - -func makeTargetURL(path string, apiVersion string) string { - protocol := "http" - endpoint := os.Getenv("AWS_LAMBDA_RUNTIME_API") - baseurl := fmt.Sprintf("%s://%s%s", protocol, endpoint, apiVersion) - - return fmt.Sprintf("%s%s", baseurl, path) -} - -func serveTestRequest(rapiServer *Server, request *http.Request) *httptest.ResponseRecorder { - responseRecorder := httptest.NewRecorder() - rapiServer.server.Handler.ServeHTTP(responseRecorder, request) - log.Printf("test(%v) = %v", request.URL, responseRecorder.Code) - - return responseRecorder -} - -func addSeedCorpusURLTargets(f *testing.F) { - invoke := createDummyInvoke() - errStruct := runtimeFunctionErrStruct{ - ErrorMessage: "error occurred", - ErrorType: "Runtime.UnknownReason", - StackTrace: []string{}, - } - errJSON, _ := json.Marshal(errStruct) - f.Add(makeTargetURL("/runtime/init/error", version20180601), errJSON, false) - f.Add(makeTargetURL("/runtime/invocation/next", version20180601), []byte{}, true) - f.Add(makeTargetURL(fmt.Sprintf("/runtime/invocation/%s/response", invoke.ID), version20180601), []byte("SUCCESS"), false) - f.Add(makeTargetURL(fmt.Sprintf("/runtime/invocation/%s/error", invoke.ID), version20180601), errJSON, false) - f.Add(makeTargetURL("/runtime/restore/next", version20180601), []byte{}, true) - f.Add(makeTargetURL("/runtime/restore/error", version20180601), errJSON, false) - - f.Add(makeTargetURL("/extension/register", version20200101), []byte("register"), false) - f.Add(makeTargetURL("/extension/event/next", version20200101), []byte("next"), true) - f.Add(makeTargetURL("/extension/init/error", version20200101), []byte("init error"), false) - f.Add(makeTargetURL("/extension/exit/error", version20200101), []byte("exit error"), false) -} - -func addRuntimeFunctionErrorJSONCorpus(f *testing.F) { - runtimeFuncErr := runtimeFunctionErrStruct{ - ErrorMessage: "error", - ErrorType: "Runtime.Unknown", - StackTrace: []string{}, - } - data, _ := json.Marshal(runtimeFuncErr) - - f.Add(data, []byte("Runtime.Unknown")) -} - -func isExpectedPath(path string, invokeID string, isGetMethod bool) bool { - expectedPaths := make(map[string]bool) - - expectedPaths[fmt.Sprintf("%s/runtime/init/error", version20180601)] = false - expectedPaths[fmt.Sprintf("%s/runtime/invocation/next", version20180601)] = true - expectedPaths[fmt.Sprintf("%s/runtime/invocation/%s/response", version20180601, invokeID)] = false - expectedPaths[fmt.Sprintf("%s/runtime/invocation/%s/error", version20180601, invokeID)] = false - expectedPaths[fmt.Sprintf("%s/runtime/restore/next", version20180601)] = true - expectedPaths[fmt.Sprintf("%s/runtime/restore/error", version20180601)] = false - - expectedPaths[fmt.Sprintf("%s/extension/register", version20200101)] = false - expectedPaths[fmt.Sprintf("%s/extension/event/next", version20200101)] = true - expectedPaths[fmt.Sprintf("%s/extension/init/error", version20200101)] = false - expectedPaths[fmt.Sprintf("%s/extension/exit/error", version20200101)] = false - - val, found := expectedPaths[path] - return found && (val == isGetMethod) -} - -func parseToURLStruct(rawPath string) (*url.URL, error) { - invalidChars := regexp.MustCompile(`[ %]+`) - if invalidChars.MatchString(rawPath) { - return nil, errors.New("url must not contain spaces or %") - } - - for _, r := range rawPath { - if !unicode.IsGraphic(r) { - return nil, errors.New("url contains non-graphic runes") - } - } - - if _, err := url.ParseRequestURI(rawPath); err != nil { - return nil, err - } - - u, err := url.Parse(rawPath) - if err != nil { - return nil, err - } - - if u.Scheme == "" { - return nil, errors.New("blank url scheme") - } - - return u, nil -} - -func assertInvocationResponseAccepted(t *testing.T, responseRecorder *httptest.ResponseRecorder, - flowTest *testdata.FlowTest, responseBody []byte, contentType []byte) { - assert.Equal(t, http.StatusAccepted, responseRecorder.Code, - "Handler returned wrong status code: got %v expected %v", - responseRecorder.Code, http.StatusAccepted) - - expectedAPIResponse := "{\"status\":\"OK\"}\n" - body, err := io.ReadAll(responseRecorder.Body) - assert.NoError(t, err) - assert.JSONEq(t, expectedAPIResponse, string(body)) - - response := flowTest.InteropServer.Response - assert.NotNil(t, response) - assert.Nil(t, flowTest.InteropServer.ErrorResponse) - - assert.Equal(t, string(contentType), flowTest.InteropServer.ResponseContentType) - - assert.Equal(t, responseBody, response, - "Persisted response data in app context must match the submitted.") -} - -func assertInvocationResponseTooLarge(t *testing.T, responseRecorder *httptest.ResponseRecorder, flowTest *testdata.FlowTest, responseBody []byte) { - assert.Equal(t, http.StatusRequestEntityTooLarge, responseRecorder.Code, "Handler returned wrong status code: got %v expected %v", - responseRecorder.Code, http.StatusRequestEntityTooLarge) - - expectedAPIResponse := fmt.Sprintf("{\"errorMessage\":\"Exceeded maximum allowed payload size (%d bytes).\",\"errorType\":\"RequestEntityTooLarge\"}\n", interop.MaxPayloadSize) - body, err := io.ReadAll(responseRecorder.Body) - assert.NoError(t, err) - assert.JSONEq(t, expectedAPIResponse, string(body)) - - errorResponse := flowTest.InteropServer.ErrorResponse - assert.NotNil(t, errorResponse) - assert.Nil(t, flowTest.InteropServer.Response) - assert.Equal(t, fatalerror.FunctionOversizedResponse, errorResponse.FunctionError.Type) - assert.Equal(t, fmt.Sprintf("Response payload size (%v bytes) exceeded maximum allowed payload size (6291556 bytes).", len(responseBody)), errorResponse.FunctionError.Message) - - var errorPayload map[string]interface{} - assert.NoError(t, json.Unmarshal(errorResponse.Payload, &errorPayload)) - assert.Equal(t, string(errorResponse.FunctionError.Type), errorPayload["errorType"]) - assert.Equal(t, errorResponse.FunctionError.Message, errorPayload["errorMessage"]) -} - -func assertErrorResponsePersists(t *testing.T, errorBody []byte, errTypeHeader []byte, flowTest *testdata.FlowTest) { - errorResponse := flowTest.InteropServer.ErrorResponse - assert.NotNil(t, errorResponse) - assert.Nil(t, flowTest.InteropServer.Response) - - var runtimeFunctionErr runtimeFunctionErrStruct - var expectedErrMsg string - - // If input payload is a valid function error json object, - // assert that the error message persisted in the response - err := json.Unmarshal(errorBody, &runtimeFunctionErr) - if err != nil { - expectedErrMsg = runtimeFunctionErr.ErrorMessage - } - assert.Equal(t, expectedErrMsg, errorResponse.FunctionError.Message) - - // If input error type is valid (within the allow-listed value, - // assert that the error type persisted in the response - expectedErrType := fatalerror.GetValidRuntimeOrFunctionErrorType(string(errTypeHeader)) - assert.Equal(t, expectedErrType, errorResponse.FunctionError.Type) - - assert.Equal(t, errorBody, errorResponse.Payload) -} - -func isValidResponseMode(responseMode []byte) bool { - responseModeStr := string(responseMode) - return responseModeStr == "streaming" || - responseModeStr == "" -} - -func assertExpectedPathResponseCode(t *testing.T, code int, target string) { - if !(code == http.StatusOK || - code == http.StatusAccepted || - code == http.StatusForbidden) { - t.Errorf("Unexpected status code (%v) for target (%v)", code, target) - } -} - -func assertUnexpectedPathResponseCode(t *testing.T, code int, target string) { - if !(code == http.StatusNotFound || - code == http.StatusMethodNotAllowed || - code == http.StatusBadRequest) { - t.Errorf("Unexpected status code (%v) for target (%v)", code, target) - } -} diff --git a/lambda/rapi/rendering/doc.go b/lambda/rapi/rendering/doc.go deleted file mode 100644 index bc359a1..0000000 --- a/lambda/rapi/rendering/doc.go +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -/* -Package rendering provides stateful event rendering service. - -State of the rendering service should be set from the main event dispatch thread -prior to releasing threads that are registered for the event. - -Example of INVOKE event: - -[thread] // suspended in READY state - -[main] // set renderer for INVOKE event -[main] renderingService.SetRenderer(rendering.NewInvokeRenderer()) -[main] // release threads registered for INVOKE event - -[thread] // receives INVOKE event -*/ -package rendering diff --git a/lambda/rapi/rendering/render_error.go b/lambda/rapi/rendering/render_error.go deleted file mode 100644 index 151e606..0000000 --- a/lambda/rapi/rendering/render_error.go +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rendering - -import ( - "fmt" - "net/http" - - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapi/model" -) - -// RenderForbiddenWithTypeMsg method for rendering error response -func RenderForbiddenWithTypeMsg(w http.ResponseWriter, r *http.Request, errorType string, format string, args ...interface{}) { - if err := RenderJSON(http.StatusForbidden, w, r, &model.ErrorResponse{ - ErrorType: errorType, - ErrorMessage: fmt.Sprintf(format, args...), - }); err != nil { - log.WithError(err).Warn("Error while rendering response") - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -// RenderInternalServerError method for rendering error response -func RenderInternalServerError(w http.ResponseWriter, r *http.Request) { - if err := RenderJSON(http.StatusInternalServerError, w, r, &model.ErrorResponse{ - ErrorMessage: "Internal Server Error", - ErrorType: ErrorTypeInternalServerError, - }); err != nil { - log.WithError(err).Warn("Error while rendering response") - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -// RenderRequestEntityTooLarge method for rendering error response -func RenderRequestEntityTooLarge(w http.ResponseWriter, r *http.Request) { - if err := RenderJSON(http.StatusRequestEntityTooLarge, w, r, &model.ErrorResponse{ - ErrorMessage: fmt.Sprintf("Exceeded maximum allowed payload size (%d bytes).", interop.MaxPayloadSize), - ErrorType: ErrorTypeRequestEntityTooLarge, - }); err != nil { - log.WithError(err).Warn("Error while rendering response") - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -// RenderTruncatedHTTPRequestError method for rendering error response -func RenderTruncatedHTTPRequestError(w http.ResponseWriter, r *http.Request) { - if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ - ErrorMessage: "HTTP request detected as truncated", - ErrorType: ErrorTypeTruncatedHTTPRequest, - }); err != nil { - log.WithError(err).Warn("Error while rendering response") - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -// RenderInvalidRequestID renders invalid request ID error response -func RenderInvalidRequestID(w http.ResponseWriter, r *http.Request) { - if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ - ErrorMessage: "Invalid request ID", - ErrorType: "InvalidRequestID", - }); err != nil { - log.WithError(err).Warn("Error while rendering response") - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -// RenderInvalidFunctionResponseMode renders invalid function response mode response -func RenderInvalidFunctionResponseMode(w http.ResponseWriter, r *http.Request) { - if err := RenderJSON(http.StatusBadRequest, w, r, &model.ErrorResponse{ - ErrorMessage: "Invalid function response mode", - ErrorType: "InvalidFunctionResponseMode", - }); err != nil { - log.WithError(err).Warn("Error while rendering response") - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -// RenderInteropError is a convenience method for interpreting interop errors -func RenderInteropError(writer http.ResponseWriter, request *http.Request, err error) { - if err == interop.ErrInvalidInvokeID || err == interop.ErrResponseSent { - RenderInvalidRequestID(writer, request) - } else { - log.Panic(err) - } -} diff --git a/lambda/rapi/rendering/render_json.go b/lambda/rapi/rendering/render_json.go deleted file mode 100644 index 1afbfe8..0000000 --- a/lambda/rapi/rendering/render_json.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rendering - -import ( - "bytes" - "encoding/json" - "net/http" - - log "github.com/sirupsen/logrus" -) - -// RenderJSON: -// - marshals 'v' to JSON, automatically escaping HTML -// - sets the Content-Type as application/json -// - sets the HTTP response status code -// - returns an error if it occurred before writing to response -// TODO: r *http.Request is not used, remove it -func RenderJSON(status int, w http.ResponseWriter, r *http.Request, v interface{}) error { - buf := &bytes.Buffer{} - enc := json.NewEncoder(buf) - enc.SetEscapeHTML(true) - if err := enc.Encode(v); err != nil { - return err - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - if _, err := w.Write(buf.Bytes()); err != nil { - log.WithError(err).Warn("Error while writing response body") - } - - return nil -} diff --git a/lambda/rapi/rendering/rendering.go b/lambda/rapi/rendering/rendering.go deleted file mode 100644 index 08de1e3..0000000 --- a/lambda/rapi/rendering/rendering.go +++ /dev/null @@ -1,302 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 -// LOCALSTACK CHANGES 2024-02-13: casting of MaxPayloadSize - -package rendering - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "io" - "net/http" - "strconv" - "sync" - "time" - - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/metering" - "go.amzn.com/lambda/rapi/model" - - "github.com/google/uuid" - log "github.com/sirupsen/logrus" -) - -const ( - // ErrorTypeInternalServerError error type for internal server error - ErrorTypeInternalServerError = "InternalServerError" - // ErrorTypeInvalidStateTransition error type for invalid state transition - ErrorTypeInvalidStateTransition = "InvalidStateTransition" - // ErrorTypeInvalidRequestID error type for invalid request ID error - ErrorTypeInvalidRequestID = "InvalidRequestID" - // ErrorTypeRequestEntityTooLarge error type for payload too large - ErrorTypeRequestEntityTooLarge = "RequestEntityTooLarge" - // ErrorTypeTruncatedHTTPRequest error type for truncated HTTP request - ErrorTypeTruncatedHTTPRequest = "TruncatedHTTPRequest" -) - -// ErrRenderingServiceStateNotSet returned when state not set -var ErrRenderingServiceStateNotSet = errors.New("EventRenderingService state not set") - -// RendererState is renderer object state. -type RendererState interface { - RenderAgentEvent(w http.ResponseWriter, r *http.Request) error - RenderRuntimeEvent(w http.ResponseWriter, r *http.Request) error -} - -// EventRenderingService is a state machine for rendering runtime and agent API responses. -type EventRenderingService struct { - mutex *sync.RWMutex - currentState RendererState -} - -// NewRenderingService returns new EventRenderingService. -func NewRenderingService() *EventRenderingService { - return &EventRenderingService{ - mutex: &sync.RWMutex{}, - } -} - -// SetRenderer set current state -func (s *EventRenderingService) SetRenderer(state RendererState) { - s.mutex.Lock() - defer s.mutex.Unlock() - s.currentState = state -} - -// RenderAgentEvent delegates to state implementation. -func (s *EventRenderingService) RenderAgentEvent(w http.ResponseWriter, r *http.Request) error { - s.mutex.RLock() - defer s.mutex.RUnlock() - if s.currentState == nil { - return ErrRenderingServiceStateNotSet - } - return s.currentState.RenderAgentEvent(w, r) -} - -// RenderRuntimeEvent delegates to state implementation. -func (s *EventRenderingService) RenderRuntimeEvent(w http.ResponseWriter, r *http.Request) error { - s.mutex.RLock() - defer s.mutex.RUnlock() - if s.currentState == nil { - return ErrRenderingServiceStateNotSet - } - return s.currentState.RenderRuntimeEvent(w, r) -} - -type RestoreRenderer struct{} - -func NewRestoreRenderer() *RestoreRenderer { - return &RestoreRenderer{} -} - -func (s *RestoreRenderer) RenderRuntimeEvent(writer http.ResponseWriter, request *http.Request) error { - writer.WriteHeader(http.StatusOK) - return nil -} - -func (s *RestoreRenderer) RenderAgentEvent(writer http.ResponseWriter, request *http.Request) error { - return nil -} - -// InvokeRendererMetrics contains metrics of invoke request -type InvokeRendererMetrics struct { - ReadTime time.Duration - SizeBytes int -} - -// InvokeRenderer knows how to render invoke event. -type InvokeRenderer struct { - ctx context.Context - invoke *interop.Invoke - tracingHeaderParser func(context.Context) string - requestBuffer *bytes.Buffer - requestMutex sync.Mutex - metrics InvokeRendererMetrics -} - -// NewInvokeRenderer returns new invoke event renderer -func NewInvokeRenderer(ctx context.Context, invoke *interop.Invoke, requestBuffer *bytes.Buffer, traceParser func(context.Context) string) *InvokeRenderer { - requestBuffer.Reset() // clear request buffer, since this can be reused across invokes - return &InvokeRenderer{ - invoke: invoke, - ctx: ctx, - tracingHeaderParser: traceParser, - requestBuffer: requestBuffer, - requestMutex: sync.Mutex{}, - } -} - -// newAgentInvokeEvent forms a new AgentInvokeEvent from INVOKE request -func newAgentInvokeEvent(req *interop.Invoke) (*model.AgentInvokeEvent, error) { - deadlineMono, err := strconv.ParseInt(req.DeadlineNs, 10, 64) - if err != nil { - return nil, err - } - deadline := metering.MonoToEpoch(deadlineMono) / int64(time.Millisecond) - return &model.AgentInvokeEvent{ - AgentEvent: &model.AgentEvent{ - EventType: "INVOKE", - DeadlineMs: deadline, - }, - RequestID: req.ID, - InvokedFunctionArn: req.InvokedFunctionArn, - Tracing: model.NewXRayTracing(req.TraceID), - }, nil -} - -// RenderAgentEvent renders invoke event json for agent. -func (s *InvokeRenderer) RenderAgentEvent(writer http.ResponseWriter, request *http.Request) error { - event, err := newAgentInvokeEvent(s.invoke) - if err != nil { - return err - } - - bytes, err := json.Marshal(event) - if err != nil { - return err - } - - eventID := uuid.New() - headers := writer.Header() - headers.Set("Lambda-Extension-Event-Identifier", eventID.String()) - headers.Set("Content-Type", "application/json") - writer.WriteHeader(http.StatusOK) - - if _, err := writer.Write(bytes); err != nil { - return err - } - return nil -} - -func (s *InvokeRenderer) bufferInvokeRequest() error { - s.requestMutex.Lock() - defer s.requestMutex.Unlock() - var err error = nil - if s.requestBuffer.Len() == 0 { - reader := io.LimitReader(s.invoke.Payload, int64(interop.MaxPayloadSize)) - start := time.Now() - _, err = s.requestBuffer.ReadFrom(reader) - s.metrics = InvokeRendererMetrics{ - ReadTime: time.Since(start), - SizeBytes: s.requestBuffer.Len(), - } - } - return err -} - -// RenderRuntimeEvent renders invoke payload for runtime. -func (s *InvokeRenderer) RenderRuntimeEvent(writer http.ResponseWriter, request *http.Request) error { - invoke := s.invoke - customerTraceID := s.tracingHeaderParser(s.ctx) - - cognitoIdentityJSON := "" - if len(invoke.CognitoIdentityID) != 0 || len(invoke.CognitoIdentityPoolID) != 0 { - cognitoJSON, err := json.Marshal(model.CognitoIdentity{ - CognitoIdentityID: invoke.CognitoIdentityID, - CognitoIdentityPoolID: invoke.CognitoIdentityPoolID, - }) - if err != nil { - return err - } - - cognitoIdentityJSON = string(cognitoJSON) - } - - var deadlineHeader string - if t, err := strconv.ParseInt(invoke.DeadlineNs, 10, 64); err == nil { - deadlineHeader = strconv.FormatInt(metering.MonoToEpoch(t)/int64(time.Millisecond), 10) - } else { - log.WithError(err).Warn("Failed to compute deadline header") - } - - renderInvokeHeaders(writer, invoke.ID, customerTraceID, invoke.ClientContext, - cognitoIdentityJSON, invoke.InvokedFunctionArn, deadlineHeader, invoke.ContentType) - - if invoke.Payload != nil { - if err := s.bufferInvokeRequest(); err != nil { - return err - } - _, err := writer.Write(s.requestBuffer.Bytes()) - return err - } - - return nil -} - -func (s *InvokeRenderer) GetMetrics() InvokeRendererMetrics { - s.requestMutex.Lock() - defer s.requestMutex.Unlock() - return s.metrics -} - -// ShutdownRenderer renderer for shutdown event. -type ShutdownRenderer struct { - AgentEvent model.AgentShutdownEvent -} - -// RenderAgentEvent renders shutdown event for agent. -func (s *ShutdownRenderer) RenderAgentEvent(w http.ResponseWriter, r *http.Request) error { - bytes, err := json.Marshal(s.AgentEvent) - if err != nil { - return err - } - if _, err := w.Write(bytes); err != nil { - return err - } - return nil -} - -// RenderRuntimeEvent renders shutdown event for runtime. -func (s *ShutdownRenderer) RenderRuntimeEvent(w http.ResponseWriter, r *http.Request) error { - panic("We should SIGTERM runtime") -} - -func renderInvokeHeaders(writer http.ResponseWriter, invokeID string, customerTraceID string, clientContext string, - cognitoIdentity string, invokedFunctionArn string, deadlineMs string, contentType string) { - - setHeaderIfNotEmpty := func(headers http.Header, key string, value string) { - if value != "" { - headers.Set(key, value) - } - } - - headers := writer.Header() - setHeaderIfNotEmpty(headers, "Lambda-Runtime-Aws-Request-Id", invokeID) - setHeaderIfNotEmpty(headers, "Lambda-Runtime-Trace-Id", customerTraceID) - setHeaderIfNotEmpty(headers, "Lambda-Runtime-Client-Context", clientContext) - setHeaderIfNotEmpty(headers, "Lambda-Runtime-Cognito-Identity", cognitoIdentity) - setHeaderIfNotEmpty(headers, "Lambda-Runtime-Invoked-Function-Arn", invokedFunctionArn) - setHeaderIfNotEmpty(headers, "Lambda-Runtime-Deadline-Ms", deadlineMs) - if contentType == "" { - contentType = "application/json" - } - headers.Set("Content-Type", contentType) - writer.WriteHeader(http.StatusOK) -} - -// RenderRuntimeLogsResponse renders response from Telemetry API -func RenderRuntimeLogsResponse(w http.ResponseWriter, respBody []byte, status int, headers map[string][]string) error { - respHeaders := w.Header() - for k, vals := range headers { - for _, v := range vals { - respHeaders.Add(k, v) - } - } - - w.WriteHeader(status) - - _, err := w.Write(respBody) - return err -} - -// RenderAccepted method for rendering accepted status response -func RenderAccepted(w http.ResponseWriter, r *http.Request) { - if err := RenderJSON(http.StatusAccepted, w, r, &model.StatusResponse{ - Status: "OK", - }); err != nil { - log.WithError(err).Warn("Error while rendering response") - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} diff --git a/lambda/rapi/router.go b/lambda/rapi/router.go deleted file mode 100644 index dc036bc..0000000 --- a/lambda/rapi/router.go +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapi - -import ( - "net/http" - - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/rapi/handler" - "go.amzn.com/lambda/rapi/middleware" - "go.amzn.com/lambda/telemetry" - - "github.com/go-chi/chi" - - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/rapi/rendering" -) - -// NewRouter returns a new instance of chi router implementing -// Runtime API specification. -func NewRouter(appCtx appctx.ApplicationContext, registrationService core.RegistrationService, renderingService *rendering.EventRenderingService) http.Handler { - - router := chi.NewRouter() - router.Use(middleware.AppCtxMiddleware(appCtx)) - router.Use(middleware.AccessLogMiddleware()) - router.Use(middleware.RuntimeReleaseMiddleware()) - - // To respect Hyrum's Law, keeping /ping API even though - // we no longer use it ourselves. - // http://www.hyrumslaw.com/ - router.Get("/ping", handler.NewPingHandler().ServeHTTP) - - router.Get("/runtime/invocation/next", - handler.NewInvocationNextHandler(registrationService, renderingService).ServeHTTP) - - // Note, request validation must happen before state - // transition. State machine transitions are irreversible - // at the moment. - router.Post("/runtime/invocation/{awsrequestid}/response", - middleware.AwsRequestIDValidator( - handler.NewInvocationResponseHandler(registrationService)).ServeHTTP) - - router.Post("/runtime/invocation/{awsrequestid}/error", - middleware.AwsRequestIDValidator( - handler.NewInvocationErrorHandler(registrationService)).ServeHTTP) - - router.Post("/runtime/init/error", handler.NewInitErrorHandler(registrationService).ServeHTTP) - - if appctx.LoadInitType(appCtx) == appctx.InitCaching { - router.Get("/runtime/restore/next", handler.NewRestoreNextHandler(registrationService, renderingService).ServeHTTP) - router.Post("/runtime/restore/error", handler.NewRestoreErrorHandler(registrationService).ServeHTTP) - } - - return router -} - -// ExtensionsRouter returns a new instance of chi router implementing -// Extensions Runtime API specification. -func ExtensionsRouter(appCtx appctx.ApplicationContext, registrationService core.RegistrationService, renderingService *rendering.EventRenderingService) http.Handler { - router := chi.NewRouter() - router.Use(middleware.AccessLogMiddleware()) - router.Use(middleware.AllowIfExtensionsEnabled) - router.Use(middleware.AppCtxMiddleware(appCtx)) - - registerHandler := handler.NewAgentRegisterHandler(registrationService) - router.Post("/extension/register", - registerHandler.ServeHTTP) - - router.Get("/extension/event/next", - middleware.AgentUniqueIdentifierHeaderValidator( - handler.NewAgentNextHandler(registrationService, renderingService)).ServeHTTP) - - router.Post("/extension/init/error", - middleware.AgentUniqueIdentifierHeaderValidator( - handler.NewAgentInitErrorHandler(registrationService)).ServeHTTP) - - router.Post("/extension/exit/error", - middleware.AgentUniqueIdentifierHeaderValidator( - handler.NewAgentExitErrorHandler(registrationService)).ServeHTTP) - - return router -} - -// LogsAPIRouter returns a new instance of chi router implementing -// Logs API specification. -func LogsAPIRouter(registrationService core.RegistrationService, logsSubscriptionAPI telemetry.SubscriptionAPI) http.Handler { - router := chi.NewRouter() - router.Use(middleware.AccessLogMiddleware()) - router.Use(middleware.AllowIfExtensionsEnabled) - - router.Put("/logs", - middleware.AgentUniqueIdentifierHeaderValidator( - handler.NewRuntimeTelemetrySubscriptionHandler(registrationService, logsSubscriptionAPI)).ServeHTTP) - - return router -} - -// LogsAPIStubRouter returns a new instance of chi router implementing -// a stub of Logs API that always returns a non-committal response to -// prevent customer code from crashing when Logs API is disabled locally -func LogsAPIStubRouter() http.Handler { - router := chi.NewRouter() - - router.Put("/logs", handler.NewRuntimeLogsAPIStubHandler().ServeHTTP) - - return router -} - -// TelemetryRouter returns a new instance of chi router implementing -// Telemetry API specification. -func TelemetryAPIRouter(registrationService core.RegistrationService, telemetrySubscriptionAPI telemetry.SubscriptionAPI) http.Handler { - router := chi.NewRouter() - router.Use(middleware.AccessLogMiddleware()) - router.Use(middleware.AllowIfExtensionsEnabled) - - router.Put("/telemetry", - middleware.AgentUniqueIdentifierHeaderValidator( - handler.NewRuntimeTelemetrySubscriptionHandler(registrationService, telemetrySubscriptionAPI)).ServeHTTP) - - return router -} - -// TelemetryStubRouter returns a new instance of chi router implementing -// a stub of Telemetry API that always returns a non-committal response to -// prevent customer code from crashing when Telemetry API is disabled locally -func TelemetryAPIStubRouter() http.Handler { - router := chi.NewRouter() - - router.Put("/telemetry", handler.NewRuntimeTelemetryAPIStubHandler().ServeHTTP) - - return router -} - -func CredentialsAPIRouter(credentialsService core.CredentialsService) http.Handler { - router := chi.NewRouter() - - router.Get("/credentials", handler.NewCredentialsHandler(credentialsService).ServeHTTP) - - return router -} diff --git a/lambda/rapi/router_test.go b/lambda/rapi/router_test.go deleted file mode 100644 index 276fa53..0000000 --- a/lambda/rapi/router_test.go +++ /dev/null @@ -1,353 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapi - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapi/model" - - "go.amzn.com/lambda/testdata" -) - -func createInvoke(id string) *interop.Invoke { - return &interop.Invoke{ - ID: id, - InvokedFunctionArn: "arn::dummy:Function", - Payload: strings.NewReader("{\"invoke\":\"" + id + "\"}"), - DeadlineNs: "123456", - } -} - -// Make a test request -func makeTestRequest(t *testing.T, router http.Handler, request *http.Request) *httptest.ResponseRecorder { - responseRecorder := httptest.NewRecorder() - router.ServeHTTP(responseRecorder, request) - t.Logf("test(%v) = %v", request.URL, responseRecorder.Code) - return responseRecorder -} - -// Make a test request in a benchmark -func makeBenchRequest(b *testing.B, router http.Handler, request *http.Request) *httptest.ResponseRecorder { - responseRecorder := httptest.NewRecorder() - b.StartTimer() - router.ServeHTTP(responseRecorder, request) - b.StopTimer() - return responseRecorder -} - -// Verify response error type -func assertResponseErrorType(t *testing.T, expectedErrorType string, response *httptest.ResponseRecorder) { - errResp := model.ErrorResponse{} - err := json.Unmarshal(response.Body.Bytes(), &errResp) - assert.Nil(t, err) - assert.Equal(t, expectedErrorType, errResp.ErrorType) -} - -// TestAcceptXML tests that server response is always -// rendered as JSON, regardless of the value provided -// in "Accept" header. -// -// When using render.Render(...), rendering function -// would attempt to render response using content type -// specified in the "Accept" header. -// -// The purpose of this test is to confirm that RAPID -// renders all server responses as application/json. -func TestAcceptXML(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) - responseRecorder := httptest.NewRecorder() - request := httptest.NewRequest("POST", "/runtime/invocation/x-y-z/error", bytes.NewReader([]byte(""))) - // Tell server that client side accepts "application/xml". - request.Header.Add("Accept", "application/xml") - // Serve request. - router.ServeHTTP(responseRecorder, request) - v := &model.ErrorResponse{} - // Expected response is 403 transition is not allowed, rendered as JSON. - err := json.Unmarshal(responseRecorder.Body.Bytes(), v) - if err != nil { - t.Error("Expected JSON document, received: ", responseRecorder.Body.String()) - } - assert.Equal(t, "InvalidRequestID", v.ErrorType) - assert.Equal(t, http.StatusBadRequest, responseRecorder.Code) -} - -// Verify that unsupported methods return 404 -func Test404PageNotFound(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) - responseRecorder := makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/unsupported", bytes.NewReader([]byte("")))) - assert.Equal(t, http.StatusNotFound, responseRecorder.Code) - assert.Equal(t, "404 page not found\n", responseRecorder.Body.String()) -} - -func Test405MethodNotAllowed(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) - responseRecorder := makeTestRequest(t, router, httptest.NewRequest("DELETE", "/runtime/invocation/ABC/error", bytes.NewReader([]byte("")))) - assert.Equal(t, http.StatusMethodNotAllowed, responseRecorder.Code) -} - -func TestInitErrorAccepted(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) - responseRecorder := makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/init/error", bytes.NewReader([]byte("{}")))) - assert.Equal(t, http.StatusAccepted, responseRecorder.Code) -} - -func TestInitErrorForbidden(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) - flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) - responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) - assert.Equal(t, http.StatusOK, responseRecorder.Code) - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/init/error", bytes.NewReader([]byte("{}")))) - assert.Equal(t, http.StatusForbidden, responseRecorder.Code) -} - -func TestInvokeResponseAccepted(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) - flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) - responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) - assert.Equal(t, http.StatusOK, responseRecorder.Code) - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/invocation/InvokeA/response", bytes.NewReader([]byte("{}")))) - assert.Equal(t, http.StatusAccepted, responseRecorder.Code) -} - -func TestInvokeErrorResponseAccepted(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) - flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) - responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) - assert.Equal(t, http.StatusOK, responseRecorder.Code) - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/invocation/InvokeA/error", bytes.NewReader([]byte("{}")))) - assert.Equal(t, http.StatusAccepted, responseRecorder.Code) -} - -func TestInvokeNextTwice(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) - flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) - responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) - assert.Equal(t, http.StatusOK, responseRecorder.Code) - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) - assert.Equal(t, http.StatusOK, responseRecorder.Code) -} - -func TestInvokeResponseInvalidRequestID(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) - flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) - responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) - assert.Equal(t, http.StatusOK, responseRecorder.Code) - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/invocation/XYZ/response", bytes.NewReader([]byte("{}")))) - assert.Equal(t, http.StatusBadRequest, responseRecorder.Code) - assertResponseErrorType(t, "InvalidRequestID", responseRecorder) -} - -func TestInvokeErrorResponseInvalidRequestID(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) - flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) - responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) - assert.Equal(t, http.StatusOK, responseRecorder.Code) - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/invocation/XYZ/error", bytes.NewReader([]byte("{}")))) - assert.Equal(t, http.StatusBadRequest, responseRecorder.Code) - assertResponseErrorType(t, "InvalidRequestID", responseRecorder) -} - -func TestInvokeResponseTwice(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) - flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) - responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) - assert.Equal(t, http.StatusOK, responseRecorder.Code) - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/invocation/ABC/response", bytes.NewReader([]byte("{}")))) - assert.Equal(t, http.StatusAccepted, responseRecorder.Code) - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/invocation/ABC/response", bytes.NewReader([]byte("{}")))) - assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - assertResponseErrorType(t, "InvalidStateTransition", responseRecorder) -} - -func TestInvokeErrorResponseTwice(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) - flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) - responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) - assert.Equal(t, http.StatusOK, responseRecorder.Code) - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/invocation/ABC/error", bytes.NewReader([]byte("{}")))) - assert.Equal(t, http.StatusAccepted, responseRecorder.Code) - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/invocation/ABC/error", bytes.NewReader([]byte("{}")))) - assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - assertResponseErrorType(t, "InvalidStateTransition", responseRecorder) -} - -func TestInvokeResponseAfterErrorResponse(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) - flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) - responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) - assert.Equal(t, http.StatusOK, responseRecorder.Code) - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/invocation/ABC/error", bytes.NewReader([]byte("{}")))) - assert.Equal(t, http.StatusAccepted, responseRecorder.Code) - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/invocation/ABC/response", bytes.NewReader([]byte("{}")))) - assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - assertResponseErrorType(t, "InvalidStateTransition", responseRecorder) -} - -func TestInvokeErrorResponseAfterResponse(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) - flowTest.ConfigureForInvoke(context.Background(), createInvoke("ABC")) - responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) - assert.Equal(t, http.StatusOK, responseRecorder.Code) - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/invocation/ABC/response", bytes.NewReader([]byte("{}")))) - assert.Equal(t, http.StatusAccepted, responseRecorder.Code) - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/invocation/ABC/error", bytes.NewReader([]byte("{}")))) - assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - assertResponseErrorType(t, "InvalidStateTransition", responseRecorder) -} - -func TestMoreThanOneInvoke(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) - var responseRecorder *httptest.ResponseRecorder - for _, id := range []string{"A", "B", "C"} { - flowTest.ConfigureForInvoke(context.Background(), createInvoke(id)) - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) - assert.Equal(t, http.StatusOK, responseRecorder.Code) - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("POST", fmt.Sprintf("/runtime/invocation/%s/response", id), bytes.NewReader([]byte("{}")))) - assert.Equal(t, http.StatusAccepted, responseRecorder.Code) - } -} - -func TestInitCachingAPIDisabledForPlainInit(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) - var responseRecorder *httptest.ResponseRecorder - - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/restore/next", nil)) - assert.Equal(t, http.StatusNotFound, responseRecorder.Code) - - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("GET", "/credentials", nil)) - assert.Equal(t, http.StatusNotFound, responseRecorder.Code) -} - -func benchmarkInvokeResponse(b *testing.B, payload []byte) { - b.StopTimer() - b.ResetTimer() // does not restart timer, only resets state - b.ReportAllocs() - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) - for i := 0; i < b.N; i++ { - id := uuid.New().String() - flowTest.ConfigureForInvoke(context.Background(), createInvoke(id)) - makeBenchRequest(b, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) - makeBenchRequest(b, router, httptest.NewRequest("POST", fmt.Sprintf("/runtime/invocation/%s/response", id), bytes.NewReader(payload))) - } -} - -func BenchmarkInvokeResponseWithEmptyPayload(b *testing.B) { - benchmarkInvokeResponse(b, []byte("")) -} - -func BenchmarkInvokeResponseWith4KBPayload(b *testing.B) { - benchmarkInvokeResponse(b, bytes.Repeat([]byte("a"), 4*1024)) -} - -func BenchmarkInvokeResponseWith512KBPayload(b *testing.B) { - benchmarkInvokeResponse(b, bytes.Repeat([]byte("a"), 512*1024)) -} - -func BenchmarkInvokeResponseWith1MBPayload(b *testing.B) { - benchmarkInvokeResponse(b, bytes.Repeat([]byte("a"), 1*1024*1024)) -} - -func BenchmarkInvokeResponseWith2MBPayload(b *testing.B) { - benchmarkInvokeResponse(b, bytes.Repeat([]byte("a"), 2*1024*1024)) -} - -func BenchmarkInvokeResponseWith4MBPayload(b *testing.B) { - benchmarkInvokeResponse(b, bytes.Repeat([]byte("a"), 4*1024*1024)) -} - -func BenchmarkInvokeResponseWith6MBPayload(b *testing.B) { - benchmarkInvokeResponse(b, bytes.Repeat([]byte("a"), 6*1024*1024)) -} - -func benchmarkInvokeRequest(b *testing.B, payload []byte) { - b.StopTimer() - b.ResetTimer() // does not restart timer, only resets state - b.ReportAllocs() - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) - var requestBuf bytes.Buffer - for i := 0; i < b.N; i++ { - id := uuid.New().String() - ctx, invoke := context.Background(), createInvoke(id) - flowTest.ConfigureForInvoke(ctx, invoke) // set invoke ID and initialize barriers - flowTest.ConfigureInvokeRenderer(ctx, invoke, &requestBuf) // override invoke renderer to reuse buffer - makeBenchRequest(b, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) - makeBenchRequest(b, router, httptest.NewRequest("POST", fmt.Sprintf("/runtime/invocation/%s/response", id), bytes.NewReader(payload))) - } -} - -func BenchmarkInvokeRequestWithEmptyPayload(b *testing.B) { - benchmarkInvokeRequest(b, []byte("")) -} - -func BenchmarkInvokeRequestWith4KBPayload(b *testing.B) { - benchmarkInvokeRequest(b, bytes.Repeat([]byte("a"), 4*1024)) -} - -func BenchmarkInvokeRequestWith512KBPayload(b *testing.B) { - benchmarkInvokeRequest(b, bytes.Repeat([]byte("a"), 512*1024)) -} - -func BenchmarkInvokeRequestWith1MBPayload(b *testing.B) { - benchmarkInvokeRequest(b, bytes.Repeat([]byte("a"), 1*1024*1024)) -} - -func BenchmarkInvokeRequestWith2MBPayload(b *testing.B) { - benchmarkInvokeRequest(b, bytes.Repeat([]byte("a"), 2*1024*1024)) -} - -func BenchmarkInvokeRequestWith4MBPayload(b *testing.B) { - benchmarkInvokeRequest(b, bytes.Repeat([]byte("a"), 4*1024*1024)) -} - -func BenchmarkInvokeRequestWith6MBPayload(b *testing.B) { - benchmarkInvokeRequest(b, bytes.Repeat([]byte("a"), 6*1024*1024)) -} diff --git a/lambda/rapi/security_test.go b/lambda/rapi/security_test.go deleted file mode 100644 index 3f869d5..0000000 --- a/lambda/rapi/security_test.go +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapi - -import ( - "bytes" - "context" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" - - "go.amzn.com/lambda/testdata" -) - -// Verify that state machine will accept response and error with valid invoke id -func TestInvokeValidId(t *testing.T) { - - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) - - flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) - - // Send /next to start Invoke A - responseRecorder := makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) - - assert.Equal(t, http.StatusOK, responseRecorder.Code) - - // Send invocation response with correct Invoke Id - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/invocation/InvokeA/response", bytes.NewReader([]byte("{}")))) - - assert.Equal(t, http.StatusAccepted, responseRecorder.Code) - - flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeB")) - - // Send /next to start Invoke B - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) - - assert.Equal(t, http.StatusOK, responseRecorder.Code) - - // Send invocation error with correct Invoke id - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/invocation/InvokeB/error", bytes.NewReader([]byte("{}")))) - - assert.Equal(t, http.StatusAccepted, responseRecorder.Code) -} - -// All invoke responses must be validated to ensure they use the active Invoke request-id -// This is a Security requirement -func TestSecurityInvokeResponseBadRequestId(t *testing.T) { - - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) - - flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) - - // Try to use the invoke id before next - responseRecorder := makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/invocation/InvokeA/response", bytes.NewReader([]byte("{}")))) - - // NOTE: InvalidStateTransition 403 - forbidden by the state machine - assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - assertResponseErrorType(t, "InvalidStateTransition", responseRecorder) - - // Send /next to start Invoke A - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) - - assert.Equal(t, http.StatusOK, responseRecorder.Code) - - // Send invocation response with invalid invoke id - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/invocation/InvokeZ/response", bytes.NewReader([]byte("{}")))) - - assert.Equal(t, http.StatusBadRequest, responseRecorder.Code) - assertResponseErrorType(t, "InvalidRequestID", responseRecorder) - - // Send invocation response with correct Invoke Id - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/invocation/InvokeA/response", bytes.NewReader([]byte("{}")))) - - assert.Equal(t, http.StatusAccepted, responseRecorder.Code) - - flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeB")) - - // Send /next to start new Invoke - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) - - assert.Equal(t, http.StatusOK, responseRecorder.Code) - - // Try to re-use the old invoke id - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/invocation/InvokeA/response", bytes.NewReader([]byte("{}")))) - - assert.Equal(t, http.StatusBadRequest, responseRecorder.Code) - assertResponseErrorType(t, "InvalidRequestID", responseRecorder) -} - -// All invoke errors must be validated to ensure they use the active Invoke request-id -// This is a Security requirement -func TestSecurityInvokeErrorBadRequestId(t *testing.T) { - - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - router := NewRouter(flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService) - - flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeA")) - - // Try to use invoke id before next - responseRecorder := makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/invocation/InvokeA/error", bytes.NewReader([]byte("{}")))) - - // NOTE: InvalidStateTransition 403 - forbidden by the state machine - assert.Equal(t, http.StatusForbidden, responseRecorder.Code) - assertResponseErrorType(t, "InvalidStateTransition", responseRecorder) - - // Send /next to start Invoke A - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) - - assert.Equal(t, http.StatusOK, responseRecorder.Code) - - // Send invocation response with invalid invoke id - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/invocation/InvokeZ/error", bytes.NewReader([]byte("{}")))) - - assert.Equal(t, http.StatusBadRequest, responseRecorder.Code) - assertResponseErrorType(t, "InvalidRequestID", responseRecorder) - - // Send invocation error with correct Invoke Id - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/invocation/InvokeA/error", bytes.NewReader([]byte("{}")))) - - assert.Equal(t, http.StatusAccepted, responseRecorder.Code) - - flowTest.ConfigureForInvoke(context.Background(), createInvoke("InvokeB")) - - // Send /next to start Invoke B - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("GET", "/runtime/invocation/next", nil)) - - assert.Equal(t, http.StatusOK, responseRecorder.Code) - - // Try to re-use the previous invoke id - responseRecorder = makeTestRequest(t, router, httptest.NewRequest("POST", "/runtime/invocation/InvokeA/error", bytes.NewReader([]byte("{}")))) - - assert.Equal(t, http.StatusBadRequest, responseRecorder.Code) - assertResponseErrorType(t, "InvalidRequestID", responseRecorder) -} diff --git a/lambda/rapi/server.go b/lambda/rapi/server.go deleted file mode 100644 index b90f208..0000000 --- a/lambda/rapi/server.go +++ /dev/null @@ -1,166 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -// LOCALSTACK CHANGES 2022-03-10: Chi logger middleware added -// LOCALSTACK CHANGES 2023-04-14: Replace Chi logger with the rapid AccessLogMiddleware based on Logrus - -package rapi - -import ( - "context" - "fmt" - "net" - "net/http" - - "github.com/go-chi/chi" - "go.amzn.com/lambda/appctx" - - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapi/middleware" - "go.amzn.com/lambda/rapi/rendering" - "go.amzn.com/lambda/telemetry" - - log "github.com/sirupsen/logrus" -) - -const version20180601 = "/2018-06-01" -const version20200101 = "/2020-01-01" -const version20200815 = "/2020-08-15" -const version20210423 = "/2021-04-23" -const version20220701 = "/2022-07-01" - -// Server is a Runtime API server -type Server struct { - host string - port int - server *http.Server - listener net.Listener - exit chan error -} - -func SaveConnInContext(ctx context.Context, c net.Conn) context.Context { - return context.WithValue(ctx, interop.HTTPConnKey, c) -} - -// NewServer creates a new Runtime API Server -// -// Unlike net/http server's ListenAndServe, we separate Listen() -// and Serve(), this is done to guarantee order: call to Listen() -// should happen before provided runtime is started. -// -// When port is 0, OS will dynamically allocate the listening port. -func NewServer( - host string, - port int, - appCtx appctx.ApplicationContext, - registrationService core.RegistrationService, - renderingService *rendering.EventRenderingService, - telemetryAPIEnabled bool, - logsSubscriptionAPI telemetry.SubscriptionAPI, - telemetrySubscriptionAPI telemetry.SubscriptionAPI, - credentialsService core.CredentialsService, -) *Server { - - exitErrors := make(chan error, 1) - - router := chi.NewRouter() - router.Use(middleware.AccessLogMiddleware()) - router.Mount(version20180601, NewRouter(appCtx, registrationService, renderingService)) - router.Mount(version20200101, ExtensionsRouter(appCtx, registrationService, renderingService)) - - if telemetryAPIEnabled { - router.Mount(version20200815, LogsAPIRouter(registrationService, logsSubscriptionAPI)) - router.Mount(version20220701, TelemetryAPIRouter(registrationService, telemetrySubscriptionAPI)) - } else { - router.Mount(version20200815, LogsAPIStubRouter()) - router.Mount(version20220701, TelemetryAPIStubRouter()) - } - - if appctx.LoadInitType(appCtx) == appctx.InitCaching { - router.Mount(version20210423, CredentialsAPIRouter(credentialsService)) - } - - return &Server{ - host: host, - port: port, - server: &http.Server{Handler: router, ConnContext: SaveConnInContext}, - listener: nil, - exit: exitErrors, - } -} - -// Listen on port -func (s *Server) Listen() error { - addr := fmt.Sprintf("%s:%d", s.host, s.port) - - ln, err := net.Listen("tcp", addr) - if err != nil { - return err - } - - s.listener = ln - if s.port == 0 { - s.port = ln.Addr().(*net.TCPAddr).Port - log.WithField("port", s.port).Info("Listening port was dynamically allocated") - } - - log.Debugf("Runtime API Server listening on %s:%d", s.host, s.port) - - return nil -} - -// Serve requests and close on cancelation signals -func (s *Server) Serve(ctx context.Context) error { - defer s.Close() - - select { - case err := <-s.serveAsync(): - return err - - case err := <-s.exit: - log.Errorf("Error triggered exit: %s", err) - return err - - case <-ctx.Done(): - return ctx.Err() - } -} - -func (s *Server) serveAsync() chan error { - errors := make(chan error) - go func() { - errors <- s.server.Serve(s.listener) - }() - - return errors -} - -// Host is server's host -func (s *Server) Host() string { - return s.host -} - -// Port is server's port -func (s *Server) Port() int { - return s.port -} - -// URL is full server url for specified endpoint -func (s *Server) URL(endpoint string) string { - return fmt.Sprintf("http://%s:%d%s%s", s.Host(), s.Port(), version20180601, endpoint) -} - -// Close forcefully closes listeners & connections -func (s *Server) Close() error { - err := s.server.Close() - if err == nil { - log.Info("Runtime API Server closed") - } - return err -} - -// Shutdown gracefully shuts down server -func (s *Server) Shutdown() error { - return s.server.Shutdown(context.Background()) -} diff --git a/lambda/rapi/server_test.go b/lambda/rapi/server_test.go deleted file mode 100644 index cf31fab..0000000 --- a/lambda/rapi/server_test.go +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapi - -import ( - "context" - "errors" - "fmt" - "io" - "net/http" - "testing" - "time" - - "go.amzn.com/lambda/testdata" - - "github.com/stretchr/testify/assert" -) - -const nextAvailablePort = 0 // net.Listener convention for next available port -const serverAddress = "127.0.0.1" - -func createTestServer(handlerFunc http.HandlerFunc) (*Server, error) { - host, server := serverAddress, &http.Server{Handler: handlerFunc} - s := &Server{ - host: host, - port: nextAvailablePort, - server: server, - listener: nil, - exit: make(chan error, 1), - } - err := s.Listen() - return s, err -} - -func TestServerReturnsSuccessfulResponse(t *testing.T) { - expectedResponse := "foo" - testHandler := http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, expectedResponse) - }) - s, err := createTestServer(testHandler) - if err != nil { - assert.FailNowf(t, "Server failed to listen", err.Error()) - } - go func() { - s.Serve(context.Background()) - }() - - resp, err := http.Get(fmt.Sprintf("http://%s/", s.listener.Addr())) - if err != nil { - assert.FailNowf(t, "Failed to get response", err.Error()) - } - body, err := io.ReadAll(resp.Body) - if err != nil { - assert.FailNowf(t, "Failed to read response body", err.Error()) - } - resp.Body.Close() - s.Close() - - assert.Equal(t, expectedResponse, string(body)) -} - -func TestServerExitsOnContextCancelation(t *testing.T) { - testHandler := http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, "foo") - }) - ctx, cancel := context.WithCancel(context.Background()) - errChan := make(chan error, 1) - s, err := createTestServer(testHandler) - if err != nil { - assert.FailNowf(t, "Server failed to listen", err.Error()) - } - go func() { - errChan <- s.Serve(ctx) - }() - - pingRequestFunc := func() (bool, error) { - if _, err := http.Get(s.URL("/ping")); err != nil { - return false, err - } - return true, nil - } - serverStarted := testdata.Eventually(t, pingRequestFunc, 10*time.Millisecond, 10) - assert.True(t, serverStarted) - - cancel() - error := testdata.WaitForErrorWithTimeout(errChan, 2*time.Second) - s.Close() - assert.Error(t, error) - assert.Contains(t, error.Error(), ctx.Err().Error()) -} - -func TestServerExitsOnExitSignalFromHandler(t *testing.T) { - testHandler := http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, "foo") - }) - errChan := make(chan error, 1) - s, err := createTestServer(testHandler) - if err != nil { - assert.FailNowf(t, "Server failed to listen", err.Error()) - } - go func() { - errChan <- s.Serve(context.Background()) - }() - - exitError := errors.New("foo bar error") - s.exit <- exitError - - error := testdata.WaitForErrorWithTimeout(errChan, time.Second) - s.Close() - - assert.Error(t, error) - assert.Contains(t, error.Error(), exitError.Error()) -} diff --git a/lambda/rapi/telemetry_logs_fuzz_test.go b/lambda/rapi/telemetry_logs_fuzz_test.go deleted file mode 100644 index 89adbd1..0000000 --- a/lambda/rapi/telemetry_logs_fuzz_test.go +++ /dev/null @@ -1,185 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapi - -import ( - "bytes" - "fmt" - "io" - "net/http" - "net/http/httptest" - "testing" - - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "go.amzn.com/lambda/extensions" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapi/handler" - "go.amzn.com/lambda/telemetry" - "go.amzn.com/lambda/testdata" -) - -const ( - logsHandlerPath = "/logs" - telemetryHandlerPath = "/telemetry" - - samplePayload = `{"foo" : "bar"}` -) - -func FuzzTelemetryLogRouters(f *testing.F) { - extensions.Enable() - defer extensions.Disable() - - f.Add(makeTargetURL(logsHandlerPath, version20200815), []byte(samplePayload)) - f.Add(makeTargetURL(telemetryHandlerPath, version20220701), []byte(samplePayload)) - - logsPath := fmt.Sprintf("%s%s", version20200815, logsHandlerPath) - telemetryPath := fmt.Sprintf("%s%s", version20220701, telemetryHandlerPath) - - f.Fuzz(func(t *testing.T, rawPath string, payload []byte) { - u, err := parseToURLStruct(rawPath) - if err != nil { - t.Skipf("error parsing url: %v. Skipping test.", err) - } - - flowTest := testdata.NewFlowTest() - - rapiServer := makeRapiServerWithMockSubscriptionAPI(flowTest, newMockSubscriptionAPI(true), newMockSubscriptionAPI(true)) - - request := httptest.NewRequest("PUT", rawPath, bytes.NewReader(payload)) - responseRecorder := serveTestRequest(rapiServer, request) - - if u.Path == logsPath || u.Path == telemetryPath { - assertExpectedPathResponseCode(t, responseRecorder.Code, rawPath) - } else { - assertUnexpectedPathResponseCode(t, responseRecorder.Code, rawPath) - } - }) -} - -func FuzzLogsHandler(f *testing.F) { - extensions.Enable() - defer extensions.Disable() - - fuzzSubscriptionAPIHandler(f, logsHandlerPath, version20200815) -} - -func FuzzTelemetryHandler(f *testing.F) { - extensions.Enable() - defer extensions.Disable() - - fuzzSubscriptionAPIHandler(f, telemetryHandlerPath, version20220701) -} - -func fuzzSubscriptionAPIHandler(f *testing.F, handlerPath string, apiVersion string) { - flowTest := testdata.NewFlowTest() - agent := makeExternalAgent(flowTest.RegistrationService) - f.Add([]byte(samplePayload), agent.ID.String(), true) - f.Add([]byte(samplePayload), agent.ID.String(), false) - - f.Fuzz(func(t *testing.T, payload []byte, agentIdentifierHeader string, serviceOn bool) { - telemetrySubscriptionAPI := newMockSubscriptionAPI(serviceOn) - logsSubscriptionAPI := newMockSubscriptionAPI(serviceOn) - rapiServer := makeRapiServerWithMockSubscriptionAPI(flowTest, logsSubscriptionAPI, telemetrySubscriptionAPI) - - apiUnderTest := telemetrySubscriptionAPI - if handlerPath == logsHandlerPath { - apiUnderTest = logsSubscriptionAPI - } - - target := makeTargetURL(handlerPath, apiVersion) - request := httptest.NewRequest("PUT", target, bytes.NewReader(payload)) - request.Header.Set(handler.LambdaAgentIdentifier, agentIdentifierHeader) - - responseRecorder := serveTestRequest(rapiServer, request) - - if agentIdentifierHeader == "" { - assertForbiddenErrorType(t, responseRecorder, handler.ErrAgentIdentifierMissing) - return - } - - if _, err := uuid.Parse(agentIdentifierHeader); err != nil { - assertForbiddenErrorType(t, responseRecorder, handler.ErrAgentIdentifierInvalid) - return - } - - if agentIdentifierHeader != agent.ID.String() { - assertForbiddenErrorType(t, responseRecorder, "Extension.UnknownExtensionIdentifier") - return - } - - if !serviceOn { - assertForbiddenErrorType(t, responseRecorder, apiUnderTest.GetServiceClosedErrorType()) - return - } - - // assert that payload has been stored in the mock subscription api after the handler calls Subscribe() - assert.Equal(t, payload, apiUnderTest.receivedPayload) - }) -} - -func makeRapiServerWithMockSubscriptionAPI( - flowTest *testdata.FlowTest, - logsSubscription telemetry.SubscriptionAPI, - telemetrySubscription telemetry.SubscriptionAPI) *Server { - return NewServer( - "127.0.0.1", - 0, - flowTest.AppCtx, - flowTest.RegistrationService, - flowTest.RenderingService, - true, - logsSubscription, - telemetrySubscription, - flowTest.CredentialsService, - ) -} - -type mockSubscriptionAPI struct { - serviceOn bool - receivedPayload []byte -} - -func newMockSubscriptionAPI(serviceOn bool) *mockSubscriptionAPI { - return &mockSubscriptionAPI{ - serviceOn: serviceOn, - } -} - -func (m *mockSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string, remoteAddr string) ([]byte, int, map[string][]string, error) { - if !m.serviceOn { - return nil, 0, map[string][]string{}, telemetry.ErrTelemetryServiceOff - } - - bodyBytes, err := io.ReadAll(body) - if err != nil { - return nil, 0, map[string][]string{}, fmt.Errorf("error Reading the body of subscription request: %s", err) - } - - m.receivedPayload = bodyBytes - - return []byte("OK"), http.StatusOK, map[string][]string{}, nil -} - -func (m *mockSubscriptionAPI) RecordCounterMetric(metricName string, count int) {} - -func (m *mockSubscriptionAPI) FlushMetrics() interop.TelemetrySubscriptionMetrics { - return nil -} - -func (m *mockSubscriptionAPI) Clear() {} - -func (m *mockSubscriptionAPI) TurnOff() {} - -func (m *mockSubscriptionAPI) GetEndpointURL() string { - return "/subscribe" -} - -func (m *mockSubscriptionAPI) GetServiceClosedErrorMessage() string { - return "Subscription API is closed" -} - -func (m *mockSubscriptionAPI) GetServiceClosedErrorType() string { - return "SubscriptionClosed" -} diff --git a/lambda/rapid/exit.go b/lambda/rapid/exit.go deleted file mode 100644 index a601efc..0000000 --- a/lambda/rapid/exit.go +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapid - -import ( - "time" - - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/extensions" - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/telemetry" - - log "github.com/sirupsen/logrus" -) - -func handleInvokeError(execCtx *rapidContext, invokeRequest *interop.Invoke, invokeMx *invokeMetrics, err error) *interop.InvokeFailure { - invokeFailure := newInvokeFailureMsg(execCtx, invokeRequest, invokeMx, err) - - // This is the default error response that gets sent back as the function response in failure cases - invokeFailure.DefaultErrorResponse = interop.GetErrorResponseWithFormattedErrorMessage(invokeFailure.ErrorType, invokeFailure.ErrorMessage, invokeRequest.ID) - - // Invoke with extensions disabled maintains behaviour parity with pre-extensions rapid - if !extensions.AreEnabled() { - invokeFailure.RequestReset = false - return invokeFailure - } - - if err == errResetReceived { - // errResetReceived is returned when execution flow was interrupted by the Reset message, - // hence this error deserves special handling and we yield to main receive loop to handle it - invokeFailure.ResetReceived = true - return invokeFailure - } - - invokeFailure.RequestReset = true - return invokeFailure -} - -func newInvokeFailureMsg(execCtx *rapidContext, invokeRequest *interop.Invoke, invokeMx *invokeMetrics, err error) *interop.InvokeFailure { - errorType, found := appctx.LoadFirstFatalError(execCtx.appCtx) - if !found { - errorType = fatalerror.SandboxFailure - } - - invokeFailure := &interop.InvokeFailure{ - ErrorType: errorType, - ErrorMessage: err, - RequestReset: true, - ResetReceived: false, - RuntimeRelease: appctx.GetRuntimeRelease(execCtx.appCtx), - NumActiveExtensions: execCtx.registrationService.CountAgents(), - InvokeReceivedTime: invokeRequest.InvokeReceivedTime, - } - - if invokeRequest.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(invokeRequest.InvokeResponseMetrics) { - invokeFailure.ResponseMetrics.RuntimeResponseLatencyMs = telemetry.CalculateDuration(execCtx.RuntimeStartedTime, invokeRequest.InvokeResponseMetrics.StartReadingResponseMonoTimeMs) - invokeFailure.ResponseMetrics.RuntimeTimeThrottledMs = invokeRequest.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond) - invokeFailure.ResponseMetrics.RuntimeProducedBytes = invokeRequest.InvokeResponseMetrics.ProducedBytes - invokeFailure.ResponseMetrics.RuntimeOutboundThroughputBps = invokeRequest.InvokeResponseMetrics.OutboundThroughputBps - } - - if invokeMx != nil { - invokeFailure.InvokeMetrics.InvokeRequestReadTimeNs = invokeMx.rendererMetrics.ReadTime.Nanoseconds() - invokeFailure.InvokeMetrics.InvokeRequestSizeBytes = int64(invokeMx.rendererMetrics.SizeBytes) - invokeFailure.InvokeMetrics.RuntimeReadyTime = int64(invokeMx.runtimeReadyTime) - invokeFailure.ExtensionNames = execCtx.GetExtensionNames() - } - - if execCtx.telemetryAPIEnabled { - invokeFailure.LogsAPIMetrics = interop.MergeSubscriptionMetrics(execCtx.logsSubscriptionAPI.FlushMetrics(), execCtx.telemetrySubscriptionAPI.FlushMetrics()) - } - - invokeFailure.InvokeResponseMode = invokeRequest.InvokeResponseMode - - return invokeFailure -} - -func generateInitFailureMsg(execCtx *rapidContext, err error) interop.InitFailure { - errorType, found := appctx.LoadFirstFatalError(execCtx.appCtx) - if !found { - errorType = fatalerror.SandboxFailure - } - - initFailureMsg := interop.InitFailure{ - RequestReset: true, - ErrorType: errorType, - ErrorMessage: err, - RuntimeRelease: appctx.GetRuntimeRelease(execCtx.appCtx), - NumActiveExtensions: execCtx.registrationService.CountAgents(), - Ack: make(chan struct{}), - } - - if execCtx.telemetryAPIEnabled { - initFailureMsg.LogsAPIMetrics = interop.MergeSubscriptionMetrics(execCtx.logsSubscriptionAPI.FlushMetrics(), execCtx.telemetrySubscriptionAPI.FlushMetrics()) - } - - return initFailureMsg -} - -func handleInitError(execCtx *rapidContext, invokeID string, err error, initFailureResponse chan<- interop.InitFailure) { - log.WithError(err).WithField("InvokeID", invokeID).Error("Init failed") - initFailureMsg := generateInitFailureMsg(execCtx, err) - - if err == errResetReceived { - // errResetReceived is returned when execution flow was interrupted by the Reset message, - // hence this error deserves special handling and we yield to main receive loop to handle it - initFailureMsg.ResetReceived = true - initFailureResponse <- initFailureMsg - <-initFailureMsg.Ack - return - } - - if !execCtx.HasActiveExtensions() && !execCtx.standaloneMode { - // different behaviour when no extensions are present, - // for compatibility with previous implementations - initFailureMsg.RequestReset = false - } else { - initFailureMsg.RequestReset = true - } - - initFailureResponse <- initFailureMsg - <-initFailureMsg.Ack -} diff --git a/lambda/rapid/handlers.go b/lambda/rapid/handlers.go deleted file mode 100644 index 2e759e9..0000000 --- a/lambda/rapid/handlers.go +++ /dev/null @@ -1,1013 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -// Package rapid implements synchronous even dispatch loop. -package rapid - -import ( - "bytes" - "context" - "errors" - "fmt" - "os" - "path" - "strings" - "sync" - "time" - - "go.amzn.com/lambda/agents" - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/extensions" - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/metering" - "go.amzn.com/lambda/rapi" - "go.amzn.com/lambda/rapi/rendering" - "go.amzn.com/lambda/rapidcore/env" - supvmodel "go.amzn.com/lambda/supervisor/model" - "go.amzn.com/lambda/telemetry" - - "github.com/google/uuid" - log "github.com/sirupsen/logrus" -) - -const ( - RuntimeDomain = "runtime" - OperatorDomain = "operator" - defaultAgentLocation = "/opt/extensions" - runtimeProcessName = "runtime" -) - -const ( - // Same value as defined in LambdaSandbox minus 1. - maxExtensionNamesLength = 127 - standaloneShutdownReason = "spindown" -) - -var errResetReceived = errors.New("errResetReceived") - -type processSupervisor struct { - supvmodel.ProcessSupervisor - RootPath string -} - -type rapidContext struct { - interopServer interop.Server - server *rapi.Server - appCtx appctx.ApplicationContext - initDone bool - supervisor processSupervisor - runtimeDomainGeneration uint32 - initFlow core.InitFlowSynchronization - invokeFlow core.InvokeFlowSynchronization - registrationService core.RegistrationService - renderingService *rendering.EventRenderingService - telemetryAPIEnabled bool - logsSubscriptionAPI telemetry.SubscriptionAPI - telemetrySubscriptionAPI telemetry.SubscriptionAPI - logsEgressAPI telemetry.StdLogsEgressAPI - xray telemetry.Tracer - standaloneMode bool - eventsAPI interop.EventsAPI - initCachingEnabled bool - credentialsService core.CredentialsService - handlerExecutionMutex sync.Mutex - shutdownContext *shutdownContext - logStreamName string - - RuntimeStartedTime int64 - RuntimeOverheadStartedTime int64 - InvokeResponseMetrics *interop.InvokeResponseMetrics -} - -// Validate interface compliance -var _ interop.RapidContext = (*rapidContext)(nil) - -type invokeMetrics struct { - rendererMetrics rendering.InvokeRendererMetrics - - runtimeReadyTime int64 -} - -func (c *rapidContext) HasActiveExtensions() bool { - return extensions.AreEnabled() && c.registrationService.CountAgents() > 0 -} - -func (c *rapidContext) GetExtensionNames() string { - var extensionNamesList []string - for _, agent := range c.registrationService.AgentsInfo() { - extensionNamesList = append(extensionNamesList, agent.Name) - } - extensionNames := strings.Join(extensionNamesList, ";") - if len(extensionNames) > maxExtensionNamesLength { - if idx := strings.LastIndex(extensionNames[:maxExtensionNamesLength], ";"); idx != -1 { - return extensionNames[:idx] - } - return "" - } - return extensionNames -} - -func logAgentsInitStatus(execCtx *rapidContext) { - for _, agent := range execCtx.registrationService.AgentsInfo() { - extensionInitData := interop.ExtensionInitData{ - AgentName: agent.Name, - State: agent.State, - ErrorType: agent.ErrorType, - Subscriptions: agent.Subscriptions, - } - execCtx.eventsAPI.SendExtensionInit(extensionInitData) - } -} - -func agentLaunchError(agent *core.ExternalAgent, appCtx appctx.ApplicationContext, launchError error) { - if err := agent.LaunchError(launchError); err != nil { - log.Warnf("LaunchError transition fail for %s from %s: %s", agent, agent.GetState().Name(), err) - } - appctx.StoreFirstFatalError(appCtx, fatalerror.AgentLaunchError) -} - -func doInitExtensions(domain string, agentPaths []string, execCtx *rapidContext, env *env.Environment) error { - initFlow := execCtx.registrationService.InitFlow() - - // we don't bring it into the loop below because we don't want unnecessary broadcasts on agent gate - if err := initFlow.SetExternalAgentsRegisterCount(uint16(len(agentPaths))); err != nil { - return err - } - - for _, agentPath := range agentPaths { - // Using path.Base(agentPath) not agentName because the agent name is contact, as standalone can get the internal state. - agent, err := execCtx.registrationService.CreateExternalAgent(path.Base(agentPath)) - if err != nil { - return err - } - - if execCtx.registrationService.CountAgents() > core.MaxAgentsAllowed { - agentLaunchError(agent, execCtx.appCtx, core.ErrTooManyExtensions) - return core.ErrTooManyExtensions - } - - env := env.AgentExecEnv() - - agentStdoutWriter, agentStderrWriter, err := execCtx.logsEgressAPI.GetExtensionSockets() - if err != nil { - return err - } - agentName := fmt.Sprintf("extension-%s-%d", path.Base(agentPath), execCtx.runtimeDomainGeneration) - - err = execCtx.supervisor.Exec(context.Background(), &supvmodel.ExecRequest{ - Domain: domain, - Name: agentName, - Path: agentPath, - Env: &env, - Logging: supvmodel.Logging{ - Managed: supvmodel.ManagedLogging{ - Topic: supvmodel.RtExtensionManagedLoggingTopic, - Formats: []supvmodel.ManagedLoggingFormat{ - supvmodel.LineBasedManagedLogging, - }, - }, - }, - StdoutWriter: agentStdoutWriter, - StderrWriter: agentStderrWriter, - }) - if err != nil { - agentLaunchError(agent, execCtx.appCtx, err) - return err - } - - execCtx.shutdownContext.createExitedChannel(agentName) - } - - if err := initFlow.AwaitExternalAgentsRegistered(); err != nil { - return err - } - - return nil -} - -func doRuntimeBootstrap(execCtx *rapidContext, sbInfoFromInit interop.SandboxInfoFromInit) ([]string, map[string]string, string, []*os.File, error) { - env := sbInfoFromInit.EnvironmentVariables - runtimeBootstrap := sbInfoFromInit.RuntimeBootstrap - bootstrapCmd, err := runtimeBootstrap.Cmd() - if err != nil { - if fatalError, formattedLog, hasError := runtimeBootstrap.CachedFatalError(err); hasError { - appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) - execCtx.eventsAPI.SendImageErrorLog(interop.ImageErrorLogData(formattedLog)) - } else { - appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidEntrypoint) - } - return []string{}, map[string]string{}, "", []*os.File{}, err - } - - bootstrapEnv := runtimeBootstrap.Env(env) - bootstrapCwd, err := runtimeBootstrap.Cwd() - if err != nil { - if fatalError, formattedLog, hasError := runtimeBootstrap.CachedFatalError(err); hasError { - appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) - execCtx.eventsAPI.SendImageErrorLog(interop.ImageErrorLogData(formattedLog)) - } else { - appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidWorkingDir) - } - return []string{}, map[string]string{}, "", []*os.File{}, err - } - - bootstrapExtraFiles := runtimeBootstrap.ExtraFiles() - - return bootstrapCmd, bootstrapEnv, bootstrapCwd, bootstrapExtraFiles, nil -} - -func (c *rapidContext) watchEvents(events <-chan supvmodel.Event) { - for event := range events { - var err error - log.Debugf("The events handler received the event %+v.", event) - if loss := event.Event.EventLoss(); loss != nil { - log.Panicf("Lost %d events from supervisor", *loss) - } - termination := event.Event.ProcessTerminated() - - // If we are not shutting down then we care if an unexpected exit happens. - if !c.shutdownContext.isShuttingDown() { - runtimeProcessName := fmt.Sprintf("%s-%d", runtimeProcessName, c.runtimeDomainGeneration) - - // If event from the runtime. - if *termination.Name == runtimeProcessName { - if termination.Success() { - err = fmt.Errorf("Runtime exited without providing a reason") - } else { - err = fmt.Errorf("Runtime exited with error: %s", termination.String()) - } - appctx.StoreFirstFatalError(c.appCtx, fatalerror.RuntimeExit) - } else { - if termination.Success() { - err = fmt.Errorf("exit code 0") - } else { - err = fmt.Errorf("%s", termination.String()) - } - - appctx.StoreFirstFatalError(c.appCtx, fatalerror.AgentCrash) - } - - log.Warnf("Process %s exited: %+v", *termination.Name, termination) - } - - // At the moment we only get termination events. - // When their are other event types then we would need to be selective, - // about what we send to handleShutdownEvent(). - c.shutdownContext.handleProcessExit(*termination) - c.registrationService.CancelFlows(err) - } -} - -// subscribe to /events for runtime domain in supervisor -func setupEventsWatcher(execCtx *rapidContext) error { - eventsRequest := supvmodel.EventsRequest{ - Domain: RuntimeDomain, - } - - events, err := execCtx.supervisor.Events(context.Background(), &eventsRequest) - if err != nil { - log.Errorf("Could not get events stream from supervisor: %s", err) - return err - } - - go execCtx.watchEvents(events) - return nil -} - -func doRuntimeDomainInit(execCtx *rapidContext, sbInfoFromInit interop.SandboxInfoFromInit, phase interop.LifecyclePhase) error { - initStartTime := metering.Monotime() - sendInitStartLogEvent(execCtx, sbInfoFromInit.SandboxType, phase) - defer sendInitReportLogEvent(execCtx, sbInfoFromInit.SandboxType, initStartTime, phase) - - execCtx.xray.RecordInitStartTime() - defer execCtx.xray.RecordInitEndTime() - - defer func() { - if extensions.AreEnabled() { - logAgentsInitStatus(execCtx) - } - }() - - execCtx.runtimeDomainGeneration++ - - if extensions.AreEnabled() { - runtimeExtensions := agents.ListExternalAgentPaths(defaultAgentLocation, - execCtx.supervisor.RootPath) - if err := doInitExtensions(RuntimeDomain, runtimeExtensions, execCtx, sbInfoFromInit.EnvironmentVariables); err != nil { - return err - } - } - - appctx.StoreSandboxType(execCtx.appCtx, sbInfoFromInit.SandboxType) - - initFlow := execCtx.registrationService.InitFlow() - - // Runtime state machine - runtime := core.NewRuntime(initFlow, execCtx.invokeFlow) - - // Registration service keeps track of parties registered in the system and events they are registered for. - // Runtime's use case is generalized, because runtime doesn't register itself, we preregister it in the system; - // runtime is implicitly subscribed for certain lifecycle events. - log.Debug("Preregister runtime") - registrationService := execCtx.registrationService - err := registrationService.PreregisterRuntime(runtime) - if err != nil { - return err - } - - bootstrapCmd, bootstrapEnv, bootstrapCwd, bootstrapExtraFiles, err := doRuntimeBootstrap(execCtx, sbInfoFromInit) - if err != nil { - return err - } - - runtimeStdoutWriter, runtimeStderrWriter, err := execCtx.logsEgressAPI.GetRuntimeSockets() - if err != nil { - return err - } - - log.Debug("Start runtime") - checkCredentials(execCtx, bootstrapEnv) - name := fmt.Sprintf("%s-%d", runtimeProcessName, execCtx.runtimeDomainGeneration) - - err = execCtx.supervisor.Exec(context.Background(), &supvmodel.ExecRequest{ - Domain: RuntimeDomain, - Name: name, - Cwd: &bootstrapCwd, - Path: bootstrapCmd[0], - Args: bootstrapCmd[1:], - Env: &bootstrapEnv, - Logging: supvmodel.Logging{ - Managed: supvmodel.ManagedLogging{ - Topic: supvmodel.RuntimeManagedLoggingTopic, - Formats: []supvmodel.ManagedLoggingFormat{ - supvmodel.LineBasedManagedLogging, - supvmodel.MessageBasedManagedLogging, - }, - }, - }, - StdoutWriter: runtimeStdoutWriter, - StderrWriter: runtimeStderrWriter, - ExtraFiles: &bootstrapExtraFiles, - }) - - runtimeDoneStatus := telemetry.RuntimeDoneSuccess - - defer func() { - sendInitRuntimeDoneLogEvent(execCtx, sbInfoFromInit.SandboxType, runtimeDoneStatus, phase) - }() - - if err != nil { - if fatalError, formattedLog, hasError := sbInfoFromInit.RuntimeBootstrap.CachedFatalError(err); hasError { - appctx.StoreFirstFatalError(execCtx.appCtx, fatalError) - execCtx.eventsAPI.SendImageErrorLog(interop.ImageErrorLogData(formattedLog)) - } else { - appctx.StoreFirstFatalError(execCtx.appCtx, fatalerror.InvalidEntrypoint) - } - - runtimeDoneStatus = telemetry.RuntimeDoneError - return err - } - - execCtx.shutdownContext.createExitedChannel(name) - - if err := initFlow.AwaitRuntimeRestoreReady(); err != nil { - runtimeDoneStatus = telemetry.RuntimeDoneError - return err - } - - runtimeDoneStatus = telemetry.RuntimeDoneSuccess - - // Registration phase finished for agents - no more agents can be registered with the system - registrationService.TurnOff() - if extensions.AreEnabled() { - // Initialize and activate the gate with the number of agent we wait to return ready - if err := initFlow.SetAgentsReadyCount(registrationService.GetRegisteredAgentsSize()); err != nil { - return err - } - if err := initFlow.AwaitAgentsReady(); err != nil { - runtimeDoneStatus = telemetry.RuntimeDoneError - return err - } - } - - // Logs API subscription phase finished for agents - no more agents can be subscribed to the Logs API - if execCtx.telemetryAPIEnabled { - execCtx.logsSubscriptionAPI.TurnOff() - execCtx.telemetrySubscriptionAPI.TurnOff() - } - - execCtx.initDone = true - - return nil -} - -func doInvoke(execCtx *rapidContext, invokeRequest *interop.Invoke, mx *invokeMetrics, sbInfoFromInit interop.SandboxInfoFromInit, requestBuffer *bytes.Buffer) error { - execCtx.eventsAPI.SetCurrentRequestID(interop.RequestID(invokeRequest.ID)) - appCtx := execCtx.appCtx - - xray := execCtx.xray - xray.Configure(invokeRequest) - - ctx := context.Background() - - return xray.CaptureInvokeSegment(ctx, xray.WithErrorCause(ctx, appCtx, func(ctx context.Context) error { - telemetryTracingCtx := xray.BuildTracingCtxForStart() - - if !execCtx.initDone { - // do inline init - if err := xray.CaptureInitSubsegment(ctx, func(ctx context.Context) error { - return doRuntimeDomainInit(execCtx, sbInfoFromInit, interop.LifecyclePhaseInvoke) - }); err != nil { - sendInvokeStartLogEvent(execCtx, invokeRequest.ID, telemetryTracingCtx) - return err - } - } else if sbInfoFromInit.SandboxType != interop.SandboxPreWarmed && !execCtx.initCachingEnabled { - xray.SendInitSubsegmentWithRecordedTimesOnce(ctx) - } - - xray.SendRestoreSubsegmentWithRecordedTimesOnce(ctx) - - sendInvokeStartLogEvent(execCtx, invokeRequest.ID, telemetryTracingCtx) - - invokeFlow := execCtx.invokeFlow - log.Debug("Initialize invoke flow barriers") - err := invokeFlow.InitializeBarriers() - if err != nil { - return err - } - - registrationService := execCtx.registrationService - runtime := registrationService.GetRuntime() - var intAgents []*core.InternalAgent - var extAgents []*core.ExternalAgent - - if extensions.AreEnabled() { - intAgents = registrationService.GetSubscribedInternalAgents(core.InvokeEvent) - extAgents = registrationService.GetSubscribedExternalAgents(core.InvokeEvent) - if err := invokeFlow.SetAgentsReadyCount(uint16(len(intAgents) + len(extAgents))); err != nil { - return err - } - } - - // Invoke - if err := xray.CaptureInvokeSubsegment(ctx, xray.WithError(ctx, appCtx, func(ctx context.Context) error { - log.Debug("Set renderer for invoke") - renderer := rendering.NewInvokeRenderer(ctx, invokeRequest, requestBuffer, xray.BuildTracingHeader()) - defer func() { - mx.rendererMetrics = renderer.GetMetrics() - }() - - execCtx.renderingService.SetRenderer(renderer) - if extensions.AreEnabled() { - log.Debug("Release agents conditions") - for _, agent := range extAgents { - //TODO handle Supervisors listening channel - agent.Release() - } - for _, agent := range intAgents { - //TODO handle Supervisors listening channel - agent.Release() - } - } - - log.Debug("Release runtime condition") - //TODO handle Supervisors listening channel - execCtx.SetRuntimeStartedTime(metering.Monotime()) - runtime.Release() - log.Debug("Await runtime response") - //TODO handle Supervisors listening channel - return invokeFlow.AwaitRuntimeResponse() - })); err != nil { - return err - } - - // Runtime overhead - if err := xray.CaptureOverheadSubsegment(ctx, func(ctx context.Context) error { - log.Debug("Await runtime ready") - execCtx.SetRuntimeOverheadStartedTime(metering.Monotime()) - //TODO handle Supervisors listening channel - return invokeFlow.AwaitRuntimeReady() - }); err != nil { - return err - } - mx.runtimeReadyTime = metering.Monotime() - - runtimeDoneEventData := interop.InvokeRuntimeDoneData{ - Status: telemetry.RuntimeDoneSuccess, - Metrics: telemetry.GetRuntimeDoneInvokeMetrics(execCtx.RuntimeStartedTime, invokeRequest.InvokeResponseMetrics, mx.runtimeReadyTime), - InternalMetrics: invokeRequest.InvokeResponseMetrics, - Tracing: xray.BuildTracingCtxAfterInvokeComplete(), - Spans: execCtx.eventsAPI.GetRuntimeDoneSpans(execCtx.RuntimeStartedTime, invokeRequest.InvokeResponseMetrics, execCtx.RuntimeOverheadStartedTime, mx.runtimeReadyTime), - } - log.Info(runtimeDoneEventData.String()) - if err := execCtx.eventsAPI.SendInvokeRuntimeDone(runtimeDoneEventData); err != nil { - log.Errorf("Failed to send INVOKE RTDONE: %s", err) - } - - // Extensions overhead - if execCtx.HasActiveExtensions() { - extensionOverheadStartTime := metering.Monotime() - execCtx.interopServer.SendRuntimeReady() - log.Debug("Await agents ready") - //TODO handle Supervisors listening channel - if err := invokeFlow.AwaitAgentsReady(); err != nil { - log.Warnf("AwaitAgentsReady() = %s", err) - return err - } - extensionOverheadEndTime := metering.Monotime() - extensionOverheadMsSpan := interop.Span{ - Name: "extensionOverhead", - Start: telemetry.GetEpochTimeInISO8601FormatFromMonotime(extensionOverheadStartTime), - DurationMs: telemetry.CalculateDuration(extensionOverheadStartTime, extensionOverheadEndTime), - } - if err := execCtx.eventsAPI.SendReportSpan(extensionOverheadMsSpan); err != nil { - log.WithError(err).Error("Failed to create REPORT Span") - } - } - - return nil - })) -} - -// acceptInitRequest is a second initialization phase, performed after receiving START -// initialized entities: _HANDLER, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN -func (c *rapidContext) acceptInitRequest(initRequest *interop.Init) *interop.Init { - initRequest.EnvironmentVariables.StoreEnvironmentVariablesFromInit( - initRequest.CustomerEnvironmentVariables, - initRequest.Handler, - initRequest.AwsKey, - initRequest.AwsSecret, - initRequest.AwsSession, - initRequest.FunctionName, - initRequest.FunctionVersion) - c.registrationService.SetFunctionMetadata(core.FunctionMetadata{ - AccountID: initRequest.AccountID, - FunctionName: initRequest.FunctionName, - FunctionVersion: initRequest.FunctionVersion, - InstanceMaxMemory: initRequest.InstanceMaxMemory, - Handler: initRequest.Handler, - RuntimeInfo: initRequest.RuntimeInfo, - }) - c.SetLogStreamName(initRequest.LogStreamName) - - return initRequest -} - -func (c *rapidContext) acceptInitRequestForInitCaching(initRequest *interop.Init) (*interop.Init, error) { - log.Info("Configure environment for Init Caching.") - randomUUID, err := uuid.NewRandom() - - if err != nil { - return initRequest, err - } - - initCachingToken := randomUUID.String() - - initRequest.EnvironmentVariables.StoreEnvironmentVariablesFromInitForInitCaching( - c.server.Host(), - c.server.Port(), - initRequest.CustomerEnvironmentVariables, - initRequest.Handler, - initRequest.FunctionName, - initRequest.FunctionVersion, - initCachingToken) - - c.registrationService.SetFunctionMetadata(core.FunctionMetadata{ - AccountID: initRequest.AccountID, - FunctionName: initRequest.FunctionName, - FunctionVersion: initRequest.FunctionVersion, - InstanceMaxMemory: initRequest.InstanceMaxMemory, - Handler: initRequest.Handler, - RuntimeInfo: initRequest.RuntimeInfo, - }) - c.SetLogStreamName(initRequest.LogStreamName) - - c.credentialsService.SetCredentials(initCachingToken, initRequest.AwsKey, initRequest.AwsSecret, initRequest.AwsSession, initRequest.CredentialsExpiry) - - return initRequest, nil -} - -func handleInit(execCtx *rapidContext, initRequest *interop.Init, initSuccessResponse chan<- interop.InitSuccess, initFailureResponse chan<- interop.InitFailure) { - if execCtx.initCachingEnabled { - var err error - if initRequest, err = execCtx.acceptInitRequestForInitCaching(initRequest); err != nil { - // TODO: call handleInitError only after sending the RUNNING, since - // Slicer will fail receiving DONEFAIL here as it is expecting RUNNING - handleInitError(execCtx, initRequest.InvokeID, err, initFailureResponse) - return - } - } else { - initRequest = execCtx.acceptInitRequest(initRequest) - } - - if err := setupEventsWatcher(execCtx); err != nil { - handleInitError(execCtx, initRequest.InvokeID, err, initFailureResponse) - return - } - - if !initRequest.SuppressInit { - // doRuntimeDomainInit() is used in both init/invoke, so the signature requires sbInfo arg - sbInfo := interop.SandboxInfoFromInit{ - EnvironmentVariables: initRequest.EnvironmentVariables, - SandboxType: initRequest.SandboxType, - RuntimeBootstrap: initRequest.Bootstrap, - } - if err := doRuntimeDomainInit(execCtx, sbInfo, interop.LifecyclePhaseInit); err != nil { - handleInitError(execCtx, initRequest.InvokeID, err, initFailureResponse) - return - } - } - - initSuccessMsg := interop.InitSuccess{ - RuntimeRelease: appctx.GetRuntimeRelease(execCtx.appCtx), - NumActiveExtensions: execCtx.registrationService.CountAgents(), - ExtensionNames: execCtx.GetExtensionNames(), - Ack: make(chan struct{}), - } - - if execCtx.telemetryAPIEnabled { - initSuccessMsg.LogsAPIMetrics = interop.MergeSubscriptionMetrics(execCtx.logsSubscriptionAPI.FlushMetrics(), execCtx.telemetrySubscriptionAPI.FlushMetrics()) - } - - initSuccessResponse <- initSuccessMsg - <-initSuccessMsg.Ack -} - -func handleInvoke(execCtx *rapidContext, invokeRequest *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit, requestBuffer *bytes.Buffer, responseSender interop.InvokeResponseSender) (interop.InvokeSuccess, *interop.InvokeFailure) { - appctx.StoreResponseSender(execCtx.appCtx, responseSender) - invokeMx := invokeMetrics{} - - if err := doInvoke(execCtx, invokeRequest, &invokeMx, sbInfoFromInit, requestBuffer); err != nil { - log.WithError(err).WithField("InvokeID", invokeRequest.ID).Error("Invoke failed") - invokeFailure := handleInvokeError(execCtx, invokeRequest, &invokeMx, err) - invokeFailure.InvokeResponseMode = invokeRequest.InvokeResponseMode - - if invokeRequest.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(invokeRequest.InvokeResponseMetrics) { - invokeFailure.ResponseMetrics = interop.ResponseMetrics{ - RuntimeResponseLatencyMs: telemetry.CalculateDuration(execCtx.RuntimeStartedTime, invokeRequest.InvokeResponseMetrics.StartReadingResponseMonoTimeMs), - RuntimeTimeThrottledMs: invokeRequest.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond), - RuntimeProducedBytes: invokeRequest.InvokeResponseMetrics.ProducedBytes, - RuntimeOutboundThroughputBps: invokeRequest.InvokeResponseMetrics.OutboundThroughputBps, - } - } - return interop.InvokeSuccess{}, invokeFailure - } - - var invokeCompletionTimeNs int64 - if responseTimeNs := execCtx.registrationService.GetRuntime().GetRuntimeDescription().State.ResponseTimeNs; responseTimeNs != 0 { - invokeCompletionTimeNs = time.Now().UnixNano() - responseTimeNs - } - - invokeSuccessMsg := interop.InvokeSuccess{ - RuntimeRelease: appctx.GetRuntimeRelease(execCtx.appCtx), - NumActiveExtensions: execCtx.registrationService.CountAgents(), - ExtensionNames: execCtx.GetExtensionNames(), - InvokeMetrics: interop.InvokeMetrics{ - InvokeRequestReadTimeNs: invokeMx.rendererMetrics.ReadTime.Nanoseconds(), - InvokeRequestSizeBytes: int64(invokeMx.rendererMetrics.SizeBytes), - RuntimeReadyTime: invokeMx.runtimeReadyTime, - }, - InvokeCompletionTimeNs: invokeCompletionTimeNs, - InvokeReceivedTime: invokeRequest.InvokeReceivedTime, - InvokeResponseMode: invokeRequest.InvokeResponseMode, - } - - if invokeRequest.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(invokeRequest.InvokeResponseMetrics) { - invokeSuccessMsg.ResponseMetrics = interop.ResponseMetrics{ - RuntimeResponseLatencyMs: telemetry.CalculateDuration(execCtx.RuntimeStartedTime, invokeRequest.InvokeResponseMetrics.StartReadingResponseMonoTimeMs), - RuntimeTimeThrottledMs: invokeRequest.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond), - RuntimeProducedBytes: invokeRequest.InvokeResponseMetrics.ProducedBytes, - RuntimeOutboundThroughputBps: invokeRequest.InvokeResponseMetrics.OutboundThroughputBps, - } - } - - if execCtx.telemetryAPIEnabled { - invokeSuccessMsg.LogsAPIMetrics = interop.MergeSubscriptionMetrics(execCtx.logsSubscriptionAPI.FlushMetrics(), execCtx.telemetrySubscriptionAPI.FlushMetrics()) - } - - return invokeSuccessMsg, nil -} - -func reinitialize(execCtx *rapidContext) { - execCtx.appCtx.Delete(appctx.AppCtxInvokeErrorTraceDataKey) - execCtx.appCtx.Delete(appctx.AppCtxRuntimeReleaseKey) - execCtx.appCtx.Delete(appctx.AppCtxFirstFatalErrorKey) - execCtx.renderingService.SetRenderer(nil) - execCtx.initDone = false - execCtx.registrationService.Clear() - execCtx.initFlow.Clear() - execCtx.invokeFlow.Clear() - if execCtx.telemetryAPIEnabled { - execCtx.logsSubscriptionAPI.Clear() - execCtx.telemetrySubscriptionAPI.Clear() - } -} - -// handle notification of reset -func handleReset(execCtx *rapidContext, resetEvent *interop.Reset, runtimeStartedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics) (interop.ResetSuccess, *interop.ResetFailure) { - log.Warnf("Reset initiated: %s", resetEvent.Reason) - - // Only send RuntimeDone event if we get a reset during an Invoke - if resetEvent.Reason == "failure" || resetEvent.Reason == "timeout" { - var errorType *string - if resetEvent.Reason == "failure" { - firstFatalError, found := appctx.LoadFirstFatalError(execCtx.appCtx) - if !found { - firstFatalError = fatalerror.SandboxFailure - } - stringifiedError := string(firstFatalError) - errorType = &stringifiedError - } - - var status string - if resetEvent.Reason == "timeout" { - status = "timeout" - } else if strings.HasPrefix(*errorType, "Sandbox.") { - status = "failure" - } else { - status = "error" - } - - var runtimeReadyTime int64 = metering.Monotime() - runtimeDoneEventData := interop.InvokeRuntimeDoneData{ - Status: status, - InternalMetrics: invokeResponseMetrics, - Metrics: telemetry.GetRuntimeDoneInvokeMetrics(runtimeStartedTime, invokeResponseMetrics, runtimeReadyTime), - Tracing: execCtx.xray.BuildTracingCtxAfterInvokeComplete(), - Spans: execCtx.eventsAPI.GetRuntimeDoneSpans(runtimeStartedTime, invokeResponseMetrics, execCtx.RuntimeOverheadStartedTime, runtimeReadyTime), - ErrorType: errorType, - } - if err := execCtx.eventsAPI.SendInvokeRuntimeDone(runtimeDoneEventData); err != nil { - log.Errorf("Failed to send INVOKE RTDONE: %s", err) - } - } - - extensionsResetMs, resetTimeout, _ := execCtx.shutdownContext.shutdown(execCtx, resetEvent.DeadlineNs, resetEvent.Reason) - - execCtx.runtimeDomainGeneration++ - - // Only used by standalone for more indepth assertions. - var fatalErrorType fatalerror.ErrorType - - if execCtx.standaloneMode { - fatalErrorType, _ = appctx.LoadFirstFatalError(execCtx.appCtx) - } - - // TODO: move interop.ResponseMetrics{} to a factory method and initialize it there. - // Initialization is very similar in handleInvoke's invokeFailure.ResponseMetrics and - // invokeSuccessMsg.ResponseMetrics - var responseMetrics interop.ResponseMetrics - if resetEvent.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(resetEvent.InvokeResponseMetrics) { - responseMetrics.RuntimeResponseLatencyMs = telemetry.CalculateDuration(execCtx.RuntimeStartedTime, resetEvent.InvokeResponseMetrics.StartReadingResponseMonoTimeMs) - responseMetrics.RuntimeTimeThrottledMs = resetEvent.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond) - responseMetrics.RuntimeProducedBytes = resetEvent.InvokeResponseMetrics.ProducedBytes - responseMetrics.RuntimeOutboundThroughputBps = resetEvent.InvokeResponseMetrics.OutboundThroughputBps - } - - if resetTimeout { - return interop.ResetSuccess{}, &interop.ResetFailure{ - ExtensionsResetMs: extensionsResetMs, - ErrorType: fatalErrorType, - ResponseMetrics: responseMetrics, - InvokeResponseMode: resetEvent.InvokeResponseMode, - } - } - - return interop.ResetSuccess{ - ExtensionsResetMs: extensionsResetMs, - ErrorType: fatalErrorType, - ResponseMetrics: responseMetrics, - InvokeResponseMode: resetEvent.InvokeResponseMode, - }, nil -} - -// handle notification of shutdown -func handleShutdown(execCtx *rapidContext, shutdownEvent *interop.Shutdown, reason string) interop.ShutdownSuccess { - log.Warnf("Shutdown initiated: %s", reason) - // TODO Handle shutdown error - _, _, _ = execCtx.shutdownContext.shutdown(execCtx, shutdownEvent.DeadlineNs, reason) - - // Only used by standalone for more indepth assertions. - var fatalErrorType fatalerror.ErrorType - - if execCtx.standaloneMode { - fatalErrorType, _ = appctx.LoadFirstFatalError(execCtx.appCtx) - } - - return interop.ShutdownSuccess{ErrorType: fatalErrorType} -} - -func handleRestore(execCtx *rapidContext, restore *interop.Restore) (interop.RestoreResult, error) { - err := execCtx.credentialsService.UpdateCredentials(restore.AwsKey, restore.AwsSecret, restore.AwsSession, restore.CredentialsExpiry) - restoreStatus := telemetry.RuntimeDoneSuccess - - restoreResult := interop.RestoreResult{} - - defer func() { - sendRestoreRuntimeDoneLogEvent(execCtx, restoreStatus) - }() - - if err != nil { - log.Infof("error when updating credentials: %s", err) - return restoreResult, interop.ErrRestoreUpdateCredentials - } - - renderer := rendering.NewRestoreRenderer() - execCtx.renderingService.SetRenderer(renderer) - - registrationService := execCtx.registrationService - runtime := registrationService.GetRuntime() - - execCtx.SetLogStreamName(restore.LogStreamName) - - // If runtime has not called /restore/next then just return - // instead of releasing the Runtime since there is no need to release. - // Then the runtime should be released only during Invoke - if runtime.GetState() != runtime.RuntimeRestoreReadyState { - restoreStatus = telemetry.RuntimeDoneSuccess - log.Infof("Runtime is in state: %s just returning", runtime.GetState().Name()) - - return restoreResult, nil - } - - deadlineNs := time.Now().Add(time.Duration(restore.RestoreHookTimeoutMs) * time.Millisecond).UnixNano() - - ctx, ctxCancel := context.WithDeadline(context.Background(), time.Unix(0, deadlineNs)) - - defer ctxCancel() - - startTime := metering.Monotime() - - runtime.Release() - - initFlow := execCtx.initFlow - err = initFlow.AwaitRuntimeReadyWithDeadline(ctx) - - fatalErrorType, fatalErrorFound := appctx.LoadFirstFatalError(execCtx.appCtx) - - // If there is an error occured when waiting runtime to complete the restore hook execution, - // check if there is any error stored in appctx to get the root cause error type - // Runtime.ExitError is an example to such a scenario - if fatalErrorFound { - err = fmt.Errorf("%s", string(fatalErrorType)) - } - - if err != nil { - restoreStatus = telemetry.RuntimeDoneError - } - - endTime := metering.Monotime() - restoreDuration := time.Duration(endTime - startTime) - restoreResult.RestoreMs = restoreDuration.Milliseconds() - - return restoreResult, err -} - -func startRuntimeAPI(ctx context.Context, execCtx *rapidContext) { - // Start Runtime API Server - err := execCtx.server.Listen() - if err != nil { - log.WithError(err).Panic("Runtime API Server failed to listen") - } - - execCtx.server.Serve(ctx) // blocking until server exits - - // Note, most of initialization code should run before blocking to receive START, - // code before START runs in parallel with code downloads. -} - -func getFirstFatalError(execCtx *rapidContext, status string) *string { - if status == telemetry.RuntimeDoneSuccess { - return nil - } - - firstFatalError, found := appctx.LoadFirstFatalError(execCtx.appCtx) - if !found { - // We will set errorType to "Runtime.Unknown" in case of INIT timeout and RESTORE timeout - // This is a trade-off we are willing to make. We will improve this later - firstFatalError = fatalerror.RuntimeUnknown - } - stringifiedError := string(firstFatalError) - return &stringifiedError -} - -func sendRestoreRuntimeDoneLogEvent(execCtx *rapidContext, status string) { - firstFatalError := getFirstFatalError(execCtx, status) - - restoreRuntimeDoneData := interop.RestoreRuntimeDoneData{ - Status: status, - ErrorType: firstFatalError, - } - - if err := execCtx.eventsAPI.SendRestoreRuntimeDone(restoreRuntimeDoneData); err != nil { - log.Errorf("Failed to send RESTORE RTDONE: %s", err) - } -} - -func sendInitStartLogEvent(execCtx *rapidContext, sandboxType interop.SandboxType, phase interop.LifecyclePhase) { - initPhase, err := telemetry.InitPhaseFromLifecyclePhase(phase) - if err != nil { - log.Errorf("failed to convert lifecycle phase into init phase: %s", err) - return - } - - functionMetadata := execCtx.registrationService.GetFunctionMetadata() - initStartData := interop.InitStartData{ - InitializationType: telemetry.InferInitType(execCtx.initCachingEnabled, sandboxType), - RuntimeVersion: functionMetadata.RuntimeInfo.Version, - RuntimeVersionArn: functionMetadata.RuntimeInfo.Arn, - FunctionName: functionMetadata.FunctionName, - FunctionVersion: functionMetadata.FunctionVersion, - // based on https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/resource/semantic_conventions/faas.md - // we're sending the logStream as the instance id - InstanceID: execCtx.logStreamName, - InstanceMaxMemory: functionMetadata.InstanceMaxMemory, - Phase: initPhase, - } - log.Info(initStartData.String()) - - if err := execCtx.eventsAPI.SendInitStart(initStartData); err != nil { - log.Errorf("Failed to send INIT START: %s", err) - } -} - -func sendInitRuntimeDoneLogEvent(execCtx *rapidContext, sandboxType interop.SandboxType, status string, phase interop.LifecyclePhase) { - initPhase, err := telemetry.InitPhaseFromLifecyclePhase(phase) - if err != nil { - log.Errorf("failed to convert lifecycle phase into init phase: %s", err) - return - } - - firstFatalError := getFirstFatalError(execCtx, status) - - initRuntimeDoneData := interop.InitRuntimeDoneData{ - InitializationType: telemetry.InferInitType(execCtx.initCachingEnabled, sandboxType), - Status: status, - Phase: initPhase, - ErrorType: firstFatalError, - } - - log.Info(initRuntimeDoneData.String()) - - if err := execCtx.eventsAPI.SendInitRuntimeDone(initRuntimeDoneData); err != nil { - log.Errorf("Failed to send INIT RTDONE: %s", err) - } -} - -func sendInitReportLogEvent( - execCtx *rapidContext, - sandboxType interop.SandboxType, - initStartMonotime int64, - phase interop.LifecyclePhase, -) { - initPhase, err := telemetry.InitPhaseFromLifecyclePhase(phase) - if err != nil { - log.Errorf("failed to convert lifecycle phase into init phase: %s", err) - return - } - - initReportData := interop.InitReportData{ - InitializationType: telemetry.InferInitType(execCtx.initCachingEnabled, sandboxType), - Metrics: interop.InitReportMetrics{ - DurationMs: telemetry.CalculateDuration(initStartMonotime, metering.Monotime()), - }, - Phase: initPhase, - } - log.Info(initReportData.String()) - - if err = execCtx.eventsAPI.SendInitReport(initReportData); err != nil { - log.Errorf("Failed to send INIT REPORT: %s", err) - } -} - -func sendInvokeStartLogEvent(execCtx *rapidContext, invokeRequestID string, tracingCtx *interop.TracingCtx) { - invokeStartData := interop.InvokeStartData{ - RequestID: invokeRequestID, - Version: execCtx.registrationService.GetFunctionMetadata().FunctionVersion, - Tracing: tracingCtx, - } - log.Info(invokeStartData.String()) - - if err := execCtx.eventsAPI.SendInvokeStart(invokeStartData); err != nil { - log.Errorf("Failed to send INVOKE START: %s", err) - } -} - -// This function will log a line if AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, or AWS_SESSION_TOKEN is missing -// This is expected to happen in cases when credentials provider is not needed -func checkCredentials(execCtx *rapidContext, bootstrapEnv map[string]string) { - credentialsKeys := []string{"AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"} - missingCreds := []string{} - - for _, credEnvVar := range credentialsKeys { - if val, keyExists := bootstrapEnv[credEnvVar]; !keyExists || val == "" { - missingCreds = append(missingCreds, credEnvVar) - } - } - - if len(missingCreds) > 0 { - log.Infof("Starting runtime without %s , Expected?: %t", strings.Join(missingCreds[:], ", "), execCtx.initCachingEnabled) - } -} diff --git a/lambda/rapid/handlers_test.go b/lambda/rapid/handlers_test.go deleted file mode 100644 index 089dbb7..0000000 --- a/lambda/rapid/handlers_test.go +++ /dev/null @@ -1,341 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapid - -import ( - "bytes" - "context" - "fmt" - "io" - "net/http" - "net/http/httptest" - "regexp" - "strconv" - "strings" - "sync" - "testing" - "time" - - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapi" - "go.amzn.com/lambda/rapi/handler" - "go.amzn.com/lambda/rapi/rendering" - "go.amzn.com/lambda/rapidcore/env" - "go.amzn.com/lambda/supervisor/model" - "go.amzn.com/lambda/telemetry" - "go.amzn.com/lambda/testdata" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func BenchmarkChannelsSelect10(b *testing.B) { - c1 := make(chan int) - c2 := make(chan int) - c3 := make(chan int) - c4 := make(chan int) - c5 := make(chan int) - c6 := make(chan int) - c7 := make(chan int) - c8 := make(chan int) - c9 := make(chan int) - c10 := make(chan int) - - for n := 0; n < b.N; n++ { - select { - case <-c1: - case <-c2: - case <-c3: - case <-c4: - case <-c5: - case <-c6: - case <-c7: - case <-c8: - case <-c9: - case <-c10: - default: - } - } -} - -func BenchmarkChannelsSelect2(b *testing.B) { - c1 := make(chan int) - c2 := make(chan int) - - for n := 0; n < b.N; n++ { - select { - case <-c1: - case <-c2: - default: - } - } -} - -func TestGetExtensionNamesWithNoExtensions(t *testing.T) { - rs := core.NewRegistrationService(nil, nil) - - c := &rapidContext{ - registrationService: rs, - } - - assert.Equal(t, "", c.GetExtensionNames()) -} - -func TestGetExtensionNamesWithMultipleExtensions(t *testing.T) { - rs := core.NewRegistrationService(nil, nil) - _, _ = rs.CreateExternalAgent("Example1") - _, _ = rs.CreateInternalAgent("Example2") - _, _ = rs.CreateExternalAgent("Example3") - _, _ = rs.CreateInternalAgent("Example4") - - c := &rapidContext{ - registrationService: rs, - } - - r := regexp.MustCompile(`^(Example\d;){3}(Example\d)$`) - assert.True(t, r.MatchString(c.GetExtensionNames())) -} - -func TestGetExtensionNamesWithTooManyExtensions(t *testing.T) { - rs := core.NewRegistrationService(nil, nil) - for i := 10; i < 60; i++ { - _, _ = rs.CreateExternalAgent("E" + strconv.Itoa(i)) - } - - c := &rapidContext{ - registrationService: rs, - } - - output := c.GetExtensionNames() - - r := regexp.MustCompile(`^(E\d\d;){30}(E\d\d)$`) - assert.LessOrEqual(t, len(output), maxExtensionNamesLength) - assert.True(t, r.MatchString(output)) -} - -func TestGetExtensionNamesWithTooLongExtensionName(t *testing.T) { - rs := core.NewRegistrationService(nil, nil) - for i := 10; i < 60; i++ { - _, _ = rs.CreateExternalAgent(strings.Repeat("E", 130)) - } - - c := &rapidContext{ - registrationService: rs, - } - - assert.Equal(t, "", c.GetExtensionNames()) -} - -// This test confirms our assumption that http client can establish a tcp connection -// to a listening server. -func TestListen(t *testing.T) { - flowTest := testdata.NewFlowTest() - flowTest.ConfigureForInit() - flowTest.ConfigureForInvoke(context.Background(), &interop.Invoke{ID: "ID", DeadlineNs: "1", Payload: strings.NewReader("MyTest")}) - - ctx := context.Background() - telemetryAPIEnabled := true - server := rapi.NewServer("127.0.0.1", 0, flowTest.AppCtx, flowTest.RegistrationService, flowTest.RenderingService, telemetryAPIEnabled, flowTest.TelemetrySubscription, flowTest.TelemetrySubscription, flowTest.CredentialsService) - err := server.Listen() - assert.NoError(t, err) - - defer server.Close() - - go func() { - time.Sleep(time.Second) - fmt.Println("Serving...") - server.Serve(ctx) - }() - - done := make(chan struct{}) - - go func() { - fmt.Println("Connecting...") - resp, err1 := http.Get(fmt.Sprintf("http://%s:%d/2018-06-01/runtime/invocation/next", server.Host(), server.Port())) - assert.Nil(t, err1) - - body, err2 := io.ReadAll(resp.Body) - assert.Nil(t, err2) - - assert.Equal(t, "MyTest", string(body)) - - done <- struct{}{} - }() - - <-done -} - -func makeRapidContext(appCtx appctx.ApplicationContext, initFlow core.InitFlowSynchronization, invokeFlow core.InvokeFlowSynchronization, registrationService core.RegistrationService, supervisor *processSupervisor) *rapidContext { - - appctx.StoreInitType(appCtx, true) - appctx.StoreInteropServer(appCtx, MockInteropServer{}) - - renderingService := rendering.NewRenderingService() - - credentialsService := core.NewCredentialsService() - credentialsService.SetCredentials("token", "key", "secret", "session", time.Now()) - - // Runtime state machine - runtime := core.NewRuntime(initFlow, invokeFlow) - - registrationService.PreregisterRuntime(runtime) - runtime.SetState(runtime.RuntimeRestoreReadyState) - - rapidCtx := &rapidContext{ - // Internally initialized configurations - appCtx: appCtx, - initDone: true, - initFlow: initFlow, - invokeFlow: invokeFlow, - registrationService: registrationService, - renderingService: renderingService, - credentialsService: credentialsService, - handlerExecutionMutex: sync.Mutex{}, - shutdownContext: newShutdownContext(), - eventsAPI: &telemetry.NoOpEventsAPI{}, - } - if supervisor != nil { - rapidCtx.supervisor = *supervisor - } - - return rapidCtx -} - -const hookErrorType = "Runtime.RestoreHookUserErrorType" - -func makeRequest(appCtx appctx.ApplicationContext) *http.Request { - errorBody := []byte("My byte array is yours") - - request := appctx.RequestWithAppCtx(httptest.NewRequest("POST", "/", bytes.NewReader(errorBody)), appCtx) - - request.Header.Set("Content-Type", "application/MyBinaryType") - request.Header.Set("Lambda-Runtime-Function-Error-Type", hookErrorType) - - return request -} - -type MockInteropServer struct{} - -func (server MockInteropServer) GetCurrentInvokeID() string { - return "" -} - -func (server MockInteropServer) SendRuntimeReady() error { - return nil -} - -func (server MockInteropServer) SendInitErrorResponse(response *interop.ErrorInvokeResponse) error { - return nil -} - -func TestRestoreErrorAndAwaitRestoreCompletionRaceCondition(t *testing.T) { - appCtx := appctx.NewApplicationContext() - initFlow := core.NewInitFlowSynchronization() - invokeFlow := core.NewInvokeFlowSynchronization() - registrationService := core.NewRegistrationService(initFlow, invokeFlow) - - rapidCtx := makeRapidContext(appCtx, initFlow, invokeFlow, registrationService, nil /* don't set process supervisor */) - - // Runtime state machine - runtime := core.NewRuntime(initFlow, invokeFlow) - registrationService.PreregisterRuntime(runtime) - runtime.SetState(runtime.RuntimeRestoreReadyState) - - restore := &interop.Restore{ - AwsKey: "key", - AwsSecret: "secret", - AwsSession: "session", - CredentialsExpiry: time.Now(), - RestoreHookTimeoutMs: 10 * 1000, - } - - var wg sync.WaitGroup - - wg.Add(1) - - go func() { - defer wg.Done() - _, err := rapidCtx.HandleRestore(restore) - assert.Equal(t, err.Error(), "errRestoreHookUserError") - v, ok := err.(interop.ErrRestoreHookUserError) - assert.True(t, ok) - assert.Equal(t, v.UserError.Type, fatalerror.ErrorType(hookErrorType)) - }() - - responseRecorder := httptest.NewRecorder() - - handler := handler.NewRestoreErrorHandler(registrationService) - - request := makeRequest(appCtx) - - wg.Add(1) - - time.Sleep(1 * time.Second) - runtime.SetState(runtime.RuntimeRestoringState) - - go func() { - defer wg.Done() - handler.ServeHTTP(responseRecorder, request) - }() - - wg.Wait() -} - -type MockedProcessSupervisor struct { - mock.Mock -} - -func (supv *MockedProcessSupervisor) Exec(ctx context.Context, req *model.ExecRequest) error { - args := supv.Called(req) - return args.Error(0) -} - -func (supv *MockedProcessSupervisor) Events(ctx context.Context, req *model.EventsRequest) (<-chan model.Event, error) { - args := supv.Called(req) - err := args.Error(1) - if err != nil { - return nil, err - } - return args.Get(0).(<-chan model.Event), nil -} - -func (supv *MockedProcessSupervisor) Terminate(ctx context.Context, req *model.TerminateRequest) error { - args := supv.Called(req) - return args.Error(0) -} - -func (supv *MockedProcessSupervisor) Kill(ctx context.Context, req *model.KillRequest) error { - args := supv.Called(req) - return args.Error(0) -} - -var _ model.ProcessSupervisor = (*MockedProcessSupervisor)(nil) - -func TestSetupEventWatcherErrorHandling(t *testing.T) { - appCtx := appctx.NewApplicationContext() - initFlow := core.NewInitFlowSynchronization() - invokeFlow := core.NewInvokeFlowSynchronization() - registrationService := core.NewRegistrationService(initFlow, invokeFlow) - mockedProcessSupervisor := &MockedProcessSupervisor{} - mockedProcessSupervisor.On("Events", mock.Anything).Return(nil, fmt.Errorf("events call failed")) - procSupv := &processSupervisor{ProcessSupervisor: mockedProcessSupervisor} - - rapidCtx := makeRapidContext(appCtx, initFlow, invokeFlow, registrationService, procSupv) - - initSuccessResponseChan := make(chan interop.InitSuccess) - initFailureResponseChan := make(chan interop.InitFailure) - init := &interop.Init{EnvironmentVariables: env.NewEnvironment()} - - go assert.NotPanics(t, func() { - rapidCtx.HandleInit(init, initSuccessResponseChan, initFailureResponseChan) - }) - - failure := <-initFailureResponseChan - failure.Ack <- struct{}{} - errorType := interop.InitFailure(failure).ErrorType - assert.Equal(t, fatalerror.SandboxFailure, errorType) -} diff --git a/lambda/rapid/sandbox.go b/lambda/rapid/sandbox.go deleted file mode 100644 index 26eaff0..0000000 --- a/lambda/rapid/sandbox.go +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapid - -import ( - "bytes" - "context" - "fmt" - "io" - "sync" - - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapi" - "go.amzn.com/lambda/rapi/rendering" - supvmodel "go.amzn.com/lambda/supervisor/model" - "go.amzn.com/lambda/telemetry" -) - -type Sandbox struct { - EnableTelemetryAPI bool - StandaloneMode bool - InteropServer interop.Server - Tracer telemetry.Tracer - LogsSubscriptionAPI telemetry.SubscriptionAPI - TelemetrySubscriptionAPI telemetry.SubscriptionAPI - LogsEgressAPI telemetry.StdLogsEgressAPI - RuntimeStdoutWriter io.Writer - RuntimeStderrWriter io.Writer - Handler string - EventsAPI interop.EventsAPI - InitCachingEnabled bool - Supervisor supvmodel.ProcessSupervisor - RuntimeFsRootPath string // path to the root of the domain within the root mnt namespace. Reqired to find extensions - RuntimeAPIHost string - RuntimeAPIPort int -} - -// Start pings Supervisor, and starts the Runtime API server. It allows the caller to configure: -// - Supervisor implementation: performs container construction & process management -// - Telemetry API and Logs API implementation: handling /logs and /telemetry of Runtime API -// - Events API implementation: handles platform log events emitted by Rapid (e.g. RuntimeDone, InitStart) -// - Logs Egress implementation: handling stdout/stderr logs from extension & runtime processes (TODO: remove & unify with Supervisor) -// - Tracer implementation: handling trace segments generate by platform (TODO: remove & unify with Events API) -// - InteropServer implementation: legacy interface for sending internal protocol messages, today only RuntimeReady remains (TODO: move RuntimeReady outside Core) -// - Feature flags: -// - StandaloneMode: indicates if being called by Rapid Core's standalone HTTP frontend (TODO: remove after unifying error reporting) -// - InitCachingEnabled: indicates if handlers must run Init Caching specific logic -// - TelemetryAPIEnabled: indicates if /telemetry and /logs endpoint HTTP handlers must be mounted -// -// - Contexts & Data: -// - ctx is used to gracefully terminate Runtime API HTTP Server on exit -func Start(ctx context.Context, s *Sandbox) (interop.RapidContext, interop.InternalStateGetter, string) { - // Initialize internal state objects required by Rapid handlers - appCtx := appctx.NewApplicationContext() - initFlow := core.NewInitFlowSynchronization() - invokeFlow := core.NewInvokeFlowSynchronization() - registrationService := core.NewRegistrationService(initFlow, invokeFlow) - renderingService := rendering.NewRenderingService() - credentialsService := core.NewCredentialsService() - - appctx.StoreInitType(appCtx, s.InitCachingEnabled) - - server := rapi.NewServer(s.RuntimeAPIHost, s.RuntimeAPIPort, appCtx, registrationService, renderingService, s.EnableTelemetryAPI, s.LogsSubscriptionAPI, s.TelemetrySubscriptionAPI, credentialsService) - runtimeAPIAddr := fmt.Sprintf("%s:%d", server.Host(), server.Port()) - - // TODO: pass this directly down to HTTP servers and handlers, instead of using - // global state to share the interop server implementation - appctx.StoreInteropServer(appCtx, s.InteropServer) - - execCtx := &rapidContext{ - // Internally initialized configurations - server: server, - appCtx: appCtx, - initDone: false, - initFlow: initFlow, - invokeFlow: invokeFlow, - registrationService: registrationService, - renderingService: renderingService, - credentialsService: credentialsService, - handlerExecutionMutex: sync.Mutex{}, - shutdownContext: newShutdownContext(), - - // Externally specified configurations (i.e. via SandboxBuilder) - telemetryAPIEnabled: s.EnableTelemetryAPI, - logsSubscriptionAPI: s.LogsSubscriptionAPI, - telemetrySubscriptionAPI: s.TelemetrySubscriptionAPI, - logsEgressAPI: s.LogsEgressAPI, - interopServer: s.InteropServer, - xray: s.Tracer, - standaloneMode: s.StandaloneMode, - eventsAPI: s.EventsAPI, - initCachingEnabled: s.InitCachingEnabled, - supervisor: processSupervisor{ - ProcessSupervisor: s.Supervisor, - RootPath: s.RuntimeFsRootPath, - }, - - RuntimeStartedTime: -1, - RuntimeOverheadStartedTime: -1, - InvokeResponseMetrics: nil, - } - - go startRuntimeAPI(ctx, execCtx) - - return execCtx, registrationService.GetInternalStateDescriptor(appCtx), runtimeAPIAddr -} - -func (r *rapidContext) HandleInit(init *interop.Init, initSuccessResponseChan chan<- interop.InitSuccess, initFailureResponseChan chan<- interop.InitFailure) { - r.handlerExecutionMutex.Lock() - defer r.handlerExecutionMutex.Unlock() - handleInit(r, init, initSuccessResponseChan, initFailureResponseChan) -} - -func (r *rapidContext) HandleInvoke(invoke *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit, requestBuffer *bytes.Buffer, responseSender interop.InvokeResponseSender) (interop.InvokeSuccess, *interop.InvokeFailure) { - r.handlerExecutionMutex.Lock() - defer r.handlerExecutionMutex.Unlock() - // Clear the context used by the last invoke - r.appCtx.Delete(appctx.AppCtxInvokeErrorTraceDataKey) - return handleInvoke(r, invoke, sbInfoFromInit, requestBuffer, responseSender) -} - -func (r *rapidContext) HandleReset(reset *interop.Reset) (interop.ResetSuccess, *interop.ResetFailure) { - // In the event of a Reset during init/invoke, CancelFlows cancels execution - // flows and return with the errResetReceived err - this error is special-cased - // and not handled by the init/invoke (unexpected) error handling functions - r.registrationService.CancelFlows(errResetReceived) - - // Wait until invoke error handling has returned before continuing execution - r.handlerExecutionMutex.Lock() - defer r.handlerExecutionMutex.Unlock() - - // Clear the context used by the last invoke - r.appCtx.Delete(appctx.AppCtxInvokeErrorTraceDataKey) - return handleReset(r, reset, r.RuntimeStartedTime, r.InvokeResponseMetrics) -} - -func (r *rapidContext) HandleShutdown(shutdown *interop.Shutdown) interop.ShutdownSuccess { - // Wait until invoke error handling has returned before continuing execution - r.handlerExecutionMutex.Lock() - defer r.handlerExecutionMutex.Unlock() - // Shutdown doesn't cancel flows, so it can block forever - return handleShutdown(r, shutdown, standaloneShutdownReason) -} - -func (r *rapidContext) HandleRestore(restore *interop.Restore) (interop.RestoreResult, error) { - return handleRestore(r, restore) -} - -func (r *rapidContext) Clear() { - reinitialize(r) -} - -func (r *rapidContext) SetRuntimeStartedTime(runtimeStartedTime int64) { - r.RuntimeStartedTime = runtimeStartedTime -} - -func (r *rapidContext) SetRuntimeOverheadStartedTime(runtimeOverheadStartedTime int64) { - r.RuntimeOverheadStartedTime = runtimeOverheadStartedTime -} - -func (r *rapidContext) SetInvokeResponseMetrics(metrics *interop.InvokeResponseMetrics) { - r.InvokeResponseMetrics = metrics -} - -func (r *rapidContext) SetLogStreamName(logStreamName string) { - r.logStreamName = logStreamName -} - -func (r *rapidContext) SetEventsAPI(eventsAPI interop.EventsAPI) { - r.eventsAPI = eventsAPI -} diff --git a/lambda/rapid/shutdown.go b/lambda/rapid/shutdown.go deleted file mode 100644 index 05695e3..0000000 --- a/lambda/rapid/shutdown.go +++ /dev/null @@ -1,368 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -// Package rapid implements synchronous even dispatch loop. -package rapid - -import ( - "context" - "errors" - "fmt" - "sync" - "time" - - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/metering" - "go.amzn.com/lambda/rapi/model" - "go.amzn.com/lambda/rapi/rendering" - supvmodel "go.amzn.com/lambda/supervisor/model" - - log "github.com/sirupsen/logrus" -) - -const ( - // supervisor shutdown and kill operations block until the exit status of the - // interested process has been collected, or until the specified deadline expires - // Note that this deadline is mainly relevant when any of the domain - // processes are in uninterruptible sleep state (notable examples: syscall - // to read/write a networked driver) - // - // We set a non nil value for these timeouts so that RAPID doesn't block - // forever in one of the cases above. - supervisorBlockingMaxMillis = 9000 - runtimeDeadlineShare = 0.3 - - maxProcessExitWait = 2 * time.Second -) - -// TODO: aggregate struct's methods into an interface, so that we can mock in tests -type shutdownContext struct { - // Adding a mutex around shuttingDown because there may be concurrent reads/writes. - // Because the code in shutdown() and the seperate go routine created in setupEventsWatcher() - // could be concurrently accessing the field shuttingDown. - shuttingDownMutex sync.Mutex - shuttingDown bool - agentsAwaitingExit map[string]*core.ExternalAgent - // Adding a mutex around runtimeDomainExited because there may be concurrent reads/writes. - // The first reason this can be caused is by different go routines reading/writing different keys. - // The second reason this can be caused is between the code shutting down the runtime/extensions and - // handleProcessExit in a separate go routine, reading and writing to the same key. Caused by - // unexpected exits. - runtimeDomainExitedMutex sync.Mutex - // used to synchronize on processes exits. We create the channel when a - // process is started and we close it upon exit notification from - // supervisor. Closing the channel is basically a persistent broadcast of process exit. - // We never write anything to the channels - runtimeDomainExited map[string]chan struct{} -} - -func newShutdownContext() *shutdownContext { - return &shutdownContext{ - shuttingDownMutex: sync.Mutex{}, - shuttingDown: false, - agentsAwaitingExit: make(map[string]*core.ExternalAgent), - runtimeDomainExited: make(map[string]chan struct{}), - runtimeDomainExitedMutex: sync.Mutex{}, - } -} - -func (s *shutdownContext) isShuttingDown() bool { - s.shuttingDownMutex.Lock() - defer s.shuttingDownMutex.Unlock() - return s.shuttingDown -} - -func (s *shutdownContext) setShuttingDown(value bool) { - s.shuttingDownMutex.Lock() - defer s.shuttingDownMutex.Unlock() - s.shuttingDown = value -} - -func (s *shutdownContext) handleProcessExit(termination supvmodel.ProcessTermination) { - - name := *termination.Name - agent, found := s.agentsAwaitingExit[name] - - // If it is an agent registered to receive a shutdown event. - if found { - log.Debugf("Handling termination for %s", name) - exitStatus := termination.Exited() - if exitStatus != nil && *exitStatus == 0 { - // If the agent exited by itself after receiving the shutdown event. - stateErr := agent.Exited() - if stateErr != nil { - log.Warnf("%s failed to transition to EXITED: %s (current state: %s)", agent.String(), stateErr, agent.GetState().Name()) - } - } else { - // If the agent did not exit by itself, had to be SIGKILLed (only in standalone mode). - stateErr := agent.ShutdownFailed() - if stateErr != nil { - log.Warnf("%s failed to transition to ShutdownFailed: %s (current state: %s)", agent, stateErr, agent.GetState().Name()) - } - } - } - - exitedChannel, found := s.getExitedChannel(name) - - if !found { - log.Panicf("Unable to find an exitedChannel for '%s', it should have been created just after it was execed.", name) - } - // we close the channel so that whoever is blocked on it - // or will try to block on it in the future unblocks immediately - close(exitedChannel) -} - -func (s *shutdownContext) getExitedChannel(name string) (chan struct{}, bool) { - s.runtimeDomainExitedMutex.Lock() - defer s.runtimeDomainExitedMutex.Unlock() - exitedChannel, found := s.runtimeDomainExited[name] - return exitedChannel, found -} - -func (s *shutdownContext) createExitedChannel(name string) { - s.runtimeDomainExitedMutex.Lock() - defer s.runtimeDomainExitedMutex.Unlock() - - _, found := s.runtimeDomainExited[name] - - if found { - log.Panicf("Tried to create an exited channel for '%s' but one already exists.", name) - } - s.runtimeDomainExited[name] = make(chan struct{}) -} - -// Blocks until all the processes in the runtime domain generation have exited. -// This helps us have a nice sync point on Shutdown where we know for sure that -// all the processes have exited and the state has been cleared. The exception -// to that rule is that if any of the processes don't exit within -// maxProcessExitWait from the beginning of the waiting period, an error is -// returned, in order to prevent it from waiting forever if any of the processes -// cannot be killed. -// -// It is OK not to hold the lock because we know that this is called only during -// shutdown and nobody will start a new process during shutdown -func (s *shutdownContext) clearExitedChannel() error { - s.runtimeDomainExitedMutex.Lock() - mapLen := len(s.runtimeDomainExited) - channels := make([]chan struct{}, 0, mapLen) - for _, v := range s.runtimeDomainExited { - channels = append(channels, v) - } - s.runtimeDomainExitedMutex.Unlock() - - exitTimeout := time.After(maxProcessExitWait) - for _, v := range channels { - select { - case <-v: - case <-exitTimeout: - return errors.New("timed out waiting for runtime processes to exit") - } - } - - s.runtimeDomainExitedMutex.Lock() - s.runtimeDomainExited = make(map[string]chan struct{}, mapLen) - s.runtimeDomainExitedMutex.Unlock() - return nil -} - -func (s *shutdownContext) shutdownRuntime(execCtx *rapidContext, start time.Time, deadline time.Time) { - // If runtime is started: - // 1. SIGTERM and wait until deadline - // 2. SIGKILL on deadline - log.Debug("Shutting down the runtime.") - name := fmt.Sprintf("%s-%d", runtimeProcessName, execCtx.runtimeDomainGeneration) - exitedChannel, found := s.getExitedChannel(name) - - if found { - - err := execCtx.supervisor.Terminate(context.Background(), &supvmodel.TerminateRequest{ - Domain: RuntimeDomain, - Name: name, - }) - if err != nil { - // We are not reporting the error upstream because we will anyway - // shut the domain out at the end of the shutdown sequence - log.WithError(err).Warn("Failed sending Termination signal to runtime") - } - - ctx, cancel := context.WithDeadline(context.Background(), deadline) - defer cancel() - - select { - case <-ctx.Done(): - log.Warnf("Deadline: The runtime did not exit after deadline %s; Killing it.", deadline) - - err = execCtx.supervisor.Kill(context.Background(), &supvmodel.KillRequest{ - Domain: RuntimeDomain, - Name: name, - Deadline: time.Now().Add(time.Millisecond * supervisorBlockingMaxMillis), - }) - - if err != nil { - // We are not reporting the error upstream because we will anyway - // shut the domain out at the end of the shutdown sequence - log.WithError(err).Warn("Failed sending Kill signal to runtime") - } - case <-exitedChannel: - } - } else { - log.Warn("The runtime was not started.") - } - log.Debug("Shutdown the runtime.") -} - -func (s *shutdownContext) shutdownAgents(execCtx *rapidContext, start time.Time, deadline time.Time, reason string) { - // For each external agent, if agent is launched: - // 1. Send Shutdown event if subscribed for it, else send SIGKILL to process group - // 2. Wait for all Shutdown-subscribed agents to exit with deadline - // 3. Send SIGKILL to process group for Shutdown-subscribed agents on deadline - - log.Debug("Shutting down the agents.") - execCtx.renderingService.SetRenderer( - &rendering.ShutdownRenderer{ - AgentEvent: model.AgentShutdownEvent{ - AgentEvent: &model.AgentEvent{ - EventType: "SHUTDOWN", - DeadlineMs: deadline.UnixNano() / (1000 * 1000), - }, - ShutdownReason: reason, - }, - }) - - var wg sync.WaitGroup - - // clear agentsAwaitingExit from last shutdownAgents - s.agentsAwaitingExit = make(map[string]*core.ExternalAgent) - - for _, a := range execCtx.registrationService.GetExternalAgents() { - name := fmt.Sprintf("extension-%s-%d", a.Name, execCtx.runtimeDomainGeneration) - exitedChannel, found := s.getExitedChannel(name) - - if !found { - log.Warnf("Agent %s failed to launch, therefore skipping shutting it down.", a) - continue - } - - wg.Add(1) - - if a.IsSubscribed(core.ShutdownEvent) { - log.Debugf("Agent %s is registered for the shutdown event.", a) - s.agentsAwaitingExit[name] = a - - go func(name string, agent *core.ExternalAgent) { - defer wg.Done() - - agent.Release() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - if execCtx.standaloneMode { - ctx, cancel = context.WithDeadline(ctx, deadline) - defer cancel() - } - - select { - case <-ctx.Done(): - log.Warnf("Deadline: the agent %s did not exit after deadline %s; Killing it.", name, deadline) - err := execCtx.supervisor.Kill(context.Background(), &supvmodel.KillRequest{ - Domain: RuntimeDomain, - Name: name, - Deadline: time.Now().Add(time.Millisecond * supervisorBlockingMaxMillis), - }) - if err != nil { - // We are not reporting the error upstream because we will anyway - // shut the domain out at the end of the shutdown sequence - log.WithError(err).Warn("Failed sending Kill signal to agent") - } - case <-exitedChannel: - } - }(name, a) - } else { - log.Debugf("Agent %s is not registered for the shutdown event, so just killing it.", a) - - go func(name string) { - defer wg.Done() - - err := execCtx.supervisor.Kill(context.Background(), &supvmodel.KillRequest{ - Domain: RuntimeDomain, - Name: name, - Deadline: time.Now().Add(time.Millisecond * supervisorBlockingMaxMillis), - }) - if err != nil { - log.WithError(err).Warn("Failed sending Kill signal to agent") - } - }(name) - } - } - - // Wait on the agents subscribed to the shutdown event to voluntary shutting down after receiving the shutdown event or be sigkilled. - // In addition to waiting on the agents not subscribed to the shutdown event being sigkilled. - wg.Wait() - log.Debug("Shutdown the agents.") -} - -func (s *shutdownContext) shutdown(execCtx *rapidContext, deadlineNs int64, reason string) (int64, bool, error) { - var err error - s.setShuttingDown(true) - defer s.setShuttingDown(false) - - // Fatal errors such as Runtime exit and Extension.Crash - // are ignored by the events watcher when shutting down - execCtx.appCtx.Delete(appctx.AppCtxFirstFatalErrorKey) - - runtimeDomainProfiler := &metering.ExtensionsResetDurationProfiler{} - - // We do not spend any compute time on runtime graceful shutdown if there are no agents - if execCtx.registrationService.CountAgents() == 0 { - name := fmt.Sprintf("%s-%d", runtimeProcessName, execCtx.runtimeDomainGeneration) - - _, found := s.getExitedChannel(name) - - if found { - log.Debug("SIGKILLing the runtime as no agents are registered.") - err = execCtx.supervisor.Kill(context.Background(), &supvmodel.KillRequest{ - Domain: RuntimeDomain, - Name: name, - Deadline: time.Now().Add(time.Millisecond * supervisorBlockingMaxMillis), - }) - if err != nil { - // We are not reporting the error upstream because we will anyway - // shut the domain out at the end of the shutdown sequence - log.WithError(err).Warn("Failed sending Kill signal to runtime") - } - } else { - log.Debugf("Could not find runtime process %s in processes map. Already exited/never started", name) - } - } else { - mono := metering.Monotime() - availableNs := deadlineNs - mono - - if availableNs < 0 { - log.Warnf("Deadline is in the past: %v, %v, %v", mono, deadlineNs, availableNs) - availableNs = 0 - } - - start := time.Now() - - runtimeDeadline := start.Add(time.Duration(float64(availableNs) * runtimeDeadlineShare)) - agentsDeadline := start.Add(time.Duration(availableNs)) - - runtimeDomainProfiler.AvailableNs = availableNs - runtimeDomainProfiler.Start() - - s.shutdownRuntime(execCtx, start, runtimeDeadline) - s.shutdownAgents(execCtx, start, agentsDeadline, reason) - - runtimeDomainProfiler.NumAgentsRegisteredForShutdown = len(s.agentsAwaitingExit) - } - - log.Info("Waiting for runtime domain processes termination") - if err := s.clearExitedChannel(); err != nil { - log.Error(err) - } - - runtimeDomainProfiler.Stop() - extensionsResetMs, timeout := runtimeDomainProfiler.CalculateExtensionsResetMs() - return extensionsResetMs, timeout, err -} diff --git a/lambda/rapidcore/env/constants.go b/lambda/rapidcore/env/constants.go deleted file mode 100644 index 50e9fcb..0000000 --- a/lambda/rapidcore/env/constants.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package env - -func predefinedInternalEnvVarKeys() map[string]bool { - return map[string]bool{ - "_LAMBDA_SB_ID": true, - "_LAMBDA_LOG_FD": true, - "_LAMBDA_SHARED_MEM_FD": true, - "_LAMBDA_CONTROL_SOCKET": true, - "_LAMBDA_DIRECT_INVOKE_SOCKET": true, - "_LAMBDA_RUNTIME_LOAD_TIME": true, - "_LAMBDA_CONSOLE_SOCKET": true, - // _X_AMZN_TRACE_ID is set by stock runtimes. Provided - // runtimes should set and mutate it on each invoke. - "_X_AMZN_TRACE_ID": true, - "_LAMBDA_TELEMETRY_API_PASSPHRASE": true, - } -} - -func predefinedPlatformEnvVarKeys() map[string]bool { - return map[string]bool{ - "AWS_REGION": true, - "AWS_DEFAULT_REGION": true, - "AWS_LAMBDA_FUNCTION_NAME": true, - "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": true, - "AWS_LAMBDA_FUNCTION_VERSION": true, - "AWS_LAMBDA_RUNTIME_API": true, - "TZ": true, - } -} - -func predefinedRuntimeEnvVarKeys() map[string]bool { - return map[string]bool{ - "_HANDLER": true, - "AWS_EXECUTION_ENV": true, - "AWS_LAMBDA_LOG_GROUP_NAME": true, - "AWS_LAMBDA_LOG_STREAM_NAME": true, - "LAMBDA_TASK_ROOT": true, - "LAMBDA_RUNTIME_DIR": true, - } -} - -func predefinedPlatformUnreservedEnvVarKeys() map[string]bool { - return map[string]bool{ - // AWS_XRAY_DAEMON_ADDRESS is unreserved but RAPID boot depends on it - "AWS_XRAY_DAEMON_ADDRESS": true, - } -} - -func predefinedCredentialsEnvVarKeys() map[string]bool { - return map[string]bool{ - "AWS_ACCESS_KEY_ID": true, - "AWS_SECRET_ACCESS_KEY": true, - "AWS_SESSION_TOKEN": true, - } -} - -func extensionExcludedKeys() map[string]bool { - return map[string]bool{ - "AWS_XRAY_CONTEXT_MISSING": true, - "_AWS_XRAY_DAEMON_ADDRESS": true, - "_AWS_XRAY_DAEMON_PORT": true, - "_LAMBDA_TELEMETRY_LOG_FD": true, - } -} diff --git a/lambda/rapidcore/env/customer.go b/lambda/rapidcore/env/customer.go deleted file mode 100644 index f784570..0000000 --- a/lambda/rapidcore/env/customer.go +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package env - -import ( - "os" - "strings" - - log "github.com/sirupsen/logrus" -) - -func isInternalEnvVar(envKey string) bool { - // the rule is no '_' prefixed env. variables will be propagated to the runtime but the ones explicitly exempted - allowedKeys := map[string]bool{ - "_HANDLER": true, - "_AWS_XRAY_DAEMON_ADDRESS": true, - "_AWS_XRAY_DAEMON_PORT": true, - "_LAMBDA_TELEMETRY_LOG_FD": true, - } - return strings.HasPrefix(envKey, "_") && !allowedKeys[envKey] -} - -// CustomerEnvironmentVariables parses all environment variables that are -// not internal/credential/platform, and must be called before agent bootstrap. -func CustomerEnvironmentVariables() map[string]string { - internalKeys := predefinedInternalEnvVarKeys() - platformKeys := predefinedPlatformEnvVarKeys() - runtimeKeys := predefinedRuntimeEnvVarKeys() - credentialKeys := predefinedCredentialsEnvVarKeys() - platformUnreservedKeys := predefinedPlatformUnreservedEnvVarKeys() - isCustomer := func(key string) bool { - return !internalKeys[key] && - !runtimeKeys[key] && - !platformKeys[key] && - !credentialKeys[key] && - !platformUnreservedKeys[key] && - !isInternalEnvVar(key) - } - - customerEnv := map[string]string{} - for _, keyval := range os.Environ() { - key, val, err := SplitEnvironmentVariable(keyval) - if err != nil { - log.Warnf("Customer environment variable with invalid format: %s", err) - continue - } - - if isCustomer(key) { - customerEnv[key] = val - } - } - - return customerEnv -} diff --git a/lambda/rapidcore/env/environment.go b/lambda/rapidcore/env/environment.go deleted file mode 100644 index fbe0ef2..0000000 --- a/lambda/rapidcore/env/environment.go +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package env - -import ( - "fmt" - "os" - "strings" - - log "github.com/sirupsen/logrus" -) - -const runtimeAPIAddressKey = "AWS_LAMBDA_RUNTIME_API" -const handlerEnvKey = "_HANDLER" -const executionEnvKey = "AWS_EXECUTION_ENV" -const taskRootEnvKey = "LAMBDA_TASK_ROOT" -const runtimeDirEnvKey = "LAMBDA_RUNTIME_DIR" - -// Environment holds env vars for runtime, agents, and for -// internal use, parsed during startup and from START msg -type Environment struct { - Customer map[string]string // customer & unreserved platform env vars, set on INIT - - rapid map[string]string // env vars req'd internally by RAPID - platform map[string]string // reserved platform env vars as per Lambda docs - runtime map[string]string // reserved runtime env vars as per Lambda docs - platformUnreserved map[string]string // unreserved platform env vars that customers can override - credentials map[string]string // reserved env vars for credentials, set on INIT - - runtimeAPISet bool - initEnvVarsSet bool -} - -func lookupEnv(keys map[string]bool) map[string]string { - res := map[string]string{} - for key := range keys { - val, ok := os.LookupEnv(key) - if ok { - res[key] = val - } - } - return res -} - -// NewEnvironment parses environment variables into an Environment object -func NewEnvironment() *Environment { - return &Environment{ - rapid: lookupEnv(predefinedInternalEnvVarKeys()), - platform: lookupEnv(predefinedPlatformEnvVarKeys()), - runtime: lookupEnv(predefinedRuntimeEnvVarKeys()), - platformUnreserved: lookupEnv(predefinedPlatformUnreservedEnvVarKeys()), - - Customer: map[string]string{}, - credentials: map[string]string{}, - - runtimeAPISet: false, - initEnvVarsSet: false, - } - -} - -// StoreRuntimeAPIEnvironmentVariable stores value for AWS_LAMBDA_RUNTIME_API -func (e *Environment) StoreRuntimeAPIEnvironmentVariable(runtimeAPIAddress string) { - e.platform[runtimeAPIAddressKey] = runtimeAPIAddress - e.runtimeAPISet = true -} - -// SetHandler sets _HANDLER env variable value for Runtime -func (e *Environment) SetHandler(handler string) { - e.runtime[handlerEnvKey] = handler -} - -// GetExecutionEnv returns the current setting for AWS_EXECUTION_ENV -func (e *Environment) GetExecutionEnv() string { - return e.runtime[executionEnvKey] -} - -// SetExecutionEnv sets AWS_EXECUTION_ENV variable value for Runtime -func (e *Environment) SetExecutionEnv(executionEnv string) { - e.runtime[executionEnvKey] = executionEnv -} - -// SetTaskRoot sets the LAMBDA_TASK_ROOT environment variable for Runtime -func (e *Environment) SetTaskRoot(taskRoot string) { - e.runtime[taskRootEnvKey] = taskRoot -} - -// SetRuntimeDir sets the LAMBDA_RUNTIME_DIR environment variable for Runtime -func (e *Environment) SetRuntimeDir(runtimeDir string) { - e.runtime[runtimeDirEnvKey] = runtimeDir -} - -// StoreEnvironmentVariablesFromInit sets the environment variables -// for credentials & _HANDLER which are received in the START message -func (e *Environment) StoreEnvironmentVariablesFromInit(customerEnv map[string]string, handler, awsKey, awsSecret, awsSession, funcName, funcVer string) { - - e.credentials["AWS_ACCESS_KEY_ID"] = awsKey - e.credentials["AWS_SECRET_ACCESS_KEY"] = awsSecret - e.credentials["AWS_SESSION_TOKEN"] = awsSession - - e.storeNonCredentialEnvironmentVariablesFromInit(customerEnv, handler, funcName, funcVer) -} - -func (e *Environment) StoreEnvironmentVariablesFromInitForInitCaching(host string, port int, customerEnv map[string]string, handler, funcName, funcVer, token string) { - e.credentials["AWS_CONTAINER_CREDENTIALS_FULL_URI"] = fmt.Sprintf("http://%s:%d/2021-04-23/credentials", host, port) - e.credentials["AWS_CONTAINER_AUTHORIZATION_TOKEN"] = token - - e.storeNonCredentialEnvironmentVariablesFromInit(customerEnv, handler, funcName, funcVer) -} - -func (e *Environment) storeNonCredentialEnvironmentVariablesFromInit(customerEnv map[string]string, handler, funcName, funcVer string) { - if handler != "" { - e.SetHandler(handler) - } - - if funcName != "" { - e.platform["AWS_LAMBDA_FUNCTION_NAME"] = funcName - } - - if funcVer != "" { - e.platform["AWS_LAMBDA_FUNCTION_VERSION"] = funcVer - } - - e.mergeCustomerEnvironmentVariables(customerEnv) // overrides env vars from CLI options - e.initEnvVarsSet = true -} - -// StoreEnvironmentVariablesFromCLIOptions sets the environment -// variables received via a CLI flag, for example LCIS config -func (e *Environment) StoreEnvironmentVariablesFromCLIOptions(envVars map[string]string) { - e.mergeCustomerEnvironmentVariables(envVars) -} - -// mergeCustomerEnvironmentVariables appends to customer env vars, overwriting entries if they exist -func (e *Environment) mergeCustomerEnvironmentVariables(envVars map[string]string) { - e.Customer = mapUnion(e.Customer, envVars) -} - -// RuntimeExecEnv returns the key=value strings of all environment variables -// passed to runtime process on exec() -func (e *Environment) RuntimeExecEnv() map[string]string { - if !e.initEnvVarsSet || !e.runtimeAPISet { - log.Fatal("credentials, customer and runtime API address must be set") - } - - return mapUnion(e.Customer, e.platformUnreserved, e.credentials, e.runtime, e.platform) -} - -// AgentExecEnv returns the key=value strings of all environment variables -// passed to agent process on exec() -func (e *Environment) AgentExecEnv() map[string]string { - if !e.initEnvVarsSet || !e.runtimeAPISet { - log.Fatal("credentials, customer and runtime API address must be set") - } - - excludedKeys := extensionExcludedKeys() - excludeCondition := func(key string) bool { return excludedKeys[key] || strings.HasPrefix(key, "_") } - return mapExclude(mapUnion(e.Customer, e.credentials, e.platform), excludeCondition) -} - -func mapUnion(maps ...map[string]string) map[string]string { - // last maps in argument overwrite values of ones before - union := map[string]string{} - for _, m := range maps { - for key, val := range m { - union[key] = val - } - } - return union -} - -func mapExclude(m map[string]string, excludeCondition func(string) bool) map[string]string { - res := map[string]string{} - for key, val := range m { - if !excludeCondition(key) { - res[key] = val - } - } - return res -} diff --git a/lambda/rapidcore/env/environment_test.go b/lambda/rapidcore/env/environment_test.go deleted file mode 100644 index 04c0494..0000000 --- a/lambda/rapidcore/env/environment_test.go +++ /dev/null @@ -1,315 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package env - -import ( - "fmt" - "os" - "testing" - - "github.com/stretchr/testify/assert" -) - -func envToSlice(env map[string]string) []string { - ret := make([]string, len(env)) - i := 0 - for key, val := range env { - ret[i] = key + "=" + val - i++ - } - return ret -} - -func TestRAPIDInternalConfig(t *testing.T) { - os.Clearenv() - os.Setenv("_LAMBDA_SB_ID", "sbid") - os.Setenv("_LAMBDA_LOG_FD", "1") - os.Setenv("_LAMBDA_SHARED_MEM_FD", "1") - os.Setenv("_LAMBDA_CONTROL_SOCKET", "1") - os.Setenv("_LAMBDA_CONSOLE_SOCKET", "1") - os.Setenv("_LAMBDA_RUNTIME_LOAD_TIME", "1") - os.Setenv("LAMBDA_TASK_ROOT", "a") - os.Setenv("AWS_XRAY_DAEMON_ADDRESS", "a") - os.Setenv("AWS_LAMBDA_FUNCTION_NAME", "a") - os.Setenv("_LAMBDA_TELEMETRY_API_PASSPHRASE", "a") - os.Setenv("_LAMBDA_DIRECT_INVOKE_SOCKET", "1") - NewRapidConfig(NewEnvironment()) -} - -func TestEnvironmentParsing(t *testing.T) { - internalEnvVal, platformEnvVal, credsEnvVal := "rapid", "platform", "creds" - runtimeEnvVal := "runtime" - customerEnvVal := "customer=foo=bar" - runtimeAPIAddress := "host:port" - - os.Clearenv() - setAll(predefinedInternalEnvVarKeys(), internalEnvVal) - setAll(predefinedPlatformEnvVarKeys(), platformEnvVal) - setAll(predefinedRuntimeEnvVarKeys(), runtimeEnvVal) - setAll(predefinedPlatformUnreservedEnvVarKeys(), customerEnvVal) - setAll(predefinedCredentialsEnvVarKeys(), credsEnvVal) - os.Setenv("MY_FOOBAR_ENV_1", customerEnvVal) - os.Setenv("MY_EMPTY_ENV", "") - os.Setenv("_UNKNOWN_INTERNAL_ENV", platformEnvVal) - - env := NewEnvironment() // parse environment variables - customerEnv := CustomerEnvironmentVariables() - - env.StoreRuntimeAPIEnvironmentVariable(runtimeAPIAddress) - env.StoreEnvironmentVariablesFromInit(customerEnv, runtimeEnvVal, credsEnvVal, credsEnvVal, credsEnvVal, platformEnvVal, platformEnvVal) - - for _, val := range env.rapid { - assert.Equal(t, internalEnvVal, val) - } - - for key, val := range env.platform { - if key == runtimeAPIAddressKey { - assert.Equal(t, runtimeAPIAddress, val) - } else { - assert.Equal(t, platformEnvVal, val) - } - } - - for _, val := range env.runtime { - assert.Equal(t, runtimeEnvVal, val) - } - - for key, val := range env.credentials { - assert.Equal(t, credsEnvVal, val) - assert.NotContains(t, env.Customer, key) - } - - for _, val := range env.platformUnreserved { - assert.Equal(t, customerEnvVal, val) - } - - assert.Equal(t, customerEnvVal, env.Customer["MY_FOOBAR_ENV_1"]) - assert.Equal(t, "", env.Customer["MY_EMPTY_ENV"]) - assert.Equal(t, "", env.Customer["_UNKNOWN_INTERNAL_ENV"]) -} - -func TestEnvironmentParsingUnsetPlatformAndInternalEnvVarKeysAreDeleted(t *testing.T) { - // Done to ensure that we can continue to distinguish between unset and empty env vars - os.Clearenv() - env := NewEnvironment() - - assert.Len(t, env.rapid, 0) - assert.Len(t, env.platform, 0) - assert.Len(t, env.platformUnreserved, 0) - assert.Len(t, env.credentials, 0) // uninitialized - assert.Len(t, env.Customer, 0) // uninitialized -} - -func TestRuntimeExecEnvironmentVariables(t *testing.T) { - internalEnvVal, platformEnvVal, credsEnvVal := "rapid", "platform", "creds" - customerEnvVal, platformUnreservedEnvVal := "customer", "platform-unreserved" - lcisCLIArgEnvVal := "lcis" - runtimeAPIAddress := "host:port" - runtimeEnvVal := "runtime" - - os.Clearenv() - setAll(predefinedInternalEnvVarKeys(), internalEnvVal) - setAll(predefinedPlatformEnvVarKeys(), platformEnvVal) - setAll(predefinedRuntimeEnvVarKeys(), runtimeEnvVal) - setAll(predefinedPlatformUnreservedEnvVarKeys(), platformUnreservedEnvVal) - setAll(predefinedCredentialsEnvVarKeys(), credsEnvVal) - customerEnv := map[string]string{ - "MY_FOOBAR_ENV_1": customerEnvVal, - } - - cliOptionsEnv := map[string]string{ - "LCIS_ARG1": lcisCLIArgEnvVal, - } - - env := NewEnvironment() - env.StoreRuntimeAPIEnvironmentVariable(runtimeAPIAddress) - env.StoreEnvironmentVariablesFromCLIOptions(cliOptionsEnv) - env.StoreEnvironmentVariablesFromInit(customerEnv, runtimeEnvVal, credsEnvVal, credsEnvVal, credsEnvVal, platformEnvVal, platformEnvVal) - - rapidEnvVars := env.RuntimeExecEnv() - - var rapidEnvKeys []string - for key := range rapidEnvVars { - rapidEnvKeys = append(rapidEnvKeys, key) - } - - rapidEnvVarsSlice := envToSlice(rapidEnvVars) - - for key := range env.rapid { - assert.NotContains(t, rapidEnvKeys, key) - } - - for key, val := range env.runtime { - assert.Contains(t, rapidEnvVarsSlice, key+"="+val) - } - - for key, val := range env.platform { - assert.Contains(t, rapidEnvVarsSlice, key+"="+val) - } - - for key, val := range env.platformUnreserved { - assert.Contains(t, rapidEnvVarsSlice, key+"="+val) - assert.NotContains(t, env.Customer, key) - } - - for key, val := range env.credentials { - assert.Contains(t, rapidEnvVarsSlice, key+"="+val) - } - - for key, val := range env.Customer { - assert.Contains(t, rapidEnvVarsSlice, key+"="+val) - assert.NotContains(t, env.platformUnreserved, key) - } -} - -func TestRuntimeExecEnvironmentVariablesPriority(t *testing.T) { - internalEnvVal, platformEnvVal, credsEnvVal := "rapid", "platform", "creds" - customerEnvVal, platformUnreservedEnvVal := "customer", "platform-unreserved" - runtimeEnvVal := "runtime" - lcisCLIArgEnvVal := "lcis" - runtimeAPIAddress := "host:port" - - os.Clearenv() - setAll(predefinedInternalEnvVarKeys(), internalEnvVal) - setAll(predefinedPlatformEnvVarKeys(), platformEnvVal) - setAll(predefinedPlatformUnreservedEnvVarKeys(), platformUnreservedEnvVal) - setAll(predefinedCredentialsEnvVarKeys(), credsEnvVal) - setAll(predefinedRuntimeEnvVarKeys(), runtimeEnvVal) - - conflictPlatformKeyFromInit := "AWS_REGION" - conflictPlatformKeyFromCLI := "LAMBDA_TASK_ROOT" - - customerEnv := map[string]string{ - "MY_FOOBAR_ENV_1": customerEnvVal, - conflictPlatformKeyFromInit: customerEnvVal, - } - - cliOptionsEnv := map[string]string{ - "LCIS_ARG1": lcisCLIArgEnvVal, - conflictPlatformKeyFromCLI: lcisCLIArgEnvVal, - } - - env := NewEnvironment() - env.StoreRuntimeAPIEnvironmentVariable(runtimeAPIAddress) - env.StoreEnvironmentVariablesFromCLIOptions(cliOptionsEnv) - env.StoreEnvironmentVariablesFromInit(customerEnv, runtimeEnvVal, credsEnvVal, credsEnvVal, credsEnvVal, platformEnvVal, platformEnvVal) - - assert.Equal(t, len(predefinedPlatformEnvVarKeys()), len(env.platform)) - assert.Equal(t, len(predefinedCredentialsEnvVarKeys()), len(env.credentials)) - assert.Equal(t, len(predefinedPlatformUnreservedEnvVarKeys()), len(env.platformUnreserved)) - assert.Equal(t, len(predefinedInternalEnvVarKeys()), len(env.rapid)) - assert.Equal(t, len(predefinedRuntimeEnvVarKeys()), len(env.runtime)) - - rapidEnvVars := envToSlice(env.RuntimeExecEnv()) - - // Customer env vars cannot override platform/internal ones - assert.NotContains(t, rapidEnvVars, conflictPlatformKeyFromInit+"="+customerEnvVal) - assert.NotContains(t, rapidEnvVars, conflictPlatformKeyFromCLI+"="+lcisCLIArgEnvVal) - assert.Contains(t, rapidEnvVars, conflictPlatformKeyFromInit+"="+platformEnvVal) - assert.Contains(t, rapidEnvVars, conflictPlatformKeyFromCLI+"="+runtimeEnvVal) -} - -func TestCustomerEnvironmentVariablesFromInitCanOverrideEnvironmentVariablesFromCLIOptions(t *testing.T) { - platformEnvVal, credsEnvVal, customerEnvVal := "platform", "creds", "customer" - lcisCLIArgEnvVal := "lcis" - runtimeAPIAddress := "host:port" - runtimeEnvVal := "runtime" - - os.Clearenv() - customerEnv := map[string]string{ - "MY_FOOBAR_ENV_1": customerEnvVal, - } - - cliOptionsEnv := map[string]string{ - "LCIS_ARG1": lcisCLIArgEnvVal, - "MY_FOOBAR_ENV_1": lcisCLIArgEnvVal, - } - - env := NewEnvironment() - env.StoreRuntimeAPIEnvironmentVariable(runtimeAPIAddress) - env.StoreEnvironmentVariablesFromCLIOptions(cliOptionsEnv) - env.StoreEnvironmentVariablesFromInit(customerEnv, runtimeEnvVal, credsEnvVal, credsEnvVal, credsEnvVal, platformEnvVal, platformEnvVal) - - assert.Equal(t, env.Customer["LCIS_ARG1"], lcisCLIArgEnvVal) - assert.Equal(t, env.Customer["MY_FOOBAR_ENV_1"], customerEnvVal) - - rapidEnvVars := envToSlice(env.RuntimeExecEnv()) - - assert.Contains(t, rapidEnvVars, "LCIS_ARG1="+lcisCLIArgEnvVal) - assert.Contains(t, rapidEnvVars, "MY_FOOBAR_ENV_1="+customerEnvVal) -} - -func TestAgentExecEnvironmentVariables(t *testing.T) { - internalEnvVal, platformEnvVal, credsEnvVal := "rapid", "platform", "creds" - customerEnvVal, platformUnreservedEnvVal := "customer", "platform-unreserved" - runtimeAPIAddress := "host:port" - runtimeEnvVal := "runtime" - - os.Clearenv() - setAll(predefinedInternalEnvVarKeys(), internalEnvVal) - setAll(predefinedPlatformEnvVarKeys(), platformEnvVal) - setAll(predefinedPlatformUnreservedEnvVarKeys(), platformUnreservedEnvVal) - setAll(predefinedCredentialsEnvVarKeys(), credsEnvVal) - customerEnv := map[string]string{"MY_FOOBAR_ENV_1": customerEnvVal} - - env := NewEnvironment() - env.StoreRuntimeAPIEnvironmentVariable(runtimeAPIAddress) - env.StoreEnvironmentVariablesFromInit(customerEnv, runtimeEnvVal, credsEnvVal, credsEnvVal, credsEnvVal, platformEnvVal, platformEnvVal) - - agentEnvVars := env.AgentExecEnv() - - var agentEnvKeys []string - for key := range agentEnvVars { - agentEnvKeys = append(agentEnvKeys, key) - } - - agentEnvVarsSlice := envToSlice(agentEnvVars) - - for key := range env.rapid { - assert.NotContains(t, agentEnvKeys, key) - } - - for key, val := range env.runtime { - assert.NotContains(t, agentEnvVarsSlice, key+"="+val) - } - - for key := range env.platform { - assert.Contains(t, agentEnvKeys, key) - } - - for key := range env.Customer { - assert.Contains(t, agentEnvKeys, key) - } - - for key, val := range env.credentials { - assert.Contains(t, agentEnvVarsSlice, key+"="+val) - } - - assert.Contains(t, agentEnvVarsSlice, runtimeAPIAddressKey+"="+env.platform[runtimeAPIAddressKey]) -} - -func TestStoreEnvironmentVariablesFromInitCaching(t *testing.T) { - host := "samplehost" - port := 1234 - handler := "samplehandler" - funcName := "samplefunctionname" - funcVer := "samplefunctionver" - token := "sampletoken" - env := NewEnvironment() - customerEnv := CustomerEnvironmentVariables() - - env.StoreEnvironmentVariablesFromInitForInitCaching("samplehost", 1234, customerEnv, handler, funcName, funcVer, token) - - assert.Equal(t, fmt.Sprintf("http://%s:%d/2021-04-23/credentials", host, port), env.credentials["AWS_CONTAINER_CREDENTIALS_FULL_URI"]) - assert.Equal(t, token, env.credentials["AWS_CONTAINER_AUTHORIZATION_TOKEN"]) - assert.Equal(t, funcName, env.platform["AWS_LAMBDA_FUNCTION_NAME"]) - assert.Equal(t, funcVer, env.platform["AWS_LAMBDA_FUNCTION_VERSION"]) - assert.Equal(t, handler, env.runtime["_HANDLER"]) -} - -func setAll(keys map[string]bool, value string) { - for key := range keys { - os.Setenv(key, value) - } -} diff --git a/lambda/rapidcore/env/rapidenv.go b/lambda/rapidcore/env/rapidenv.go deleted file mode 100644 index bc1a6ad..0000000 --- a/lambda/rapidcore/env/rapidenv.go +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package env - -import ( - "strconv" - "syscall" - - log "github.com/sirupsen/logrus" -) - -// RapidConfig holds config req'd for RAPID's internal -// operation, parsed from internal env vars. -// It should be build using `NewRapidConfig` to make sure that all the -// internal invariants are respected. -type RapidConfig struct { - SbID string - LogFd int - ShmFd int - CtrlFd int - CnslFd int - DirectInvokeFd int - LambdaTaskRoot string - XrayDaemonAddress string - PreLoadTimeNs int64 - FunctionName string - TelemetryAPIPassphrase string -} - -// Build the `RapidConfig` struct checking all the internal invariants -func NewRapidConfig(e *Environment) RapidConfig { - return RapidConfig{ - SbID: getStrEnvVarOrDie(e.rapid, "_LAMBDA_SB_ID"), - LogFd: getSocketEnvVarOrDie(e.rapid, "_LAMBDA_LOG_FD"), - ShmFd: getSocketEnvVarOrDie(e.rapid, "_LAMBDA_SHARED_MEM_FD"), - CtrlFd: getSocketEnvVarOrDie(e.rapid, "_LAMBDA_CONTROL_SOCKET"), - CnslFd: getSocketEnvVarOrDie(e.rapid, "_LAMBDA_CONSOLE_SOCKET"), - DirectInvokeFd: getOptionalSocketEnvVar(e.rapid, "_LAMBDA_DIRECT_INVOKE_SOCKET"), - PreLoadTimeNs: getInt64EnvVarOrDie(e.rapid, "_LAMBDA_RUNTIME_LOAD_TIME"), - LambdaTaskRoot: getStrEnvVarOrDie(e.runtime, "LAMBDA_TASK_ROOT"), - XrayDaemonAddress: getStrEnvVarOrDie(e.platformUnreserved, "AWS_XRAY_DAEMON_ADDRESS"), - FunctionName: getStrEnvVarOrDie(e.platform, "AWS_LAMBDA_FUNCTION_NAME"), - TelemetryAPIPassphrase: e.rapid["_LAMBDA_TELEMETRY_API_PASSPHRASE"], // TODO: Die if not set - } -} - -func getStrEnvVarOrDie(env map[string]string, name string) string { - val, ok := env[name] - if !ok { - log.WithField("name", name).Fatal("Environment variable is not set") - } - return val -} - -func getInt64EnvVarOrDie(env map[string]string, name string) int64 { - strval := getStrEnvVarOrDie(env, name) - val, err := strconv.ParseInt(strval, 10, 64) - if err != nil { - log.WithError(err).WithField("name", name).Fatal("Unable to parse int env var.") - } - return val -} - -func getIntEnvVarOrDie(env map[string]string, name string) int { - return int(getInt64EnvVarOrDie(env, name)) -} - -// getSocketEnvVarOrDie reads and returns an int value of the -// environment variable or dies, when unable to do so. -// It also makes CloseOnExec for this value. -func getSocketEnvVarOrDie(env map[string]string, name string) int { - sock := getIntEnvVarOrDie(env, name) - syscall.CloseOnExec(sock) - return sock -} - -// returns -1 if env variable was not set. Exits if it holds unexpected (non-int) value -func getOptionalSocketEnvVar(env map[string]string, name string) int { - val, found := env[name] - if !found { - return -1 - } - - sock, err := strconv.Atoi(val) - if err != nil { - log.WithError(err).WithField("name", name).Fatal("Unable to parse socket env var.") - } - - if sock < 0 { - log.WithError(err).WithField("name", name).Fatal("Negative socket descriptor value") - } - - syscall.CloseOnExec(sock) - return sock -} diff --git a/lambda/rapidcore/env/util.go b/lambda/rapidcore/env/util.go deleted file mode 100644 index 83a69fd..0000000 --- a/lambda/rapidcore/env/util.go +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package env - -import ( - "errors" - "strings" -) - -func SplitEnvironmentVariable(envKeyVal string) (string, string, error) { - splitKeyVal := strings.SplitN(envKeyVal, "=", 2) // values can contain '=' - if len(splitKeyVal) < 2 { - return "", "", errors.New("could not split env var by '=' delimiter") - } - return splitKeyVal[0], splitKeyVal[1], nil -} diff --git a/lambda/rapidcore/env/util_test.go b/lambda/rapidcore/env/util_test.go deleted file mode 100644 index 2b0b139..0000000 --- a/lambda/rapidcore/env/util_test.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package env - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestEnvironmentVariableSplitting(t *testing.T) { - envVar := "FOO=BAR" - k, v, err := SplitEnvironmentVariable(envVar) - assert.NoError(t, err) - assert.Equal(t, k, "FOO") - assert.Equal(t, v, "BAR") - - envVar = "FOO=BAR=BAZ" - k, v, err = SplitEnvironmentVariable(envVar) - assert.NoError(t, err) - assert.Equal(t, k, "FOO") - assert.Equal(t, v, "BAR=BAZ") - - envVar = "FOO=" - k, v, err = SplitEnvironmentVariable(envVar) - assert.NoError(t, err) - assert.Equal(t, k, "FOO") - assert.Equal(t, v, "") - - envVar = "FOO" - k, v, err = SplitEnvironmentVariable(envVar) - assert.Error(t, err) - assert.Equal(t, k, "") - assert.Equal(t, v, "") -} diff --git a/lambda/rapidcore/errors.go b/lambda/rapidcore/errors.go deleted file mode 100644 index 7f35ca8..0000000 --- a/lambda/rapidcore/errors.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapidcore - -import "errors" - -var ErrInitDoneFailed = errors.New("InitDoneFailed") -var ErrInitNotStarted = errors.New("InitNotStarted") -var ErrInitResetReceived = errors.New("InitResetReceived") - -var ErrNotReserved = errors.New("NotReserved") -var ErrAlreadyReserved = errors.New("AlreadyReserved") -var ErrAlreadyReplied = errors.New("AlreadyReplied") -var ErrAlreadyInvocating = errors.New("AlreadyInvocating") -var ErrReserveReservationDone = errors.New("ReserveReservationDone") - -var ErrInvokeResponseAlreadyWritten = errors.New("InvokeResponseAlreadyWritten") -var ErrInvokeDoneFailed = errors.New("InvokeDoneFailed") -var ErrInvokeReservationDone = errors.New("InvokeReservationDone") - -var ErrReleaseReservationDone = errors.New("ReleaseReservationDone") - -var ErrInternalServerError = errors.New("InternalServerError") -var ErrInvokeTimeout = errors.New("InvokeTimeout") diff --git a/lambda/rapidcore/runtime_release.go b/lambda/rapidcore/runtime_release.go deleted file mode 100644 index 3875209..0000000 --- a/lambda/rapidcore/runtime_release.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapidcore - -import ( - "bufio" - "fmt" - "os" - "strings" -) - -type Logging string - -const ( - AmznStdout Logging = "amzn-stdout" - AmznStdoutTLV Logging = "amzn-stdout-tlv" -) - -// RuntimeRelease stores runtime identification data -type RuntimeRelease struct { - Name string - Version string - Logging Logging -} - -const RuntimeReleasePath = "/var/runtime/runtime-release" - -// GetRuntimeRelease reads Runtime identification data from config file and parses it into a struct -func GetRuntimeRelease(path string) (*RuntimeRelease, error) { - pairs, err := ParsePropertiesFile(path) - if err != nil { - return nil, fmt.Errorf("could not parse %s: %w", path, err) - } - - return &RuntimeRelease{pairs["NAME"], pairs["VERSION"], Logging(pairs["LOGGING"])}, nil -} - -// ParsePropertiesFile reads key-value pairs from file in newline-separated list of environment-like -// shell-compatible variable assignments. -// Format: https://www.freedesktop.org/software/systemd/man/os-release.html -// Value quotes are trimmed. Latest write wins for duplicated keys. -func ParsePropertiesFile(path string) (map[string]string, error) { - f, err := os.Open(path) - if err != nil { - return nil, fmt.Errorf("could not open %s: %w", path, err) - } - defer f.Close() - - pairs := make(map[string]string) - - s := bufio.NewScanner(f) - for s.Scan() { - if s.Text() == "" || strings.HasPrefix(s.Text(), "#") { - continue - } - k, v, found := strings.Cut(s.Text(), "=") - if !found { - return nil, fmt.Errorf("could not parse key-value pair from a line: %s", s.Text()) - } - pairs[k] = strings.Trim(v, "'\"") - } - if err := s.Err(); err != nil { - return nil, fmt.Errorf("failed to read properties file: %w", err) - } - - return pairs, nil -} diff --git a/lambda/rapidcore/runtime_release_test.go b/lambda/rapidcore/runtime_release_test.go deleted file mode 100644 index 7397140..0000000 --- a/lambda/rapidcore/runtime_release_test.go +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapidcore - -import ( - "os" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestGetRuntimeRelease(t *testing.T) { - tests := []struct { - name string - content string - want *RuntimeRelease - }{ - { - "simple", - "NAME=foo\nVERSION=bar\nLOGGING=baz\n", - &RuntimeRelease{"foo", "bar", "baz"}, - }, - { - "no trailing new line", - "NAME=foo\nVERSION=bar\nLOGGING=baz", - &RuntimeRelease{"foo", "bar", "baz"}, - }, - { - "nonexistent keys", - "LOGGING=baz\n", - &RuntimeRelease{"", "", "baz"}, - }, - { - "empty value", - "NAME=\nVERSION=\nLOGGING=\n", - &RuntimeRelease{"", "", ""}, - }, - { - "delimiter in value", - "NAME=Foo=Bar\nVERSION=bar\nLOGGING=baz\n", - &RuntimeRelease{"Foo=Bar", "bar", "baz"}, - }, - { - "empty file", - "", - &RuntimeRelease{"", "", ""}, - }, - { - "quotes", - "NAME=\"foo\"\nVERSION='bar'\n", - &RuntimeRelease{"foo", "bar", ""}, - }, - { - "double quotes", - "NAME='\"foo\"'\nVERSION=\"'bar'\"\n", - &RuntimeRelease{"foo", "bar", ""}, - }, - { - "empty lines", // production runtime-release files have empty line in the end of the file - "\nNAME=foo\n\nVERSION=bar\n\nLOGGING=baz\n\n", - &RuntimeRelease{"foo", "bar", "baz"}, - }, - { - "comments", - "# comment 1\nNAME=foo\n# comment 2\nVERSION=bar\n# comment 3\nLOGGING=baz\n# comment 4\n", - &RuntimeRelease{"foo", "bar", "baz"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - f, err := os.CreateTemp(os.TempDir(), "runtime-release") - require.NoError(t, err) - _, err = f.WriteString(tt.content) - require.NoError(t, err) - got, err := GetRuntimeRelease(f.Name()) - assert.NoError(t, err) - assert.Equal(t, tt.want, got) - }) - } -} - -func TestGetRuntimeRelease_NotFound(t *testing.T) { - _, err := GetRuntimeRelease("/sys/not-exists") - assert.Error(t, err) -} - -func TestGetRuntimeRelease_InvalidLine(t *testing.T) { - f, err := os.CreateTemp(os.TempDir(), "runtime-release") - require.NoError(t, err) - _, err = f.WriteString("NAME=foo\nVERSION=bar\nLOGGING=baz\nSOMETHING") - require.NoError(t, err) - _, err = GetRuntimeRelease(f.Name()) - assert.Error(t, err) -} diff --git a/lambda/rapidcore/sandbox_api.go b/lambda/rapidcore/sandbox_api.go deleted file mode 100644 index 2e8d713..0000000 --- a/lambda/rapidcore/sandbox_api.go +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapidcore - -import ( - "bytes" - - "go.amzn.com/lambda/extensions" - "go.amzn.com/lambda/interop" -) - -// SandboxContext and other structs form the implementation of the SandboxAPI -// interface defined in interop/sandbox_model.go, using the implementation of -// Init, Invoke and Reset handlers in rapid/sandbox.go -type SandboxContext struct { - rapidCtx interop.RapidContext - handler string - runtimeAPIAddress string -} - -// initContext and its methods model the initialization lifecycle -// of the Sandbox, which persist across invocations -type initContext struct { - initSuccessChan chan interop.InitSuccess - initFailureChan chan interop.InitFailure - rapidCtx interop.RapidContext - sbInfoFromInit interop.SandboxInfoFromInit // contains data that needs to be persisted from init for suppressed inits during invoke - invokeRequestBuffer *bytes.Buffer // byte buffer used to store the invoke request rendered to runtime (reused until reset) -} - -// invokeContext and its methods model the invocation lifecycle -type invokeContext struct { - rapidCtx interop.RapidContext - invokeRequestChan chan *interop.Invoke - invokeSuccessChan chan interop.InvokeSuccess - invokeFailureChan chan interop.InvokeFailure - sbInfoFromInit interop.SandboxInfoFromInit // contains data that needs to be persisted from init for suppressed inits during invoke - invokeRequestBuffer *bytes.Buffer // byte buffer used to store the invoke request rendered to runtime (reused until reset) -} - -// Validate interface compliance -var _ interop.SandboxContext = (*SandboxContext)(nil) -var _ interop.InitContext = (*initContext)(nil) -var _ interop.InvokeContext = (*invokeContext)(nil) - -// Init starts the runtime domain initialization in a separate goroutine. -// Return value indicates that init request has been accepted and started. -func (s SandboxContext) Init(init *interop.Init, timeoutMs int64) interop.InitContext { - initSuccessResponseChan := make(chan interop.InitSuccess) - initFailureResponseChan := make(chan interop.InitFailure) - - if len(s.handler) > 0 { - init.EnvironmentVariables.SetHandler(s.handler) - } - - init.EnvironmentVariables.StoreRuntimeAPIEnvironmentVariable(s.runtimeAPIAddress) - extensions.DisableViaMagicLayer() - - // We start initialization handling in a separate goroutine so that control can be returned back to - // caller, which can do work (e.g. notifying further upstream that initialization has started), and - // and call initCtx.Wait() to wait async for completion of initialization phase. - go s.rapidCtx.HandleInit(init, initSuccessResponseChan, initFailureResponseChan) - - sbMetadata := interop.SandboxInfoFromInit{ - EnvironmentVariables: init.EnvironmentVariables, - SandboxType: init.SandboxType, - RuntimeBootstrap: init.Bootstrap, - } - return newInitContext(s.rapidCtx, sbMetadata, initSuccessResponseChan, initFailureResponseChan) -} - -// Reset triggers a reset. In case of timeouts, the reset handler cancels all flows which triggers -// ongoing invoke handlers to return before proceeding with invoke -// TODO: move this method to the initialization context, since reset is conceptually on RT domain -func (s SandboxContext) Reset(reset *interop.Reset) (interop.ResetSuccess, *interop.ResetFailure) { - defer s.rapidCtx.Clear() - return s.rapidCtx.HandleReset(reset) -} - -// Reset triggers a shutdown. This is similar to a reset, except that this is a terminal state -// and no further invokes are allowed -func (s SandboxContext) Shutdown(shutdown *interop.Shutdown) interop.ShutdownSuccess { - return s.rapidCtx.HandleShutdown(shutdown) -} - -func (s SandboxContext) Restore(restore *interop.Restore) (interop.RestoreResult, error) { - return s.rapidCtx.HandleRestore(restore) -} - -func (s *SandboxContext) SetRuntimeStartedTime(runtimeStartedTime int64) { - s.rapidCtx.SetRuntimeStartedTime(runtimeStartedTime) -} - -func (s *SandboxContext) SetInvokeResponseMetrics(metrics *interop.InvokeResponseMetrics) { - s.rapidCtx.SetInvokeResponseMetrics(metrics) -} - -func newInitContext(r interop.RapidContext, sbMetadata interop.SandboxInfoFromInit, - initSuccessChan chan interop.InitSuccess, initFailureChan chan interop.InitFailure) initContext { - - // Invocation request buffer is initialized once per initialization - // to reduce memory usage & GC CPU time across invocations - var requestBuffer bytes.Buffer - - return initContext{ - initSuccessChan: initSuccessChan, - initFailureChan: initFailureChan, - rapidCtx: r, - sbInfoFromInit: sbMetadata, - invokeRequestBuffer: &requestBuffer, - } -} - -// Wait awaits until initialization phase is complete, i.e. one of: -// - until all runtime domain process call /next -// - any one of the runtime domain processes exit (init failure) -// Timeout handling is managed upstream entirely -func (i initContext) Wait() (interop.InitSuccess, *interop.InitFailure) { - select { - case initSuccess, isOpen := <-i.initSuccessChan: - if !isOpen { - // If init has already suceeded, we return quickly - return interop.InitSuccess{}, nil - } - return initSuccess, nil - case initFailure, isOpen := <-i.initFailureChan: - if !isOpen { - // If init has already failed, we return quickly for init to be suppressed - return interop.InitSuccess{}, &initFailure - } - return interop.InitSuccess{}, &initFailure - } -} - -// Reserve is used to initialize invoke-related state -func (i initContext) Reserve() interop.InvokeContext { - invokeRequestChan := make(chan *interop.Invoke) - invokeSuccessChan := make(chan interop.InvokeSuccess) - invokeFailureChan := make(chan interop.InvokeFailure) - - return invokeContext{ - rapidCtx: i.rapidCtx, - invokeRequestChan: invokeRequestChan, - invokeSuccessChan: invokeSuccessChan, - invokeFailureChan: invokeFailureChan, - sbInfoFromInit: i.sbInfoFromInit, - invokeRequestBuffer: i.invokeRequestBuffer, - } -} - -// SendRequest starts the invocation request handling in a separate goroutine, -// i.e. sending the request payload via /next response, -// and waiting for the synchronization points -func (invCtx invokeContext) SendRequest(invoke *interop.Invoke, responseSender interop.InvokeResponseSender) { - // Invoke handling needs to be in a separate goroutine so that control can - // be returned immediately to calling goroutine, which can do work and - // asynchronously call invCtx.Wait() to await completion of the invoke phase - go func() { - // For suppressed inits, invoke needs the runtime and agent env vars - invokeSuccess, invokeFailure := invCtx.rapidCtx.HandleInvoke(invoke, invCtx.sbInfoFromInit, invCtx.invokeRequestBuffer, responseSender) - if invokeFailure != nil { - invCtx.invokeFailureChan <- *invokeFailure - } else { - invCtx.invokeSuccessChan <- invokeSuccess - } - }() -} - -// Wait awaits invoke completion, i.e. one of the following cases: -// - until all runtime domain process call /next -// - until a process exit (that notifies upstream to trigger a reset due to "failure") -// - until a timeout (triggered by a reset from upstream due to "timeout") -func (invCtx invokeContext) Wait() (interop.InvokeSuccess, *interop.InvokeFailure) { - select { - case invokeSuccess := <-invCtx.invokeSuccessChan: - return invokeSuccess, nil - case invokeFailure := <-invCtx.invokeFailureChan: - return interop.InvokeSuccess{}, &invokeFailure - } -} diff --git a/lambda/rapidcore/sandbox_builder.go b/lambda/rapidcore/sandbox_builder.go deleted file mode 100644 index f51acda..0000000 --- a/lambda/rapidcore/sandbox_builder.go +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapidcore - -import ( - "context" - "io" - "net" - "os" - "os/signal" - "strconv" - "syscall" - - "go.amzn.com/lambda/extensions" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/logging" - "go.amzn.com/lambda/rapid" - "go.amzn.com/lambda/supervisor" - supvmodel "go.amzn.com/lambda/supervisor/model" - "go.amzn.com/lambda/telemetry" - - log "github.com/sirupsen/logrus" -) - -const ( - defaultSigtermResetTimeoutMs = int64(2000) -) - -type SandboxBuilder struct { - sandbox *rapid.Sandbox - sandboxContext interop.SandboxContext - lambdaInvokeAPI LambdaInvokeAPI - defaultInteropServer *Server - useCustomInteropServer bool - shutdownFuncs []func() - handler string -} - -type logSink int - -const ( - RuntimeLogSink logSink = iota - ExtensionLogSink -) - -func NewSandboxBuilder() *SandboxBuilder { - defaultInteropServer := NewServer() - - localSv := supervisor.NewLocalSupervisor() - b := &SandboxBuilder{ - sandbox: &rapid.Sandbox{ - StandaloneMode: true, - LogsEgressAPI: &telemetry.NoOpLogsEgressAPI{}, - EnableTelemetryAPI: false, - Tracer: telemetry.NewNoOpTracer(), - EventsAPI: &telemetry.NoOpEventsAPI{}, - InitCachingEnabled: false, - Supervisor: localSv, - RuntimeFsRootPath: localSv.RootPath, - RuntimeAPIHost: "127.0.0.1", - RuntimeAPIPort: 9001, - }, - defaultInteropServer: defaultInteropServer, - shutdownFuncs: []func(){}, - lambdaInvokeAPI: NewEmulatorAPI(defaultInteropServer), - } - - b.AddShutdownFunc(func() { - log.Info("Shutting down...") - defaultInteropServer.Reset("SandboxTerminated", defaultSigtermResetTimeoutMs) - }) - - return b -} - -func (b *SandboxBuilder) SetSupervisor(supervisor supvmodel.ProcessSupervisor) *SandboxBuilder { - b.sandbox.Supervisor = supervisor - return b -} - -func (b *SandboxBuilder) SetRuntimeFsRootPath(rootPath string) *SandboxBuilder { - b.sandbox.RuntimeFsRootPath = rootPath - return b -} - -func (b *SandboxBuilder) SetRuntimeAPIAddress(runtimeAPIAddress string) *SandboxBuilder { - host, port, err := net.SplitHostPort(runtimeAPIAddress) - if err != nil { - log.WithError(err).Warnf("Failed to parse RuntimeApiAddress: %s:", runtimeAPIAddress) - return b - } - - portInt, err := strconv.Atoi(port) - if err != nil { - log.WithError(err).Warnf("Failed to parse RuntimeApiPort: %s:", port) - return b - } - - b.sandbox.RuntimeAPIHost = host - b.sandbox.RuntimeAPIPort = portInt - return b -} - -func (b *SandboxBuilder) SetInteropServer(interopServer interop.Server) *SandboxBuilder { - b.sandbox.InteropServer = interopServer - b.useCustomInteropServer = true - return b -} - -func (b *SandboxBuilder) SetEventsAPI(eventsAPI interop.EventsAPI) *SandboxBuilder { - b.sandbox.EventsAPI = eventsAPI - return b -} - -func (b *SandboxBuilder) SetTracer(tracer telemetry.Tracer) *SandboxBuilder { - b.sandbox.Tracer = tracer - return b -} - -func (b *SandboxBuilder) DisableStandaloneMode() *SandboxBuilder { - b.sandbox.StandaloneMode = false - return b -} - -func (b *SandboxBuilder) SetExtensionsFlag(extensionsEnabled bool) *SandboxBuilder { - if extensionsEnabled { - extensions.Enable() - } else { - extensions.Disable() - } - return b -} - -func (b *SandboxBuilder) SetInitCachingFlag(initCachingEnabled bool) *SandboxBuilder { - b.sandbox.InitCachingEnabled = initCachingEnabled - return b -} - -func (b *SandboxBuilder) SetTelemetrySubscription(logsSubscriptionAPI telemetry.SubscriptionAPI, telemetrySubscriptionAPI telemetry.SubscriptionAPI) *SandboxBuilder { - b.sandbox.EnableTelemetryAPI = true - b.sandbox.LogsSubscriptionAPI = logsSubscriptionAPI - b.sandbox.TelemetrySubscriptionAPI = telemetrySubscriptionAPI - return b -} - -func (b *SandboxBuilder) SetLogsEgressAPI(logsEgressAPI telemetry.StdLogsEgressAPI) *SandboxBuilder { - b.sandbox.LogsEgressAPI = logsEgressAPI - return b -} - -func (b *SandboxBuilder) SetHandler(handler string) *SandboxBuilder { - b.handler = handler - return b -} - -func (b *SandboxBuilder) AddShutdownFunc(shutdownFunc func()) *SandboxBuilder { - b.shutdownFuncs = append(b.shutdownFuncs, shutdownFunc) - return b -} - -func (b *SandboxBuilder) Create() (interop.SandboxContext, interop.InternalStateGetter) { - if !b.useCustomInteropServer { - b.sandbox.InteropServer = b.defaultInteropServer - } - - ctx, cancel := context.WithCancel(context.Background()) - - // cancel is called when handling termination signals as a cancellation - // signal to the Runtime API sever to terminate gracefully - go signalHandler(cancel, b.shutdownFuncs) - - // rapid.Start, among other things, starts the Runtime API server and - // terminates it gracefully if the cxt is canceled - rapidCtx, internalStateFn, runtimeAPIAddr := rapid.Start(ctx, b.sandbox) - - b.sandboxContext = &SandboxContext{ - rapidCtx: rapidCtx, - handler: b.handler, - runtimeAPIAddress: runtimeAPIAddr, - } - - return b.sandboxContext, internalStateFn -} - -func (b *SandboxBuilder) DefaultInteropServer() *Server { - return b.defaultInteropServer -} - -func (b *SandboxBuilder) LambdaInvokeAPI() LambdaInvokeAPI { - return b.lambdaInvokeAPI -} - -// SetLogLevel sets the log level for internal logging. Needs to be called very -// early during startup to configure logs emitted during initialization -func SetLogLevel(logLevel string) { - level, err := log.ParseLevel(logLevel) - if err != nil { - log.WithError(err).Fatal("Failed to set log level. Valid log levels are:", log.AllLevels) - } - - log.SetLevel(level) - log.SetFormatter(&logging.InternalFormatter{}) -} - -func SetInternalLogOutput(w io.Writer) { - logging.SetOutput(w) -} - -// Trap SIGINT and SIGTERM signals, call shutdown function, and cancel the -// ctx to terminate gracefully the Runtime API server -func signalHandler(cancel context.CancelFunc, shutdownFuncs []func()) { - defer cancel() - sig := make(chan os.Signal, 1) - signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) - sigReceived := <-sig - log.WithField("signal", sigReceived.String()).Info("Received signal") - for _, shutdownFunc := range shutdownFuncs { - shutdownFunc() - } -} diff --git a/lambda/rapidcore/sandbox_emulator_api.go b/lambda/rapidcore/sandbox_emulator_api.go deleted file mode 100644 index 4cc2183..0000000 --- a/lambda/rapidcore/sandbox_emulator_api.go +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapidcore - -import ( - "go.amzn.com/lambda/interop" - - "net/http" -) - -// LambdaInvokeAPI are the methods used by the Runtime Interface Emulator -type LambdaInvokeAPI interface { - Init(i *interop.Init, invokeTimeoutMs int64) - Invoke(responseWriter http.ResponseWriter, invoke *interop.Invoke) error -} - -// EmulatorAPI wraps the standalone interop server to provide a convenient interface -// for Rapid Standalone -type EmulatorAPI struct { - server *Server -} - -// Validate interface compliance -var _ LambdaInvokeAPI = (*EmulatorAPI)(nil) - -func NewEmulatorAPI(s *Server) *EmulatorAPI { - return &EmulatorAPI{s} -} - -// Init method is only used by the Runtime interface emulator -func (l *EmulatorAPI) Init(i *interop.Init, timeoutMs int64) { - l.server.Init(&interop.Init{ - AccountID: i.AccountID, - Handler: i.Handler, - AwsKey: i.AwsKey, - AwsSecret: i.AwsSecret, - AwsSession: i.AwsSession, - XRayDaemonAddress: i.XRayDaemonAddress, - FunctionName: i.FunctionName, - FunctionVersion: i.FunctionVersion, - CustomerEnvironmentVariables: i.CustomerEnvironmentVariables, - RuntimeInfo: i.RuntimeInfo, - SandboxType: i.SandboxType, - Bootstrap: i.Bootstrap, - EnvironmentVariables: i.EnvironmentVariables, - }, timeoutMs) -} - -// Invoke method is only used by the Runtime interface emulator -func (l *EmulatorAPI) Invoke(w http.ResponseWriter, i *interop.Invoke) error { - return l.server.Invoke(w, i) -} diff --git a/lambda/rapidcore/server.go b/lambda/rapidcore/server.go deleted file mode 100644 index d0f7c7a..0000000 --- a/lambda/rapidcore/server.go +++ /dev/null @@ -1,894 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 -// LOCALSTACK CHANGES 2023-10-17: pass request metadata into .Reserve(invoke.ID, invoke.TraceID, invoke.LambdaSegmentID) - -package rapidcore - -import ( - "bytes" - "context" - "errors" - "fmt" - "io" - "math" - "net/http" - "sync" - "time" - - "go.amzn.com/lambda/core/directinvoke" - "go.amzn.com/lambda/core/statejson" - "go.amzn.com/lambda/fatalerror" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/metering" - - "github.com/google/uuid" - log "github.com/sirupsen/logrus" -) - -const ( - autoresetReasonTimeout = "Timeout" - autoresetReasonReserveFail = "ReserveFail" - autoresetReasonReleaseFail = "ReleaseFail" - standaloneVersionID = "1" - - resetDefaultTimeoutMs = 2000 -) - -type rapidPhase int - -const ( - phaseIdle rapidPhase = iota - phaseInitializing - phaseInvoking -) - -type runtimeState int - -const ( - runtimeNotStarted = iota - - runtimeInitError - runtimeInitComplete - runtimeInitFailed - - runtimeInvokeResponseSent - runtimeInvokeError - runtimeReady - runtimeInvokeComplete -) - -type DoneWithState struct { - *interop.Done - State statejson.InternalStateDescription -} - -func (s *DoneWithState) String() string { - return fmt.Sprintf("%v %v", *s.Done, string(s.State.AsJSON())) -} - -type InvokeContext struct { - Token interop.Token - ReplySent bool - ReplyStream http.ResponseWriter - Direct bool -} - -type Server struct { - InternalStateGetter interop.InternalStateGetter - - initChanOut chan *interop.Init - interruptedResponseChan chan *interop.Reset - - sendResponseChan chan *interop.InvokeResponseMetrics - doneChan chan *interop.Done - - InitDoneChan chan DoneWithState - InvokeDoneChan chan DoneWithState - ResetDoneChan chan *interop.Done - ShutdownDoneChan chan *interop.Done - - mutex sync.Mutex - invokeCtx *InvokeContext - invokeTimeout time.Duration - - reservationContext context.Context - reservationCancel func() - - rapidPhase rapidPhase - runtimeState runtimeState - - sandboxContext interop.SandboxContext - initContext interop.InitContext - invoker interop.InvokeContext - initFailures chan interop.InitFailure - cachedInitErrorResponse *interop.ErrorInvokeResponse -} - -// Validate interface compliance -var _ interop.Server = (*Server)(nil) - -func (s *Server) setRapidPhase(phase rapidPhase) { - s.mutex.Lock() - defer s.mutex.Unlock() - - s.rapidPhase = phase -} - -func (s *Server) getRapidPhase() rapidPhase { - s.mutex.Lock() - defer s.mutex.Unlock() - - return s.rapidPhase -} - -func (s *Server) setRuntimeState(state runtimeState) { - s.mutex.Lock() - defer s.mutex.Unlock() - - s.runtimeState = state -} - -func (s *Server) getRuntimeState() runtimeState { - s.mutex.Lock() - defer s.mutex.Unlock() - - return s.runtimeState -} - -func (s *Server) SetInvokeTimeout(timeout time.Duration) { - s.mutex.Lock() - defer s.mutex.Unlock() - - s.invokeTimeout = timeout -} - -func (s *Server) GetInvokeTimeout() time.Duration { - s.mutex.Lock() - defer s.mutex.Unlock() - - return s.invokeTimeout -} - -func (s *Server) GetInvokeContext() *InvokeContext { - s.mutex.Lock() - defer s.mutex.Unlock() - - ctx := *s.invokeCtx - return &ctx -} - -func (s *Server) setNewInvokeContext(invokeID string, traceID, lambdaSegmentID string) (*ReserveResponse, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.invokeCtx != nil { - return nil, ErrAlreadyReserved - } - - s.invokeCtx = &InvokeContext{ - Token: interop.Token{ - ReservationToken: uuid.New().String(), - InvokeID: invokeID, - VersionID: standaloneVersionID, - FunctionTimeout: s.invokeTimeout, - TraceID: traceID, - LambdaSegmentID: lambdaSegmentID, - InvackDeadlineNs: math.MaxInt64, // no INVACK in standalone - }, - } - - resp := &ReserveResponse{ - Token: s.invokeCtx.Token, - } - - s.reservationContext, s.reservationCancel = context.WithCancel(context.Background()) - - return resp, nil -} - -type ReserveResponse struct { - Token interop.Token - InternalState *statejson.InternalStateDescription -} - -// Reserve allocates invoke context -func (s *Server) Reserve(id string, traceID, lambdaSegmentID string) (*ReserveResponse, error) { - invokeID := uuid.New().String() - if len(id) > 0 { - invokeID = id - } - resp, err := s.setNewInvokeContext(invokeID, traceID, lambdaSegmentID) - if err != nil { - return nil, err - } - - // The two errors reserve returns in standalone mode are INIT timeout - // and INIT failure (two types of failure: runtime exit, /init/error). Both require suppressed - // initialization, so we succeed the reservation. - invCtx := s.initContext.Reserve() - s.invoker = invCtx - resp.InternalState, err = s.InternalState() - - return resp, err -} - -func (s *Server) awaitInitCompletion() { - initSuccess, initFailure := s.initContext.Wait() - if initFailure != nil { - // In standalone, we don't have to block rapid start() goroutine until init failure is consumed - // because there is no channel back to the invoker until an invoke arrives via a Reserve() - initFailure.Ack <- struct{}{} - s.initFailures <- *initFailure - } else { - initSuccess.Ack <- struct{}{} - } - // always closing the channel makes this method idempotent - close(s.initFailures) -} - -func (s *Server) setReplyStream(w http.ResponseWriter, direct bool) (string, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.invokeCtx == nil { - return "", ErrNotReserved - } - - if s.invokeCtx.ReplySent { - return "", ErrAlreadyReplied - } - - if s.invokeCtx.ReplyStream != nil { - return "", ErrAlreadyInvocating - } - - s.invokeCtx.ReplyStream = w - s.invokeCtx.Direct = direct - return s.invokeCtx.Token.InvokeID, nil -} - -// Release closes the invocation, making server ready for reserve again -func (s *Server) Release() error { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.invokeCtx == nil { - return ErrNotReserved - } - - if s.reservationCancel != nil { - s.reservationCancel() - } - - s.sandboxContext.SetRuntimeStartedTime(-1) - s.sandboxContext.SetInvokeResponseMetrics(nil) - s.invokeCtx = nil - return nil -} - -// GetCurrentInvokeID -func (s *Server) GetCurrentInvokeID() string { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.invokeCtx == nil { - return "" - } - - return s.invokeCtx.Token.InvokeID -} - -// SetSandboxContext is used to set the sandbox context after intiialization of interop server. -// After refactoring all messages, this needs to be removed and made an struct parameter on initialization. -func (s *Server) SetSandboxContext(sbCtx interop.SandboxContext) { - s.sandboxContext = sbCtx -} - -// SetInternalStateGetter is used to set callback which returnes internal state for /test/internalState request -func (s *Server) SetInternalStateGetter(cb interop.InternalStateGetter) { - s.InternalStateGetter = cb -} - -func (s *Server) sendResponseUnsafe(invokeID string, additionalHeaders map[string]string, payload io.Reader, trailers http.Header, request *interop.CancellableRequest, runtimeCalledResponse bool) error { - if s.invokeCtx == nil || invokeID != s.invokeCtx.Token.InvokeID { - return interop.ErrInvalidInvokeID - } - - if s.invokeCtx.ReplySent { - return interop.ErrResponseSent - } - - if s.invokeCtx.ReplyStream == nil { - return fmt.Errorf("ReplyStream not available") - } - - var reportedErr error - if s.invokeCtx.Direct { - if err := directinvoke.SendDirectInvokeResponse(additionalHeaders, payload, trailers, s.invokeCtx.ReplyStream, s.interruptedResponseChan, s.sendResponseChan, request, runtimeCalledResponse, invokeID); err != nil { - // TODO: Do we need to drain the reader in case of a large payload and connection reuse? - log.Errorf("Failed to write response to %s: %s", invokeID, err) - reportedErr = err - } - } else { - data, err := io.ReadAll(payload) - if err != nil { - return fmt.Errorf("Failed to read response on %s: %s", invokeID, err) - } - if len(data) > interop.MaxPayloadSize { - return &interop.ErrorResponseTooLarge{ - ResponseSize: len(data), - MaxResponseSize: interop.MaxPayloadSize, - } - } - - startReadingResponseMonoTimeMs := metering.Monotime() - s.invokeCtx.ReplyStream.Header().Add(directinvoke.ContentTypeHeader, additionalHeaders[directinvoke.ContentTypeHeader]) - written, err := s.invokeCtx.ReplyStream.Write(data) - if err != nil { - return fmt.Errorf("Failed to write response to %s: %s", invokeID, err) - } - - s.sendResponseChan <- &interop.InvokeResponseMetrics{ - ProducedBytes: int64(written), - StartReadingResponseMonoTimeMs: startReadingResponseMonoTimeMs, - FinishReadingResponseMonoTimeMs: metering.Monotime(), - TimeShapedNs: int64(-1), - OutboundThroughputBps: int64(-1), - // FIXME: - // The runtime tells whether the function response mode is streaming or not. - // Ideally, we would want to use that value here. Since I'm just rebasing, I will leave - // as-is, but we should use that instead of relying on our memory to set this here - // because we "know" it's a streaming code path. - FunctionResponseMode: interop.FunctionResponseModeBuffered, - RuntimeCalledResponse: runtimeCalledResponse, - } - } - - s.invokeCtx.ReplySent = true - s.invokeCtx.Direct = false - return reportedErr -} - -func (s *Server) SendResponse(invokeID string, resp *interop.StreamableInvokeResponse) error { - s.setRuntimeState(runtimeInvokeResponseSent) - s.mutex.Lock() - defer s.mutex.Unlock() - runtimeCalledResponse := true - return s.sendResponseUnsafe(invokeID, resp.Headers, resp.Payload, resp.Trailers, resp.Request, runtimeCalledResponse) -} - -func (s *Server) SendInitErrorResponse(resp *interop.ErrorInvokeResponse) error { - log.Debugf("Sending Init Error Response: %s", resp.FunctionError.Type) - if s.getRapidPhase() == phaseInvoking { - // This branch occurs during suppressed init - return s.SendErrorResponse(s.GetCurrentInvokeID(), resp) - } - - // Handle an /init/error outside of the invoke phase - s.setCachedInitErrorResponse(resp) - s.setRuntimeState(runtimeInitError) - return nil -} - -func (s *Server) SendErrorResponse(invokeID string, resp *interop.ErrorInvokeResponse) error { - log.Debugf("Sending Error Response: %s", resp.FunctionError.Type) - s.setRuntimeState(runtimeInvokeError) - s.mutex.Lock() - defer s.mutex.Unlock() - additionalHeaders := map[string]string{ - directinvoke.ContentTypeHeader: resp.Headers.ContentType, - directinvoke.ErrorTypeHeader: string(resp.FunctionError.Type), - } - if functionResponseMode := resp.Headers.FunctionResponseMode; functionResponseMode != "" { - additionalHeaders[directinvoke.FunctionResponseModeHeader] = functionResponseMode - } - runtimeCalledResponse := false // we are sending an error here, so runtime called /error or crashed/timeout - return s.sendResponseUnsafe(invokeID, additionalHeaders, bytes.NewReader(resp.Payload), nil, nil, runtimeCalledResponse) -} - -func (s *Server) Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) { - // pass reset to rapid - reset := &interop.Reset{ - Reason: reason, - DeadlineNs: deadlineNsFromTimeoutMs(timeoutMs), - } - go func() { - select { - case s.interruptedResponseChan <- reset: - <-s.interruptedResponseChan // wait for response streaming metrics being added to reset struct - s.sandboxContext.SetInvokeResponseMetrics(reset.InvokeResponseMetrics) - default: - } - - resetSuccess, resetFailure := s.sandboxContext.Reset(reset) - s.Clear() // clear server state to prepare for new invokes - s.setRapidPhase(phaseIdle) - s.setRuntimeState(runtimeNotStarted) - - var meta interop.DoneMetadata - if reset.InvokeResponseMetrics != nil && interop.IsResponseStreamingMetrics(reset.InvokeResponseMetrics) { - meta.RuntimeTimeThrottledMs = reset.InvokeResponseMetrics.TimeShapedNs / int64(time.Millisecond) - meta.RuntimeProducedBytes = reset.InvokeResponseMetrics.ProducedBytes - meta.RuntimeOutboundThroughputBps = reset.InvokeResponseMetrics.OutboundThroughputBps - meta.MetricsDimensions = interop.DoneMetadataMetricsDimensions{ - InvokeResponseMode: reset.InvokeResponseMode, - } - - // These metrics aren't present in reset struct, therefore we need to get - // them from s.sandboxContext.Reset() response - if resetFailure != nil { - meta.RuntimeResponseLatencyMs = resetFailure.ResponseMetrics.RuntimeResponseLatencyMs - } else { - meta.RuntimeResponseLatencyMs = resetSuccess.ResponseMetrics.RuntimeResponseLatencyMs - } - } - - if resetFailure != nil { - meta.ExtensionsResetMs = resetFailure.ExtensionsResetMs - s.ResetDoneChan <- &interop.Done{ErrorType: resetFailure.ErrorType, Meta: meta} - } else { - meta.ExtensionsResetMs = resetSuccess.ExtensionsResetMs - s.ResetDoneChan <- &interop.Done{ErrorType: resetSuccess.ErrorType, Meta: meta} - } - }() - - done := <-s.ResetDoneChan - s.Release() - - if done.ErrorType != "" { - return nil, errors.New(string(done.ErrorType)) - } - - return &statejson.ResetDescription{ - ExtensionsResetMs: done.Meta.ExtensionsResetMs, - ResponseMetrics: statejson.ResponseMetrics{ - RuntimeResponseLatencyMs: done.Meta.RuntimeResponseLatencyMs, - Dimensions: statejson.ResponseMetricsDimensions{ - InvokeResponseMode: statejson.InvokeResponseMode( - done.Meta.MetricsDimensions.InvokeResponseMode, - ), - }, - }, - }, nil -} - -func NewServer() *Server { - s := &Server{ - initChanOut: make(chan *interop.Init), - interruptedResponseChan: make(chan *interop.Reset), - - sendResponseChan: make(chan *interop.InvokeResponseMetrics), - doneChan: make(chan *interop.Done), - - // These two channels are buffered, because they are depleted asynchronously (by reserve and waitUntilRelease) and we don't want to block in SendDone until they are called - InitDoneChan: make(chan DoneWithState, 1), - InvokeDoneChan: make(chan DoneWithState, 1), - - ResetDoneChan: make(chan *interop.Done), - ShutdownDoneChan: make(chan *interop.Done), - } - - return s -} - -func drainChannel(c chan DoneWithState) { - for { - select { - case dws := <-c: - log.Warnf("Discard DONE response: %s", dws.String()) - break - default: - return - } - } -} - -func (s *Server) Clear() { - // we do not drain InitDoneChannel, because Init is only done once during rapid lifetime - - drainChannel(s.InvokeDoneChan) - s.Release() -} - -func (s *Server) SendRuntimeReady() error { - // only called when extensions are enabled - s.setRuntimeState(runtimeReady) - return nil -} - -func deadlineNsFromTimeoutMs(timeoutMs int64) int64 { - mono := metering.Monotime() - return mono + timeoutMs*1000*1000 -} - -func (s *Server) setInitFailuresChan() { - s.mutex.Lock() - defer s.mutex.Unlock() - s.initFailures = make(chan interop.InitFailure) -} - -func (s *Server) getInitFailuresChan() chan interop.InitFailure { - s.mutex.Lock() - defer s.mutex.Unlock() - return s.initFailures -} - -func (s *Server) Init(i *interop.Init, invokeTimeoutMs int64) error { - s.SetInvokeTimeout(time.Duration(invokeTimeoutMs) * time.Millisecond) - s.setRapidPhase(phaseInitializing) - s.setInitFailuresChan() - initCtx := s.sandboxContext.Init(i, invokeTimeoutMs) - - s.initContext = initCtx - go s.awaitInitCompletion() - - return nil -} - -func (s *Server) FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error { - invokeID, err := s.setReplyStream(w, direct) - if err != nil { - return err - } - - s.setRapidPhase(phaseInvoking) - - i.ID = invokeID - - select { - case <-s.sendResponseChan: - // we didn't pass invoke to rapid yet, but rapid already has written some response - // It can happend if runtime/agent crashed even before we passed invoke to it - return ErrInvokeResponseAlreadyWritten - default: - } - - go func() { - if s.invoker == nil { - // Reset occurred, do not send invoke request - s.InvokeDoneChan <- DoneWithState{State: s.InternalStateGetter()} - s.setRuntimeState(runtimeInvokeComplete) - return - } - s.invoker.SendRequest(i, s) - invokeSuccess, invokeFailure := s.invoker.Wait() - if invokeFailure != nil { - if invokeFailure.ResetReceived { - return - } - - // Rapid constructs a response body itself when invoke fails, with error type. - // These are on the handleInvokeError path, may occur during timeout resets, - // failure reset (proc exit). It is expected to be non-nil on all invoke failures. - if invokeFailure.DefaultErrorResponse == nil { - log.Panicf("default error response was nil for invoke failure, %v", invokeFailure) - } - - if cachedInitError := s.getCachedInitErrorResponse(); cachedInitError != nil { - // /init/error was called - s.trySendDefaultErrorResponse(cachedInitError) - } else { - // sent only if /error and /response not called - s.trySendDefaultErrorResponse(invokeFailure.DefaultErrorResponse) - } - doneFail := doneFailFromInvokeFailure(invokeFailure) - s.InvokeDoneChan <- DoneWithState{ - Done: &interop.Done{ErrorType: doneFail.ErrorType, Meta: doneFail.Meta}, - State: s.InternalStateGetter(), - } - } else { - done := doneFromInvokeSuccess(invokeSuccess) - s.InvokeDoneChan <- DoneWithState{Done: done, State: s.InternalStateGetter()} - } - }() - - select { - case i.InvokeResponseMetrics = <-s.sendResponseChan: - s.sandboxContext.SetInvokeResponseMetrics(i.InvokeResponseMetrics) - break - case <-s.reservationContext.Done(): - return ErrInvokeReservationDone - } - - return nil -} - -func (s *Server) setCachedInitErrorResponse(errResp *interop.ErrorInvokeResponse) { - s.mutex.Lock() - defer s.mutex.Unlock() - s.cachedInitErrorResponse = errResp -} - -func (s *Server) getCachedInitErrorResponse() *interop.ErrorInvokeResponse { - s.mutex.Lock() - defer s.mutex.Unlock() - return s.cachedInitErrorResponse -} - -func (s *Server) trySendDefaultErrorResponse(resp *interop.ErrorInvokeResponse) { - if err := s.SendErrorResponse(s.GetCurrentInvokeID(), resp); err != nil { - if err != interop.ErrResponseSent { - log.Panicf("Failed to send default error response: %s", err) - } - } -} - -func (s *Server) CurrentToken() *interop.Token { - s.mutex.Lock() - defer s.mutex.Unlock() - if s.invokeCtx == nil { - return nil - } - tok := s.invokeCtx.Token - return &tok -} - -// Invoke is used by the Runtime Interface Emulator (Rapid Local) -// https://github.com/aws/aws-lambda-runtime-interface-emulator -func (s *Server) Invoke(responseWriter http.ResponseWriter, invoke *interop.Invoke) error { - resetCtx, resetCancel := context.WithCancel(context.Background()) - defer resetCancel() - - timeoutChan := make(chan error) - go func() { - select { - case <-time.After(s.GetInvokeTimeout()): - log.Debug("Invoke() timeout") - timeoutChan <- ErrInvokeTimeout - case <-resetCtx.Done(): - log.Debugf("execute finished, autoreset cancelled") - } - }() - - initFailures := s.getInitFailuresChan() - if initFailures == nil { - return ErrInitNotStarted - } - - releaseErrChan := make(chan error) - releaseSuccessChan := make(chan struct{}) - go func() { - // This thread can block in one of two method calls Reserve() & AwaitRelease(), - // corresponding to Init and Invoke phase. - // FastInvoke is intended to be 'async' response stream copying. - // When a timeout occurs, we send a 'Reset' with the timeout reason - // When a Reset is sent, the reset handler in rapid lib cancels existing flows, - // including init/invoke. This causes either initFailure/invokeFailure, and then - // the Reset is handled and processed. - // TODO: however, ideally Reserve() does not block on init, but FastInvoke does - // The logic would be almost identical, except that init failures could manifest - // through return values of FastInvoke and not Reserve() - - reserveResp, err := s.Reserve(invoke.ID, invoke.TraceID, invoke.LambdaSegmentID) - if err != nil { - log.Infof("ReserveFailed: %s", err) - } - - invoke.DeadlineNs = fmt.Sprintf("%d", metering.Monotime()+reserveResp.Token.FunctionTimeout.Nanoseconds()) - go func() { - if initCompletionResp, err := s.awaitInitialized(); err != nil { - switch err { - case ErrInitResetReceived, ErrInitDoneFailed: - // For init failures, cache the response so they can be checked later - // We check if they have not already been set by a call to /init/error by runtime - if s.getCachedInitErrorResponse() == nil { - errType, errMsg := initCompletionResp.InitErrorType, initCompletionResp.InitErrorMessage.Error() - headers := interop.InvokeResponseHeaders{} - fnError := interop.FunctionError{Type: errType, Message: errMsg} - s.setCachedInitErrorResponse(&interop.ErrorInvokeResponse{Headers: headers, FunctionError: fnError, Payload: []byte{}}) - } - - // Init failed, so we explicitly shutdown runtime (cleanup unused extensions). - // Because following fast invoke will start new (supressed) Init phase without reset call - s.Shutdown(&interop.Shutdown{DeadlineNs: metering.Monotime() + int64(resetDefaultTimeoutMs*1000*1000)}) - } - } - - if err := s.FastInvoke(responseWriter, invoke, false); err != nil { - log.Debugf("FastInvoke() error: %s", err) - } - }() - - _, err = s.AwaitRelease() - if err != nil && err != ErrReleaseReservationDone { - log.Debugf("AwaitRelease() error: %s", err) - switch err { - case ErrReleaseReservationDone: // not an error, expected return value when Reset is called - if s.getCachedInitErrorResponse() != nil { - // For Init failures, AwaitRelease returns ErrReleaseReservationDone - // because the Reset calls Release & cancels the release context - // We rename the error to ErrInitDoneFailed - releaseErrChan <- ErrInitDoneFailed - } - case ErrInitDoneFailed, ErrInvokeDoneFailed: - // Reset when either init or invoke failrues occur, i.e. - // init/error, invocation/error, Runtime.ExitError, Extension.ExitError - s.Reset(autoresetReasonReleaseFail, resetDefaultTimeoutMs) - releaseErrChan <- err - default: - releaseErrChan <- err - } - return - } - - releaseSuccessChan <- struct{}{} - }() - - var err error - select { - case timeoutErr := <-timeoutChan: - s.Reset(autoresetReasonTimeout, resetDefaultTimeoutMs) - select { - case releaseErr := <-releaseErrChan: // when AwaitRelease() has errors - log.Debugf("Invoke() release error on Execute() timeout: %s", releaseErr) - case <-releaseSuccessChan: // when AwaitRelease() finishes cleanly - } - err = timeoutErr - case err = <-releaseErrChan: - log.Debug("Invoke() release error") - case <-releaseSuccessChan: - s.Release() - log.Debug("Invoke() success") - } - - return err -} - -type initCompletionResponse struct { - InitErrorType fatalerror.ErrorType - InitErrorMessage error -} - -func (s *Server) awaitInitialized() (initCompletionResponse, error) { - initFailure, awaitingInitStatus := <-s.getInitFailuresChan() - resp := initCompletionResponse{} - - if initFailure.ResetReceived { - // Resets during Init are only received in standalone - // during an invoke timeout - s.setRuntimeState(runtimeInitFailed) - resp.InitErrorType = initFailure.ErrorType - resp.InitErrorMessage = initFailure.ErrorMessage - return resp, ErrInitResetReceived - } - - if awaitingInitStatus { - // channel not closed, received init failure - // Sandbox can be reserved even if init failed (due to function errors) - s.setRuntimeState(runtimeInitFailed) - resp.InitErrorType = initFailure.ErrorType - resp.InitErrorMessage = initFailure.ErrorMessage - return resp, ErrInitDoneFailed - } - - // not awaiting init status (channel closed) - return resp, nil -} - -// AwaitInitialized waits until init is complete. It must be idempotent, -// since it can be called twice when a caller wants to wait until init is complete -func (s *Server) AwaitInitialized() error { - if _, err := s.awaitInitialized(); err != nil { - if releaseErr := s.Release(); err != nil { - log.Infof("Error releasing after init failure %s: %s", err, releaseErr) - } - s.setRuntimeState(runtimeInitFailed) - return err - } - s.setRuntimeState(runtimeInitComplete) - return nil -} - -func (s *Server) AwaitRelease() (*statejson.ReleaseResponse, error) { - defer func() { - s.setRapidPhase(phaseIdle) - s.setRuntimeState(runtimeInvokeComplete) - }() - - select { - case doneWithState := <-s.InvokeDoneChan: - if len(doneWithState.ErrorType) > 0 && string(doneWithState.ErrorType) == ErrInitDoneFailed.Error() { - return nil, ErrInitDoneFailed - } - - if len(doneWithState.ErrorType) > 0 { - log.Errorf("Invoke DONE failed: %s", doneWithState.ErrorType) - return nil, ErrInvokeDoneFailed - } - - releaseResponse := statejson.ReleaseResponse{ - InternalStateDescription: &doneWithState.State, - ResponseMetrics: statejson.ResponseMetrics{ - RuntimeResponseLatencyMs: doneWithState.Meta.RuntimeResponseLatencyMs, - Dimensions: statejson.ResponseMetricsDimensions{ - InvokeResponseMode: statejson.InvokeResponseMode( - doneWithState.Meta.MetricsDimensions.InvokeResponseMode, - ), - }, - }, - } - - s.Release() - return &releaseResponse, nil - - case <-s.reservationContext.Done(): - return nil, ErrReleaseReservationDone - } -} - -func (s *Server) Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription { - shutdownSuccess := s.sandboxContext.Shutdown(shutdown) - if len(shutdownSuccess.ErrorType) > 0 { - log.Errorf("Shutdown first fatal error: %s", shutdownSuccess.ErrorType) - } - - s.setRapidPhase(phaseIdle) - s.setRuntimeState(runtimeNotStarted) - - state := s.InternalStateGetter() - return &state -} - -func (s *Server) InternalState() (*statejson.InternalStateDescription, error) { - if s.InternalStateGetter == nil { - return nil, errors.New("InternalStateGetterNotSet") - } - - state := s.InternalStateGetter() - return &state, nil -} - -func (s *Server) Restore(restore *interop.Restore) (interop.RestoreResult, error) { - return s.sandboxContext.Restore(restore) -} - -func doneFromInvokeSuccess(successMsg interop.InvokeSuccess) *interop.Done { - return &interop.Done{ - Meta: interop.DoneMetadata{ - RuntimeRelease: successMsg.RuntimeRelease, - NumActiveExtensions: successMsg.NumActiveExtensions, - ExtensionNames: successMsg.ExtensionNames, - InvokeRequestReadTimeNs: successMsg.InvokeMetrics.InvokeRequestReadTimeNs, - InvokeRequestSizeBytes: successMsg.InvokeMetrics.InvokeRequestSizeBytes, - RuntimeReadyTime: successMsg.InvokeMetrics.RuntimeReadyTime, - - InvokeCompletionTimeNs: successMsg.InvokeCompletionTimeNs, - InvokeReceivedTime: successMsg.InvokeReceivedTime, - RuntimeResponseLatencyMs: successMsg.ResponseMetrics.RuntimeResponseLatencyMs, - RuntimeTimeThrottledMs: successMsg.ResponseMetrics.RuntimeTimeThrottledMs, - RuntimeProducedBytes: successMsg.ResponseMetrics.RuntimeProducedBytes, - RuntimeOutboundThroughputBps: successMsg.ResponseMetrics.RuntimeOutboundThroughputBps, - LogsAPIMetrics: successMsg.LogsAPIMetrics, - MetricsDimensions: interop.DoneMetadataMetricsDimensions{ - InvokeResponseMode: successMsg.InvokeResponseMode, - }, - }, - } -} - -func doneFailFromInvokeFailure(failureMsg *interop.InvokeFailure) *interop.DoneFail { - return &interop.DoneFail{ - ErrorType: failureMsg.ErrorType, - Meta: interop.DoneMetadata{ - RuntimeRelease: failureMsg.RuntimeRelease, - NumActiveExtensions: failureMsg.NumActiveExtensions, - InvokeReceivedTime: failureMsg.InvokeReceivedTime, - - RuntimeResponseLatencyMs: failureMsg.ResponseMetrics.RuntimeResponseLatencyMs, - RuntimeTimeThrottledMs: failureMsg.ResponseMetrics.RuntimeTimeThrottledMs, - RuntimeProducedBytes: failureMsg.ResponseMetrics.RuntimeProducedBytes, - RuntimeOutboundThroughputBps: failureMsg.ResponseMetrics.RuntimeOutboundThroughputBps, - - InvokeRequestReadTimeNs: failureMsg.InvokeMetrics.InvokeRequestReadTimeNs, - InvokeRequestSizeBytes: failureMsg.InvokeMetrics.InvokeRequestSizeBytes, - RuntimeReadyTime: failureMsg.InvokeMetrics.RuntimeReadyTime, - - ExtensionNames: failureMsg.ExtensionNames, - LogsAPIMetrics: failureMsg.LogsAPIMetrics, - - MetricsDimensions: interop.DoneMetadataMetricsDimensions{ - InvokeResponseMode: failureMsg.InvokeResponseMode, - }, - }, - } -} diff --git a/lambda/rapidcore/server_test.go b/lambda/rapidcore/server_test.go deleted file mode 100644 index 68ac30c..0000000 --- a/lambda/rapidcore/server_test.go +++ /dev/null @@ -1,544 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package rapidcore - -import ( - "bytes" - "context" - "errors" - "fmt" - "net/http/httptest" - "testing" - "time" - - "github.com/stretchr/testify/require" - "go.amzn.com/lambda/core/statejson" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapidcore/env" -) - -func waitForChanWithTimeout(channel <-chan error, timeout time.Duration) error { - select { - case err := <-channel: - return err - case <-time.After(timeout): - return nil - } -} - -func sendInitSuccessResponse(responseChannel chan<- interop.InitSuccess, msg interop.InitSuccess) { - msg.Ack = make(chan struct{}) - responseChannel <- msg - <-msg.Ack -} - -func sendInitFailureResponse(responseChannel chan<- interop.InitFailure, msg interop.InitFailure) { - msg.Ack = make(chan struct{}) - responseChannel <- msg - <-msg.Ack -} - -type mockRapidCtx struct { - initHandler func(success chan<- interop.InitSuccess, fail chan<- interop.InitFailure) - invokeHandler func() (interop.InvokeSuccess, *interop.InvokeFailure) - resetHandler func() (interop.ResetSuccess, *interop.ResetFailure) -} - -func (r *mockRapidCtx) HandleInit(init *interop.Init, successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - r.initHandler(successResp, failureResp) -} - -func (r *mockRapidCtx) HandleInvoke(invoke *interop.Invoke, sbInfoFromInit interop.SandboxInfoFromInit, buf *bytes.Buffer, responseSender interop.InvokeResponseSender) (interop.InvokeSuccess, *interop.InvokeFailure) { - return r.invokeHandler() -} - -func (r *mockRapidCtx) HandleReset(reset *interop.Reset) (interop.ResetSuccess, *interop.ResetFailure) { - return r.resetHandler() -} - -func (r *mockRapidCtx) HandleShutdown(shutdown *interop.Shutdown) interop.ShutdownSuccess { - return interop.ShutdownSuccess{} -} - -func (r *mockRapidCtx) HandleRestore(restore *interop.Restore) (interop.RestoreResult, error) { - return interop.RestoreResult{}, nil -} - -func (r *mockRapidCtx) Clear() {} - -func (r *mockRapidCtx) SetRuntimeStartedTime(a int64) { -} - -func (r *mockRapidCtx) SetInvokeResponseMetrics(a *interop.InvokeResponseMetrics) { -} - -func (r *mockRapidCtx) SetEventsAPI(e interop.EventsAPI) { -} - -func TestReserveDoesNotDeadlockWhenCalledMultipleTimes(t *testing.T) { - srv := NewServer() - srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - - initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitSuccessResponse(successResp, interop.InitSuccess{}) - } - srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{ - initHandler, - func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, - func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, - }, "handler", "runtimeAPIhost:999"}) - - srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) - - _, err := srv.Reserve("", "", "") // reserve successfully - require.NoError(t, err) - - resp, err := srv.Reserve("", "", "") // attempt double reservation - require.Nil(t, resp) - require.Equal(t, ErrAlreadyReserved, err) - - successChan := make(chan error) - go func() { - resp, err := srv.Reserve("", "", "") - require.Nil(t, resp) - require.Equal(t, ErrAlreadyReserved, err) - successChan <- nil - }() - - select { - case <-time.After(1 * time.Second): - require.Fail(t, "Timed out while waiting for Reserve() response") - case <-successChan: - } -} - -func TestInitSuccess(t *testing.T) { - srv := NewServer() - srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - - initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitSuccessResponse(successResp, interop.InitSuccess{}) - } - srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{ - initHandler, - func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, - func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, - }, "handler", "runtimeAPIhost:999"}) - - srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) - require.Equal(t, phaseInitializing, srv.getRapidPhase()) - - _, err := srv.Reserve("", "", "") - require.NoError(t, err) -} - -func TestInitErrorBeforeReserve(t *testing.T) { - // Rapid thread sending init failure should not be blocked even if reserve hasn't arrived - srv := NewServer() - srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - - initErrorResponseSent := make(chan error) - initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - require.NoError(t, srv.SendInitErrorResponse(&interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - sendInitFailureResponse(failureResp, interop.InitFailure{}) - initErrorResponseSent <- errors.New("initErrorResponseSent") - } - srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{ - initHandler, - func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, - func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, - }, "handler", "runtimeAPIhost:999"}) - - srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) - - if msg := waitForChanWithTimeout(initErrorResponseSent, 1*time.Second); msg == nil { - require.Fail(t, "Timed out waiting for init error response to be sent") - } - - resp, err := srv.Reserve("", "", "") - require.NoError(t, err) - require.True(t, len(resp.Token.InvokeID) > 0) - - awaitInitErr := srv.AwaitInitialized() - require.Error(t, ErrInitDoneFailed, awaitInitErr) - - _, err = srv.AwaitRelease() - require.Error(t, err, ErrReleaseReservationDone) - require.Equal(t, runtimeState(runtimeInvokeComplete), srv.getRuntimeState()) -} - -func TestInitErrorDuringReserve(t *testing.T) { - srv := NewServer() - srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - - initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - require.NoError(t, srv.SendInitErrorResponse(&interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - sendInitFailureResponse(failureResp, interop.InitFailure{}) - } - srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{ - initHandler, - func() (interop.InvokeSuccess, *interop.InvokeFailure) { return interop.InvokeSuccess{}, nil }, - func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil }, - }, "handler", "runtimeAPIhost:999"}) - - srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) - resp, err := srv.Reserve("", "", "") - require.NoError(t, err) - require.True(t, len(resp.Token.InvokeID) > 0) - - awaitInitErr := srv.AwaitInitialized() - require.Error(t, ErrInitDoneFailed, awaitInitErr) - - _, err = srv.AwaitRelease() - require.Error(t, err, ErrReleaseReservationDone) - require.Equal(t, runtimeState(runtimeInvokeComplete), srv.getRuntimeState()) -} - -func TestInvokeSuccess(t *testing.T) { - srv := NewServer() - srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - - releaseRuntimeInit := make(chan struct{}) - initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - <-releaseRuntimeInit - sendInitSuccessResponse(successResp, interop.InitSuccess{}) - } - - invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { - response := &interop.StreamableInvokeResponse{Headers: map[string]string{"Content-Type": "application/json"}, Payload: bytes.NewReader([]byte("response"))} - require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), response)) - require.NoError(t, srv.SendRuntimeReady()) - return interop.InvokeSuccess{}, nil - } - - resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { return interop.ResetSuccess{}, nil } - srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999"}) - - srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) - require.Equal(t, phaseInitializing, srv.getRapidPhase()) - releaseRuntimeInit <- struct{}{} - - _, err := srv.Reserve("", "", "") - require.NoError(t, err) - require.Equal(t, phaseInitializing, srv.getRapidPhase()) // Reserve does not wait for init completion - - awaitInitErr := srv.AwaitInitialized() - require.NoError(t, awaitInitErr) - - responseRecorder := httptest.NewRecorder() - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, false) - require.NoError(t, invokeErr) - require.Equal(t, "response", responseRecorder.Body.String()) - require.Equal(t, "application/json", responseRecorder.Result().Header.Get("Content-Type")) - - _, err = srv.AwaitRelease() - require.NoError(t, err) - require.Equal(t, runtimeState(runtimeInvokeComplete), srv.getRuntimeState()) -} - -func TestInvokeError(t *testing.T) { - srv := NewServer() - srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - - initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitSuccessResponse(successResp, interop.InitSuccess{}) - } - - invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { - headers := interop.InvokeResponseHeaders{ContentType: "application/json"} - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'A.B' }"), Headers: headers})) - require.NoError(t, srv.SendRuntimeReady()) - return interop.InvokeSuccess{}, nil - } - - resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { - return interop.ResetSuccess{}, nil - } - - srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999"}) - - srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) - require.Equal(t, phaseInitializing, srv.getRapidPhase()) - - _, err := srv.Reserve("", "", "") - require.NoError(t, err) - require.Equal(t, phaseInitializing, srv.getRapidPhase()) - - awaitInitErr := srv.AwaitInitialized() - require.NoError(t, awaitInitErr) - - responseRecorder := httptest.NewRecorder() - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, false) - require.NoError(t, invokeErr) - require.Equal(t, "{ 'errorType': 'A.B' }", responseRecorder.Body.String()) - require.Equal(t, "application/json", responseRecorder.Result().Header.Get("Content-Type")) - - _, err = srv.AwaitRelease() - require.NoError(t, err) - require.Equal(t, runtimeState(runtimeInvokeComplete), srv.getRuntimeState()) -} - -func TestInvokeWithSuppressedInitSuccess(t *testing.T) { - // Tests an init/error followed by suppressed init: - // Runtime may have called init/error before Reserve, in which case we - // expect a suppressed init, i.e. init during the invoke. - // The first Reserve() after init/error returns ErrInitError because - // SendDoneFail was called on init/error. - // We expect the caller to then call Reset() to prepare for suppressed init, - // followed by Reserve() so that a valid reservation context is available. - // Reserve() returns ErrInitAlreadyDone, since the server implementation - // closes the InitDone channel after the first InitDone message. - - srv := NewServer() - srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - - initErrorCompleted := make(chan error) - initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - require.NoError(t, srv.SendInitErrorResponse(&interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - sendInitFailureResponse(failureResp, interop.InitFailure{}) - initErrorCompleted <- errors.New("initErrorSequenceCompleted") - } - - invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { - response := &interop.StreamableInvokeResponse{Payload: bytes.NewReader([]byte("response"))} - require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), response)) - return interop.InvokeSuccess{}, nil - } - - resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { - return interop.ResetSuccess{}, nil - } - - srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999"}) - - srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) - require.Equal(t, phaseInitializing, srv.getRapidPhase()) - - if msg := waitForChanWithTimeout(initErrorCompleted, 1*time.Second); msg == nil { - require.Fail(t, "Timed out waiting for init error sequence to be called") - } - - resp, err := srv.Reserve("", "", "") - require.NoError(t, err) - require.True(t, len(resp.Token.InvokeID) > 0) - - awaitInitErr := srv.AwaitInitialized() - require.Error(t, ErrInitDoneFailed, awaitInitErr) - - _, err = srv.Reset(autoresetReasonReserveFail, resetDefaultTimeoutMs) // prepare for suppressed init - require.NoError(t, err) - - _, err = srv.Reserve("", "", "") - require.NoError(t, err) - - responseRecorder := httptest.NewRecorder() - successChan := make(chan error) - go func() { - directInvoke := false - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, directInvoke) - require.NoError(t, invokeErr) - successChan <- errors.New("invokeResponseWritten") - }() - - invokeErr := waitForChanWithTimeout(successChan, 1*time.Second) - if invokeErr == nil { - require.Fail(t, "Timed out while waiting for invoke response") - } - - require.Equal(t, "response", responseRecorder.Body.String()) - - _, err = srv.AwaitRelease() - require.NoError(t, err) - require.Equal(t, runtimeState(runtimeInvokeComplete), srv.getRuntimeState()) -} - -func TestInvokeWithSuppressedInitErrorDueToInitError(t *testing.T) { - // Tests init/error followed by init/error during suppressed init - srv := NewServer() - srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - - initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - require.NoError(t, srv.SendInitErrorResponse(&interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - sendInitFailureResponse(failureResp, interop.InitFailure{}) - } - - releaseChan := make(chan error) - invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { - require.NoError(t, srv.SendErrorResponse(srv.GetCurrentInvokeID(), &interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - releaseChan <- nil - return interop.InvokeSuccess{}, &interop.InvokeFailure{ErrorType: "A.B", RequestReset: true, DefaultErrorResponse: &interop.ErrorInvokeResponse{}} - } - - resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { - return interop.ResetSuccess{}, nil - } - - srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999"}) - - srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) - - resp, err := srv.Reserve("", "", "") - require.NoError(t, err) - require.True(t, len(resp.Token.InvokeID) > 0) - require.Equal(t, phaseInitializing, srv.getRapidPhase()) - - awaitInitErr := srv.AwaitInitialized() - require.Error(t, ErrInitDoneFailed, awaitInitErr) - - _, err = srv.Reset(autoresetReasonReserveFail, resetDefaultTimeoutMs) // prepare for invoke with suppressed init - require.NoError(t, err) - require.Equal(t, phaseIdle, srv.getRapidPhase()) - - _, err = srv.Reserve("", "", "") - require.NoError(t, err) - require.Equal(t, phaseIdle, srv.getRapidPhase()) - - responseRecorder := httptest.NewRecorder() - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, false) - require.NoError(t, invokeErr) - require.Equal(t, "{ 'errorType': 'A.B' }", responseRecorder.Body.String()) - require.Equal(t, phaseInvoking, srv.getRapidPhase()) - - <-releaseChan // Unblock gorotune to send donefail - _, err = srv.AwaitRelease() - require.EqualError(t, err, ErrInvokeDoneFailed.Error()) - require.Equal(t, phaseIdle, srv.getRapidPhase()) - require.Equal(t, runtimeState(runtimeInvokeComplete), srv.getRuntimeState()) -} - -func TestInvokeWithSuppressedInitErrorDueToInvokeError(t *testing.T) { - // Tests init/error followed by init/error during suppressed init - srv := NewServer() - srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - - initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - require.NoError(t, srv.SendInitErrorResponse(&interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'A.B' }")})) - sendInitFailureResponse(failureResp, interop.InitFailure{}) - } - invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { - require.NoError(t, srv.SendInitErrorResponse(&interop.ErrorInvokeResponse{Payload: []byte("{ 'errorType': 'B.C' }")})) - require.NoError(t, srv.SendRuntimeReady()) - return interop.InvokeSuccess{}, nil - } - - resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { - return interop.ResetSuccess{}, nil - } - - srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999"}) - - srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) - require.Equal(t, phaseInitializing, srv.getRapidPhase()) - - resp, err := srv.Reserve("", "", "") - require.NoError(t, err) - require.True(t, len(resp.Token.InvokeID) > 0) - - awaitInitErr := srv.AwaitInitialized() - require.Error(t, ErrInitDoneFailed, awaitInitErr) - - _, err = srv.Reset(autoresetReasonReserveFail, resetDefaultTimeoutMs) // prepare for invoke with suppressed init - require.NoError(t, err) - require.Equal(t, phaseIdle, srv.getRapidPhase()) - - _, err = srv.Reserve("", "", "") - require.NoError(t, err) - require.Equal(t, phaseIdle, srv.getRapidPhase()) - - responseRecorder := httptest.NewRecorder() - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, false) - require.NoError(t, invokeErr) - require.Equal(t, "{ 'errorType': 'B.C' }", responseRecorder.Body.String()) - - _, err = srv.AwaitRelease() - require.NoError(t, err) // /invocation/error -> /invocation/next returns no error / donefail - require.Equal(t, phaseIdle, srv.getRapidPhase()) - require.Equal(t, runtimeState(runtimeInvokeComplete), srv.getRuntimeState()) -} - -func TestMultipleInvokeSuccess(t *testing.T) { - srv := NewServer() - srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return statejson.InternalStateDescription{} }) - - initHandler := func(successResp chan<- interop.InitSuccess, failureResp chan<- interop.InitFailure) { - sendInitSuccessResponse(successResp, interop.InitSuccess{}) - } - i := 0 - invokeHandler := func() (interop.InvokeSuccess, *interop.InvokeFailure) { - response := &interop.StreamableInvokeResponse{Payload: bytes.NewReader([]byte("response-" + fmt.Sprint(i)))} - require.NoError(t, srv.SendResponse(srv.GetCurrentInvokeID(), response)) - require.NoError(t, srv.SendRuntimeReady()) - i++ - return interop.InvokeSuccess{}, nil - } - - resetHandler := func() (interop.ResetSuccess, *interop.ResetFailure) { - return interop.ResetSuccess{}, nil - } - - srv.SetSandboxContext(&SandboxContext{&mockRapidCtx{initHandler, invokeHandler, resetHandler}, "handler", "runtimeAPIhost:999"}) - - srv.Init(&interop.Init{EnvironmentVariables: env.NewEnvironment()}, int64(1*time.Second*time.Millisecond)) - require.Equal(t, phaseInitializing, srv.getRapidPhase()) - - for i := 0; i < 3; i++ { - _, err := srv.Reserve("", "", "") - require.NoError(t, err) - - awaitInitErr := srv.AwaitInitialized() - require.NoError(t, awaitInitErr) - - responseRecorder := httptest.NewRecorder() - invokeErr := srv.FastInvoke(responseRecorder, &interop.Invoke{}, false) - require.NoError(t, invokeErr) - require.Equal(t, "response-"+fmt.Sprint(i), responseRecorder.Body.String()) - require.Equal(t, phaseInvoking, srv.getRapidPhase()) - - _, err = srv.AwaitRelease() - require.NoError(t, err) - require.Equal(t, phaseIdle, srv.getRapidPhase()) - require.Equal(t, runtimeState(runtimeInvokeComplete), srv.getRuntimeState()) - } -} - -func TestAwaitReleaseOnSuccess(t *testing.T) { - srv := NewServer() - - // mocks - internalStateDescription := statejson.InternalStateDescription{} - srv.SetInternalStateGetter(func() statejson.InternalStateDescription { return internalStateDescription }) - doneWithState := DoneWithState{ - State: internalStateDescription, - Done: &interop.Done{ - Meta: interop.DoneMetadata{ - RuntimeResponseLatencyMs: 12345, - MetricsDimensions: interop.DoneMetadataMetricsDimensions{ - InvokeResponseMode: interop.InvokeResponseModeStreaming, - }, - }, - }, - } - srv.InvokeDoneChan <- doneWithState - srv.reservationContext, srv.reservationCancel = context.WithCancel(context.Background()) - - // under test - responseAwaitRelease, err := srv.AwaitRelease() - - // assertions - require.NoError(t, err) - require.Equal(t, doneWithState.Done.Meta.RuntimeResponseLatencyMs, responseAwaitRelease.ResponseMetrics.RuntimeResponseLatencyMs) - require.Equal(t, string(doneWithState.Done.Meta.MetricsDimensions.InvokeResponseMode), string(responseAwaitRelease.ResponseMetrics.Dimensions.InvokeResponseMode)) - require.Equal(t, &doneWithState.State, responseAwaitRelease.InternalStateDescription) -} - -/* Unit tests remaining: -- Shutdown behaviour -- Reset behaviour during various phases -- Runtime / extensions process exit sequences -- Invoke() and Init() api tests -- How can we add handleRestore test here? - -See PlantUML state diagram for potential other uncovered paths -through the state machine -*/ diff --git a/lambda/rapidcore/standalone/directInvokeHandler.go b/lambda/rapidcore/standalone/directInvokeHandler.go deleted file mode 100644 index 1c7e7cb..0000000 --- a/lambda/rapidcore/standalone/directInvokeHandler.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package standalone - -import ( - "go.amzn.com/lambda/rapidcore" - - "net/http" - - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/core/directinvoke" -) - -func DirectInvokeHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { - tok := s.CurrentToken() - if tok == nil { - log.Errorf("Attempt to call directInvoke without Reserve") - w.WriteHeader(http.StatusBadRequest) - return - } - - invoke, err := directinvoke.ReceiveDirectInvoke(w, r, *tok) - if err != nil { - log.Errorf("direct invoke error: %s", err) - return - } - - if err := s.AwaitInitialized(); err != nil { - w.WriteHeader(DoneFailedHTTPCode) - if state, err := s.InternalState(); err == nil { - w.Write(state.AsJSON()) - } - return - } - - if err := s.FastInvoke(w, invoke, true); err != nil { - switch err { - case rapidcore.ErrNotReserved: - case rapidcore.ErrAlreadyReplied: - case rapidcore.ErrAlreadyInvocating: - log.Errorf("Failed to set reply stream: %s", err) - w.WriteHeader(http.StatusBadRequest) - return - case rapidcore.ErrInvokeReservationDone: - w.WriteHeader(http.StatusBadGateway) - } - } -} diff --git a/lambda/rapidcore/standalone/eventLogHandler.go b/lambda/rapidcore/standalone/eventLogHandler.go deleted file mode 100644 index e5bf7ac..0000000 --- a/lambda/rapidcore/standalone/eventLogHandler.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package standalone - -import ( - "encoding/json" - "fmt" - "net/http" - - "go.amzn.com/lambda/rapidcore/standalone/telemetry" -) - -func EventLogHandler(w http.ResponseWriter, r *http.Request, eventsAPI *telemetry.StandaloneEventsAPI) { - bytes, err := json.Marshal(eventsAPI.EventLog()) - if err != nil { - http.Error(w, fmt.Sprintf("marshalling error: %s", err), http.StatusInternalServerError) - return - } - w.Write(bytes) -} diff --git a/lambda/rapidcore/standalone/executeHandler.go b/lambda/rapidcore/standalone/executeHandler.go deleted file mode 100644 index 0c7162b..0000000 --- a/lambda/rapidcore/standalone/executeHandler.go +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package standalone - -import ( - "net/http" - - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/metering" - "go.amzn.com/lambda/rapidcore" -) - -func Execute(w http.ResponseWriter, r *http.Request, sandbox rapidcore.LambdaInvokeAPI) { - - invokePayload := &interop.Invoke{ - TraceID: r.Header.Get("X-Amzn-Trace-Id"), - LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), - Payload: r.Body, - InvokeReceivedTime: metering.Monotime(), - } - - // If we write to 'w' directly and waitUntilRelease fails, we won't be able to propagate error anymore - invokeResp := &ResponseWriterProxy{} - if err := sandbox.Invoke(invokeResp, invokePayload); err != nil { - switch err { - // Reserve errors: - case rapidcore.ErrAlreadyReserved: - log.WithError(err).Error("Failed to reserve as it is already reserved.") - w.WriteHeader(400) - case rapidcore.ErrInternalServerError: - log.WithError(err).Error("Failed to reserve from an internal server error.") - w.WriteHeader(http.StatusInternalServerError) - - // Invoke errors: - case rapidcore.ErrNotReserved, rapidcore.ErrAlreadyReplied, rapidcore.ErrAlreadyInvocating: - log.WithError(err).Error("Failed to invoke from setting the reply stream.") - w.WriteHeader(400) - - case rapidcore.ErrInvokeResponseAlreadyWritten: - return - case rapidcore.ErrInvokeTimeout, rapidcore.ErrInitResetReceived: - log.WithError(err).Error("Failed to invoke from an invoke timeout.") - w.WriteHeader(http.StatusGatewayTimeout) - - // DONE failures: - case rapidcore.ErrInvokeDoneFailed: - copyHeaders(invokeResp, w) - w.WriteHeader(DoneFailedHTTPCode) - w.Write(invokeResp.Body) - return - // Reservation canceled errors - case rapidcore.ErrReserveReservationDone, rapidcore.ErrInvokeReservationDone, rapidcore.ErrReleaseReservationDone, rapidcore.ErrInitNotStarted: - log.WithError(err).Error("Failed to cancel reservation.") - w.WriteHeader(http.StatusGatewayTimeout) - } - - return - } - - copyHeaders(invokeResp, w) - if invokeResp.StatusCode != 0 { - w.WriteHeader(invokeResp.StatusCode) - } - w.Write(invokeResp.Body) -} - -func copyHeaders(proxyWriter, writer http.ResponseWriter) { - for key, val := range proxyWriter.Header() { - writer.Header().Set(key, val[0]) - } -} diff --git a/lambda/rapidcore/standalone/initHandler.go b/lambda/rapidcore/standalone/initHandler.go deleted file mode 100644 index d60ec6f..0000000 --- a/lambda/rapidcore/standalone/initHandler.go +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package standalone - -import ( - "fmt" - "net/http" - "os" - "time" - - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapidcore/env" -) - -type RuntimeInfo struct { - ImageJSON string `json:"runtimeImageJSON,omitempty"` - Arn string `json:"runtimeArn,omitempty"` - Version string `json:"runtimeVersion,omitempty"` -} - -// TODO: introduce suppress init flag -type InitBody struct { - Handler string `json:"handler"` - FunctionName string `json:"functionName"` - FunctionVersion string `json:"functionVersion"` - InvokeTimeoutMs int64 `json:"invokeTimeoutMs"` - RuntimeInfo RuntimeInfo `json:"runtimeInfo"` - Customer struct { - Environment map[string]string `json:"environment"` - } `json:"customer"` - AwsKey *string `json:"awskey"` - AwsSecret *string `json:"awssecret"` - AwsSession *string `json:"awssession"` - CredentialsExpiry time.Time `json:"credentialsExpiry"` - Throttled bool `json:"throttled"` -} - -type InitRequest struct { - InitBody - ReplyChan chan Reply -} - -func (c *InitBody) Validate() error { - // Handler is optional - if c.FunctionName == "" { - return fmt.Errorf("functionName missing") - } - if c.FunctionVersion == "" { - return fmt.Errorf("FunctionVersion missing") - } - if c.InvokeTimeoutMs == 0 { - return fmt.Errorf("invokeTimeoutMs missing") - } - - return nil -} - -func InitHandler(w http.ResponseWriter, r *http.Request, sandbox InteropServer, bs interop.Bootstrap) { - init := InitBody{} - if lerr := readBodyAndUnmarshalJSON(r, &init); lerr != nil { - lerr.Send(w, r) - return - } - - if err := init.Validate(); err != nil { - newErrorReply(ClientInvalidRequest, err.Error()).Send(w, r) - return - } - - for envKey, envVal := range init.Customer.Environment { - // We set environment variables to keep the env parsing & filtering - // logic consistent across standalone-mode and girp-mode - os.Setenv(envKey, envVal) - } - - awsKey, awsSecret, awsSession := getCredentials(init) - - sandboxType := interop.SandboxClassic - - if init.Throttled { - sandboxType = interop.SandboxPreWarmed - } - - // pass to rapid - sandbox.Init(&interop.Init{ - Handler: init.Handler, - AwsKey: awsKey, - AwsSecret: awsSecret, - AwsSession: awsSession, - CredentialsExpiry: init.CredentialsExpiry, - XRayDaemonAddress: "0.0.0.0:0", // TODO - FunctionName: init.FunctionName, - FunctionVersion: init.FunctionVersion, - RuntimeInfo: interop.RuntimeInfo{ - ImageJSON: init.RuntimeInfo.ImageJSON, - Arn: init.RuntimeInfo.Arn, - Version: init.RuntimeInfo.Version}, - CustomerEnvironmentVariables: env.CustomerEnvironmentVariables(), - SandboxType: sandboxType, - Bootstrap: bs, - EnvironmentVariables: env.NewEnvironment(), - }, init.InvokeTimeoutMs) -} - -func getCredentials(init InitBody) (string, string, string) { - // ToDo(guvfatih): I think instead of passing and getting these credentials values via environment variables - // we need to make StandaloneTests passing these via the Init request to be compliant with the existing protocol. - awsKey := os.Getenv("AWS_ACCESS_KEY_ID") - awsSecret := os.Getenv("AWS_SECRET_ACCESS_KEY") - awsSession := os.Getenv("AWS_SESSION_TOKEN") - - if init.AwsKey != nil { - awsKey = *init.AwsKey - } - - if init.AwsSecret != nil { - awsSecret = *init.AwsSecret - } - - if init.AwsSession != nil { - awsSession = *init.AwsSession - } - - return awsKey, awsSecret, awsSession -} diff --git a/lambda/rapidcore/standalone/internalStateHandler.go b/lambda/rapidcore/standalone/internalStateHandler.go deleted file mode 100644 index cb40c1c..0000000 --- a/lambda/rapidcore/standalone/internalStateHandler.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package standalone - -import ( - "net/http" -) - -func InternalStateHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { - state, err := s.InternalState() - if err != nil { - http.Error(w, "internal state callback not set", http.StatusInternalServerError) - return - } - - w.Write(state.AsJSON()) -} diff --git a/lambda/rapidcore/standalone/invokeHandler.go b/lambda/rapidcore/standalone/invokeHandler.go deleted file mode 100644 index 48a3a03..0000000 --- a/lambda/rapidcore/standalone/invokeHandler.go +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package standalone - -import ( - "fmt" - "net/http" - "strconv" - - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/metering" - "go.amzn.com/lambda/rapidcore" - - log "github.com/sirupsen/logrus" -) - -func InvokeHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { - tok := s.CurrentToken() - if tok == nil { - log.Errorf("Attempt to call directInvoke without Reserve") - w.WriteHeader(http.StatusBadRequest) - return - } - - restoreDurationHeader := r.Header.Get("restore-duration") - restoreStartHeader := r.Header.Get("restore-start-time") - - var restoreDurationNs int64 = 0 - var restoreStartTimeMonotime int64 = 0 - if restoreDurationHeader != "" && restoreStartHeader != "" { - var err1, err2 error - restoreDurationNs, err1 = strconv.ParseInt(restoreDurationHeader, 10, 64) - restoreStartTimeMonotime, err2 = strconv.ParseInt(restoreStartHeader, 10, 64) - if err1 != nil || err2 != nil { - log.Errorf("Failed to parse 'restore-duration' from '%s' and/or 'restore-start-time' from '%s'", restoreDurationHeader, restoreStartHeader) - restoreDurationNs = 0 - restoreStartTimeMonotime = 0 - } - } - - invokePayload := &interop.Invoke{ - TraceID: r.Header.Get("X-Amzn-Trace-Id"), - LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), - Payload: r.Body, - DeadlineNs: fmt.Sprintf("%d", metering.Monotime()+tok.FunctionTimeout.Nanoseconds()), - InvokeReceivedTime: metering.Monotime(), - RestoreDurationNs: restoreDurationNs, - RestoreStartTimeMonotime: restoreStartTimeMonotime, - } - - if err := s.AwaitInitialized(); err != nil { - w.WriteHeader(DoneFailedHTTPCode) - if state, err := s.InternalState(); err == nil { - w.Write(state.AsJSON()) - } - return - } - - if err := s.FastInvoke(w, invokePayload, false); err != nil { - switch err { - case rapidcore.ErrNotReserved: - case rapidcore.ErrAlreadyReplied: - case rapidcore.ErrAlreadyInvocating: - log.Errorf("Failed to set reply stream: %s", err) - w.WriteHeader(400) - return - case rapidcore.ErrInvokeReservationDone: - // TODO use http.StatusBadGateway - w.WriteHeader(http.StatusGatewayTimeout) - } - } -} diff --git a/lambda/rapidcore/standalone/middleware.go b/lambda/rapidcore/standalone/middleware.go deleted file mode 100644 index 06baae3..0000000 --- a/lambda/rapidcore/standalone/middleware.go +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package standalone - -import ( - "net/http" - - "github.com/go-chi/chi/middleware" - log "github.com/sirupsen/logrus" -) - -func standaloneAccessLogDecorator(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Debugf("standalone: -> %s %s %v", r.Method, r.URL, r.Header) - ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) - next.ServeHTTP(ww, r) - - status := 200 - if ww.Status() != 0 { - status = ww.Status() - } - - if status != 0 && status/100 != 2 { - log.Errorf("standalone: <- %s %d %v", r.URL, status, w.Header()) - } else { - log.Debugf("standalone: <- %s %d %v", r.URL, status, w.Header()) - } - }) -} diff --git a/lambda/rapidcore/standalone/pingHandler.go b/lambda/rapidcore/standalone/pingHandler.go deleted file mode 100644 index c6cb021..0000000 --- a/lambda/rapidcore/standalone/pingHandler.go +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package standalone - -import ( - "net/http" -) - -func PingHandler(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("pong")) -} diff --git a/lambda/rapidcore/standalone/reserveHandler.go b/lambda/rapidcore/standalone/reserveHandler.go deleted file mode 100644 index 52b51cd..0000000 --- a/lambda/rapidcore/standalone/reserveHandler.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package standalone - -import ( - "net/http" - - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/core/directinvoke" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapidcore" -) - -const ( - ReservationTokenHeader = "Reservation-Token" - InvokeIDHeader = "Invoke-ID" - VersionIDHeader = "Version-ID" -) - -func tokenToHeaders(w http.ResponseWriter, token interop.Token) { - w.Header().Set(ReservationTokenHeader, token.ReservationToken) - w.Header().Set(directinvoke.InvokeIDHeader, token.InvokeID) - w.Header().Set(directinvoke.VersionIDHeader, token.VersionID) -} - -func ReserveHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { - reserveResp, err := s.Reserve("", r.Header.Get("X-Amzn-Trace-Id"), r.Header.Get("X-Amzn-Segment-Id")) - - if err != nil { - switch err { - case rapidcore.ErrReserveReservationDone: - // TODO use http.StatusBadGateway - w.WriteHeader(http.StatusGatewayTimeout) - default: - log.Errorf("Failed to reserve: %s", err) - w.WriteHeader(400) - } - return - } - - tokenToHeaders(w, reserveResp.Token) - w.Write(reserveResp.InternalState.AsJSON()) -} diff --git a/lambda/rapidcore/standalone/resetHandler.go b/lambda/rapidcore/standalone/resetHandler.go deleted file mode 100644 index 4f2ca2e..0000000 --- a/lambda/rapidcore/standalone/resetHandler.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package standalone - -import ( - "net/http" -) - -type resetAPIRequest struct { - Reason string `json:"reason"` - TimeoutMs int64 `json:"timeoutMs"` -} - -func ResetHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { - reset := resetAPIRequest{} - if lerr := readBodyAndUnmarshalJSON(r, &reset); lerr != nil { - lerr.Send(w, r) - return - } - - resetDescription, err := s.Reset(reset.Reason, reset.TimeoutMs) - if err != nil { - (&FailureReply{}).Send(w, r) - return - } - - w.Write(resetDescription.AsJSON()) -} diff --git a/lambda/rapidcore/standalone/restoreHandler.go b/lambda/rapidcore/standalone/restoreHandler.go deleted file mode 100644 index fdf7a5d..0000000 --- a/lambda/rapidcore/standalone/restoreHandler.go +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package standalone - -import ( - "encoding/json" - "net/http" - "strconv" - "time" - - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/interop" -) - -type RestoreBody struct { - AwsKey string `json:"awskey"` - AwsSecret string `json:"awssecret"` - AwsSession string `json:"awssession"` - CredentialsExpiry time.Time `json:"credentialsExpiry"` - RestoreHookTimeoutMs int64 `json:"restoreHookTimeoutMs"` -} - -func RestoreHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { - restoreRequest := RestoreBody{} - if lerr := readBodyAndUnmarshalJSON(r, &restoreRequest); lerr != nil { - lerr.Send(w, r) - return - } - - restore := &interop.Restore{ - AwsKey: restoreRequest.AwsKey, - AwsSecret: restoreRequest.AwsSecret, - AwsSession: restoreRequest.AwsSession, - CredentialsExpiry: restoreRequest.CredentialsExpiry, - RestoreHookTimeoutMs: restoreRequest.RestoreHookTimeoutMs, - } - - restoreResult, err := s.Restore(restore) - - responseMap := make(map[string]string) - - responseMap["restoreMs"] = strconv.FormatInt(restoreResult.RestoreMs, 10) - - if err != nil { - log.Errorf("Failed to restore: %s", err) - responseMap["restoreError"] = err.Error() - w.WriteHeader(http.StatusBadGateway) - } - - responseJSON, err := json.Marshal(responseMap) - - if err != nil { - log.Panicf("Cannot marshal the response map for RESTORE, %v", responseMap) - } - - w.Write(responseJSON) -} diff --git a/lambda/rapidcore/standalone/router.go b/lambda/rapidcore/standalone/router.go deleted file mode 100644 index 7957c32..0000000 --- a/lambda/rapidcore/standalone/router.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package standalone - -import ( - "context" - "net/http" - - "go.amzn.com/lambda/core/statejson" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapidcore" - "go.amzn.com/lambda/rapidcore/standalone/telemetry" - - "github.com/go-chi/chi" -) - -type InteropServer interface { - Init(i *interop.Init, invokeTimeoutMs int64) error - AwaitInitialized() error - FastInvoke(w http.ResponseWriter, i *interop.Invoke, direct bool) error - Reserve(id string, traceID, lambdaSegmentID string) (*rapidcore.ReserveResponse, error) - Reset(reason string, timeoutMs int64) (*statejson.ResetDescription, error) - AwaitRelease() (*statejson.ReleaseResponse, error) - Shutdown(shutdown *interop.Shutdown) *statejson.InternalStateDescription - InternalState() (*statejson.InternalStateDescription, error) - CurrentToken() *interop.Token - Restore(restore *interop.Restore) (interop.RestoreResult, error) -} - -func NewHTTPRouter(ipcSrv InteropServer, lambdaInvokeAPI rapidcore.LambdaInvokeAPI, eventsAPI *telemetry.StandaloneEventsAPI, shutdownFunc context.CancelFunc, bs interop.Bootstrap) *chi.Mux { - r := chi.NewRouter() - r.Use(standaloneAccessLogDecorator) - - r.Post("/2015-03-31/functions/*/invocations", func(w http.ResponseWriter, r *http.Request) { Execute(w, r, lambdaInvokeAPI) }) - r.Get("/test/ping", func(w http.ResponseWriter, r *http.Request) { PingHandler(w, r) }) - r.Post("/test/init", func(w http.ResponseWriter, r *http.Request) { InitHandler(w, r, ipcSrv, bs) }) - r.Post("/test/waitUntilInitialized", func(w http.ResponseWriter, r *http.Request) { WaitUntilInitializedHandler(w, r, ipcSrv) }) - r.Post("/test/reserve", func(w http.ResponseWriter, r *http.Request) { ReserveHandler(w, r, ipcSrv) }) - r.Post("/test/invoke", func(w http.ResponseWriter, r *http.Request) { InvokeHandler(w, r, ipcSrv) }) - r.Post("/test/waitUntilRelease", func(w http.ResponseWriter, r *http.Request) { WaitUntilReleaseHandler(w, r, ipcSrv) }) - r.Post("/test/reset", func(w http.ResponseWriter, r *http.Request) { ResetHandler(w, r, ipcSrv) }) - r.Post("/test/shutdown", func(w http.ResponseWriter, r *http.Request) { ShutdownHandler(w, r, ipcSrv, shutdownFunc) }) - r.Post("/test/directInvoke/{reservationtoken}", func(w http.ResponseWriter, r *http.Request) { DirectInvokeHandler(w, r, ipcSrv) }) - r.Get("/test/internalState", func(w http.ResponseWriter, r *http.Request) { InternalStateHandler(w, r, ipcSrv) }) - r.Get("/test/eventLog", func(w http.ResponseWriter, r *http.Request) { EventLogHandler(w, r, eventsAPI) }) - r.Post("/test/restore", func(w http.ResponseWriter, r *http.Request) { RestoreHandler(w, r, ipcSrv) }) - return r -} diff --git a/lambda/rapidcore/standalone/shutdownHandler.go b/lambda/rapidcore/standalone/shutdownHandler.go deleted file mode 100644 index 8085541..0000000 --- a/lambda/rapidcore/standalone/shutdownHandler.go +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package standalone - -import ( - "context" - "net/http" - - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/metering" -) - -type shutdownAPIRequest struct { - TimeoutMs int64 `json:"timeoutMs"` -} - -func ShutdownHandler(w http.ResponseWriter, r *http.Request, s InteropServer, shutdownFunc context.CancelFunc) { - shutdown := shutdownAPIRequest{} - if lerr := readBodyAndUnmarshalJSON(r, &shutdown); lerr != nil { - lerr.Send(w, r) - return - } - - internalState := s.Shutdown(&interop.Shutdown{ - DeadlineNs: metering.Monotime() + int64(shutdown.TimeoutMs*1000*1000), - }) - - w.Write(internalState.AsJSON()) - - shutdownFunc() -} diff --git a/lambda/rapidcore/standalone/telemetry/agent_writer.go b/lambda/rapidcore/standalone/telemetry/agent_writer.go deleted file mode 100644 index 6ff2581..0000000 --- a/lambda/rapidcore/standalone/telemetry/agent_writer.go +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package telemetry - -import ( - "bufio" - "bytes" -) - -type SandboxAgentWriter struct { - eventType string // 'runtime' or 'extension' - eventsAPI *StandaloneEventsAPI -} - -func NewSandboxAgentWriter(api *StandaloneEventsAPI, source string) *SandboxAgentWriter { - return &SandboxAgentWriter{ - eventType: source, - eventsAPI: api, - } -} - -func (w *SandboxAgentWriter) Write(logline []byte) (int, error) { - scanner := bufio.NewScanner(bytes.NewReader(logline)) - scanner.Split(bufio.ScanLines) - for scanner.Scan() { - w.eventsAPI.sendLogEvent(w.eventType, scanner.Text()) - } - return len(logline), nil -} diff --git a/lambda/rapidcore/standalone/telemetry/eventLog.go b/lambda/rapidcore/standalone/telemetry/eventLog.go deleted file mode 100644 index 0ab7c44..0000000 --- a/lambda/rapidcore/standalone/telemetry/eventLog.go +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package telemetry - -type EventLog struct { - Events []SandboxEvent `json:"events,omitempty"` // populated by the StandaloneEventLog object - Traces []TracingEvent `json:"traces,omitempty"` -} - -func NewEventLog() *EventLog { - return &EventLog{} -} diff --git a/lambda/rapidcore/standalone/telemetry/events_api.go b/lambda/rapidcore/standalone/telemetry/events_api.go deleted file mode 100644 index dcac7a3..0000000 --- a/lambda/rapidcore/standalone/telemetry/events_api.go +++ /dev/null @@ -1,293 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package telemetry - -import ( - "encoding/json" - "sort" - "sync" - "time" - - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/telemetry" -) - -type EventType = string - -const ( - PlatformInitStart = EventType("platform.initStart") - PlatformInitRuntimeDone = EventType("platform.initRuntimeDone") - PlatformInitReport = EventType("platform.initReport") - PlatformRestoreRuntimeDone = EventType("platform.restoreRuntimeDone") - PlatformStart = EventType("platform.start") - PlatformRuntimeDone = EventType("platform.runtimeDone") - PlatformExtension = EventType("platform.extension") - PlatformEnd = EventType("platform.end") - PlatformReport = EventType("platform.report") - PlatformFault = EventType("platform.fault") -) - -/* -SandboxEvent represents a generic sandbox event. For example: - - { - "time": "2021-03-16T13:10:42.358Z", - "type": "platform.extension", - "platformEvent": { "name": "foo bar", "state": "Ready", "events": ["INVOKE", "SHUTDOWN"]} - } - -Or: - - { - "time": "2021-03-16T13:10:42.358Z", - "type": "extension", - "logMessage": "raw agent console output" - } - -FluxPump produces entries with a single field 'record', containing either an object or a string. -We make the distinction explicit by providing separate fields for the two cases, 'PlatformEvent' and 'LogMessage'. -Either one of the two would be populated, but not both. This makes code cleaner, but requires test client to merge -two fields back, producing a single 'record' entry again -- to match the FluxPump format that tests actually check. -*/ -type SandboxEvent struct { - Time string `json:"time"` - Type EventType `json:"type"` - PlatformEvent map[string]interface{} `json:"platformEvent,omitempty"` - LogMessage string `json:"logMessage,omitempty"` -} - -type tailLogs struct { - Events []SandboxEvent `json:"events,omitempty"` -} - -type StandaloneEventsAPI struct { - lock sync.Mutex - requestID interop.RequestID - eventLog EventLog -} - -func (s *StandaloneEventsAPI) LogTrace(entry TracingEvent) { - s.lock.Lock() - defer s.lock.Unlock() - s.eventLog.Traces = append(s.eventLog.Traces, entry) -} - -func (s *StandaloneEventsAPI) EventLog() *EventLog { - return &s.eventLog -} - -func (s *StandaloneEventsAPI) SetCurrentRequestID(requestID interop.RequestID) { - s.requestID = requestID -} - -func (s *StandaloneEventsAPI) SendInitStart(data interop.InitStartData) error { - record := map[string]interface{}{ - "initializationType": data.InitializationType, - "runtimeVersion": data.RuntimeVersion, - "runtimeArn": data.RuntimeVersionArn, - "runtimeVersionArn": data.RuntimeVersionArn, - "functionArn": data.FunctionArn, - "functionName": data.FunctionName, - "functionVersion": data.FunctionVersion, - "instanceId": data.InstanceID, - "instanceMaxMemory": data.InstanceMaxMemory, - "phase": data.Phase, - } - - s.addTracingToRecord(data.Tracing, record) - - return s.sendPlatformEvent(PlatformInitStart, record) -} - -func (s *StandaloneEventsAPI) SendInitRuntimeDone(data interop.InitRuntimeDoneData) error { - record := map[string]interface{}{ - "initializationType": data.InitializationType, - "status": data.Status, - "phase": data.Phase, - } - - s.addTracingToRecord(data.Tracing, record) - - if data.ErrorType != nil { - record["errorType"] = data.ErrorType - } - - return s.sendPlatformEvent(PlatformInitRuntimeDone, record) -} - -func (s *StandaloneEventsAPI) SendInitReport(data interop.InitReportData) error { - record := map[string]interface{}{ - "initializationType": data.InitializationType, - "metrics": data.Metrics, - "phase": data.Phase, - } - - s.addTracingToRecord(data.Tracing, record) - - return s.sendPlatformEvent(PlatformInitReport, record) -} - -func (s *StandaloneEventsAPI) SendRestoreRuntimeDone(data interop.RestoreRuntimeDoneData) error { - record := map[string]interface{}{"status": data.Status} - - s.addTracingToRecord(data.Tracing, record) - - if data.ErrorType != nil { - record["errorType"] = data.ErrorType - } - - return s.sendPlatformEvent(PlatformRestoreRuntimeDone, record) -} - -func (s *StandaloneEventsAPI) SendInvokeStart(data interop.InvokeStartData) error { - record := map[string]interface{}{ - "version": data.Version, - "requestId": data.RequestID, - } - - s.addTracingToRecord(data.Tracing, record) - - return s.sendPlatformEvent(PlatformStart, record) -} - -func (s *StandaloneEventsAPI) SendInvokeRuntimeDone(data interop.InvokeRuntimeDoneData) error { - record := map[string]interface{}{ - "requestId": s.requestID, - "status": data.Status, - "metrics": data.Metrics, - "internalMetrics": data.InternalMetrics, - "spans": data.Spans, - } - - if data.ErrorType != nil { - record["errorType"] = data.ErrorType - } - - s.addTracingToRecord(data.Tracing, record) - - return s.sendPlatformEvent(PlatformRuntimeDone, record) -} - -func (s *StandaloneEventsAPI) SendExtensionInit(data interop.ExtensionInitData) error { - sort.Strings(data.Subscriptions) - record := map[string]interface{}{ - "name": data.AgentName, - "state": data.State, - "events": data.Subscriptions, - } - if len(data.ErrorType) > 0 { - record["errorType"] = data.ErrorType - } - return s.sendPlatformEvent(PlatformExtension, record) -} - -func (s *StandaloneEventsAPI) SendImageErrorLog(interop.ImageErrorLogData) { - // Called on bootstrap exec errors for OCI error modes, e.g. InvalidEntrypoint etc. -} - -func (s *StandaloneEventsAPI) SendEnd(data interop.EndData) error { - record := map[string]interface{}{ - "requestId": data.RequestID, - } - - return s.sendPlatformEvent(PlatformEnd, record) -} - -func (s *StandaloneEventsAPI) SendReportSpan(interop.Span) error { - return nil -} - -func (s *StandaloneEventsAPI) SendReport(data interop.ReportData) error { - record := map[string]interface{}{ - "requestId": s.requestID, - "status": data.Status, - "metrics": data.Metrics, - "spans": data.Spans, - "tracing": data.Tracing, - } - if data.ErrorType != nil { - record["errorType"] = data.ErrorType - } - - return s.sendPlatformEvent(PlatformReport, record) -} - -func (s *StandaloneEventsAPI) SendFault(data interop.FaultData) error { - record := map[string]interface{}{ - "fault": data.String(), - } - - return s.sendPlatformEvent(PlatformFault, record) -} - -func (s *StandaloneEventsAPI) FetchTailLogs(string) (string, error) { - s.lock.Lock() - defer s.lock.Unlock() - - if len(s.eventLog.Events) == 0 { - return "", nil - } - - logs := tailLogs{Events: s.eventLog.Events} - logsBytes, err := json.Marshal(logs) - if err != nil { - return "", err - } - - s.eventLog.Events = nil - - return string(logsBytes), nil -} - -func (s *StandaloneEventsAPI) GetRuntimeDoneSpans( - runtimeStartedTime int64, - invokeResponseMetrics *interop.InvokeResponseMetrics, - runtimeOverheadStartedTime int64, - runtimeReadyTime int64, -) []interop.Span { - spans := telemetry.GetRuntimeDoneSpans(runtimeStartedTime, invokeResponseMetrics) - return spans -} - -func (s *StandaloneEventsAPI) sendPlatformEvent(eventType string, record map[string]interface{}) error { - e := SandboxEvent{ - Time: time.Now().Format(time.RFC3339), - Type: eventType, - PlatformEvent: record, - } - s.appendEvent(e) - s.logEvent(e) - return nil -} - -func (s *StandaloneEventsAPI) sendLogEvent(eventType, logMessage string) error { - e := SandboxEvent{ - Time: time.Now().Format(time.RFC3339), - Type: eventType, - LogMessage: logMessage, - } - s.appendEvent(e) - s.logEvent(e) - return nil -} - -func (s *StandaloneEventsAPI) appendEvent(event SandboxEvent) { - s.lock.Lock() - defer s.lock.Unlock() - s.eventLog.Events = append(s.eventLog.Events, event) -} - -func (s *StandaloneEventsAPI) logEvent(e SandboxEvent) { - log.WithField("event", e).Info("sandbox event") -} - -func (s *StandaloneEventsAPI) addTracingToRecord(tracingData *interop.TracingCtx, record map[string]interface{}) { - if tracingData != nil { - record["tracing"] = map[string]string{ - "spanId": tracingData.SpanID, - "type": string(tracingData.Type), - "value": tracingData.Value, - } - } -} diff --git a/lambda/rapidcore/standalone/telemetry/logs_egress_api.go b/lambda/rapidcore/standalone/telemetry/logs_egress_api.go deleted file mode 100644 index 0f42dd1..0000000 --- a/lambda/rapidcore/standalone/telemetry/logs_egress_api.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package telemetry - -import "io" - -type StandaloneLogsEgressAPI struct { - api *StandaloneEventsAPI -} - -func NewStandaloneLogsEgressAPI(api *StandaloneEventsAPI) *StandaloneLogsEgressAPI { - return &StandaloneLogsEgressAPI{ - api: api, - } -} - -func (s *StandaloneLogsEgressAPI) GetExtensionSockets() (io.Writer, io.Writer, error) { - w := NewSandboxAgentWriter(s.api, "extension") - return w, w, nil -} - -func (s *StandaloneLogsEgressAPI) GetRuntimeSockets() (io.Writer, io.Writer, error) { - w := NewSandboxAgentWriter(s.api, "function") - return w, w, nil -} diff --git a/lambda/rapidcore/standalone/telemetry/structured_logger.go b/lambda/rapidcore/standalone/telemetry/structured_logger.go deleted file mode 100644 index 8d9382b..0000000 --- a/lambda/rapidcore/standalone/telemetry/structured_logger.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package telemetry - -import ( - "github.com/sirupsen/logrus" - "os" -) - -var log = getLogger() - -func getLogger() *logrus.Logger { - formatter := logrus.JSONFormatter{} - formatter.DisableTimestamp = true - logger := new(logrus.Logger) - logger.Out = os.Stdout - logger.Formatter = &formatter - logger.Level = logrus.InfoLevel - return logger -} diff --git a/lambda/rapidcore/standalone/telemetry/tracer.go b/lambda/rapidcore/standalone/telemetry/tracer.go deleted file mode 100644 index ba7f32d..0000000 --- a/lambda/rapidcore/standalone/telemetry/tracer.go +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package telemetry - -import ( - "context" - "encoding/json" - "fmt" - "time" - - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/metering" - "go.amzn.com/lambda/rapi/model" - "go.amzn.com/lambda/telemetry" - - "github.com/sirupsen/logrus" -) - -// InitSubsegmentName provides name attribute for Init subsegment -const InitSubsegmentName = "Initialization" - -// RestoreSubsegmentName provides name attribute for Restore subsegment -const RestoreSubsegmentName = "Restore" - -// InvokeSubsegmentName provides name attribute for Invoke subsegment -const InvokeSubsegmentName = "Invocation" - -// OverheadSubsegmentName provides name attribute for Overhead subsegment -const OverheadSubsegmentName = "Overhead" - -type StandaloneTracer struct { - startFunction func(ctx context.Context, invoke *interop.Invoke, segmentName string, timestamp int64) - endFunction func(ctx context.Context, invoke *interop.Invoke, segmentName string, timestamp int64) - invoke *interop.Invoke - tracingHeader string - rootTraceID string - parent string - sampled string - lineage string - invocationSubsegmentID string - initStartTime int64 - initEndTime int64 - restoreStartTime int64 - restoreEndTime int64 - restorePresent bool -} - -type TracingEvent struct { - Message string `json:"message"` - TraceID string `json:"trace_id"` - SegmentName string `json:"segment_name"` - SegmentID string `json:"segment_id"` - Timestamp int64 `json:"timestamp"` -} - -func (t *StandaloneTracer) Configure(invoke *interop.Invoke) { - t.invoke = invoke - t.tracingHeader = invoke.TraceID - t.invocationSubsegmentID = "" - t.rootTraceID, t.parent, t.sampled, t.lineage = telemetry.ParseTracingHeader(invoke.TraceID) - if invoke.RestoreDurationNs == 0 { - t.restorePresent = false - } else { - t.restorePresent = true - t.restoreStartTime = metering.MonoToEpoch(invoke.RestoreStartTimeMonotime) - t.restoreEndTime = t.restoreStartTime + invoke.RestoreDurationNs - } -} - -func (t *StandaloneTracer) CaptureInvokeSegment(ctx context.Context, criticalFunction func(context.Context) error) error { - return t.withStartAndEnd(ctx, criticalFunction, "STANDALONE_FUNCTION_NAME") -} - -func (t *StandaloneTracer) CaptureInitSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { - return t.withStartAndEnd(ctx, criticalFunction, InitSubsegmentName) -} - -func (t *StandaloneTracer) CaptureInvokeSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { - t.invocationSubsegmentID = InvokeSubsegmentName - return t.withStartAndEnd(ctx, criticalFunction, InvokeSubsegmentName) -} - -func (t *StandaloneTracer) CaptureOverheadSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { - return t.withStartAndEnd(ctx, criticalFunction, OverheadSubsegmentName) -} - -func (t *StandaloneTracer) withStartAndEnd(ctx context.Context, criticalFunction func(context.Context) error, segmentName string) error { - ctx = telemetry.NewTraceContext(ctx, t.rootTraceID, segmentName) - t.startFunction(ctx, t.invoke, segmentName, time.Now().UnixNano()) - err := criticalFunction(ctx) - t.endFunction(ctx, t.invoke, segmentName, time.Now().UnixNano()) - return err -} - -func (t *StandaloneTracer) RecordInitStartTime() { - t.initStartTime = time.Now().UnixNano() -} - -func (t *StandaloneTracer) RecordInitEndTime() { - t.initEndTime = time.Now().UnixNano() - -} - -func (t *StandaloneTracer) sendPrepSubsegment(ctx context.Context, subsegmentName string, startTime int64, endTime int64) { - ctx = telemetry.NewTraceContext(ctx, t.rootTraceID, subsegmentName) - t.startFunction(ctx, t.invoke, subsegmentName, startTime) - t.endFunction(ctx, t.invoke, subsegmentName, endTime) -} - -func (t *StandaloneTracer) SendInitSubsegmentWithRecordedTimesOnce(ctx context.Context) { - t.sendPrepSubsegment(ctx, InitSubsegmentName, t.initStartTime, t.initEndTime) -} -func (t *StandaloneTracer) SendRestoreSubsegmentWithRecordedTimesOnce(ctx context.Context) { - if t.restorePresent { - t.sendPrepSubsegment(ctx, RestoreSubsegmentName, t.restoreStartTime, t.restoreEndTime) - } -} -func (t *StandaloneTracer) MarkError(ctx context.Context) {} -func (t *StandaloneTracer) AttachErrorCause(ctx context.Context, errorCause json.RawMessage) {} - -func (t *StandaloneTracer) WithErrorCause(ctx context.Context, appCtx appctx.ApplicationContext, criticalFunction func(ctx context.Context) error) func(ctx context.Context) error { - return criticalFunction -} -func (t *StandaloneTracer) WithError(ctx context.Context, appCtx appctx.ApplicationContext, criticalFunction func(ctx context.Context) error) func(ctx context.Context) error { - return criticalFunction -} - -func (t *StandaloneTracer) BuildTracingHeader() func(ctx context.Context) string { - // extract root trace ID and parent from context and build the tracing header - return func(ctx context.Context) string { - var parent string - var ok bool - - if parent, ok = ctx.Value(telemetry.DocumentIDKey).(string); !ok || parent == "" { - return t.invoke.TraceID - } - - if t.rootTraceID == "" || t.sampled == "" { - return "" - } - - var tracingHeader = "Root=%s;Parent=%s;Sampled=%s" - - if t.lineage == "" { - return fmt.Sprintf(tracingHeader, t.rootTraceID, parent, t.sampled) - } - - return fmt.Sprintf(tracingHeader+";Lineage=%s", t.rootTraceID, parent, t.sampled, t.lineage) - } -} - -func (t *StandaloneTracer) BuildTracingCtxForStart() *interop.TracingCtx { - if t.rootTraceID == "" || t.sampled != model.XRaySampled { - return nil - } - - return &interop.TracingCtx{ - SpanID: t.parent, - Type: model.XRayTracingType, - Value: telemetry.BuildFullTraceID(t.rootTraceID, t.invoke.LambdaSegmentID, t.sampled), - } -} -func (t *StandaloneTracer) BuildTracingCtxAfterInvokeComplete() *interop.TracingCtx { - if t.rootTraceID == "" || t.sampled != model.XRaySampled || t.invocationSubsegmentID == "" { - return nil - } - - return &interop.TracingCtx{ - SpanID: t.invocationSubsegmentID, - Type: model.XRayTracingType, - Value: t.tracingHeader, - } -} - -func isTracingEnabled(root, parent, sampled string) bool { - return len(root) != 0 && len(parent) != 0 && sampled == "1" -} - -func NewStandaloneTracer(api *StandaloneEventsAPI) *StandaloneTracer { - startCaptureFn := func(ctx context.Context, i *interop.Invoke, segmentName string, timestamp int64) { - root, parent, sampled, _ := telemetry.ParseTracingHeader(i.TraceID) - if isTracingEnabled(root, parent, sampled) { - e := TracingEvent{ - Message: "START", - TraceID: root, - SegmentName: segmentName, - SegmentID: parent, - Timestamp: timestamp / int64(time.Millisecond), - } - api.LogTrace(e) - log.WithFields(logrus.Fields{"trace": e}).Info("sandbox trace") - } - } - - endCaptureFn := func(ctx context.Context, i *interop.Invoke, segmentName string, timestamp int64) { - root, parent, sampled, _ := telemetry.ParseTracingHeader(i.TraceID) - if isTracingEnabled(root, parent, sampled) { - e := TracingEvent{ - Message: "END", - TraceID: root, - SegmentName: "", - SegmentID: parent, - Timestamp: timestamp / int64(time.Millisecond), - } - api.LogTrace(e) - log.WithFields(logrus.Fields{"trace": e}).Info("sandbox trace") - } - } - - return &StandaloneTracer{ - startFunction: startCaptureFn, - endFunction: endCaptureFn, - } -} diff --git a/lambda/rapidcore/standalone/util.go b/lambda/rapidcore/standalone/util.go deleted file mode 100644 index 7ba7420..0000000 --- a/lambda/rapidcore/standalone/util.go +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package standalone - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/rapi/model" -) - -const ( - DoneFailedHTTPCode = 502 -) - -type ErrorType int - -const ( - ClientInvalidRequest ErrorType = iota -) - -func (t ErrorType) String() string { - switch t { - case ClientInvalidRequest: - return "Client.InvalidRequest" - } - return fmt.Sprintf("Cannot stringify standalone.ErrorType.%d", int(t)) -} - -type ResponseWriterProxy struct { - Body []byte - StatusCode int - header http.Header -} - -func (w *ResponseWriterProxy) Header() http.Header { - if w.header == nil { - w.header = http.Header{} - } - return w.header -} - -func (w *ResponseWriterProxy) Write(b []byte) (int, error) { - w.Body = b - return 0, nil -} - -func (w *ResponseWriterProxy) WriteHeader(statusCode int) { - w.StatusCode = statusCode -} - -func (w *ResponseWriterProxy) IsError() bool { - return w.StatusCode != 0 && w.StatusCode/100 != 2 -} - -func readBodyAndUnmarshalJSON(r *http.Request, dst interface{}) *ErrorReply { - bodyBytes, err := io.ReadAll(r.Body) - if err != nil { - return newErrorReply(ClientInvalidRequest, fmt.Sprintf("Failed to read full body: %s", err)) - } - - if err = json.Unmarshal(bodyBytes, dst); err != nil { - return newErrorReply(ClientInvalidRequest, fmt.Sprintf("Invalid json %s: %s", string(bodyBytes), err)) - } - - return nil -} - -type ErrorReply struct { - model.ErrorResponse -} - -type RuntimeErrorReply struct { - Payload []byte -} - -func (e *RuntimeErrorReply) Send(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(500) - w.Write(e.Payload) -} - -func newErrorReply(errType ErrorType, errMsg string) *ErrorReply { - return &ErrorReply{ErrorResponse: model.ErrorResponse{ErrorType: errType.String(), ErrorMessage: errMsg}} -} - -func (e *ErrorReply) Send(w http.ResponseWriter, r *http.Request) { - http.Error(w, e.ErrorType, 400) - bodyJSON, err := json.Marshal(*e) - if err != nil { - http.Error(w, "Invalid format", 500) - log.Errorf("Failed to Marshal(%#v): %s", e, err) - } else { - w.Write(bodyJSON) - } -} - -type SuccessReply struct { - Body []byte -} - -func (s *SuccessReply) Send(w http.ResponseWriter, r *http.Request) { - w.Write(s.Body) -} - -type FailureReply struct { - Body []byte -} - -func (s *FailureReply) Send(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(DoneFailedHTTPCode) - w.Write(s.Body) -} - -type Reply interface { - Send(http.ResponseWriter, *http.Request) -} diff --git a/lambda/rapidcore/standalone/waitUntilInitializedHandler.go b/lambda/rapidcore/standalone/waitUntilInitializedHandler.go deleted file mode 100644 index 95d64ac..0000000 --- a/lambda/rapidcore/standalone/waitUntilInitializedHandler.go +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package standalone - -import ( - "net/http" - - "go.amzn.com/lambda/rapidcore" -) - -func WaitUntilInitializedHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { - err := s.AwaitInitialized() - if err != nil { - switch err { - case rapidcore.ErrInitDoneFailed: - w.WriteHeader(DoneFailedHTTPCode) - case rapidcore.ErrInitResetReceived: - w.WriteHeader(DoneFailedHTTPCode) - } - } - w.WriteHeader(http.StatusOK) -} diff --git a/lambda/rapidcore/standalone/waitUntilReleaseHandler.go b/lambda/rapidcore/standalone/waitUntilReleaseHandler.go deleted file mode 100644 index 1caeb8c..0000000 --- a/lambda/rapidcore/standalone/waitUntilReleaseHandler.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package standalone - -import ( - "net/http" - - "go.amzn.com/lambda/rapidcore" -) - -func WaitUntilReleaseHandler(w http.ResponseWriter, r *http.Request, s InteropServer) { - releaseAwait, err := s.AwaitRelease() - if err != nil { - switch err { - case rapidcore.ErrInvokeDoneFailed: - w.WriteHeader(http.StatusBadGateway) - case rapidcore.ErrReleaseReservationDone: - // TODO return sandbox status when we implement async reset handling - // TODO use http.StatusOK - w.WriteHeader(http.StatusGatewayTimeout) - return - case rapidcore.ErrInitDoneFailed: - w.WriteHeader(DoneFailedHTTPCode) - w.Write(releaseAwait.AsJSON()) - return - } - } - - w.Write(releaseAwait.AsJSON()) -} diff --git a/lambda/supervisor/local_supervisor.go b/lambda/supervisor/local_supervisor.go deleted file mode 100644 index 4405686..0000000 --- a/lambda/supervisor/local_supervisor.go +++ /dev/null @@ -1,303 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package supervisor - -import ( - "context" - "errors" - "fmt" - "os/exec" - "runtime" - "sync" - "syscall" - "time" - - log "github.com/sirupsen/logrus" - "go.amzn.com/lambda/supervisor/model" -) - -// typecheck interface compliance -var _ model.SupervisorClient = (*LocalSupervisor)(nil) - -type process struct { - // pid of the running process - pid int - // channel that can be use to block - // while waiting on process termination. - termination chan struct{} -} - -type LocalSupervisor struct { - events chan model.Event - processMapLock sync.Mutex - processMap map[string]process - freezeThawCycleStart time.Time - - RootPath string -} - -func NewLocalSupervisor() *LocalSupervisor { - return &LocalSupervisor{ - events: make(chan model.Event), - processMap: make(map[string]process), - RootPath: "/", - } -} - -func (*LocalSupervisor) Start(ctx context.Context, req *model.StartRequest) error { - return nil -} -func (*LocalSupervisor) Configure(ctx context.Context, req *model.ConfigureRequest) error { - return nil -} -func (*LocalSupervisor) Exit(ctx context.Context) {} - -func (s *LocalSupervisor) Exec(ctx context.Context, req *model.ExecRequest) error { - if req.Domain != "runtime" { - log.Debug("Exec is a no op if domain != runtime") - return nil - } - command := exec.Command(req.Path, req.Args...) - - if req.Env != nil { - envStrings := make([]string, 0, len(*req.Env)) - for key, value := range *req.Env { - envStrings = append(envStrings, key+"="+value) - } - command.Env = envStrings - } - - if req.Cwd != nil && *req.Cwd != "" { - command.Dir = *req.Cwd - } - - if req.ExtraFiles != nil { - command.ExtraFiles = *req.ExtraFiles - } - - command.Stdout = req.StdoutWriter - command.Stderr = req.StderrWriter - - command.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} - - err := command.Start() - - if err != nil { - return err - // TODO Use supevisor specific error - } - - pid := command.Process.Pid - termination := make(chan struct{}) - s.processMapLock.Lock() - s.processMap[req.Name] = process{ - pid: pid, - termination: termination, - } - s.processMapLock.Unlock() - - // The first freeze thaw cycle starts on Exec() at init time - s.freezeThawCycleStart = time.Now() - - go func() { - err = command.Wait() - // close the termination channel to unblock whoever's blocked on - // it (used to implement kill's blocking behaviour) - close(termination) - - var cell int32 - var exitStatus *int32 - var signo *int32 - var exitErr *exec.ExitError - - if err == nil { - exitStatus = &cell - } else if errors.As(err, &exitErr) { - if status, ok := exitErr.Sys().(syscall.WaitStatus); ok { - if code := status.ExitStatus(); code >= 0 { - cell = int32(code) - exitStatus = &cell - } else { - cell = int32(status.Signal()) - signo = &cell - } - } - } - - if signo == nil && exitStatus == nil { - log.Error("Cannot convert process exit status to unix WaitStatus. This is unexpected. Assuming ExitStatus 1") - cell = 1 - exitStatus = &cell - } - s.events <- model.Event{ - Time: uint64(time.Now().UnixMilli()), - Event: model.EventData{ - Domain: &req.Domain, - Name: &req.Name, - Signo: signo, - ExitStatus: exitStatus, - }, - } - }() - - return nil -} - -func kill(p process, name string, deadline time.Time) error { - // kill should report success if the process terminated by the time - //supervisor receives the request. - select { - // if this case is selected, the channel is closed, - // which means the process is terminated - case <-p.termination: - log.Debugf("Process %s already terminated.", name) - return nil - default: - log.Infof("Sending SIGKILL to %s(%d).", name, p.pid) - } - - if (time.Since(deadline)) > 0 { - return fmt.Errorf("invalid timeout while killing %s", name) - } - - pgid, err := syscall.Getpgid(p.pid) - - if err == nil { - // Negative pid sends signal to all in process group - syscall.Kill(-pgid, syscall.SIGKILL) - } else { - syscall.Kill(p.pid, syscall.SIGKILL) - } - - ctx, cancel := context.WithDeadline(context.Background(), deadline) - defer cancel() - - // block until the (main) process exits - // or the timeout fires - select { - case <-p.termination: - return nil - case <-ctx.Done(): - return fmt.Errorf("timed out while trying to SIGKILL %s", name) - } -} - -func (s *LocalSupervisor) Kill(ctx context.Context, req *model.KillRequest) error { - if req.Domain != "runtime" { - log.Debug("Kill is a no op if domain != runtime") - return nil - } - s.processMapLock.Lock() - process, ok := s.processMap[req.Name] - s.processMapLock.Unlock() - if !ok { - msg := "Unknown process" - return &model.SupervisorError{ - Kind: model.NoSuchEntity, - Message: &msg, - } - } - - return kill(process, req.Name, req.Deadline) -} - -func (s *LocalSupervisor) Terminate(ctx context.Context, req *model.TerminateRequest) error { - if req.Domain != "runtime" { - log.Debug("Terminate is no op if domain != runtime") - return nil - } - s.processMapLock.Lock() - process, ok := s.processMap[req.Name] - pid := process.pid - s.processMapLock.Unlock() - if !ok { - msg := "Unknown process" - err := &model.SupervisorError{ - Kind: model.NoSuchEntity, - Message: &msg, - } - log.WithError(err).Errorf("Process %s not found in local supervisor map", req.Name) - return err - } - - pgid, err := syscall.Getpgid(pid) - - if err == nil { - // Negative pid sends signal to all in process group - // best effort, ignore errors - _ = syscall.Kill(-pgid, syscall.SIGTERM) - } else { - _ = syscall.Kill(pid, syscall.SIGTERM) - } - - return nil -} - -func (s *LocalSupervisor) Stop(ctx context.Context, req *model.StopRequest) (*model.StopResponse, error) { - if req.Domain != "runtime" { - log.Debug("Shutdown is no op if domain != runtime") - return &model.StopResponse{}, nil - } - - // shut down kills all the processes in the map - s.processMapLock.Lock() - defer s.processMapLock.Unlock() - - nprocs := len(s.processMap) - - successes := make(chan struct{}) - errors := make(chan error) - for name, proc := range s.processMap { - go func(n string, p process) { - log.Debugf("Killing %s", n) - err := kill(p, n, req.Deadline) - if err != nil { - errors <- err - } else { - successes <- struct{}{} - } - - }(name, proc) - } - - var err error - for i := 0; i < nprocs; i++ { - select { - case <-successes: - case e := <-errors: - if err == nil { - err = fmt.Errorf("shutdown failed: %s", e.Error()) - } - } - - } - - s.processMap = make(map[string]process) - return nil, err -} - -func (s *LocalSupervisor) Freeze(ctx context.Context, req *model.FreezeRequest) (*model.FreezeResponse, error) { - // We return mocked freeze/thaw cycle metrics to mimic usage metrics in standalone mode - var m runtime.MemStats - runtime.ReadMemStats(&m) - return &model.FreezeResponse{ - CycleDeltaMetrics: model.CycleDeltaMetrics{ - DomainCPURunNs: uint64(time.Since(s.freezeThawCycleStart).Nanoseconds()), - DomainRunNs: uint64(time.Since(s.freezeThawCycleStart).Nanoseconds()), - DomainMaxMemoryUsageBytes: m.Alloc, - MicrovmCPURunNs: uint64(time.Since(s.freezeThawCycleStart).Nanoseconds()), - }, - }, nil -} -func (s *LocalSupervisor) Thaw(ctx context.Context, req *model.ThawRequest) error { - s.freezeThawCycleStart = time.Now() - return nil -} -func (s *LocalSupervisor) Ping(ctx context.Context) error { - return nil -} - -func (s *LocalSupervisor) Events(ctx context.Context, req *model.EventsRequest) (<-chan model.Event, error) { - return s.events, nil -} diff --git a/lambda/supervisor/local_supervisor_test.go b/lambda/supervisor/local_supervisor_test.go deleted file mode 100644 index 02a06f6..0000000 --- a/lambda/supervisor/local_supervisor_test.go +++ /dev/null @@ -1,238 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package supervisor - -import ( - "context" - "errors" - "fmt" - "syscall" - "testing" - "time" - - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.amzn.com/lambda/supervisor/model" -) - -func TestRuntimeDomainExec(t *testing.T) { - supv := NewLocalSupervisor() - err := supv.Exec(context.Background(), &model.ExecRequest{ - Domain: "runtime", - Name: "agent", - Path: "/bin/bash", - }) - - assert.Nil(t, err) -} - -func TestInvalidRuntimeDomainExec(t *testing.T) { - supv := NewLocalSupervisor() - err := supv.Exec(context.Background(), &model.ExecRequest{ - Domain: "runtime", - Name: "agent", - Path: "/bin/none", - }) - - require.Error(t, err) -} - -func TestEvents(t *testing.T) { - supv := NewLocalSupervisor() - sync := make(chan struct{}) - go func() { - eventCh, err := supv.Events(context.Background(), &model.EventsRequest{ - Domain: "runtime", - }) - require.NoError(t, err) - - evt, ok := <-eventCh - require.True(t, ok) - termination := evt.Event.ProcessTerminated() - require.NotNil(t, termination) - assert.Equal(t, "runtime", *termination.Domain) - assert.Equal(t, "agent", *termination.Name) - sync <- struct{}{} - }() - - err := supv.Exec(context.Background(), &model.ExecRequest{ - Domain: "runtime", - Name: "agent", - Path: "/bin/bash", - }) - require.NoError(t, err) - <-sync -} - -func TestTerminate(t *testing.T) { - supv := NewLocalSupervisor() - err := supv.Exec(context.Background(), &model.ExecRequest{ - Domain: "runtime", - Name: "agent", - Path: "/bin/bash", - Args: []string{"-c", "sleep 10s"}, - }) - require.NoError(t, err) - time.Sleep(100 * time.Millisecond) - err = supv.Terminate(context.Background(), &model.TerminateRequest{ - Domain: "runtime", - Name: "agent", - }) - require.NoError(t, err) - // wait for process exit notification - eventCh, err := supv.Events(context.Background(), &model.EventsRequest{ - Domain: "runtime", - }) - require.NoError(t, err) - ev := <-eventCh - - require.NotNil(t, ev.Event.ProcessTerminated()) - term := *ev.Event.ProcessTerminated() - require.Nil(t, term.Exited()) - require.NotNil(t, term.Signaled()) - require.EqualValues(t, syscall.SIGTERM, *term.Signo) -} - -// Termiante should not fail if the message is not delivered -func TestTerminateExited(t *testing.T) { - supv := NewLocalSupervisor() - err := supv.Exec(context.Background(), &model.ExecRequest{ - Domain: "runtime", - Name: "agent", - Path: "/bin/bash", - }) - require.NoError(t, err) - // wait a short bit for bash to exit - time.Sleep(100 * time.Millisecond) - err = supv.Terminate(context.Background(), &model.TerminateRequest{ - Domain: "runtime", - Name: "agent", - }) - require.NoError(t, err) -} - -func TestKill(t *testing.T) { - supv := NewLocalSupervisor() - err := supv.Exec(context.Background(), &model.ExecRequest{ - Domain: "runtime", - Name: "agent", - Path: "/bin/bash", - Args: []string{"-c", "sleep 10s"}, - }) - require.NoError(t, err) - err = supv.Kill(context.Background(), &model.KillRequest{ - Domain: "runtime", - Name: "agent", - Deadline: time.Now().Add(time.Second), - }) - require.NoError(t, err) - timer := time.NewTimer(50 * time.Millisecond) - eventCh, err := supv.Events(context.Background(), &model.EventsRequest{ - Domain: "runtime", - }) - require.NoError(t, err) - - select { - case _, ok := <-eventCh: - assert.True(t, ok) - case <-timer.C: - require.Fail(t, "Process should have exited by the time kill returns") - } -} - -func TestKillExited(t *testing.T) { - supv := NewLocalSupervisor() - err := supv.Exec(context.Background(), &model.ExecRequest{ - Domain: "runtime", - Name: "agent", - Path: "/bin/bash", - }) - require.NoError(t, err) - //wait for natural exit event - eventCh, err := supv.Events(context.Background(), &model.EventsRequest{ - Domain: "runtime", - }) - require.NoError(t, err) - <-eventCh - err = supv.Kill(context.Background(), &model.KillRequest{ - Domain: "runtime", - Name: "agent", - Deadline: time.Now().Add(time.Second), - }) - require.NoError(t, err, "Kill should succeed for exited processes") -} - -func TestKillUnknown(t *testing.T) { - supv := NewLocalSupervisor() - err := supv.Kill(context.Background(), &model.KillRequest{ - Domain: "runtime", - Name: "unknown", - Deadline: time.Now().Add(time.Second), - }) - require.Error(t, err) - var supvError *model.SupervisorError - assert.True(t, errors.As(err, &supvError)) - assert.Equal(t, supvError.Kind, model.NoSuchEntity) -} - -func TestShutdown(t *testing.T) { - supv := NewLocalSupervisor() - log.Debug("hello") - // start a bunch of processes, some short running, some longer running - err := supv.Exec(context.Background(), &model.ExecRequest{ - Domain: "runtime", - Name: "agent-0", - Path: "/bin/bash", - Args: []string{"-c", "sleep 1s"}, - }) - require.NoError(t, err) - - err = supv.Exec(context.Background(), &model.ExecRequest{ - Domain: "runtime", - Name: "agent-1", - Path: "/bin/bash", - }) - require.NoError(t, err) - - err = supv.Exec(context.Background(), &model.ExecRequest{ - Domain: "runtime", - Name: "agent-2", - Path: "/bin/bash", - Args: []string{"-c", "sleep 2s"}, - }) - require.NoError(t, err) - time.Sleep(100 * time.Millisecond) - _, err = supv.Stop(context.Background(), &model.StopRequest{ - Domain: "runtime", - Deadline: time.Now().Add(time.Second), - }) - require.NoError(t, err) - // Shutdown is expected to block untill all processes have exited - expected := map[string]struct{}{ - "agent-0": {}, - "agent-1": {}, - "agent-2": {}, - } - done := false - timer := time.NewTimer(200 * time.Millisecond) - eventCh, err := supv.Events(context.Background(), &model.EventsRequest{ - Domain: "runtime", - }) - require.NoError(t, err) - for !done { - select { - case ev := <-eventCh: - data := ev.Event.ProcessTerminated() - assert.NotNil(t, data) - _, ok := expected[*data.Name] - assert.True(t, ok) - delete(expected, *data.Name) - case <-timer.C: - fmt.Print(expected) - assert.Equal(t, 0, len(expected), "All process should terminate at shutdown") - done = true - } - } -} diff --git a/lambda/supervisor/model/model.go b/lambda/supervisor/model/model.go deleted file mode 100644 index d89ec18..0000000 --- a/lambda/supervisor/model/model.go +++ /dev/null @@ -1,376 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package model - -import ( - "context" - "encoding/json" - "fmt" - "io" - "os" - "syscall" - "time" -) - -// Start, Stop and Configure methods are not used in Core anymore. -// Client interface splitted into Launcher and Executer parts for backward compatibility of dependent packages. -type ContainerSupervisor interface { - Start(context.Context, *StartRequest) error - Configure(context.Context, *ConfigureRequest) error - Stop(context.Context, *StopRequest) (*StopResponse, error) - Freeze(context.Context, *FreezeRequest) (*FreezeResponse, error) - Thaw(context.Context, *ThawRequest) error - Exit(context.Context) -} - -type ProcessSupervisor interface { - Exec(context.Context, *ExecRequest) error - Terminate(context.Context, *TerminateRequest) error - Kill(context.Context, *KillRequest) error - Events(context.Context, *EventsRequest) (<-chan Event, error) -} - -type SupervisorClient interface { - ContainerSupervisor - ProcessSupervisor - Ping(ctx context.Context) error -} - -type StartRequest struct { - Domain string `json:"domain"` -} - -type Mount struct { - DriveMount DriveMount - BindMount BindMount - MountType MountType -} - -type MountType int - -const ( - _ MountType = iota - MountTypeDrive - MountTypeBind -) - -type CgroupProfileName string - -const ( - Throttled CgroupProfileName = "throttled" - Unthrottled CgroupProfileName = "unthrottled" -) - -func (m *Mount) MarshalJSON() ([]byte, error) { - switch m.MountType { - case MountTypeDrive: - return m.DriveMount.MarshalJSON() - case MountTypeBind: - return m.BindMount.MarshalJSON() - default: - return nil, fmt.Errorf("invalid mount type: %v", m.MountType) - } -} - -// Mount in lockhard::mnt is a Rust enum, an algebraic type, where each case has different set of fields. -// This models only the Mount::Drive case, the only one we need for now. -type DriveMount struct { - Source string `json:"source,omitempty"` - Destination string `json:"destination,omitempty"` - FsType string `json:"fs_type,omitempty"` - Options []string `json:"options,omitempty"` - Chowner []uint32 `json:"chowner,omitempty"` // array of two integers representing a tuple - Chmode uint32 `json:"chmode,omitempty"` - // Lockhard also expects a "type" field here, which in our case is constant, so we provide it upon serialization below -} - -// Adds the "type": "drive" to json -func (m *DriveMount) MarshalJSON() ([]byte, error) { - type driveMountAlias DriveMount - - return json.Marshal(&struct { - Type string `json:"type,omitempty"` - *driveMountAlias - }{ - Type: "drive", - driveMountAlias: (*driveMountAlias)(m), - }) -} - -type BindMount struct { - Source string `json:"source,omitempty"` - Destination string `json:"destination,omitempty"` - Options []string `json:"options,omitempty"` -} - -func (m *BindMount) MarshalJSON() ([]byte, error) { - type bindMountAlias BindMount - - return json.Marshal(&struct { - Type string `json:"type,omitempty"` - *bindMountAlias - }{ - Type: "bind", - bindMountAlias: (*bindMountAlias)(m), - }) -} - -type Capabilities struct { - Ambient []string `json:"ambient,omitempty"` - Bounding []string `json:"bounding,omitempty"` - Effective []string `json:"effective,omitempty"` - Inheritable []string `json:"inheritable,omitempty"` - Permitted []string `json:"permitted,omitempty"` -} - -type CgroupProfiles struct { - Throttled CgroupProfileConfig `json:"throttled"` - Unthrottled CgroupProfileConfig `json:"unthrottled"` -} - -type CgroupProfileConfig struct { - CPULimit float64 `json:"cpu_limit"` - MemoryLimitBytes uint64 `json:"memory_limit_bytes"` -} - -type ExecUser struct { - UID *uint32 `json:"uid"` - GID *uint32 `json:"gid"` -} - -type ConfigureRequest struct { - // domain to configure - Domain string `json:"domain"` - Mounts []Mount `json:"mounts,omitempty"` - Capabilities *Capabilities `json:"capabilities,omitempty"` - SeccompFilters []string `json:"seccomp_filters,omitempty"` - // list of cgroup profiles available for the domain - // cgroup profiles are set on start and thaw request. Start profile - // if configured (as it can vary), thaw profile is always the same (throttled) - CgroupProfiles *CgroupProfiles `json:"cgroup_profiles,omitempty"` - // name of the cgroup profile to enforce at domain start - StartProfile CgroupProfileName `json:"start_profile,omitempty"` - // uid and gid of the user the spawned process runs as (w.r.t. the domain user namespace). - // If nil, Supervisor will use the ExecUser specified in the domain configuration file - ExecUser *ExecUser `json:"exec_user,omitempty"` - // additional hooks to execute on domain start - AdditionalStartHooks []Hook `json:"additional_start_hooks,omitempty"` -} - -type EventsRequest struct { - Domain string `json:"domain"` -} - -type Event struct { - Time uint64 `json:"timestamp_millis"` - Event EventData `json:"event"` -} - -// EventData is a union type tagged by the "EventType" -// and "Cause" strings. -// you can use ProcessTermination() or EventLoss() to access -// the correct type of Event. -type EventData struct { - EvType string `json:"type"` - Domain *string `json:"domain"` - Name *string `json:"name"` - Cause *string `json:"cause"` - Signo *int32 `json:"signo"` - ExitStatus *int32 `json:"exit_status"` - Size *uint64 `json:"size"` -} - -// returns nil if the event is not a EventLoss event -// otherwise returns how many events were lost due to -// backpressure (slow reader) -func (d EventData) EventLoss() *uint64 { - return d.Size -} - -// Returns a ProcessTermination struct that describe the process -// which terminated. Use Signaled() or Exited() to check whether -// the process terminated because of a signal or exited on its own -func (d EventData) ProcessTerminated() *ProcessTermination { - if d.Signo != nil || d.ExitStatus != nil { - return &ProcessTermination{ - Domain: d.Domain, - Name: d.Name, - Signo: d.Signo, - ExitStatus: d.ExitStatus, - } - } - return nil -} - -// Event signalling that a process exited -type ProcessTermination struct { - Domain *string - Name *string - Signo *int32 - ExitStatus *int32 -} - -// If not nil, the process was terminated by an unhandled signal. -// The returned value is the number of the signal that terminated the process -func (t ProcessTermination) Signaled() *int32 { - return t.Signo -} - -// It not nil, the process exited (as opposed to killed by a signal). -// The returned value is the exit_status returned by the process -func (t ProcessTermination) Exited() *int32 { - return t.ExitStatus -} - -func (t ProcessTermination) Success() bool { - return t.ExitStatus != nil && *t.ExitStatus == 0 -} - -// Transform the process termination status in a string that -// is equal to what would be returned by golang exec.ExitError.Error() -// We used to rely on this format to report errors to customer (sigh) -// so we keep this for backwards compatibility -func (t ProcessTermination) String() string { - if t.ExitStatus != nil { - return fmt.Sprintf("exit status %d", *t.ExitStatus) - } - sig := syscall.Signal(*t.Signo) - return fmt.Sprintf("signal: %s", sig.String()) -} - -type Hook struct { - // Unique name identifying the hook - Name string `json:"name"` - // Path in the parent domain mount namespace that locates - // the executable to run as the hook - Path string `json:"path"` - // Args for the hook - Args []string `json:"args,omitempty"` - // Map of ENV variables to set when running the hook - Env *map[string]string `json:"envs,omitempty"` -} - -type ExecRequest struct { - // Identifier that Supervisor will assign to the spawned process. - // The tuple (Domain,Name) must be unique. It is the caller's responsibility - // to generate the unique name - Name string `json:"name"` - Domain string `json:"domain"` - // Path pointing to the exectuable file within the domain's root filesystem - Path string `json:"path"` - Args []string `json:"args,omitempty"` - // If nil, root of the domain - Cwd *string `json:"cwd,omitempty"` - Env *map[string]string `json:"env,omitempty"` - Logging Logging `json:"log_config"` - StdoutWriter io.Writer `json:"-"` - StderrWriter io.Writer `json:"-"` - ExtraFiles *[]*os.File `json:"-"` -} - -// Logging specifies where Supervisor should send Command's logs to -type Logging struct { - Managed ManagedLogging `json:"managed"` -} - -type ManagedLogging struct { - Topic ManagedLoggingTopic `json:"topic"` - Formats []ManagedLoggingFormat `json:"formats"` -} - -type ManagedLoggingTopic string - -const ( - RuntimeManagedLoggingTopic ManagedLoggingTopic = "runtime" - RtExtensionManagedLoggingTopic ManagedLoggingTopic = "runtime_extension" -) - -type ManagedLoggingFormat string - -const ( - LineBasedManagedLogging ManagedLoggingFormat = "line" - MessageBasedManagedLogging ManagedLoggingFormat = "message" -) - -type ErrorKind string - -const ( - // operation on an unkown entity (e.g., domain process) - NoSuchEntity ErrorKind = "no_such_entity" - // operation not allowed in the current state (e.g., tried to exec a proces in a domain which is not booted) - InvalidState ErrorKind = "invalid_state" - // Serialization or derserialization issue in the communication - Serde ErrorKind = "serde" - // Unhandled Supervisor server error - Failure ErrorKind = "failure" -) - -type SupervisorError struct { - Kind ErrorKind `json:"error_kind"` - Message *string `json:"message"` -} - -func (e *SupervisorError) Error() string { - return string(e.Kind) -} - -// Send SIGETERM asynchrnously to a process -type TerminateRequest struct { - Name string `json:"name"` - Domain string `json:"domain"` -} - -// Force terminate a process (SIGKILL) -// Block until process is exited or timeout -// Deadline needs to be in the future -type KillRequest struct { - Name string `json:"name"` - Domain string `json:"domain"` - Deadline time.Time `json:"deadline"` -} - -// Stop the domain. -type StopRequest struct { - Domain string `json:"domain"` - Deadline time.Time `json:"deadline"` -} - -type StopResponse struct { - CycleDeltaMetrics CycleDeltaMetrics `json:"cycle_delta_metrics"` -} - -type FreezeRequest struct { - Domain string `json:"domain"` -} - -type FreezeResponse struct { - CycleDeltaMetrics CycleDeltaMetrics `json:"cycle_delta_metrics"` -} - -type MicrovmNetworkInterfaceMetrics struct { - ReceivedBytes uint64 `json:"received_bytes"` - TransmittedBytes uint64 `json:"transmitted_bytes"` -} - -type CycleDeltaMetrics struct { - // CPU time (in nanoseconds) obtained by domain cgroup from cpuacct.usage - // https://www.kernel.org/doc/Documentation/cgroup-v1/cpuacct.txt - DomainCPURunNs uint64 `json:"domain_cpu_run_ns"` - // time (in nanoseconds) for domain cycle - DomainRunNs uint64 `json:"domain_run_ns"` - // CPU delta time for service cgroup - ServiceCPURunNs uint64 `json:"service_cpu_run_ns"` - // Maximum memory used (in bytes) for domain - DomainMaxMemoryUsageBytes uint64 `json:"domain_max_memory_usage_bytes"` - // CPU delta time (in nanoseconds) obtained from /sys/fs/cgroup/cpu,cpuacct/cpuacct.usage - MicrovmCPURunNs uint64 `json:"microvm_cpu_run_ns"` - // Map with network interface name as key and network metrics as a value - MicrovmNetworksBytes map[string]MicrovmNetworkInterfaceMetrics `json:"microvm_network_interfaces"` - // time ( in nanoseconds ) for idle cpu time - InvokeIdleCPURunNs uint64 `json:"idle_cpu_run_ns"` -} - -type ThawRequest struct { - Domain string `json:"domain"` -} diff --git a/lambda/supervisor/model/model_test.go b/lambda/supervisor/model/model_test.go deleted file mode 100644 index ea39580..0000000 --- a/lambda/supervisor/model/model_test.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package model - -import ( - "encoding/json" - "testing" - "time" -) - -// LockHard accepts deadlines encoded as RFC3339 - we enforce this with a test -func Test_KillDeadlineIsMarshalledIntoRFC3339(t *testing.T) { - deadline, err := time.Parse(time.RFC3339, "2022-12-21T10:00:00Z") - if err != nil { - t.Error(err) - } - k := KillRequest{ - Name: "", - Domain: "", - Deadline: deadline, - } - bytes, err := json.Marshal(k) - if err != nil { - t.Error(err) - } - exepected := `{"name":"","domain":"","deadline":"2022-12-21T10:00:00Z"}` - if string(bytes) != exepected { - t.Errorf("error in marshaling `KillRequest` it does not match the expected string (Expected(%q) != Got(%q))", exepected, string(bytes)) - } -} diff --git a/lambda/telemetry/constants.go b/lambda/telemetry/constants.go deleted file mode 100644 index 0198660..0000000 --- a/lambda/telemetry/constants.go +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package telemetry - -import "errors" - -const ( - // Metrics - SubscribeSuccess = "logs_api_subscribe_success" - SubscribeClientErr = "logs_api_subscribe_client_err" - SubscribeServerErr = "logs_api_subscribe_server_err" - NumSubscribers = "logs_api_num_subscribers" -) - -// ErrTelemetryServiceOff returned on attempt to subscribe after telemetry service has been turned off. -var ErrTelemetryServiceOff = errors.New("ErrTelemetryServiceOff") diff --git a/lambda/telemetry/events_api.go b/lambda/telemetry/events_api.go deleted file mode 100644 index 371f439..0000000 --- a/lambda/telemetry/events_api.go +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package telemetry - -import ( - "fmt" - "time" - - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/metering" -) - -func GetRuntimeDoneInvokeMetrics(runtimeStartedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics, runtimeDoneTime int64) *interop.RuntimeDoneInvokeMetrics { - // time taken from sending the invoke to the sandbox until the runtime calls GET /next - duration := CalculateDuration(runtimeStartedTime, runtimeDoneTime) - if invokeResponseMetrics != nil && invokeResponseMetrics.RuntimeCalledResponse && runtimeStartedTime != -1 { - return &interop.RuntimeDoneInvokeMetrics{ - ProducedBytes: invokeResponseMetrics.ProducedBytes, - DurationMs: duration, - } - } - - // when we get a reset before runtime called /response - if runtimeStartedTime != -1 { - return &interop.RuntimeDoneInvokeMetrics{ - ProducedBytes: int64(0), - DurationMs: duration, - } - } - - // We didn't have time to register the invokeReceiveTime, which means we crash/reset very early, - // too early for the runtime to actual run. In such case, the runtimeDone event shouldn't be sent - // Not returning Nil even in this improbable case guarantees that we will always have some metrics to send to FluxPump - return &interop.RuntimeDoneInvokeMetrics{ - ProducedBytes: int64(0), - DurationMs: float64(0), - } -} - -const ( - InitInsideInitPhase interop.InitPhase = "init" - InitInsideInvokePhase interop.InitPhase = "invoke" -) - -func InitPhaseFromLifecyclePhase(phase interop.LifecyclePhase) (interop.InitPhase, error) { - switch phase { - case interop.LifecyclePhaseInit: - return InitInsideInitPhase, nil - case interop.LifecyclePhaseInvoke: - return InitInsideInvokePhase, nil - default: - return interop.InitPhase(""), fmt.Errorf("unexpected lifecycle phase: %v", phase) - } -} - -func GetRuntimeDoneSpans(runtimeStartedTime int64, invokeResponseMetrics *interop.InvokeResponseMetrics) []interop.Span { - if invokeResponseMetrics != nil && invokeResponseMetrics.RuntimeCalledResponse && runtimeStartedTime != -1 { - // time span from when the invoke is received in the sandbox to the moment the runtime calls PUT /response - responseLatencyMsSpan := interop.Span{ - Name: "responseLatency", - Start: GetEpochTimeInISO8601FormatFromMonotime(runtimeStartedTime), - DurationMs: CalculateDuration(runtimeStartedTime, invokeResponseMetrics.StartReadingResponseMonoTimeMs), - } - - // time span from when the runtime called PUT /response to the moment the body of the response is fully sent - responseDurationMsSpan := interop.Span{ - Name: "responseDuration", - Start: GetEpochTimeInISO8601FormatFromMonotime(invokeResponseMetrics.StartReadingResponseMonoTimeMs), - DurationMs: CalculateDuration(invokeResponseMetrics.StartReadingResponseMonoTimeMs, invokeResponseMetrics.FinishReadingResponseMonoTimeMs), - } - return []interop.Span{responseLatencyMsSpan, responseDurationMsSpan} - } - - return []interop.Span{} -} - -// CalculateDuration calculates duration between two moments. -// The result is milliseconds with microsecond precision. -// Two assumptions here: -// 1. the passed values are nanoseconds -// 2. endNs > startNs -func CalculateDuration(startNs, endNs int64) float64 { - microseconds := int64(endNs-startNs) / int64(time.Microsecond) - return float64(microseconds) / 1000 -} - -const ( - InitTypeOnDemand interop.InitType = "on-demand" - InitTypeProvisionedConcurrency interop.InitType = "provisioned-concurrency" - InitTypeInitCaching interop.InitType = "snap-start" -) - -func InferInitType(initCachingEnabled bool, sandboxType interop.SandboxType) interop.InitType { - initSource := InitTypeOnDemand - - // ToDo: Unify this selection of SandboxType by using the START message - // after having a roadmap on the combination of INIT modes - if initCachingEnabled { - initSource = InitTypeInitCaching - } else if sandboxType == interop.SandboxPreWarmed { - initSource = InitTypeProvisionedConcurrency - } - - return initSource -} - -func GetEpochTimeInISO8601FormatFromMonotime(monotime int64) string { - return time.Unix(0, metering.MonoToEpoch(monotime)).Format("2006-01-02T15:04:05.000Z") -} - -const ( - RuntimeDoneSuccess = "success" - RuntimeDoneError = "error" -) - -type NoOpEventsAPI struct{} - -func (s *NoOpEventsAPI) SetCurrentRequestID(interop.RequestID) {} - -func (s *NoOpEventsAPI) SendInitStart(interop.InitStartData) error { return nil } - -func (s *NoOpEventsAPI) SendInitRuntimeDone(interop.InitRuntimeDoneData) error { return nil } - -func (s *NoOpEventsAPI) SendInitReport(interop.InitReportData) error { return nil } - -func (s *NoOpEventsAPI) SendRestoreRuntimeDone(interop.RestoreRuntimeDoneData) error { return nil } - -func (s *NoOpEventsAPI) SendInvokeStart(interop.InvokeStartData) error { return nil } - -func (s *NoOpEventsAPI) SendInvokeRuntimeDone(interop.InvokeRuntimeDoneData) error { return nil } - -func (s *NoOpEventsAPI) SendExtensionInit(interop.ExtensionInitData) error { return nil } - -func (s *NoOpEventsAPI) SendEnd(interop.EndData) error { return nil } - -func (s *NoOpEventsAPI) SendReportSpan(interop.Span) error { return nil } - -func (s *NoOpEventsAPI) SendReport(interop.ReportData) error { return nil } - -func (s *NoOpEventsAPI) SendFault(interop.FaultData) error { return nil } - -func (s *NoOpEventsAPI) SendImageErrorLog(interop.ImageErrorLogData) {} - -func (s *NoOpEventsAPI) FetchTailLogs(string) (string, error) { return "", nil } - -func (s *NoOpEventsAPI) GetRuntimeDoneSpans( - runtimeStartedTime int64, - invokeResponseMetrics *interop.InvokeResponseMetrics, - runtimeOverheadStartedTime int64, - runtimeReadyTime int64, -) []interop.Span { - return []interop.Span{} -} diff --git a/lambda/telemetry/events_api_test.go b/lambda/telemetry/events_api_test.go deleted file mode 100644 index f69e4ea..0000000 --- a/lambda/telemetry/events_api_test.go +++ /dev/null @@ -1,212 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package telemetry - -import ( - "testing" - "time" - - "github.com/stretchr/testify/assert" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/metering" -) - -func TestGetRuntimeDoneInvokeMetrics(t *testing.T) { - now := metering.Monotime() - - runtimeStartedTime := now - invokeResponseMetrics := &interop.InvokeResponseMetrics{ - ProducedBytes: int64(100), - RuntimeCalledResponse: true, - } - runtimeDoneTime := now + int64(time.Millisecond*time.Duration(10)) - - expected := &interop.RuntimeDoneInvokeMetrics{ - ProducedBytes: int64(100), - DurationMs: float64(10), - } - - assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(runtimeStartedTime, invokeResponseMetrics, runtimeDoneTime)) -} - -func TestGetRuntimeDoneInvokeMetricsWhenRuntimeCalledError(t *testing.T) { - now := metering.Monotime() - - runtimeStartedTime := now - invokeResponseMetrics := &interop.InvokeResponseMetrics{ - ProducedBytes: int64(100), - RuntimeCalledResponse: false, - } - // validating microsecond precision - runtimeDoneTime := now + int64(time.Duration(10)*time.Millisecond+time.Duration(50)*time.Microsecond) - - expected := &interop.RuntimeDoneInvokeMetrics{ - ProducedBytes: int64(0), - DurationMs: float64(10.05), - } - - assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(runtimeStartedTime, invokeResponseMetrics, runtimeDoneTime)) -} - -func TestGetRuntimeDoneInvokeMetricsWhenRuntimeStartedTimeIsMinusOne(t *testing.T) { - now := int64(-1) - runtimeStartedTime := now - - runtimeDoneTime := now + int64(time.Millisecond*time.Duration(10)) - - expected := &interop.RuntimeDoneInvokeMetrics{ - ProducedBytes: int64(0), - DurationMs: float64(0), - } - actual := GetRuntimeDoneInvokeMetrics(runtimeStartedTime, nil, runtimeDoneTime) - assert.Equal(t, expected, actual) -} - -func TestGetRuntimeDoneInvokeMetricsWhenInvokeResponseMetricsIsNil(t *testing.T) { - now := metering.Monotime() - runtimeStartedTime := now - - runtimeDoneTime := now + int64(time.Millisecond*time.Duration(10)) - - expected := &interop.RuntimeDoneInvokeMetrics{ - ProducedBytes: int64(0), - DurationMs: float64(10), - } - - assert.Equal(t, expected, GetRuntimeDoneInvokeMetrics(runtimeStartedTime, nil, runtimeDoneTime)) -} - -func TestGetRuntimeDoneSpans(t *testing.T) { - now := metering.Monotime() - startReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(5)) - finishReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(7)) - - runtimeStartedTime := now - invokeResponseMetrics := &interop.InvokeResponseMetrics{ - StartReadingResponseMonoTimeMs: startReadingResponseMonoTimeMs, - FinishReadingResponseMonoTimeMs: finishReadingResponseMonoTimeMs, - RuntimeCalledResponse: true, - } - - expectedResponseLatencyMsStartTime := GetEpochTimeInISO8601FormatFromMonotime(now) - expectedResponseDurationMsStartTime := GetEpochTimeInISO8601FormatFromMonotime(startReadingResponseMonoTimeMs) - expected := []interop.Span{ - { - Name: "responseLatency", - Start: expectedResponseLatencyMsStartTime, - DurationMs: 5, - }, - { - Name: "responseDuration", - Start: expectedResponseDurationMsStartTime, - DurationMs: 2, - }, - } - - assert.Equal(t, expected, GetRuntimeDoneSpans(runtimeStartedTime, invokeResponseMetrics)) -} - -func TestGetRuntimeDoneSpansWhenRuntimeCalledError(t *testing.T) { - now := metering.Monotime() - startReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(5)) - finishReadingResponseMonoTimeMs := now + int64(time.Millisecond*time.Duration(7)) - - runtimeStartedTime := now - invokeResponseMetrics := &interop.InvokeResponseMetrics{ - StartReadingResponseMonoTimeMs: startReadingResponseMonoTimeMs, - FinishReadingResponseMonoTimeMs: finishReadingResponseMonoTimeMs, - RuntimeCalledResponse: false, - } - - assert.Equal(t, []interop.Span{}, GetRuntimeDoneSpans(runtimeStartedTime, invokeResponseMetrics)) -} - -func TestGetRuntimeDoneSpansWhenInvokeResponseMetricsNil(t *testing.T) { - runtimeStartedTime := metering.Monotime() - - assert.Equal(t, []interop.Span{}, GetRuntimeDoneSpans(runtimeStartedTime, nil)) -} - -func TestGetRuntimeDoneSpansWhenRuntimeStartedTimeIsMinusOne(t *testing.T) { - now := int64(-1) - runtimeStartedTime := now - invokeResponseMetrics := &interop.InvokeResponseMetrics{ - StartReadingResponseMonoTimeMs: now + int64(time.Millisecond*time.Duration(5)), - FinishReadingResponseMonoTimeMs: now + int64(time.Millisecond*time.Duration(7)), - } - - assert.Equal(t, []interop.Span{}, GetRuntimeDoneSpans(runtimeStartedTime, invokeResponseMetrics)) -} - -func TestInferInitType(t *testing.T) { - testCases := map[string]struct { - initCachingEnabled bool - sandboxType interop.SandboxType - expected interop.InitType - }{ - "on demand": { - initCachingEnabled: false, - sandboxType: interop.SandboxClassic, - expected: InitTypeOnDemand, - }, - "pc": { - initCachingEnabled: false, - sandboxType: interop.SandboxPreWarmed, - expected: InitTypeProvisionedConcurrency, - }, - "snap-start for OD": { - initCachingEnabled: true, - sandboxType: interop.SandboxClassic, - expected: InitTypeInitCaching, - }, - "snap-start for PC": { - initCachingEnabled: true, - sandboxType: interop.SandboxPreWarmed, - expected: InitTypeInitCaching, - }, - } - - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - initType := InferInitType(tc.initCachingEnabled, tc.sandboxType) - assert.Equal(t, tc.expected, initType) - }) - } -} - -func TestCalculateDuration(t *testing.T) { - testCases := map[string]struct { - start int64 - end int64 - expected float64 - }{ - "milliseconds only": { - start: int64(100 * time.Millisecond), - end: int64(120 * time.Millisecond), - expected: 20, - }, - "with microseconds": { - start: int64(100 * time.Millisecond), - end: int64(210*time.Millisecond + 65*time.Microsecond), - expected: 110.065, - }, - "nanoseconds must be dropped": { - start: int64(100 * time.Millisecond), - end: int64(140*time.Millisecond + 999*time.Nanosecond), - expected: 40, - }, - "microseconds presented, nanoseconds dropped": { - start: int64(100 * time.Millisecond), - end: int64(150*time.Millisecond + 2*time.Microsecond + 999*time.Nanosecond), - expected: 50.002, - }, - } - - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - actual := CalculateDuration(tc.start, tc.end) - assert.Equal(t, tc.expected, actual) - }) - } -} diff --git a/lambda/telemetry/logs_egress_api.go b/lambda/telemetry/logs_egress_api.go deleted file mode 100644 index f4da62d..0000000 --- a/lambda/telemetry/logs_egress_api.go +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package telemetry - -import ( - "io" - "os" -) - -// StdLogsEgressAPI is the interface that wraps the basic methods required to setup -// logs channels for Runtime's stdout/stderr and Extension's stdout/stderr. -// -// Implementation should return a Writer implementor for stdout and another for -// stderr on success and an error on failure. -type StdLogsEgressAPI interface { - GetExtensionSockets() (io.Writer, io.Writer, error) - GetRuntimeSockets() (io.Writer, io.Writer, error) -} - -type NoOpLogsEgressAPI struct{} - -func (s *NoOpLogsEgressAPI) GetExtensionSockets() (io.Writer, io.Writer, error) { - // os.Stderr can not be used for the stderrWriter because stderr is for internal logging (not customer visible). - return os.Stdout, os.Stdout, nil -} - -func (s *NoOpLogsEgressAPI) GetRuntimeSockets() (io.Writer, io.Writer, error) { - // os.Stderr can not be used for the stderrWriter because stderr is for internal logging (not customer visible). - return os.Stdout, os.Stdout, nil -} - -var _ StdLogsEgressAPI = (*NoOpLogsEgressAPI)(nil) diff --git a/lambda/telemetry/logs_subscription_api.go b/lambda/telemetry/logs_subscription_api.go deleted file mode 100644 index 2fa39f0..0000000 --- a/lambda/telemetry/logs_subscription_api.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package telemetry - -import ( - "io" - "net/http" - - "go.amzn.com/lambda/interop" -) - -// SubscriptionAPI represents interface that implementations of Telemetry API have to satisfy to be RAPID-compatible -type SubscriptionAPI interface { - Subscribe(agentName string, body io.Reader, headers map[string][]string, remoteAddr string) (resp []byte, status int, respHeaders map[string][]string, err error) - RecordCounterMetric(metricName string, count int) - FlushMetrics() interop.TelemetrySubscriptionMetrics - Clear() - TurnOff() - GetEndpointURL() string - GetServiceClosedErrorMessage() string - GetServiceClosedErrorType() string -} - -type NoOpSubscriptionAPI struct{} - -// Subscribe writes response to a shared memory -func (m *NoOpSubscriptionAPI) Subscribe(agentName string, body io.Reader, headers map[string][]string, remoteAddr string) ([]byte, int, map[string][]string, error) { - return []byte(`{}`), http.StatusOK, map[string][]string{}, nil -} - -func (m *NoOpSubscriptionAPI) RecordCounterMetric(metricName string, count int) {} - -func (m *NoOpSubscriptionAPI) FlushMetrics() interop.TelemetrySubscriptionMetrics { - return interop.TelemetrySubscriptionMetrics(map[string]int{}) -} - -func (m *NoOpSubscriptionAPI) Clear() {} - -func (m *NoOpSubscriptionAPI) TurnOff() {} - -func (m *NoOpSubscriptionAPI) GetEndpointURL() string { return "" } - -func (m *NoOpSubscriptionAPI) GetServiceClosedErrorMessage() string { return "" } - -func (m *NoOpSubscriptionAPI) GetServiceClosedErrorType() string { return "" } diff --git a/lambda/telemetry/tracer.go b/lambda/telemetry/tracer.go deleted file mode 100644 index 889682b..0000000 --- a/lambda/telemetry/tracer.go +++ /dev/null @@ -1,152 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package telemetry - -import ( - "context" - "encoding/json" - "fmt" - "strings" - - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapi/model" -) - -type traceContextKey int - -const ( - TraceIDKey traceContextKey = iota - DocumentIDKey -) - -type Tracer interface { - Configure(invoke *interop.Invoke) - CaptureInvokeSegment(ctx context.Context, criticalFunction func(context.Context) error) error - CaptureInitSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error - CaptureInvokeSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error - CaptureOverheadSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error - RecordInitStartTime() - RecordInitEndTime() - SendInitSubsegmentWithRecordedTimesOnce(ctx context.Context) - SendRestoreSubsegmentWithRecordedTimesOnce(ctx context.Context) - MarkError(ctx context.Context) - AttachErrorCause(ctx context.Context, errorCause json.RawMessage) - WithErrorCause(ctx context.Context, appCtx appctx.ApplicationContext, criticalFunction func(ctx context.Context) error) func(ctx context.Context) error - WithError(ctx context.Context, appCtx appctx.ApplicationContext, criticalFunction func(ctx context.Context) error) func(ctx context.Context) error - BuildTracingHeader() func(context.Context) string - BuildTracingCtxForStart() *interop.TracingCtx - BuildTracingCtxAfterInvokeComplete() *interop.TracingCtx -} - -type NoOpTracer struct{} - -func (t *NoOpTracer) Configure(invoke *interop.Invoke) {} - -func (t *NoOpTracer) CaptureInvokeSegment(ctx context.Context, criticalFunction func(context.Context) error) error { - return criticalFunction(ctx) -} - -func (t *NoOpTracer) CaptureInitSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { - return criticalFunction(ctx) -} - -func (t *NoOpTracer) CaptureInvokeSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { - return criticalFunction(ctx) -} - -func (t *NoOpTracer) CaptureOverheadSubsegment(ctx context.Context, criticalFunction func(context.Context) error) error { - return criticalFunction(ctx) -} - -func (t *NoOpTracer) RecordInitStartTime() {} -func (t *NoOpTracer) RecordInitEndTime() {} -func (t *NoOpTracer) SendInitSubsegmentWithRecordedTimesOnce(ctx context.Context) {} -func (t *NoOpTracer) SendRestoreSubsegmentWithRecordedTimesOnce(ctx context.Context) {} -func (t *NoOpTracer) MarkError(ctx context.Context) {} -func (t *NoOpTracer) AttachErrorCause(ctx context.Context, errorCause json.RawMessage) {} - -func (t *NoOpTracer) WithErrorCause(ctx context.Context, appCtx appctx.ApplicationContext, criticalFunction func(ctx context.Context) error) func(ctx context.Context) error { - return criticalFunction -} -func (t *NoOpTracer) WithError(ctx context.Context, appCtx appctx.ApplicationContext, criticalFunction func(ctx context.Context) error) func(ctx context.Context) error { - return criticalFunction -} -func (t *NoOpTracer) BuildTracingHeader() func(context.Context) string { - // extract root trace ID and parent from context and build the tracing header - return func(ctx context.Context) string { - root, _ := ctx.Value(TraceIDKey).(string) - parent, _ := ctx.Value(DocumentIDKey).(string) - - if root != "" && parent != "" { - return fmt.Sprintf("Root=%s;Parent=%s;Sampled=1", root, parent) - } - - return "" - } -} - -func (t *NoOpTracer) BuildTracingCtxForStart() *interop.TracingCtx { - return nil -} -func (t *NoOpTracer) BuildTracingCtxAfterInvokeComplete() *interop.TracingCtx { - return nil -} - -func NewNoOpTracer() *NoOpTracer { - return &NoOpTracer{} -} - -// NewTraceContext returns new derived context with trace config set for testing -func NewTraceContext(ctx context.Context, root string, parent string) context.Context { - ctxWithRoot := context.WithValue(ctx, TraceIDKey, root) - return context.WithValue(ctxWithRoot, DocumentIDKey, parent) -} - -// ParseTracingHeader extracts RootTraceID, ParentID, Sampled, and Lineage from a tracing header. -// Tracing header format is defined here: -// https://docs.aws.amazon.com/xray/latest/devguide/xray-concepts.html#xray-concepts-tracingheader -func ParseTracingHeader(tracingHeader string) (rootID, parentID, sampled, lineage string) { - keyValuePairs := strings.Split(tracingHeader, ";") - for _, pair := range keyValuePairs { - var key, value string - keyValue := strings.Split(pair, "=") - if len(keyValue) == 2 { - key = keyValue[0] - value = keyValue[1] - } - switch key { - case "Root": - rootID = value - case "Parent": - parentID = value - case "Sampled": - sampled = value - case "Lineage": - lineage = value - } - } - return -} - -// BuildFullTraceID takes individual components of X-Ray trace header -// and puts them together into a formatted trace header. -// If root is empty, returns an empty string. -func BuildFullTraceID(root, parent, sample string) string { - if root == "" { - return "" - } - - parts := make([]string, 0, 3) - parts = append(parts, "Root="+root) - if parent != "" { - parts = append(parts, "Parent="+parent) - } - if sample == "" { - sample = model.XRayNonSampled - } - parts = append(parts, "Sampled="+sample) - - return strings.Join(parts, ";") -} diff --git a/lambda/telemetry/tracer_test.go b/lambda/telemetry/tracer_test.go deleted file mode 100644 index d67c389..0000000 --- a/lambda/telemetry/tracer_test.go +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package telemetry - -import ( - "context" - "fmt" - "strings" - "testing" - - "go.amzn.com/lambda/rapi/model" -) - -var BigString = strings.Repeat("a", 255) - -var parserTests = []struct { - tracingHeaderIn string - rootIDOut string - parentIDOut string - sampledOut string - lineageOut string -}{ - {"Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9;Sampled=1", "1-5b3cc918-939afd635f8891ba6a9e1df6", "c88d77b0aef840e9", "1", ""}, - {"Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9", "1-5b3cc918-939afd635f8891ba6a9e1df6", "c88d77b0aef840e9", "", ""}, - {"1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9;Sampled=1", "", "c88d77b0aef840e9", "1", ""}, - {"Root=1-5b3cc918-939afd635f8891ba6a9e1df6", "1-5b3cc918-939afd635f8891ba6a9e1df6", "", "", ""}, - {"", "", "", "", ""}, - {"abc;;", "", "", "", ""}, - {"abc", "", "", "", ""}, - {"abc;asd", "", "", "", ""}, - {"abc=as;asd=as", "", "", "", ""}, - {"Root=abc", "abc", "", "", ""}, - {"Root=abc;Parent=zxc;Sampled=1", "abc", "zxc", "1", ""}, - {"Root=root;Parent=par", "root", "par", "", ""}, - {"Root=root;Par", "root", "", "", ""}, - {"Root=", "", "", "", ""}, - {";Root=root;;", "root", "", "", ""}, - {"Root=root;Parent=parent;", "root", "parent", "", ""}, - {"Root=;Parent=parent;Sampled=1", "", "parent", "1", ""}, - {"Root=abc;Parent=zxc;Sampled=1;Lineage", "abc", "zxc", "1", ""}, - {"Root=abc;Parent=zxc;Sampled=1;Lineage=", "abc", "zxc", "1", ""}, - {"Root=abc;Parent=zxc;Sampled=1;Lineage=foo:1|bar:65535", "abc", "zxc", "1", "foo:1|bar:65535"}, - {"Root=abc;Parent=zxc;Lineage=foo:1|bar:65535;Sampled=1", "abc", "zxc", "1", "foo:1|bar:65535"}, - {fmt.Sprintf("Root=%s;Parent=%s;Sampled=1;Lineage=%s", BigString, BigString, BigString), BigString, BigString, "1", BigString}, -} - -func TestParseTracingHeader(t *testing.T) { - for _, tt := range parserTests { - t.Run(tt.tracingHeaderIn, func(t *testing.T) { - rootID, parentID, sampled, lineage := ParseTracingHeader(tt.tracingHeaderIn) - if rootID != tt.rootIDOut { - t.Errorf("Parsing %q got %q, wanted %q", tt.tracingHeaderIn, rootID, tt.rootIDOut) - } - if parentID != tt.parentIDOut { - t.Errorf("Parsing %q got %q, wanted %q", tt.tracingHeaderIn, parentID, tt.parentIDOut) - } - if sampled != tt.sampledOut { - t.Errorf("Parsing %q got %q, wanted %q", tt.tracingHeaderIn, sampled, tt.sampledOut) - } - if lineage != tt.lineageOut { - t.Errorf("Parsing %q got %q, wanted %q", tt.tracingHeaderIn, lineage, tt.lineageOut) - } - if lineage != tt.lineageOut { - t.Errorf("got %q, wanted %q", lineage, tt.lineageOut) - } - }) - } -} - -func TestBuildFullTraceID(t *testing.T) { - specs := map[string]struct { - root string - parent string - sample string - expectedTraceID string - }{ - "all non-empty components, sampled": { - root: "1-5b3cc918-939afd635f8891ba6a9e1df6", - parent: "c88d77b0aef840e9", - sample: model.XRaySampled, - expectedTraceID: "Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9;Sampled=1", - }, - "all non-empty components, non-sampled": { - root: "1-5b3cc918-939afd635f8891ba6a9e1df6", - parent: "c88d77b0aef840e9", - sample: model.XRayNonSampled, - expectedTraceID: "Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9;Sampled=0", - }, - "root is non-empty, parent and sample are empty": { - root: "1-5b3cc918-939afd635f8891ba6a9e1df6", - expectedTraceID: "Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Sampled=0", - }, - "root is empty": { - parent: "c88d77b0aef840e9", - expectedTraceID: "", - }, - "sample is empty": { - root: "1-5b3cc918-939afd635f8891ba6a9e1df6", - parent: "c88d77b0aef840e9", - expectedTraceID: "Root=1-5b3cc918-939afd635f8891ba6a9e1df6;Parent=c88d77b0aef840e9;Sampled=0", - }, - } - - for name, spec := range specs { - t.Run(name, func(t *testing.T) { - actual := BuildFullTraceID(spec.root, spec.parent, spec.sample) - if actual != spec.expectedTraceID { - t.Errorf("got %q, wanted %q", actual, spec.expectedTraceID) - } - }) - } -} - -func TestTracerDoesntSwallowErrorsFromCriticalFunctions(t *testing.T) { - ctx := context.Background() - - testCases := []struct { - name string - tracer Tracer - expectedError error - }{ - { - name: "NoOpTracer-success", - tracer: &NoOpTracer{}, - expectedError: nil, - }, - { - name: "NoOpTracer-fail", - tracer: &NoOpTracer{}, - expectedError: fmt.Errorf("invoke error"), - }, - } - - for _, test := range testCases { - t.Run(test.name, func(t *testing.T) { - criticalFunction := func(ctx context.Context) error { - return test.expectedError - } - - if err := test.tracer.CaptureInvokeSegment(ctx, criticalFunction); err != test.expectedError { - t.Errorf("CaptureInvokeSegment failed; expected: '%v', but got: '%v'", test.expectedError, err) - } - if err := test.tracer.CaptureInitSubsegment(ctx, criticalFunction); err != test.expectedError { - t.Errorf("CaptureInitSubsegment failed; expected: '%v', but got: '%v'", test.expectedError, err) - } - if err := test.tracer.CaptureInvokeSubsegment(ctx, criticalFunction); err != test.expectedError { - t.Errorf("CaptureInvokeSubsegment failed; expected: '%v', but got: '%v'", test.expectedError, err) - } - if err := test.tracer.CaptureOverheadSubsegment(ctx, criticalFunction); err != test.expectedError { - t.Errorf("CaptureOverheadSubsegment failed; expected: '%v', but got: '%v'", test.expectedError, err) - } - }) - } -} diff --git a/lambda/testdata/agents/bash_true.sh b/lambda/testdata/agents/bash_true.sh deleted file mode 100755 index f1f641a..0000000 --- a/lambda/testdata/agents/bash_true.sh +++ /dev/null @@ -1 +0,0 @@ -#!/usr/bin/env bash diff --git a/lambda/testdata/async_assertion_utils.go b/lambda/testdata/async_assertion_utils.go deleted file mode 100644 index c29e4a0..0000000 --- a/lambda/testdata/async_assertion_utils.go +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package testdata - -import ( - "testing" - "time" -) - -func WaitForErrorWithTimeout(channel <-chan error, timeout time.Duration) error { - select { - case err := <-channel: - return err - case <-time.After(timeout): - return nil - } -} - -func Eventually(t *testing.T, testFunc func() (bool, error), pollingIntervalMultiple time.Duration, retries int) bool { - for try := 0; try < retries; try++ { - success, err := testFunc() - if success { - return true - } - if err != nil { - t.Logf("try %d: %v", try, err) - } - time.Sleep(time.Duration(try) * pollingIntervalMultiple) - } - return false -} \ No newline at end of file diff --git a/lambda/testdata/bash_function.sh b/lambda/testdata/bash_function.sh deleted file mode 100755 index c5c370b..0000000 --- a/lambda/testdata/bash_function.sh +++ /dev/null @@ -1,7 +0,0 @@ -function handler () { - EVENT_DATA=$1 - echo "$EVENT_DATA" 1>&2; - RESPONSE="Echoing request: '$EVENT_DATA'" - - echo $RESPONSE -} \ No newline at end of file diff --git a/lambda/testdata/bash_runtime.sh b/lambda/testdata/bash_runtime.sh deleted file mode 100755 index f568d07..0000000 --- a/lambda/testdata/bash_runtime.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/sh - -set -euo pipefail - -# Initialization - load function handler -source $LAMBDA_TASK_ROOT/"bash_function.sh" - -# Processing -while true -do - HEADERS="$(mktemp)" - # Get an event - EVENT_DATA=$(curl -sS -LD "$HEADERS" -X GET "http://${AWS_LAMBDA_RUNTIME_API}/2018-06-01/runtime/invocation/next") - REQUEST_ID=$(grep -Fi Lambda-Runtime-Aws-Request-Id "$HEADERS" | tr -d '[:space:]' | cut -d: -f2) - - # Execute the handler function from the script - FN_PATH=$LAMBDA_TASK_ROOT/"bash_function.sh" - RESPONSE=$($FN_PATH "$EVENT_DATA") - - # Send the response - curl -X POST "http://${AWS_LAMBDA_RUNTIME_API}/2018-06-01/runtime/invocation/$REQUEST_ID/response" -d "response_from_runtime" -done \ No newline at end of file diff --git a/lambda/testdata/bash_script_with_child_proc.sh b/lambda/testdata/bash_script_with_child_proc.sh deleted file mode 100755 index bdde5ab..0000000 --- a/lambda/testdata/bash_script_with_child_proc.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/sh - -# Spawn one child process recursively and spin -# When parent process receives a SIGTERM, child process doesn't exit - -if [ -z "$DONT_SPAWN" ] -then - DONT_SPAWN=true ./$0 & -fi - -while true -do - sleep 1 -done \ No newline at end of file diff --git a/lambda/testdata/env_setup_helpers.go b/lambda/testdata/env_setup_helpers.go deleted file mode 100644 index 07a3c5d..0000000 --- a/lambda/testdata/env_setup_helpers.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package testdata - -import ( - "net" -) - -// Test helpers -type TestSocketsRapid struct { - CtrlFd int - CnslFd int -} - -type TestSocketsSlicer struct { - CtrlSock net.Conn - CnslSock net.Conn - CtrlFd int - CnslFd int -} diff --git a/lambda/testdata/flowtesting.go b/lambda/testdata/flowtesting.go deleted file mode 100644 index e2c4b49..0000000 --- a/lambda/testdata/flowtesting.go +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package testdata - -import ( - "bytes" - "context" - "io/ioutil" - "time" - - "go.amzn.com/lambda/appctx" - "go.amzn.com/lambda/core" - "go.amzn.com/lambda/interop" - "go.amzn.com/lambda/rapi/rendering" - "go.amzn.com/lambda/telemetry" - "go.amzn.com/lambda/testdata/mockthread" -) - -const ( - contentTypeHeader = "Content-Type" - functionResponseModeHeader = "Lambda-Runtime-Function-Response-Mode" -) - -type MockInteropServer struct { - Response []byte - ErrorResponse *interop.ErrorInvokeResponse - ResponseContentType string - FunctionResponseMode string - ActiveInvokeID string -} - -// SendResponse writes response to a shared memory. -func (i *MockInteropServer) SendResponse(invokeID string, resp *interop.StreamableInvokeResponse) error { - bytes, err := ioutil.ReadAll(resp.Payload) - if err != nil { - return err - } - if len(bytes) > interop.MaxPayloadSize { - return &interop.ErrorResponseTooLarge{ - ResponseSize: len(bytes), - MaxResponseSize: interop.MaxPayloadSize, - } - } - i.Response = bytes - i.ResponseContentType = resp.Headers[contentTypeHeader] - i.FunctionResponseMode = resp.Headers[functionResponseModeHeader] - return nil -} - -// SendErrorResponse writes error response to a shared memory and sends GIRD FAULT. -func (i *MockInteropServer) SendErrorResponse(invokeID string, response *interop.ErrorInvokeResponse) error { - i.ErrorResponse = response - i.ResponseContentType = response.Headers.ContentType - i.FunctionResponseMode = response.Headers.FunctionResponseMode - return nil -} - -// SendInitErrorResponse writes error response during init to a shared memory and sends GIRD FAULT. -func (i *MockInteropServer) SendInitErrorResponse(response *interop.ErrorInvokeResponse) error { - i.ErrorResponse = response - i.ResponseContentType = response.Headers.ContentType - return nil -} - -func (i *MockInteropServer) GetCurrentInvokeID() string { - return i.ActiveInvokeID -} - -func (i *MockInteropServer) SendRuntimeReady() error { return nil } - -// FlowTest provides configuration for tests that involve synchronization flows. -type FlowTest struct { - AppCtx appctx.ApplicationContext - InitFlow core.InitFlowSynchronization - InvokeFlow core.InvokeFlowSynchronization - RegistrationService core.RegistrationService - RenderingService *rendering.EventRenderingService - Runtime *core.Runtime - InteropServer *MockInteropServer - TelemetrySubscription *telemetry.NoOpSubscriptionAPI - CredentialsService core.CredentialsService - EventsAPI interop.EventsAPI -} - -// ConfigureForInit initialize synchronization gates and states for init. -func (s *FlowTest) ConfigureForInit() { - s.RegistrationService.PreregisterRuntime(s.Runtime) -} - -// ConfigureForInvoke initialize synchronization gates and states for invoke. -func (s *FlowTest) ConfigureForInvoke(ctx context.Context, invoke *interop.Invoke) { - s.InteropServer.ActiveInvokeID = invoke.ID - s.InvokeFlow.InitializeBarriers() - var buf bytes.Buffer // create default invoke renderer with new request buffer each time - s.ConfigureInvokeRenderer(ctx, invoke, &buf) -} - -// ConfigureInvokeRenderer overrides default invoke renderer to reuse request buffers (for benchmarks), etc. -func (s *FlowTest) ConfigureInvokeRenderer(ctx context.Context, invoke *interop.Invoke, buf *bytes.Buffer) { - s.RenderingService.SetRenderer(rendering.NewInvokeRenderer(ctx, invoke, buf, telemetry.NewNoOpTracer().BuildTracingHeader())) -} - -func (s *FlowTest) ConfigureForRestore() { - s.RenderingService.SetRenderer(rendering.NewRestoreRenderer()) -} - -func (s *FlowTest) ConfigureForRestoring() { - s.RegistrationService.PreregisterRuntime(s.Runtime) - s.Runtime.SetState(s.Runtime.RuntimeRestoringState) - s.RenderingService.SetRenderer(rendering.NewRestoreRenderer()) -} - -func (s *FlowTest) ConfigureForInitCaching(token, awsKey, awsSecret, awsSession string) { - credentialsExpiration := time.Now().Add(30 * time.Minute) - s.CredentialsService.SetCredentials(token, awsKey, awsSecret, awsSession, credentialsExpiration) -} - -// NewFlowTest returns new FlowTest configuration. -func NewFlowTest() *FlowTest { - appCtx := appctx.NewApplicationContext() - initFlow := core.NewInitFlowSynchronization() - invokeFlow := core.NewInvokeFlowSynchronization() - registrationService := core.NewRegistrationService(initFlow, invokeFlow) - renderingService := rendering.NewRenderingService() - credentialsService := core.NewCredentialsService() - runtime := core.NewRuntime(initFlow, invokeFlow) - runtime.ManagedThread = &mockthread.MockManagedThread{} - interopServer := &MockInteropServer{} - eventsAPI := telemetry.NoOpEventsAPI{} - appctx.StoreInteropServer(appCtx, interopServer) - appctx.StoreResponseSender(appCtx, interopServer) - - return &FlowTest{ - AppCtx: appCtx, - InitFlow: initFlow, - InvokeFlow: invokeFlow, - RegistrationService: registrationService, - RenderingService: renderingService, - TelemetrySubscription: &telemetry.NoOpSubscriptionAPI{}, - Runtime: runtime, - InteropServer: interopServer, - CredentialsService: credentialsService, - EventsAPI: &eventsAPI, - } -} diff --git a/lambda/testdata/mockcommand.go b/lambda/testdata/mockcommand.go deleted file mode 100644 index b99a226..0000000 --- a/lambda/testdata/mockcommand.go +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -/* Package testdata holds ancillary data needed by - the tests. The go tool ignores this directory. - https://golang.org/pkg/cmd/go/internal/test/ */ -package testdata - -import ( - "context" -) - -// MockCommand represents a mock of os/exec.Cmd for testing -type MockCommand struct { - done chan error - ctx context.Context -} - -// NewMockCommand returns a MockCommand that satisfies -// the command interface -func NewMockCommand(ctx context.Context) MockCommand { - done := make(chan error) - return MockCommand{done, ctx} -} - -// Start represents a successful cmd.Start() without -// errors -func (c MockCommand) Start() error { - return nil -} - -// Wait represents a cancelable call to cmd.Wait() -func (c MockCommand) Wait() error { - select { - case <-c.done: - return nil - case <-c.ctx.Done(): - return c.ctx.Err() - } -} - -// ForceExit tells goroutine blocking on Wait() to exit -func (c MockCommand) ForceExit() { - c.done <- nil -} diff --git a/lambda/testdata/mockthread/mockthread.go b/lambda/testdata/mockthread/mockthread.go deleted file mode 100644 index 2834d45..0000000 --- a/lambda/testdata/mockthread/mockthread.go +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package mockthread - -// MockManagedThread implements core.Suspendable interface but -// does not suspend running thread on condition. -type MockManagedThread struct{} - -// SuspendUnsafe does not suspend running thread. -func (s *MockManagedThread) SuspendUnsafe() {} - -// Release resumes suspended thread. -func (s *MockManagedThread) Release() {} - -// Lock: no-op -func (s *MockManagedThread) Lock() {} - -// Unlock: no-op -func (s *MockManagedThread) Unlock() {} diff --git a/lambda/testdata/mocktracer/mocktracer.go b/lambda/testdata/mocktracer/mocktracer.go deleted file mode 100644 index 3fb7054..0000000 --- a/lambda/testdata/mocktracer/mocktracer.go +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package mocktracer - -import ( - "context" - "time" - - "go.amzn.com/lambda/xray" -) - -// MockStartTime is start time set in Start method -var MockStartTime = time.Now().UnixNano() - -// MockEndTime is end time set in End method -var MockEndTime = time.Now().UnixNano() + 1 - -// MockTracer is used for unit tests -type MockTracer struct { - documentsMap map[xray.DocumentKey]xray.Document - sentDocuments []xray.Document -} - -// Send will add document name to sentDocuments list -func (m *MockTracer) Send(document xray.Document) (dk xray.DocumentKey, err error) { - if len(document.ID) == 0 { - // Give it a predictable ID that we could use in our assertions. - document.ID = IDFor(document.Name) - } - m.sentDocuments = append(m.sentDocuments, document) - return xray.DocumentKey{ - TraceID: document.TraceID, - DocumentID: document.ID, - }, nil -} - -// Start will save document in documentsMap -func (m *MockTracer) Start(document xray.Document) (dk xray.DocumentKey, err error) { - document.StartTime = float64(MockStartTime) / xray.TimeDenominator - document.InProgress = true - dk, err = m.Send(document) - m.documentsMap[dk] = document - return -} - -// SetOptions will set value of a field on a saved document -func (m *MockTracer) SetOptions(dk xray.DocumentKey, documentOptions ...xray.DocumentOption) (err error) { - document := m.documentsMap[dk] - - for _, fieldValueSetter := range documentOptions { - fieldValueSetter(&document) - } - - m.documentsMap[dk] = document - - return nil -} - -// End will delete the key-value pair in documentsMap -func (m *MockTracer) End(dk xray.DocumentKey) (err error) { - document := m.documentsMap[dk] - document.EndTime = float64(MockEndTime) / xray.TimeDenominator - document.InProgress = false - - m.Send(document) - delete(m.documentsMap, dk) - return -} - -// GetSentDocuments will return sentDocuments for unit test to verify -func (m *MockTracer) GetSentDocuments() []xray.Document { - return m.sentDocuments -} - -// ResetSentDocuments resets captured documents list to an empty list. -func (m *MockTracer) ResetSentDocuments() { - m.sentDocuments = []xray.Document{} -} - -// SetDocumentMap sets internal state. -func (m *MockTracer) SetDocumentMap(dm map[xray.DocumentKey]xray.Document) { - m.documentsMap = dm -} - -// Capture mock method for capturing segments. -func (m *MockTracer) Capture(ctx context.Context, document xray.Document, criticalFunction func(context.Context) error) error { - return nil -} - -// SetOptionsCtx contextual SetOptions. -func (m *MockTracer) SetOptionsCtx(ctx context.Context, documentOptions ...xray.DocumentOption) (err error) { - return nil -} - -// NewMockTracer is the constructor for mock tracer -func NewMockTracer() xray.Tracer { - return &MockTracer{ - documentsMap: make(map[xray.DocumentKey]xray.Document), - sentDocuments: []xray.Document{}, - } -} - -// IDFor constructs a predictable id for a given name. -func IDFor(name string) string { - return name + "_SEGMID" -} diff --git a/lambda/testdata/parametrization.go b/lambda/testdata/parametrization.go deleted file mode 100644 index 90a2dde..0000000 --- a/lambda/testdata/parametrization.go +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package testdata - -// SuppressInitTests is a parametrization vector for testing suppress init behavior. -var SuppressInitTests = []struct { - TestName string - SuppressInit bool -}{ - {"Unsuppressed", false}, - {"Suppressed", true}, -}