diff --git a/components/ambient-control-plane/cmd/ambient-control-plane/main.go b/components/ambient-control-plane/cmd/ambient-control-plane/main.go index 8614aa273..f1e66099b 100644 --- a/components/ambient-control-plane/cmd/ambient-control-plane/main.go +++ b/components/ambient-control-plane/cmd/ambient-control-plane/main.go @@ -15,6 +15,7 @@ import ( "github.com/ambient-code/platform/components/ambient-control-plane/internal/auth" "github.com/ambient-code/platform/components/ambient-control-plane/internal/config" "github.com/ambient-code/platform/components/ambient-control-plane/internal/informer" + "github.com/ambient-code/platform/components/ambient-control-plane/internal/keypair" "github.com/ambient-code/platform/components/ambient-control-plane/internal/kubeclient" "github.com/ambient-code/platform/components/ambient-control-plane/internal/reconciler" "github.com/ambient-code/platform/components/ambient-control-plane/internal/tokenserver" @@ -25,8 +26,6 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" - "k8s.io/client-go/rest" - "k8s.io/client-go/tools/clientcmd" ) var ( @@ -123,6 +122,12 @@ func runKubeMode(ctx context.Context, cfg *config.ControlPlaneConfig) error { provisioner := buildNamespaceProvisioner(cfg, provisionerKube) tokenProvider := buildTokenProvider(cfg, log.Logger) + kp, err := keypair.EnsureKeypairSecret(ctx, provisionerKube, cfg.CPRuntimeNamespace, log.Logger) + if err != nil { + return fmt.Errorf("bootstrapping CP token keypair: %w", err) + } + log.Info().Str("namespace", cfg.CPRuntimeNamespace).Msg("CP token keypair ready") + factory := reconciler.NewSDKClientFactory(cfg.APIServerURL, tokenProvider, log.Logger) kubeReconcilerCfg := reconciler.KubeReconcilerConfig{ RunnerImage: cfg.RunnerImage, @@ -142,6 +147,7 @@ func runKubeMode(ctx context.Context, cfg *config.ControlPlaneConfig) error { RunnerLogLevel: cfg.RunnerLogLevel, CPRuntimeNamespace: cfg.CPRuntimeNamespace, CPTokenURL: cfg.CPTokenURL, + CPTokenPublicKey: string(kp.PublicKeyPEM), } conn, err := grpc.NewClient(cfg.GRPCServerAddr, grpc.WithTransportCredentials(grpcCredentials(cfg.GRPCUseTLS))) @@ -184,7 +190,7 @@ func runKubeMode(ctx context.Context, cfg *config.ControlPlaneConfig) error { tsErrCh := make(chan error, 1) go func() { - tsErrCh <- startTokenServer(ctx, cfg, tokenProvider) + tsErrCh <- startTokenServer(ctx, cfg, tokenProvider, kp) }() infErrCh := make(chan error, 1) @@ -203,24 +209,18 @@ func runKubeMode(ctx context.Context, cfg *config.ControlPlaneConfig) error { } } -func startTokenServer(ctx context.Context, cfg *config.ControlPlaneConfig, tokenProvider auth.TokenProvider) error { - k8sConfig, err := buildK8sRestConfig(cfg.Kubeconfig) +func startTokenServer(ctx context.Context, cfg *config.ControlPlaneConfig, tokenProvider auth.TokenProvider, kp *keypair.KeyPair) error { + privKey, err := keypair.ParsePrivateKey(kp.PrivateKeyPEM) if err != nil { - return fmt.Errorf("building k8s rest config for token server: %w", err) + return fmt.Errorf("parsing CP token private key: %w", err) } - ts, err := tokenserver.New(cfg.CPTokenListenAddr, tokenProvider, k8sConfig, log.Logger) + ts, err := tokenserver.New(cfg.CPTokenListenAddr, tokenProvider, privKey, log.Logger) if err != nil { return fmt.Errorf("creating token server: %w", err) } return ts.Start(ctx) } -func buildK8sRestConfig(kubeconfig string) (*rest.Config, error) { - if kubeconfig != "" { - return clientcmd.BuildConfigFromFlags("", kubeconfig) - } - return rest.InClusterConfig() -} func createSessionReconcilers(reconcilerTypes []string, factory *reconciler.SDKClientFactory, kube *kubeclient.KubeClient, projectKube *kubeclient.KubeClient, provisioner kubeclient.NamespaceProvisioner, cfg reconciler.KubeReconcilerConfig, logger zerolog.Logger) []reconciler.Reconciler { var reconcilers []reconciler.Reconciler diff --git a/components/ambient-control-plane/internal/keypair/bootstrap.go b/components/ambient-control-plane/internal/keypair/bootstrap.go new file mode 100644 index 000000000..69c1ad119 --- /dev/null +++ b/components/ambient-control-plane/internal/keypair/bootstrap.go @@ -0,0 +1,134 @@ +package keypair + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "fmt" + + "github.com/ambient-code/platform/components/ambient-control-plane/internal/kubeclient" + "github.com/rs/zerolog" + k8serrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" +) + +const ( + SecretName = "ambient-cp-token-keypair" + privateKeyKey = "private.pem" + publicKeyKey = "public.pem" + rsaKeyBits = 4096 +) + +type KeyPair struct { + PrivateKeyPEM []byte + PublicKeyPEM []byte +} + +func EnsureKeypairSecret(ctx context.Context, kube *kubeclient.KubeClient, namespace string, logger zerolog.Logger) (*KeyPair, error) { + existing, err := kube.GetSecret(ctx, namespace, SecretName) + if err == nil { + return keypairFromSecret(existing) + } + if !k8serrors.IsNotFound(err) { + return nil, fmt.Errorf("checking for keypair secret: %w", err) + } + + logger.Info().Str("namespace", namespace).Str("secret", SecretName).Msg("keypair secret not found, generating new RSA keypair") + + kp, err := generateKeypair() + if err != nil { + return nil, fmt.Errorf("generating RSA keypair: %w", err) + } + + secret := &unstructured.Unstructured{ + Object: map[string]interface{}{ + "apiVersion": "v1", + "kind": "Secret", + "metadata": map[string]interface{}{ + "name": SecretName, + "namespace": namespace, + "labels": map[string]interface{}{ + "app": "ambient-control-plane", + "ambient-code.io/managed-by": "ambient-control-plane", + }, + }, + "type": "Opaque", + "data": map[string]interface{}{ + privateKeyKey: base64.StdEncoding.EncodeToString(kp.PrivateKeyPEM), + publicKeyKey: base64.StdEncoding.EncodeToString(kp.PublicKeyPEM), + }, + }, + } + + if _, createErr := kube.CreateSecret(ctx, secret); createErr != nil { + if !k8serrors.IsAlreadyExists(createErr) { + return nil, fmt.Errorf("creating keypair secret: %w", createErr) + } + existing, err = kube.GetSecret(ctx, namespace, SecretName) + if err != nil { + return nil, fmt.Errorf("re-reading keypair secret after race: %w", err) + } + return keypairFromSecret(existing) + } + + logger.Info().Str("namespace", namespace).Str("secret", SecretName).Msg("RSA keypair secret created") + return kp, nil +} + +func keypairFromSecret(secret *unstructured.Unstructured) (*KeyPair, error) { + data, _, _ := unstructured.NestedMap(secret.Object, "data") + + privB64, ok := data[privateKeyKey].(string) + if !ok || privB64 == "" { + return nil, fmt.Errorf("keypair secret missing %q key", privateKeyKey) + } + pubB64, ok := data[publicKeyKey].(string) + if !ok || pubB64 == "" { + return nil, fmt.Errorf("keypair secret missing %q key", publicKeyKey) + } + + privPEM, err := base64.StdEncoding.DecodeString(privB64) + if err != nil { + return nil, fmt.Errorf("decoding private key from secret: %w", err) + } + pubPEM, err := base64.StdEncoding.DecodeString(pubB64) + if err != nil { + return nil, fmt.Errorf("decoding public key from secret: %w", err) + } + + return &KeyPair{PrivateKeyPEM: privPEM, PublicKeyPEM: pubPEM}, nil +} + +func generateKeypair() (*KeyPair, error) { + privKey, err := rsa.GenerateKey(rand.Reader, rsaKeyBits) + if err != nil { + return nil, fmt.Errorf("generating RSA key: %w", err) + } + + privPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privKey), + }) + + pubDER, err := x509.MarshalPKIXPublicKey(&privKey.PublicKey) + if err != nil { + return nil, fmt.Errorf("marshaling public key: %w", err) + } + pubPEM := pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: pubDER, + }) + + return &KeyPair{PrivateKeyPEM: privPEM, PublicKeyPEM: pubPEM}, nil +} + +func ParsePrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) { + block, _ := pem.Decode(pemBytes) + if block == nil { + return nil, fmt.Errorf("failed to decode PEM block for private key") + } + return x509.ParsePKCS1PrivateKey(block.Bytes) +} diff --git a/components/ambient-control-plane/internal/keypair/bootstrap_test.go b/components/ambient-control-plane/internal/keypair/bootstrap_test.go new file mode 100644 index 000000000..fdf0362cc --- /dev/null +++ b/components/ambient-control-plane/internal/keypair/bootstrap_test.go @@ -0,0 +1,160 @@ +package keypair + +import ( + "context" + "crypto/rsa" + "encoding/base64" + "testing" + + "github.com/rs/zerolog" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/dynamic/fake" + + "github.com/ambient-code/platform/components/ambient-control-plane/internal/kubeclient" +) + +func newFakeKubeClient(objects ...runtime.Object) *kubeclient.KubeClient { + scheme := runtime.NewScheme() + dynClient := fake.NewSimpleDynamicClient(scheme, objects...) + return kubeclient.NewFromDynamic(dynClient, zerolog.Nop()) +} + +func TestGenerateKeypair(t *testing.T) { + kp, err := generateKeypair() + if err != nil { + t.Fatalf("generateKeypair() error: %v", err) + } + if len(kp.PrivateKeyPEM) == 0 { + t.Error("PrivateKeyPEM is empty") + } + if len(kp.PublicKeyPEM) == 0 { + t.Error("PublicKeyPEM is empty") + } +} + +func TestParsePrivateKey(t *testing.T) { + kp, err := generateKeypair() + if err != nil { + t.Fatalf("generateKeypair() error: %v", err) + } + privKey, err := ParsePrivateKey(kp.PrivateKeyPEM) + if err != nil { + t.Fatalf("ParsePrivateKey() error: %v", err) + } + if privKey == nil { + t.Fatal("ParsePrivateKey() returned nil") + } + if _, ok := interface{}(privKey).(*rsa.PrivateKey); !ok { + t.Error("parsed key is not *rsa.PrivateKey") + } +} + +func TestParsePrivateKey_InvalidPEM(t *testing.T) { + _, err := ParsePrivateKey([]byte("not a pem block")) + if err == nil { + t.Error("expected error for invalid PEM, got nil") + } +} + +func TestKeypairFromSecret_MissingPrivateKey(t *testing.T) { + secret := &unstructured.Unstructured{ + Object: map[string]interface{}{ + "apiVersion": "v1", + "kind": "Secret", + "metadata": map[string]interface{}{"name": SecretName, "namespace": "test"}, + "data": map[string]interface{}{ + publicKeyKey: base64.StdEncoding.EncodeToString([]byte("pub")), + }, + }, + } + _, err := keypairFromSecret(secret) + if err == nil { + t.Error("expected error for missing private key, got nil") + } +} + +func TestKeypairFromSecret_MissingPublicKey(t *testing.T) { + secret := &unstructured.Unstructured{ + Object: map[string]interface{}{ + "apiVersion": "v1", + "kind": "Secret", + "metadata": map[string]interface{}{"name": SecretName, "namespace": "test"}, + "data": map[string]interface{}{ + privateKeyKey: base64.StdEncoding.EncodeToString([]byte("priv")), + }, + }, + } + _, err := keypairFromSecret(secret) + if err == nil { + t.Error("expected error for missing public key, got nil") + } +} + +func TestKeypairFromSecret_ValidSecret(t *testing.T) { + kp, err := generateKeypair() + if err != nil { + t.Fatalf("generateKeypair() error: %v", err) + } + secret := &unstructured.Unstructured{ + Object: map[string]interface{}{ + "apiVersion": "v1", + "kind": "Secret", + "metadata": map[string]interface{}{"name": SecretName, "namespace": "test"}, + "data": map[string]interface{}{ + privateKeyKey: base64.StdEncoding.EncodeToString(kp.PrivateKeyPEM), + publicKeyKey: base64.StdEncoding.EncodeToString(kp.PublicKeyPEM), + }, + }, + } + got, err := keypairFromSecret(secret) + if err != nil { + t.Fatalf("keypairFromSecret() error: %v", err) + } + if string(got.PrivateKeyPEM) != string(kp.PrivateKeyPEM) { + t.Error("PrivateKeyPEM mismatch") + } + if string(got.PublicKeyPEM) != string(kp.PublicKeyPEM) { + t.Error("PublicKeyPEM mismatch") + } +} + +func TestEnsureKeypairSecret_CreatesWhenMissing(t *testing.T) { + kube := newFakeKubeClient() + ctx := context.Background() + + kp, err := EnsureKeypairSecret(ctx, kube, "test-ns", zerolog.Nop()) + if err != nil { + t.Fatalf("EnsureKeypairSecret() error: %v", err) + } + if len(kp.PrivateKeyPEM) == 0 || len(kp.PublicKeyPEM) == 0 { + t.Error("returned keypair has empty PEM fields") + } + + privKey, err := ParsePrivateKey(kp.PrivateKeyPEM) + if err != nil { + t.Fatalf("generated private key is not parseable: %v", err) + } + if privKey.N.BitLen() != rsaKeyBits { + t.Errorf("key size: got %d, want %d", privKey.N.BitLen(), rsaKeyBits) + } +} + +func TestEnsureKeypairSecret_ReturnsExistingWhenPresent(t *testing.T) { + ctx := context.Background() + kube := newFakeKubeClient() + + first, err := EnsureKeypairSecret(ctx, kube, "test-ns", zerolog.Nop()) + if err != nil { + t.Fatalf("first call error: %v", err) + } + + second, err := EnsureKeypairSecret(ctx, kube, "test-ns", zerolog.Nop()) + if err != nil { + t.Fatalf("second call error: %v", err) + } + + if string(first.PrivateKeyPEM) != string(second.PrivateKeyPEM) { + t.Error("second call returned different private key — should reuse existing Secret") + } +} diff --git a/components/ambient-control-plane/internal/reconciler/kube_reconciler.go b/components/ambient-control-plane/internal/reconciler/kube_reconciler.go index e2053b3cd..03bf22b94 100644 --- a/components/ambient-control-plane/internal/reconciler/kube_reconciler.go +++ b/components/ambient-control-plane/internal/reconciler/kube_reconciler.go @@ -39,6 +39,7 @@ type KubeReconcilerConfig struct { RunnerLogLevel string CPRuntimeNamespace string CPTokenURL string + CPTokenPublicKey string } type SimpleKubeReconciler struct { @@ -582,6 +583,7 @@ func (r *SimpleKubeReconciler) buildEnv(ctx context.Context, session types.Sessi envVar("USE_VERTEX", useVertex), envVar("CLAUDE_CODE_USE_VERTEX", useVertex), envVar("AMBIENT_CP_TOKEN_URL", r.cfg.CPTokenURL), + envVar("AMBIENT_CP_TOKEN_PUBLIC_KEY", r.cfg.CPTokenPublicKey), envVar("AMBIENT_GRPC_URL", r.cfg.RunnerGRPCURL), envVar("AMBIENT_GRPC_ENABLED", boolToStr(r.cfg.RunnerGRPCURL != "")), envVar("AMBIENT_GRPC_USE_TLS", boolToStr(r.cfg.RunnerGRPCUseTLS)), @@ -831,6 +833,7 @@ func (r *SimpleKubeReconciler) buildMCPSidecar() interface{} { envVar("MCP_BIND_ADDR", fmt.Sprintf(":%d", mcpSidecarPort)), envVar("AMBIENT_API_URL", r.cfg.MCPAPIServerURL), envVar("AMBIENT_CP_TOKEN_URL", r.cfg.CPTokenURL), + envVar("AMBIENT_CP_TOKEN_PUBLIC_KEY", r.cfg.CPTokenPublicKey), }, "resources": map[string]interface{}{ "requests": map[string]interface{}{ diff --git a/components/ambient-control-plane/internal/tokenserver/handler.go b/components/ambient-control-plane/internal/tokenserver/handler.go index fd0bc721a..5436753e4 100644 --- a/components/ambient-control-plane/internal/tokenserver/handler.go +++ b/components/ambient-control-plane/internal/tokenserver/handler.go @@ -1,25 +1,17 @@ package tokenserver import ( - "context" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" "encoding/json" "fmt" "net/http" "strings" - "time" "github.com/ambient-code/platform/components/ambient-control-plane/internal/auth" "github.com/rs/zerolog" - authv1 "k8s.io/api/authentication/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/client-go/kubernetes" -) - -const ( - runnerSAPrefix = "system:serviceaccount:" - sessionSAInfix = ":session-" - sessionSASuffix = "-sa" - tokenReviewTimeout = 10 * time.Second ) type tokenResponse struct { @@ -28,7 +20,7 @@ type tokenResponse struct { type handler struct { tokenProvider auth.TokenProvider - k8sClient kubernetes.Interface + privateKey *rsa.PrivateKey logger zerolog.Logger } @@ -38,34 +30,34 @@ func (h *handler) handleToken(w http.ResponseWriter, r *http.Request) { return } - saToken, err := extractBearerToken(r) + ciphertext, err := extractBearerToken(r) if err != nil { h.logger.Warn().Err(err).Msg("token request: missing or malformed Authorization header") http.Error(w, "unauthorized", http.StatusUnauthorized) return } - username, err := h.validateSAToken(r.Context(), saToken) + sessionID, err := h.decryptSessionID(ciphertext) if err != nil { - h.logger.Warn().Err(err).Msg("token request: SA token validation failed") + h.logger.Warn().Err(err).Msg("token request: session ID decryption failed") http.Error(w, "unauthorized", http.StatusUnauthorized) return } - if !isRunnerSA(username) { - h.logger.Warn().Str("username", username).Msg("token request: username does not match runner SA pattern") + if !isValidSessionID(sessionID) { + h.logger.Warn().Str("session_id", sessionID).Msg("token request: decrypted value does not match session ID pattern") http.Error(w, "forbidden", http.StatusForbidden) return } apiToken, err := h.tokenProvider.Token(r.Context()) if err != nil { - h.logger.Error().Err(err).Str("username", username).Msg("token request: failed to mint API token") + h.logger.Error().Err(err).Str("session_id", sessionID).Msg("token request: failed to mint API token") http.Error(w, "internal server error", http.StatusInternalServerError) return } - h.logger.Info().Str("username", username).Msg("token request: issued fresh API token") + h.logger.Info().Str("session_id", sessionID).Msg("token request: issued fresh API token") resp := tokenResponse{Token: apiToken} w.Header().Set("Content-Type", "application/json") @@ -74,50 +66,39 @@ func (h *handler) handleToken(w http.ResponseWriter, r *http.Request) { } } -func (h *handler) validateSAToken(ctx context.Context, token string) (string, error) { - ctx, cancel := context.WithTimeout(ctx, tokenReviewTimeout) - defer cancel() - - tr := &authv1.TokenReview{ - Spec: authv1.TokenReviewSpec{ - Token: token, - }, - } - - result, err := h.k8sClient.AuthenticationV1().TokenReviews().Create(ctx, tr, metav1.CreateOptions{}) +func (h *handler) decryptSessionID(ciphertext string) (string, error) { + ciphertextBytes, err := base64.StdEncoding.DecodeString(ciphertext) if err != nil { - return "", fmt.Errorf("TokenReview API call failed: %w", err) + return "", fmt.Errorf("base64-decoding bearer token: %w", err) } - if !result.Status.Authenticated { - return "", fmt.Errorf("token not authenticated: %s", result.Status.Error) + plaintext, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, h.privateKey, ciphertextBytes, nil) + if err != nil { + return "", fmt.Errorf("RSA decryption failed: %w", err) } - - return result.Status.User.Username, nil + return string(plaintext), nil } -func isRunnerSA(username string) bool { - if !strings.HasPrefix(username, runnerSAPrefix) { - return false - } - rest := strings.TrimPrefix(username, runnerSAPrefix) - idx := strings.Index(rest, sessionSAInfix) - if idx < 0 { - return false - } - return strings.HasSuffix(rest, sessionSASuffix) +func isValidSessionID(sessionID string) bool { + return len(sessionID) >= 8 && !strings.ContainsAny(sessionID, " \t\n\r") } func extractBearerToken(r *http.Request) (string, error) { - auth := r.Header.Get("Authorization") - if auth == "" { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { return "", fmt.Errorf("Authorization header missing") } - if !strings.HasPrefix(auth, "Bearer ") { + if !strings.HasPrefix(authHeader, "Bearer ") { return "", fmt.Errorf("Authorization header must use Bearer scheme") } - token := strings.TrimPrefix(auth, "Bearer ") + token := strings.TrimPrefix(authHeader, "Bearer ") if token == "" { return "", fmt.Errorf("empty bearer token") } return token, nil } + +func handleHealthz(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) +} + diff --git a/components/ambient-control-plane/internal/tokenserver/handler_test.go b/components/ambient-control-plane/internal/tokenserver/handler_test.go new file mode 100644 index 000000000..467313e70 --- /dev/null +++ b/components/ambient-control-plane/internal/tokenserver/handler_test.go @@ -0,0 +1,162 @@ +package tokenserver + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "net/http" + "net/http/httptest" + "testing" + + "github.com/rs/zerolog" +) + +type staticTokenProvider struct{ token string } + +func (s *staticTokenProvider) Token(_ context.Context) (string, error) { + return s.token, nil +} + +func newTestHandler(t *testing.T) (*handler, *rsa.PrivateKey) { + t.Helper() + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generating RSA key: %v", err) + } + h := &handler{ + tokenProvider: &staticTokenProvider{token: "test-api-token"}, + privateKey: privKey, + logger: zerolog.Nop(), + } + return h, privKey +} + +func encryptSessionID(t *testing.T, pubKey *rsa.PublicKey, sessionID string) string { + t.Helper() + ciphertext, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, pubKey, []byte(sessionID), nil) + if err != nil { + t.Fatalf("encrypting session ID: %v", err) + } + return base64.StdEncoding.EncodeToString(ciphertext) +} + +func TestHandleToken_Success(t *testing.T) { + h, privKey := newTestHandler(t) + bearer := encryptSessionID(t, &privKey.PublicKey, "abc123session") + + req := httptest.NewRequest(http.MethodGet, "/token", nil) + req.Header.Set("Authorization", "Bearer "+bearer) + rr := httptest.NewRecorder() + + h.handleToken(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("status: got %d, want %d — body: %s", rr.Code, http.StatusOK, rr.Body.String()) + } +} + +func TestHandleToken_MissingAuthHeader(t *testing.T) { + h, _ := newTestHandler(t) + req := httptest.NewRequest(http.MethodGet, "/token", nil) + rr := httptest.NewRecorder() + + h.handleToken(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Errorf("status: got %d, want %d", rr.Code, http.StatusUnauthorized) + } +} + +func TestHandleToken_WrongBearerScheme(t *testing.T) { + h, _ := newTestHandler(t) + req := httptest.NewRequest(http.MethodGet, "/token", nil) + req.Header.Set("Authorization", "Basic dXNlcjpwYXNz") + rr := httptest.NewRecorder() + + h.handleToken(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Errorf("status: got %d, want %d", rr.Code, http.StatusUnauthorized) + } +} + +func TestHandleToken_InvalidBase64(t *testing.T) { + h, _ := newTestHandler(t) + req := httptest.NewRequest(http.MethodGet, "/token", nil) + req.Header.Set("Authorization", "Bearer not-valid-base64!!!") + rr := httptest.NewRecorder() + + h.handleToken(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Errorf("status: got %d, want %d", rr.Code, http.StatusUnauthorized) + } +} + +func TestHandleToken_WrongKey(t *testing.T) { + h, _ := newTestHandler(t) + + otherKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generating other RSA key: %v", err) + } + bearer := encryptSessionID(t, &otherKey.PublicKey, "abc123session") + + req := httptest.NewRequest(http.MethodGet, "/token", nil) + req.Header.Set("Authorization", "Bearer "+bearer) + rr := httptest.NewRecorder() + + h.handleToken(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Errorf("status: got %d, want %d", rr.Code, http.StatusUnauthorized) + } +} + +func TestHandleToken_MethodNotAllowed(t *testing.T) { + h, _ := newTestHandler(t) + req := httptest.NewRequest(http.MethodPost, "/token", nil) + rr := httptest.NewRecorder() + + h.handleToken(rr, req) + + if rr.Code != http.StatusMethodNotAllowed { + t.Errorf("status: got %d, want %d", rr.Code, http.StatusMethodNotAllowed) + } +} + +func TestIsValidSessionID(t *testing.T) { + cases := []struct { + id string + valid bool + }{ + {"abc12345", true}, + {"3BurtLWQNFMLp61XAGFKILYiHoN", true}, + {"short", false}, + {"has space", false}, + {"has\nnewline", false}, + {"", false}, + } + for _, tc := range cases { + got := isValidSessionID(tc.id) + if got != tc.valid { + t.Errorf("isValidSessionID(%q) = %v, want %v", tc.id, got, tc.valid) + } + } +} + +func TestDecryptSessionID_RoundTrip(t *testing.T) { + h, privKey := newTestHandler(t) + want := "my-session-id-xyz" + bearer := encryptSessionID(t, &privKey.PublicKey, want) + + got, err := h.decryptSessionID(bearer) + if err != nil { + t.Fatalf("decryptSessionID() error: %v", err) + } + if got != want { + t.Errorf("decryptSessionID() = %q, want %q", got, want) + } +} diff --git a/components/ambient-control-plane/internal/tokenserver/server.go b/components/ambient-control-plane/internal/tokenserver/server.go index 25f5f3ce5..c69d98925 100644 --- a/components/ambient-control-plane/internal/tokenserver/server.go +++ b/components/ambient-control-plane/internal/tokenserver/server.go @@ -2,14 +2,13 @@ package tokenserver import ( "context" + "crypto/rsa" "fmt" "net/http" "time" "github.com/ambient-code/platform/components/ambient-control-plane/internal/auth" "github.com/rs/zerolog" - "k8s.io/client-go/kubernetes" - "k8s.io/client-go/rest" ) const ( @@ -28,17 +27,12 @@ type Server struct { func New( listenAddr string, tokenProvider auth.TokenProvider, - k8sConfig *rest.Config, + privateKey *rsa.PrivateKey, logger zerolog.Logger, ) (*Server, error) { - k8sClient, err := kubernetes.NewForConfig(k8sConfig) - if err != nil { - return nil, fmt.Errorf("creating k8s client for token server: %w", err) - } - h := &handler{ tokenProvider: tokenProvider, - k8sClient: k8sClient, + privateKey: privateKey, logger: logger.With().Str("component", "tokenserver").Logger(), } @@ -79,8 +73,3 @@ func (s *Server) Start(ctx context.Context) error { return fmt.Errorf("token server: %w", err) } } - -func handleHealthz(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("ok")) -} diff --git a/components/manifests/overlays/mpp-openshift/ambient-cp-token-netpol.yaml b/components/manifests/overlays/mpp-openshift/ambient-cp-token-netpol.yaml new file mode 100644 index 000000000..aa11c728d --- /dev/null +++ b/components/manifests/overlays/mpp-openshift/ambient-cp-token-netpol.yaml @@ -0,0 +1,21 @@ +apiVersion: networking.k8s.io/v1 +kind: NetworkPolicy +metadata: + name: allow-runner-token-fetch + namespace: ambient-code--runtime-int + labels: + app: ambient-control-plane +spec: + podSelector: + matchLabels: + app: ambient-control-plane + ingress: + - from: + - namespaceSelector: + matchLabels: + tenant.paas.redhat.com/tenant: ambient-code + ports: + - protocol: TCP + port: 8080 + policyTypes: + - Ingress diff --git a/components/manifests/overlays/mpp-openshift/kustomization.yaml b/components/manifests/overlays/mpp-openshift/kustomization.yaml index 5f38a1c27..de7217cf3 100644 --- a/components/manifests/overlays/mpp-openshift/kustomization.yaml +++ b/components/manifests/overlays/mpp-openshift/kustomization.yaml @@ -9,6 +9,7 @@ resources: - ambient-api-server.yaml - ambient-control-plane.yaml - ambient-control-plane-svc.yaml +- ambient-cp-token-netpol.yaml - ambient-api-server-route.yaml - ambient-control-plane-sa.yaml - tenant-rbac/ diff --git a/components/runners/ambient-runner/ambient_runner/_grpc_client.py b/components/runners/ambient-runner/ambient_runner/_grpc_client.py index a46662bf6..2f7782af2 100644 --- a/components/runners/ambient-runner/ambient_runner/_grpc_client.py +++ b/components/runners/ambient-runner/ambient_runner/_grpc_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +import base64 import json import logging import os @@ -10,6 +11,9 @@ from pathlib import Path from typing import Optional +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding + import grpc logger = logging.getLogger(__name__) @@ -17,6 +21,8 @@ _ENV_GRPC_URL = "AMBIENT_GRPC_URL" _ENV_TOKEN = "BOT_TOKEN" _ENV_CP_TOKEN_URL = "AMBIENT_CP_TOKEN_URL" +_ENV_CP_TOKEN_PUBLIC_KEY = "AMBIENT_CP_TOKEN_PUBLIC_KEY" +_ENV_SESSION_ID = "SESSION_ID" _ENV_USE_TLS = "AMBIENT_GRPC_USE_TLS" _ENV_CA_CERT = "AMBIENT_GRPC_CA_CERT_FILE" _DEFAULT_GRPC_URL = "ambient-api-server:9000" @@ -28,8 +34,22 @@ _CP_TOKEN_FETCH_TIMEOUT = 10 +def _encrypt_session_id(public_key_pem: str, session_id: str) -> str: + """RSA-OAEP encrypt session_id with the CP public key, return base64-encoded ciphertext.""" + public_key = serialization.load_pem_public_key(public_key_pem.encode()) + ciphertext = public_key.encrypt( + session_id.encode(), + padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), + algorithm=hashes.SHA256(), + label=None, + ), + ) + return base64.b64encode(ciphertext).decode() + + def _validate_cp_token_url(url: str) -> None: - """Reject non-http(s) or credential-bearing URLs to prevent SA token exfiltration.""" + """Reject non-http(s) or credential-bearing URLs to prevent exfiltration.""" parsed = urllib.parse.urlparse(url) if ( parsed.scheme not in {"http", "https"} @@ -42,18 +62,15 @@ def _validate_cp_token_url(url: str) -> None: ) -def _fetch_token_from_cp(cp_token_url: str) -> str: - """Fetch a fresh API token from the CP /token endpoint using the pod SA token. +def _fetch_token_from_cp(cp_token_url: str, public_key_pem: str, session_id: str) -> str: + """Fetch a fresh API token from the CP /token endpoint. - Retries up to _CP_TOKEN_FETCH_ATTEMPTS times with exponential backoff - to handle transient CP unavailability. + Encrypts the session ID with the CP public key and sends it as a Bearer token. + Retries up to _CP_TOKEN_FETCH_ATTEMPTS times with exponential backoff. """ _validate_cp_token_url(cp_token_url) - try: - sa_token = _SA_TOKEN_FILE.read_text().strip() - except OSError as e: - raise RuntimeError(f"cannot read SA token from {_SA_TOKEN_FILE}: {e}") from e + bearer = _encrypt_session_id(public_key_pem, session_id) last_err: Exception = RuntimeError("no attempts made") for attempt in range(_CP_TOKEN_FETCH_ATTEMPTS): @@ -70,7 +87,7 @@ def _fetch_token_from_cp(cp_token_url: str) -> str: try: req = urllib.request.Request( cp_token_url, - headers={"Authorization": f"Bearer {sa_token}"}, + headers={"Authorization": f"Bearer {bearer}"}, ) with urllib.request.urlopen(req, timeout=_CP_TOKEN_FETCH_TIMEOUT) as resp: body = json.loads(resp.read()) @@ -164,10 +181,16 @@ def from_env(cls) -> AmbientGRPCClient: use_tls = os.environ.get(_ENV_USE_TLS, "").lower() in ("true", "1", "yes") ca_cert_file = os.environ.get(_ENV_CA_CERT) if cp_token_url: + public_key_pem = os.environ.get(_ENV_CP_TOKEN_PUBLIC_KEY, "") + session_id = os.environ.get(_ENV_SESSION_ID, "") + if not public_key_pem: + raise RuntimeError("AMBIENT_CP_TOKEN_PUBLIC_KEY env var is required when AMBIENT_CP_TOKEN_URL is set") + if not session_id: + raise RuntimeError("SESSION_ID env var is required when AMBIENT_CP_TOKEN_URL is set") logger.info( "[GRPC CLIENT] Fetching token from CP endpoint: url=%s", cp_token_url ) - token = _fetch_token_from_cp(cp_token_url) + token = _fetch_token_from_cp(cp_token_url, public_key_pem, session_id) else: token = os.environ.get(_ENV_TOKEN, "") logger.info("[GRPC CLIENT] Using BOT_TOKEN env var (local dev mode)") @@ -188,7 +211,9 @@ def from_env(cls) -> AmbientGRPCClient: def reconnect(self) -> None: """Close the existing channel and rebuild with a fresh token from the CP endpoint.""" if self._cp_token_url: - fresh_token = _fetch_token_from_cp(self._cp_token_url) + public_key_pem = os.environ.get(_ENV_CP_TOKEN_PUBLIC_KEY, "") + session_id = os.environ.get(_ENV_SESSION_ID, "") + fresh_token = _fetch_token_from_cp(self._cp_token_url, public_key_pem, session_id) else: fresh_token = os.environ.get(_ENV_TOKEN, "") logger.info( diff --git a/components/runners/ambient-runner/pyproject.toml b/components/runners/ambient-runner/pyproject.toml index 68aa69058..7af247611 100644 --- a/components/runners/ambient-runner/pyproject.toml +++ b/components/runners/ambient-runner/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "aiohttp>=3.8.0", "requests>=2.31.0", "pyjwt>=2.8.0", + "cryptography>=42.0.0", "grpcio>=1.60.0", "protobuf>=4.25.0", ] diff --git a/components/runners/ambient-runner/tests/test_grpc_client.py b/components/runners/ambient-runner/tests/test_grpc_client.py new file mode 100644 index 000000000..baf379fd2 --- /dev/null +++ b/components/runners/ambient-runner/tests/test_grpc_client.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +import base64 +import json +import os +from unittest.mock import MagicMock, patch + +import pytest +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding, rsa + +from ambient_runner._grpc_client import ( + _encrypt_session_id, + _fetch_token_from_cp, + _validate_cp_token_url, +) + + +def generate_keypair(): + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + private_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + public_pem = private_key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ).decode() + return private_key, private_pem, public_pem + + +class TestValidateCPTokenURL: + def test_valid_http(self): + _validate_cp_token_url("http://ambient-control-plane.svc:8080/token") + + def test_valid_https(self): + _validate_cp_token_url("https://ambient-control-plane.svc:8080/token") + + def test_rejects_ftp(self): + with pytest.raises(RuntimeError, match="invalid CP token URL"): + _validate_cp_token_url("ftp://example.com/token") + + def test_rejects_file(self): + with pytest.raises(RuntimeError, match="invalid CP token URL"): + _validate_cp_token_url("file:///etc/passwd") + + def test_rejects_credentials_in_url(self): + with pytest.raises(RuntimeError, match="invalid CP token URL"): + _validate_cp_token_url("http://user:pass@example.com/token") + + def test_rejects_empty(self): + with pytest.raises(RuntimeError, match="invalid CP token URL"): + _validate_cp_token_url("") + + def test_rejects_no_host(self): + with pytest.raises(RuntimeError, match="invalid CP token URL"): + _validate_cp_token_url("http:///token") + + +class TestEncryptSessionID: + def test_produces_base64_ciphertext(self): + _, _, public_pem = generate_keypair() + result = _encrypt_session_id(public_pem, "my-session-id") + decoded = base64.b64decode(result) + assert len(decoded) > 0 + + def test_decryptable_with_private_key(self): + private_key, _, public_pem = generate_keypair() + session_id = "3BurtLWQNFMLp61XAGFKILYiHoN" + + ciphertext_b64 = _encrypt_session_id(public_pem, session_id) + ciphertext = base64.b64decode(ciphertext_b64) + + plaintext = private_key.decrypt( + ciphertext, + padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), + algorithm=hashes.SHA256(), + label=None, + ), + ) + assert plaintext.decode() == session_id + + def test_different_ciphertexts_for_same_input(self): + _, _, public_pem = generate_keypair() + result1 = _encrypt_session_id(public_pem, "session-abc") + result2 = _encrypt_session_id(public_pem, "session-abc") + assert result1 != result2 + + def test_invalid_public_key_raises(self): + with pytest.raises(Exception): + _encrypt_session_id("not a pem key", "session-id") + + +class TestFetchTokenFromCP: + def _mock_successful_response(self, token: str = "api-token-xyz"): + import urllib.request + + mock_resp = MagicMock() + mock_resp.read.return_value = json.dumps({"token": token}).encode() + mock_resp.__enter__ = MagicMock(return_value=mock_resp) + mock_resp.__exit__ = MagicMock(return_value=False) + return mock_resp + + def test_success(self): + _, _, public_pem = generate_keypair() + mock_resp = self._mock_successful_response("test-api-token") + + with patch("urllib.request.urlopen", return_value=mock_resp): + token = _fetch_token_from_cp( + "http://cp.svc:8080/token", public_pem, "session-12345678" + ) + + assert token == "test-api-token" + + def test_sends_encrypted_bearer(self): + _, _, public_pem = generate_keypair() + mock_resp = self._mock_successful_response() + captured_req = {} + + def fake_urlopen(req, timeout=None): + captured_req["req"] = req + return mock_resp + + with patch("urllib.request.urlopen", side_effect=fake_urlopen): + _fetch_token_from_cp("http://cp.svc:8080/token", public_pem, "session-abc") + + auth = captured_req["req"].get_header("Authorization") + assert auth.startswith("Bearer ") + b64_part = auth[len("Bearer "):] + decoded = base64.b64decode(b64_part) + assert len(decoded) > 0 + + def test_retries_on_failure_then_succeeds(self): + _, _, public_pem = generate_keypair() + mock_resp = self._mock_successful_response() + import urllib.error + + call_count = [0] + + def fake_urlopen(req, timeout=None): + call_count[0] += 1 + if call_count[0] < 3: + raise urllib.error.URLError("connection refused") + return mock_resp + + with patch("urllib.request.urlopen", side_effect=fake_urlopen): + with patch("time.sleep"): + token = _fetch_token_from_cp( + "http://cp.svc:8080/token", public_pem, "session-12345678" + ) + + assert token == "api-token-xyz" + assert call_count[0] == 3 + + def test_raises_after_all_attempts_fail(self): + _, _, public_pem = generate_keypair() + import urllib.error + + with patch("urllib.request.urlopen", side_effect=urllib.error.URLError("refused")): + with patch("time.sleep"): + with pytest.raises(RuntimeError, match="CP token endpoint unreachable"): + _fetch_token_from_cp( + "http://cp.svc:8080/token", public_pem, "session-12345678" + ) + + def test_includes_http_error_body_in_exception(self): + _, _, public_pem = generate_keypair() + import urllib.error + + err_body = b"unauthorized: invalid token" + http_err = urllib.error.HTTPError( + url="http://cp.svc:8080/token", + code=401, + msg="Unauthorized", + hdrs=None, + fp=MagicMock(read=MagicMock(return_value=err_body)), + ) + + with patch("urllib.request.urlopen", side_effect=http_err): + with patch("time.sleep"): + with pytest.raises(RuntimeError, match="CP /token HTTP 401"): + _fetch_token_from_cp( + "http://cp.svc:8080/token", public_pem, "session-12345678" + ) + + def test_raises_on_missing_token_field(self): + _, _, public_pem = generate_keypair() + mock_resp = MagicMock() + mock_resp.read.return_value = json.dumps({"other": "field"}).encode() + mock_resp.__enter__ = MagicMock(return_value=mock_resp) + mock_resp.__exit__ = MagicMock(return_value=False) + + with patch("urllib.request.urlopen", return_value=mock_resp): + with patch("time.sleep"): + with pytest.raises(RuntimeError, match="missing 'token' field"): + _fetch_token_from_cp( + "http://cp.svc:8080/token", public_pem, "session-12345678" + ) + + +class TestFromEnvIntegration: + def test_uses_encrypted_session_id_when_cp_token_url_set(self): + _, _, public_pem = generate_keypair() + mock_resp = MagicMock() + mock_resp.read.return_value = json.dumps({"token": "env-token"}).encode() + mock_resp.__enter__ = MagicMock(return_value=mock_resp) + mock_resp.__exit__ = MagicMock(return_value=False) + + env = { + "AMBIENT_GRPC_URL": "localhost:9000", + "AMBIENT_CP_TOKEN_URL": "http://cp.svc:8080/token", + "AMBIENT_CP_TOKEN_PUBLIC_KEY": public_pem, + "SESSION_ID": "session-test-1234", + "AMBIENT_GRPC_USE_TLS": "false", + } + + with patch.dict(os.environ, env, clear=False): + with patch("urllib.request.urlopen", return_value=mock_resp): + from ambient_runner._grpc_client import AmbientGRPCClient + + client = AmbientGRPCClient.from_env() + + assert client._token == "env-token" + + def test_falls_back_to_bot_token_when_no_cp_url(self): + env = { + "AMBIENT_GRPC_URL": "localhost:9000", + "BOT_TOKEN": "static-bot-token", + "AMBIENT_GRPC_USE_TLS": "false", + } + env_without_cp = {k: v for k, v in env.items()} + + with patch.dict(os.environ, env_without_cp, clear=False): + with patch.dict(os.environ, {"AMBIENT_CP_TOKEN_URL": ""}, clear=False): + from ambient_runner._grpc_client import AmbientGRPCClient + + client = AmbientGRPCClient.from_env() + + assert client._token == "static-bot-token" diff --git a/components/runners/ambient-runner/uv.lock b/components/runners/ambient-runner/uv.lock index 67b949620..39bc614f2 100644 --- a/components/runners/ambient-runner/uv.lock +++ b/components/runners/ambient-runner/uv.lock @@ -145,6 +145,7 @@ source = { editable = "." } dependencies = [ { name = "ag-ui-protocol" }, { name = "aiohttp" }, + { name = "cryptography" }, { name = "fastapi" }, { name = "grpcio" }, { name = "protobuf" }, @@ -189,6 +190,7 @@ requires-dist = [ { name = "ambient-runner", extras = ["claude", "observability", "mcp-atlassian"], marker = "extra == 'all'" }, { name = "anthropic", extras = ["vertex"], marker = "extra == 'claude'", specifier = ">=0.86.0" }, { name = "claude-agent-sdk", marker = "extra == 'claude'", specifier = ">=0.1.50" }, + { name = "cryptography", specifier = ">=42.0.0" }, { name = "fastapi", specifier = ">=0.100.0" }, { name = "grpcio", specifier = ">=1.60.0" }, { name = "langfuse", marker = "extra == 'observability'", specifier = ">=3.0.0" },