Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/kafka-proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ func initFlags() {
// Connect through Socks5 or HTTP CONNECT to Kafka
Server.Flags().StringVar(&c.ForwardProxy.Url, "forward-proxy", "", "URL of the forward proxy. Supported schemas are socks5 and http")

viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
viper.SetEnvKeyReplacer(strings.NewReplacer("_", "-"))
viper.AutomaticEnv() // read in environment variables that match
}

Expand Down
3 changes: 3 additions & 0 deletions proxy/protocol/request_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ func (r *Request) decode(pd packetDecoder) (err error) {
if version, err = pd.getInt16(); err != nil {
return err
}
if version == -1 {
version = r.Body.version()
}
if r.Body.key() != key || r.Body.version() != version {
return PacketDecodingError{fmt.Sprintf("expected request key,version %d,%d but got %d,%d", r.Body.key(), r.Body.version(), key, version)}
}
Expand Down
3 changes: 3 additions & 0 deletions proxy/protocol/request_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ func (r *RequestV2) decode(pd packetDecoder) (err error) {
if version, err = pd.getInt16(); err != nil {
return err
}
if version == -1 {
version = r.Body.version()
}
if r.Body.key() != key || r.Body.version() != version {
return PacketDecodingError{fmt.Sprintf("expected request key,version %d,%d but got %d,%d", r.Body.key(), r.Body.version(), key, version)}
}
Expand Down
92 changes: 92 additions & 0 deletions proxy/protocol/responses.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
const (
apiKeyMetadata = 3
apiKeyFindCoordinator = 10
apiKeySaslHandshake = 17
apiKeyApiVersions = 18

brokersKeyName = "brokers"
hostKeyName = "host"
Expand All @@ -23,8 +25,62 @@ const (
var (
metadataResponseSchemaVersions = createMetadataResponseSchemaVersions()
findCoordinatorResponseSchemaVersions = createFindCoordinatorResponseSchemaVersions()
apiVersionsResponseSchemaVersions = createApiVersionsResponseSchemaVersions()
)

func createApiVersionsResponseSchemaVersions() []Schema {
apiVersionV0 := NewSchema("api_version",
&Mfield{Name: "api_key", Ty: TypeInt16},
&Mfield{Name: "min_version", Ty: TypeInt16},
&Mfield{Name: "max_version", Ty: TypeInt16},
)

// Version 0: error_code + api_keys
apiVersionsResponseV0 := NewSchema("api_versions_response_v0",
&Mfield{Name: "error_code", Ty: TypeInt16},
&Array{Name: "api_keys", Ty: apiVersionV0},
)

// Version 1: error_code + api_keys + throttle_time_ms
apiVersionsResponseV1 := NewSchema("api_versions_response_v1",
&Mfield{Name: "error_code", Ty: TypeInt16},
&Array{Name: "api_keys", Ty: apiVersionV0},
&Mfield{Name: "throttle_time_ms", Ty: TypeInt32},
)

// Version 2: Same as version 1
apiVersionsResponseV2 := apiVersionsResponseV1

// ApiVersion struct for flexible versions (v3+) with compact arrays
apiVersionV3 := NewSchema("api_version_v3",
&Mfield{Name: "api_key", Ty: TypeInt16},
&Mfield{Name: "min_version", Ty: TypeInt16},
&Mfield{Name: "max_version", Ty: TypeInt16},
&SchemaTaggedFields{Name: "api_version_tagged_fields"},
)

// Version 3: Flexible version with tagged fields
// Tagged fields: supported_features (tag 0), finalized_features_epoch (tag 1),
// finalized_features (tag 2), zk_migration_ready (tag 3)
apiVersionsResponseV3 := NewSchema("api_versions_response_v3",
&Mfield{Name: "error_code", Ty: TypeInt16},
&CompactArray{Name: "api_keys", Ty: apiVersionV3},
&Mfield{Name: "throttle_time_ms", Ty: TypeInt32},
&SchemaTaggedFields{Name: "response_tagged_fields"},
)

// Version 4: Same as version 3
apiVersionsResponseV4 := apiVersionsResponseV3

return []Schema{
apiVersionsResponseV0,
apiVersionsResponseV1,
apiVersionsResponseV2,
apiVersionsResponseV3,
apiVersionsResponseV4,
}
}

func createMetadataResponseSchemaVersions() []Schema {
metadataBrokerV0 := NewSchema("metadata_broker_v0",
&Mfield{Name: nodeKeyName, Ty: TypeInt32},
Expand Down Expand Up @@ -325,6 +381,40 @@ func createFindCoordinatorResponseSchemaVersions() []Schema {
return []Schema{findCoordinatorResponseV0, findCoordinatorResponseV1, findCoordinatorResponseV2, findCoordinatorResponseV3, findCoordinatorResponseV4, findCoordinatorResponseV5, findCoordinatorResponseV6}
}

func modifyApiVersionsResponse(decodedStruct *Struct, fn config.NetAddressMappingFunc) error {
if decodedStruct == nil {
return errors.New("decoded struct must not be nil")
}

versions, ok := decodedStruct.Get("api_keys").([]any)
if !ok || len(versions) == 0 {
return errors.New("versions not found")
}
for _, versionElement := range versions {
version := versionElement.(*Struct)
if version.Get("api_key").(int16) == apiKeySaslHandshake {
return nil
}
}

schema := versions[0].(*Struct).GetSchema()

// v1 Sasl auth does not seem to work with KafkaJS so pin to v0
values := []any{int16(17), int16(0), int16(0)}

// version 3+ of the api versions response
if len(schema.GetFields()) > 3 {
values = append(values, []rawTaggedField{})
}

versions = append(versions, &Struct{
Schema: schema,
Values: values,
})

return decodedStruct.Replace("api_keys", versions)
}

func modifyMetadataResponse(decodedStruct *Struct, fn config.NetAddressMappingFunc) error {
if decodedStruct == nil {
return errors.New("decoded struct must not be nil")
Expand Down Expand Up @@ -467,6 +557,8 @@ func (f *responseModifier) Apply(resp []byte) ([]byte, error) {

func GetResponseModifier(apiKey int16, apiVersion int16, addressMappingFunc config.NetAddressMappingFunc) (ResponseModifier, error) {
switch apiKey {
case apiKeyApiVersions:
return newResponseModifier(apiKey, apiVersion, addressMappingFunc, apiVersionsResponseSchemaVersions, modifyApiVersionsResponse)
case apiKeyMetadata:
return newResponseModifier(apiKey, apiVersion, addressMappingFunc, metadataResponseSchemaVersions, modifyMetadataResponse)
case apiKeyFindCoordinator:
Expand Down
9 changes: 7 additions & 2 deletions proxy/sasl_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ import (
"bytes"
"encoding/binary"
"fmt"
"io"
"time"

"github.com/grepplabs/kafka-proxy/pkg/apis"
"github.com/grepplabs/kafka-proxy/proxy/protocol"
"github.com/pkg/errors"
"io"
"time"
)

type LocalSasl struct {
Expand Down Expand Up @@ -146,6 +147,10 @@ func (p *LocalSasl) receiveAndSendAuthV1(conn DeadlineReaderWriter, localSaslAut
return errors.Errorf("SaslAuthenticate is expected, but got apiKey %d", requestKeyVersion.ApiKey)
}

if requestKeyVersion.ApiVersion == -1 {
requestKeyVersion.ApiVersion = 1
}

if requestKeyVersion.Length > protocol.MaxRequestSize {
return protocol.PacketDecodingError{Info: fmt.Sprintf("sasl authenticate message of length %d too large", requestKeyVersion.Length)}
}
Expand Down