Skip to content

Commit e3e6e71

Browse files
authored
Add named pipes and shared memory connections (#188)
* add named pipe support * update go-mssqldb * add shared memory * generalize protocol handling * update NOTICE
1 parent 578326e commit e3e6e71

File tree

12 files changed

+148
-44
lines changed

12 files changed

+148
-44
lines changed

.vscode/launch.json

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
55
"version": "0.2.0",
66
"configurations": [
7+
78

89

910
{
@@ -24,16 +25,16 @@
2425
"type" : "go",
2526
"request": "launch",
2627
"mode" : "auto",
27-
"program": "${workspaceFolder}/cmd/sqlcmd",
28-
"args" : ["-Q", "EXIT(select 100 as Count)"],
28+
"program": "${workspaceFolder}/cmd/modern",
29+
"args" : ["-Q", "EXIT(select net_transport from sys.dm_exec_connections)"],
2930
},
3031
{
3132
"name" : "Run file query",
3233
"type" : "go",
3334
"request": "launch",
3435
"mode" : "auto",
35-
"program": "${workspaceFolder}/cmd/sqlcmd",
36-
"args" : ["-i", "testdata\\select100.sql"],
36+
"program": "${workspaceFolder}/cmd/modern",
37+
"args" : ["-S", "np:.", "-i", "${workspaceFolder}/cmd/sqlcmd/testdata/select100.sql"],
3738
},
3839
]
3940
}

NOTICE.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5003,6 +5003,23 @@ third-party archives.
50035003
50045004
```
50055005

5006+
## gopkg.in/natefinch/npipe.v2
5007+
5008+
* Name: gopkg.in/natefinch/npipe.v2
5009+
* Version: v2.0.0-20160621034901-c1b8fa8bdcce
5010+
* License: [MIT](https://github.com/natefinch/npipe/blob/c1b8fa8bdcce/LICENSE.txt)
5011+
5012+
```
5013+
The MIT License (MIT)
5014+
Copyright (c) 2013 npipe authors
5015+
5016+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
5017+
5018+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
5019+
5020+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
5021+
```
5022+
50065023
## gopkg.in/yaml.v2
50075024

50085025
* Name: gopkg.in/yaml.v2

README.md

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,28 @@ We will be implementing command line switches and behaviors over time. Several s
3434
- The new `--driver-logging-level` command line parameter allows you to see traces from the `go-mssqldb` client driver. Use `64` to see all traces.
3535
- Sqlcmd can now print results using a vertical format. Use the new `-F vertical` command line option to set it. It's also controlled by the `SQLCMDFORMAT` scripting variable.
3636

37+
```
38+
39+
1> select session_id, client_interface_name, program_name from sys.dm_exec_sessions where session_id=@@spid
40+
2> go
41+
session_id 58
42+
client_interface_name go-mssqldb
43+
program_name sqlcmd
44+
45+
```
46+
- Sqlcmd now supports shared memory and named pipe transport. Use the appropriate protocol prefix on the server name to force a protocol
47+
* `lpc` for shared memory, only for a localhost. `sqlcmd -S lpc:.`
48+
* `np` for named pipes. Or use the UNC named pipe path as the server name: `sqlcmd -S \\myserver\pipe\sql\query`
49+
* `tcp` for tcp `sqlcmd -S tcp:myserver,1234`
50+
If no protocol is specified, sqlcmd will attempt to dial in this order: lpc->np->tcp. If dialing a remote host, `lpc` will be skipped.
51+
52+
```
53+
1> select net_transport from sys.dm_exec_connections where session_id=@@spid
54+
2> go
55+
net_transport Named pipe
56+
57+
```
58+
3759
### Azure Active Directory Authentication
3860

3961
This version of sqlcmd supports a broader range of AAD authentication models, based on the [azidentity package](https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity). The implementation relies on an AAD Connector in the [driver](https://github.com/microsoft/go-mssqldb).
@@ -105,18 +127,13 @@ pkg/sqlcmd is consumable by other hosts. Go docs for the package are forthcoming
105127

106128
## Building
107129

108-
To add version data to your build using `go-winres`, add `GOBIN` to your `PATH` then use `go generate`
109-
The version on the binary will match the version tag of the branch.
110130

111131
```sh
112132

