diff --git a/config.go b/config.go index f5b45cc..f2f129f 100644 --- a/config.go +++ b/config.go @@ -40,10 +40,26 @@ func (c *Config) InitDefaults() { } } +// dsn is a parsed "scheme://address" RPC listen string. +type dsn struct { + scheme string + addr string +} + +// parseDSN splits a "scheme://address" listen string into its scheme and +// address. It errors unless the string contains exactly one "://" separator. +func parseDSN(listen string) (dsn, error) { + scheme, addr, ok := strings.Cut(listen, "://") + if !ok || strings.Contains(addr, "://") { + return dsn{}, errors.New("invalid socket DSN (tcp://:6001, unix://file.sock)") + } + return dsn{scheme: scheme, addr: addr}, nil +} + // Valid returns nil if config is valid. func (c *Config) Valid() error { - if dsn := strings.Split(c.Listen, "://"); len(dsn) != 2 { - return errors.New("invalid socket DSN (tcp://:6001, unix://file.sock)") + if _, err := parseDSN(c.Listen); err != nil { + return err } if c.RequestTimeout < 0 { return errors.New("rpc request_timeout must be non-negative") @@ -63,10 +79,10 @@ func (c *Config) Listener() (net.Listener, error) { // Dialer creates rpc socket Dialer. func (c *Config) Dialer() (net.Conn, error) { - dsn := strings.Split(c.Listen, "://") - if len(dsn) != 2 { - return nil, errors.New("invalid socket DSN (tcp://:6001, unix://file.sock)") + parsed, err := parseDSN(c.Listen) + if err != nil { + return nil, err } var d net.Dialer - return d.DialContext(context.Background(), dsn[0], dsn[1]) + return d.DialContext(context.Background(), parsed.scheme, parsed.addr) } diff --git a/go.work.sum b/go.work.sum index f0a183d..fce1f3f 100644 --- a/go.work.sum +++ b/go.work.sum @@ -364,6 +364,8 @@ golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOM golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= +golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6 h1:QE6XYQK6naiK1EPAe1g/ILLxN5RBoH5xkJk3CqlMI/Y= golang.org/x/image v0.0.0-20190802002840-cff245a6509b h1:+qEpEAPhDZ1o0x3tHzZTQDArnOixOzGD9HUJfcg0mb4= golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5 h1:2M3HP5CCK1Si9FQhwnzYhXdG6DXeebvUHFpre8QvbyI= @@ -416,6 +418,8 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E= golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= +golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4= +golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk= golang.org/x/time v0.1.0 h1:xYY+Bajn2a7VBmTM5GikTmnK8ZuX8YgnQCqZpbBNtmA= golang.org/x/time v0.1.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= diff --git a/plugin.go b/plugin.go index e590aff..526d1a1 100644 --- a/plugin.go +++ b/plugin.go @@ -85,13 +85,13 @@ func (s *Plugin) Init(cfg Configurer, log Logger) error { return errors.E(op, err) } - var WholeCfg any - err = cfg.Unmarshal(&WholeCfg) + var wholeCfg any + err = cfg.Unmarshal(&wholeCfg) if err != nil { return errors.E(op, err) } - s.wcfg, err = json.Marshal(WholeCfg) + s.wcfg, err = json.Marshal(wholeCfg) if err != nil { return err } @@ -125,10 +125,7 @@ func (s *Plugin) Serve() chan error { mux.Handle(path, handler) // derive the gRPC service name from the mount path // (`//` or `//`) - svc := strings.TrimPrefix(path, "/") - if i := strings.Index(svc, "/"); i >= 0 { - svc = svc[:i] - } + svc, _, _ := strings.Cut(strings.TrimPrefix(path, "/"), "/") services = append(services, svc) } diff --git a/tests/config_test.go b/tests/config_test.go index 6d23f6d..9482e4e 100644 --- a/tests/config_test.go +++ b/tests/config_test.go @@ -156,6 +156,18 @@ func Test_Config_DialerErrorMethod(t *testing.T) { assert.Error(t, err) } +func Test_Config_MultipleSeparators(t *testing.T) { + // A DSN with more than one "://" must be rejected by both Valid and Dialer. + cfg := &rpc.Config{Listen: "tcp://host://6001"} + + assert.Error(t, cfg.Valid()) + + conn, err := cfg.Dialer() + assert.Nil(t, conn) + assert.Error(t, err) + assert.Equal(t, "invalid socket DSN (tcp://:6001, unix://file.sock)", err.Error()) +} + func Test_Config_Defaults(t *testing.T) { c := &rpc.Config{} c.InitDefaults()