@@ -16,6 +16,7 @@ package proxy
1616
1717import (
1818 "context"
19+ "errors"
1920 "fmt"
2021 "io"
2122 "net"
@@ -637,7 +638,7 @@ func (c *Client) Serve(ctx context.Context, notify func()) error {
637638 for _ , m := range c .mnts {
638639 go func (mnt * socketMount ) {
639640 err := c .serveSocketMount (ctx , mnt )
640- if err != nil {
641+ if err != nil && errors . Is () {
641642 select {
642643 // Best effort attempt to send error.
643644 // If this send fails, it means the reading goroutine has
@@ -731,50 +732,63 @@ func (c *Client) Close() error {
731732
732733// serveSocketMount persistently listens to the socketMounts listener and proxies connections to a
733734// given Cloud SQL instance.
734- func (c * Client ) serveSocketMount (_ context.Context , s * socketMount ) error {
735+ func (c * Client ) serveSocketMount (ctx context.Context , s * socketMount ) error {
735736 for {
736- cConn , err := s .Accept ()
737- if err != nil {
738- if nerr , ok := err .(net.Error ); ok && nerr .Timeout () {
739- c .logger .Errorf ("[%s] Error accepting connection: %v" , s .inst , err )
740- // For transient errors, wait a small amount of time to see if it resolves itself
741- time .Sleep (10 * time .Millisecond )
742- continue
743- }
744- return err
745- }
746- // handle the connection in a separate goroutine
747- go func () {
748- c .logger .Infof ("[%s] Accepted connection from %s" , s .inst , cConn .RemoteAddr ())
749-
750- // A client has established a connection to the local socket. Before
751- // we initiate a connection to the Cloud SQL backend, increment the
752- // connection counter. If the total number of connections exceeds
753- // the maximum, refuse to connect and close the client connection.
754- count := atomic .AddUint64 (& c .connCount , 1 )
755- defer atomic .AddUint64 (& c .connCount , ^ uint64 (0 ))
756-
757- if c .conf .MaxConnections > 0 && count > c .conf .MaxConnections {
758- c .logger .Infof ("max connections (%v) exceeded, refusing new connection" , c .conf .MaxConnections )
759- if c .connRefuseNotify != nil {
760- go c .connRefuseNotify ()
737+ select {
738+ case <- ctx .Done ():
739+ // If the context was canceled, do not accept any more connections,
740+ // exit gracefully.
741+ return nil
742+ default :
743+ // Wait to accept a connection. When s.Accept() returns io.EOF, exit
744+ // gracefully.
745+ cConn , err := s .Accept ()
746+ if err != nil {
747+ if nerr , ok := err .(net.Error ); ok && nerr .Timeout () {
748+ c .logger .Errorf ("[%s] Error accepting connection: %v" , s .inst , err )
749+ // For transient errors, wait a small amount of time to see if it resolves itself
750+ time .Sleep (10 * time .Millisecond )
751+ continue
752+ } else if err == io .EOF {
753+ // The socket was closed gracefully. Stop processing connections.
754+ return nil
761755 }
762- _ = cConn .Close ()
763- return
756+ return err
764757 }
765758
766- // give a max of 30 seconds to connect to the instance
767- ctx , cancel := context .WithTimeout (context .Background (), 30 * time .Second )
768- defer cancel ()
759+ // handle the connection in a separate goroutine
760+ go func () {
761+ c .logger .Infof ("[%s] Accepted connection from %s" , s .inst , cConn .RemoteAddr ())
762+
763+ // A client has established a connection to the local socket. Before
764+ // we initiate a connection to the Cloud SQL backend, increment the
765+ // connection counter. If the total number of connections exceeds
766+ // the maximum, refuse to connect and close the client connection.
767+ count := atomic .AddUint64 (& c .connCount , 1 )
768+ defer atomic .AddUint64 (& c .connCount , ^ uint64 (0 ))
769+
770+ if c .conf .MaxConnections > 0 && count > c .conf .MaxConnections {
771+ c .logger .Infof ("max connections (%v) exceeded, refusing new connection" , c .conf .MaxConnections )
772+ if c .connRefuseNotify != nil {
773+ go c .connRefuseNotify ()
774+ }
775+ _ = cConn .Close ()
776+ return
777+ }
769778
770- sConn , err := c .dialer .Dial (ctx , s .inst , s .dialOpts ... )
771- if err != nil {
772- c .logger .Errorf ("[%s] failed to connect to instance: %v" , s .inst , err )
773- _ = cConn .Close ()
774- return
775- }
776- c .proxyConn (s .inst , cConn , sConn )
777- }()
779+ // give a max of 30 seconds to connect to the instance
780+ ctx , cancel := context .WithTimeout (context .Background (), 30 * time .Second )
781+ defer cancel ()
782+
783+ sConn , err := c .dialer .Dial (ctx , s .inst , s .dialOpts ... )
784+ if err != nil {
785+ c .logger .Errorf ("[%s] failed to connect to instance: %v" , s .inst , err )
786+ _ = cConn .Close ()
787+ return
788+ }
789+ c .proxyConn (s .inst , cConn , sConn )
790+ }()
791+ }
778792 }
779793}
780794
0 commit comments