113-
go install github.com/tc-hib/go-winres@latest
114-
cd cmd/modern
115-
go generate
133+
build/build
116134

117135
```
118136

119-
Scripts to build the binaries and package them for release will be added in a build folder off the root. We will also add Azure Devops pipeline yml files there to initiate builds and releases. Until then just use `go build ./cmd/sqlcmd` to create a sqlcmd binary.
120137

121138
## Testing
122139

cmd/sqlcmd/sqlcmd.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ type SQLCmdArguments struct {
3232
InitialQuery string `short:"q" xor:"input1" help:"Executes a query when sqlcmd starts, but does not exit sqlcmd when the query has finished running. Multiple-semicolon-delimited queries can be executed."`
3333
// Query to run then exit
3434
Query string `short:"Q" xor:"input2" help:"Executes a query when sqlcmd starts and then immediately exits sqlcmd. Multiple-semicolon-delimited queries can be executed."`
35-
Server string `short:"S" help:"[tcp:]server[\\instance_name][,port]Specifies the instance of SQL Server to which to connect. It sets the sqlcmd scripting variable SQLCMDSERVER."`
35+
Server string `short:"S" help:"[[tcp:]|[lpc:]|[np:]]server[\\instance_name][,port]Specifies the instance of SQL Server to which to connect. It sets the sqlcmd scripting variable SQLCMDSERVER."`
3636
// Disable syscommands with a warning
3737
DisableCmdAndWarn bool `short:"X" xor:"syscmd" help:"Disables commands that might compromise system security. Sqlcmd issues a warning and continues."`
3838
// AuthenticationMethod is new for go-sqlcmd
@@ -291,6 +291,7 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
291291
// connect using no overrides
292292
err = s.ConnectDb(nil, line == nil)
293293
if err != nil {
294+
s.WriteError(s.GetError(), err)
294295
return 1, err
295296
}
296297

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ require (
6969
golang.org/x/tools v0.1.12 // indirect
7070
google.golang.org/protobuf v1.28.1 // indirect
7171
gopkg.in/ini.v1 v1.67.0 // indirect
72+
gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce // indirect
7273
gopkg.in/yaml.v3 v3.0.1 // indirect
7374
gotest.tools/v3 v3.4.0 // indirect
7475
)

pkg/sqlcmd/connect.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package sqlcmd
66
import (
77
"fmt"
88
"net/url"
9+
"strings"
910

1011
"github.com/microsoft/go-mssqldb/azuread"
1112
)
@@ -81,7 +82,7 @@ func (connect ConnectSettings) RequiresPassword() bool {
8182

8283
// ConnectionString returns the go-mssql connection string to use for queries
8384
func (connect ConnectSettings) ConnectionString() (connectionString string, err error) {
84-
serverName, instance, port, err := splitServer(connect.ServerName)
85+
serverName, instance, port, protocol, err := splitServer(connect.ServerName)
8586
if serverName == "" {
8687
serverName = "."
8788
}
@@ -100,6 +101,16 @@ func (connect ConnectSettings) ConnectionString() (connectionString string, err
100101
if (connect.authenticationMethod() == azuread.ActiveDirectoryMSI || connect.authenticationMethod() == azuread.ActiveDirectoryManagedIdentity) && connect.UserName != "" {
101102
connectionURL.User = url.UserPassword(connect.UserName, connect.Password)
102103
}
104+
105+
if strings.HasPrefix(serverName, `\\`) {
106+
// passing a pipe name of the format \\server\pipe\<pipename>
107+
pipeParts := strings.SplitN(string(serverName[2:]), `\`, 3)
108+
if len(pipeParts) != 3 {
109+
return "", &InvalidServerName
110+
}
111+
serverName = pipeParts[0]
112+
query.Add("pipe", pipeParts[2])
113+
}
103114
if port > 0 {
104115
connectionURL.Host = fmt.Sprintf("%s:%d", serverName, port)
105116
} else {
@@ -130,6 +141,9 @@ func (connect ConnectSettings) ConnectionString() (connectionString string, err
130141
if connect.LogLevel > 0 {
131142
query.Add("log", fmt.Sprint(connect.LogLevel))
132143
}
144+
if protocol != "" {
145+
query.Add("protocol", protocol)
146+
}
133147
if connect.ApplicationName != "" {
134148
query.Add(`app name`, connect.ApplicationName)
135149
}

pkg/sqlcmd/errors.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ func (e *ArgumentError) IsSqlcmdErr() bool {
5252
// InvalidServerName indicates the SQLCMDSERVER variable has an incorrect format
5353
var InvalidServerName = ArgumentError{
5454
Parameter: "server",
55-
Rule: "server must be of the form [tcp]:server[[/instance]|[,port]]",
55+
Rule: "server must be of the form [[np]|[lpc][tcp]]:server[[/instance]|[,port]]",
5656
}
5757

5858
// VariableError is an error about scripting variables

pkg/sqlcmd/sqlcmd.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,15 @@ import (
2121
"github.com/golang-sql/sqlexp"
2222
mssql "github.com/microsoft/go-mssqldb"
2323
"github.com/microsoft/go-mssqldb/msdsn"
24+
_ "github.com/microsoft/go-mssqldb/namedpipe"
25+
_ "github.com/microsoft/go-mssqldb/sharedmemory"
2426
"golang.org/x/text/encoding/unicode"
2527
"golang.org/x/text/transform"
2628
)
2729

30+
// Note: The order of includes above matters for namedpipe and sharedmemory.
31+
// init() swaps shared memory protocol with tcp so it gets priority when dialing.
32+
2833
var (
2934
// ErrExitRequested tells the hosting application to exit immediately
3035
ErrExitRequested = errors.New("exit")
@@ -534,3 +539,13 @@ func (s Sqlcmd) Log(_ context.Context, _ msdsn.Log, msg string) {
534539
_, _ = s.GetOutput().Write([]byte("DRIVER:" + msg))
535540
_, _ = s.GetOutput().Write([]byte(SqlcmdEol))
536541
}
542+
543+
func init() {
544+
if len(msdsn.ProtocolParsers) == 3 {
545+
// reorder the protocol parsers to lpc->np->tcp
546+
// ODBC follows this same order.
547+
var tcp = msdsn.ProtocolParsers[0]
548+
msdsn.ProtocolParsers[0] = msdsn.ProtocolParsers[2]
549+
msdsn.ProtocolParsers[2] = tcp
550+
}
551+
}

pkg/sqlcmd/sqlcmd_test.go

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ import (
1010
"io"
1111
"os"
1212
"os/user"
13+
"runtime"
1314
"strings"
1415
"testing"
1516

1617
"github.com/microsoft/go-mssqldb/azuread"
18+
"github.com/microsoft/go-mssqldb/msdsn"
1719

1820
"github.com/google/uuid"
1921
"github.com/stretchr/testify/assert"
@@ -46,16 +48,20 @@ func TestConnectionStringFromSqlCmd(t *testing.T) {
4648
},
4749
{
4850
&ConnectSettings{TrustServerCertificate: true, UseTrustedConnection: true, Password: pwd, ServerName: `tcp:someserver,1045`, UserName: "someuser"},
49-
"sqlserver://someserver:1045?trustservercertificate=true",
51+
"sqlserver://someserver:1045?protocol=tcp&trustservercertificate=true",
5052
},
5153
{
5254
&ConnectSettings{ServerName: `tcp:someserver,1045`},
53-
"sqlserver://someserver:1045",
55+
"sqlserver://someserver:1045?protocol=tcp",
5456
},
5557
{
5658
&ConnectSettings{ServerName: "someserver", AuthenticationMethod: azuread.ActiveDirectoryServicePrincipal, UserName: "myapp@mytenant", Password: pwd},
5759
fmt.Sprintf("sqlserver://myapp%%40mytenant:%s@someserver", pwd),
5860
},
61+
{
62+
&ConnectSettings{ServerName: `\\someserver\pipe\sql\query`},
63+
"sqlserver://someserver?pipe=sql%5Cquery&protocol=np",
64+
},
5965
}
6066

6167
for i, test := range commands {
@@ -356,16 +362,20 @@ func TestPromptForPasswordNegative(t *testing.T) {
356362
}
357363
v := InitializeVariables(true)
358364
s := New(console, "", v)
365+
c := newConnect(t)
359366
s.Connect.UserName = "someuser"
367+
s.Connect.ServerName = c.ServerName
360368
err := s.ConnectDb(nil, false)
361369
assert.True(t, prompted, "Password prompt not shown for SQL auth")
362370
assert.Error(t, err, "ConnectDb")
363371
prompted = false
364-
s.Connect.AuthenticationMethod = azuread.ActiveDirectoryPassword
365-
err = s.ConnectDb(nil, false)
366-
assert.True(t, prompted, "Password prompt not shown for AD Password auth")
367-
assert.Error(t, err, "ConnectDb")
368-
prompted = false
372+
if canTestAzureAuth() {
373+
s.Connect.AuthenticationMethod = azuread.ActiveDirectoryPassword
374+
err = s.ConnectDb(nil, false)
375+
assert.True(t, prompted, "Password prompt not shown for AD Password auth")
376+
assert.Error(t, err, "ConnectDb")
377+
prompted = false
378+
}
369379
}
370380

371381
func TestPromptForPasswordPositive(t *testing.T) {
@@ -619,3 +629,12 @@ func newConnect(t testing.TB) *ConnectSettings {
619629
}
620630
return &connect
621631
}
632+
633+
func TestSqlcmdPrefersSharedMemoryProtocol(t *testing.T) {
634+
if runtime.GOOS != "windows" {
635+
t.Skip()
636+
}
637+
assert.EqualValuesf(t, "lpc", msdsn.ProtocolParsers[0].Protocol(), "lpc should be first protocol")
638+
assert.EqualValuesf(t, "np", msdsn.ProtocolParsers[1].Protocol(), "np should be second protocol")
639+
assert.EqualValuesf(t, "tcp", msdsn.ProtocolParsers[2].Protocol(), "tcp should be third protocol")
640+
}

pkg/sqlcmd/util.go

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,40 +6,57 @@ package sqlcmd
66
import (
77
"strconv"
88
"strings"
9+
10+
"github.com/microsoft/go-mssqldb/msdsn"
911
)
1012

1113
// splitServer extracts connection parameters from a server name input
12-
func splitServer(serverName string) (string, string, uint64, error) {
13-
instance := ""
14-
port := uint64(0)
15-
if strings.HasPrefix(serverName, "tcp:") {
16-
if len(serverName) == 4 {
17-
return "", "", 0, &InvalidServerName
14+
func splitServer(serverName string) (string, instance string, port uint64, protocol string, err error) {
15+
instance = ""
16+
port = uint64(0)
17+
protocol = ""
18+
err = nil
19+
// We don't just look for : due to possible IPv6 address
20+
for _, p := range msdsn.ProtocolParsers {
21+
prefix := p.Protocol() + ":"
22+
if strings.HasPrefix(serverName, prefix) {
23+
if len(serverName) == len(prefix) {
24+
serverName = "."
25+
} else {
26+
serverName = serverName[len(prefix):]
27+
}
28+
protocol = p.Protocol()
1829
}
19-
serverName = serverName[4:]
20-
}
21-
serverNameParts := strings.Split(serverName, ",")
22-
if len(serverNameParts) > 2 {
23-
return "", "", 0, &InvalidServerName
2430
}
25-
if len(serverNameParts) == 2 {
26-
var err error
27-
port, err = strconv.ParseUint(serverNameParts[1], 10, 16)
28-
if err != nil {
29-
return "", "", 0, &InvalidServerName
31+
if strings.HasPrefix(serverName, `\\`) {
32+
if protocol != "np" && protocol != "" || len(serverName) == 2 {
33+
return "", "", 0, "", &InvalidServerName
3034
}
31-
serverName = serverNameParts[0]
35+
protocol = "np"
3236
} else {
33-
serverNameParts = strings.Split(serverName, "\\")
37+
serverNameParts := strings.Split(serverName, ",")
3438
if len(serverNameParts) > 2 {
35-
return "", "", 0, &InvalidServerName
39+
return "", "", 0, "", &InvalidServerName
3640
}
3741
if len(serverNameParts) == 2 {
38-
instance = serverNameParts[1]
42+
var err error
43+
port, err = strconv.ParseUint(serverNameParts[1], 10, 16)
44+
if err != nil {
45+
return "", "", 0, "", &InvalidServerName
46+
}
3947
serverName = serverNameParts[0]
48+
} else {
49+
serverNameParts = strings.Split(serverName, "\\")
50+
if len(serverNameParts) > 2 {
51+
return "", "", 0, "", &InvalidServerName
52+
}
53+
if len(serverNameParts) == 2 {
54+
instance = serverNameParts[1]
55+
serverName = serverNameParts[0]
56+
}
4057
}
4158
}
42-
return serverName, instance, port, nil
59+
return serverName, instance, port, protocol, err
4360
}
4461

4562
// padRight appends c instances of s to builder

pkg/sqlcmd/variables.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ func (v Variables) SQLCmdUser() string {
109109
}
110110

111111
// SQLCmdServer returns the server connection parameters derived from the SQLCMDSERVER variable value
112-
func (v Variables) SQLCmdServer() (serverName string, instance string, port uint64, err error) {
112+
func (v Variables) SQLCmdServer() (serverName string, instance string, port uint64, protocol string, err error) {
113113
serverName = v[SQLCMDSERVER]
114114
return splitServer(serverName)
115115
}

pkg/sqlcmd/variables_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,22 @@ func TestSqlServerSplitsName(t *testing.T) {
4949
vars := Variables{
5050
SQLCMDSERVER: `tcp:someserver\someinstance`,
5151
}
52-
serverName, instance, port, err := vars.SQLCmdServer()
52+
serverName, instance, port, protocol, err := vars.SQLCmdServer()
5353
if assert.NoError(t, err, "tcp:server\\someinstance") {
5454
assert.Equal(t, "someserver", serverName, "server name for instance")
5555
assert.Equal(t, uint64(0), port, "port for instance")
5656
assert.Equal(t, "someinstance", instance, "instance for instance")
57+
assert.Equal(t, "tcp", protocol, "protocol for instance")
5758
}
5859
vars = Variables{
5960
SQLCMDSERVER: `tcp:someserver,1111`,
6061
}
61-
serverName, instance, port, err = vars.SQLCmdServer()
62+
serverName, instance, port, protocol, err = vars.SQLCmdServer()
6263
if assert.NoError(t, err, "tcp:server,1111") {
6364
assert.Equal(t, "someserver", serverName, "server name for port number")
6465
assert.Equal(t, uint64(1111), port, "port for port number")
6566
assert.Equal(t, "", instance, "instance for port number")
67+
assert.Equal(t, "tcp", protocol, "protocol for port number")
6668
}
6769
}
6870

0 commit comments

Comments
 (0)