Skip to content

Commit 9753917

Browse files
authored
Add --database to sqlcmd config cs (#280)
1 parent 21afc1f commit 9753917

File tree

10 files changed

+87
-21
lines changed

10 files changed

+87
-21
lines changed

cmd/modern/main.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/microsoft/go-sqlcmd/internal/io/file"
2020
"github.com/microsoft/go-sqlcmd/internal/output"
2121
"github.com/microsoft/go-sqlcmd/internal/output/verbosity"
22+
"github.com/microsoft/go-sqlcmd/internal/pal"
2223
"github.com/microsoft/go-sqlcmd/pkg/sqlcmd"
2324
"github.com/spf13/cobra"
2425
"path"
@@ -142,7 +143,7 @@ func initializeCallback() {
142143
// To aid debugging issues, if the logging level is > 2 (e.g. --verbosity 3 or --verbosity 4), we
143144
// panic which outputs a stacktrace.
144145
func checkErr(err error) {
145-
if rootCmd.loggingLevel > 2 {
146+
if rootCmd != nil && rootCmd.loggingLevel > 2 {
146147
if err != nil {
147148
panic(err)
148149
}
@@ -155,7 +156,7 @@ func checkErr(err error) {
155156
// to make progress. displayHints is injected into dependencies (helpers etc.)
156157
func displayHints(hints []string) {
157158
if len(hints) > 0 {
158-
outputter.Infof("%vHINT:", sqlcmd.SqlcmdEol)
159+
outputter.Infof("%vHINT:", pal.LineBreak())
159160
for i, hint := range hints {
160161
outputter.Infof(" %d. %v", i+1, hint)
161162
}

cmd/modern/main_test.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ package main
55

66
import (
77
"errors"
8+
"github.com/microsoft/go-sqlcmd/internal/buffer"
89
"github.com/microsoft/go-sqlcmd/internal/cmdparser"
910
"github.com/microsoft/go-sqlcmd/internal/cmdparser/dependency"
1011
"github.com/microsoft/go-sqlcmd/internal/output"
1112
"github.com/microsoft/go-sqlcmd/internal/pal"
12-
"github.com/microsoft/go-sqlcmd/internal/test"
1313
"github.com/stretchr/testify/assert"
1414
"os"
1515
"testing"
@@ -26,14 +26,15 @@ func TestInitializeCallback(t *testing.T) {
2626
}
2727

2828
func TestDisplayHints(t *testing.T) {
29-
buf := test.NewMemoryBuffer()
30-
defer checkErr(buf.Close())
29+
buf := buffer.NewMemoryBuffer()
3130
outputter = output.New(output.Options{StandardWriter: buf})
3231
displayHints([]string{"This is a hint"})
3332
assert.Equal(t, pal.LineBreak()+
3433
"HINT:"+
3534
pal.LineBreak()+
3635
" 1. This is a hint"+pal.LineBreak()+pal.LineBreak(), buf.String())
36+
err := buf.Close()
37+
checkErr(err)
3738
}
3839

3940
func TestCheckErr(t *testing.T) {

cmd/modern/root/config/connection-strings.go

+32-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ package config
66
import (
77
"fmt"
88
"github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig"
9+
"github.com/microsoft/go-sqlcmd/internal/container"
10+
"github.com/microsoft/go-sqlcmd/internal/sql"
911
"strings"
1012

1113
"github.com/microsoft/go-sqlcmd/internal/cmdparser"
@@ -17,6 +19,8 @@ import (
1719
// ConnectionStrings implements the `sqlcmd config connection-strings` command
1820
type ConnectionStrings struct {
1921
cmdparser.Cmd
22+
23+
database string
2024
}
2125

2226
func (c *ConnectionStrings) DefineCommand(...cmdparser.CommandOptions) {
@@ -36,6 +40,13 @@ func (c *ConnectionStrings) DefineCommand(...cmdparser.CommandOptions) {
3640
}
3741

3842
c.Cmd.DefineCommand(options)
43+
44+
c.AddFlag(cmdparser.FlagOptions{
45+
String: &c.database,
46+
Name: "database",
47+
DefaultString: "",
48+
Shorthand: "d",
49+
Usage: "Database for the connection string (default is taken from the T/SQL login)"})
3950
}
4051

4152
// run generates connection strings for the current context in multiple formats.
@@ -48,12 +59,27 @@ func (c *ConnectionStrings) run() {
4859
"ADO.NET": "Server=tcp:%s,%d;Initial Catalog=%s;Persist Security Info=False;User ID=%s;Password=%s;MultipleActiveResultSets=False;Encrypt=True;TrustServerCertificate=%s;Connection Timeout=30;",
4960
"JDBC": "jdbc:sqlserver://%s:%d;database=%s;user=%s;password=%s;encrypt=true;trustServerCertificate=%s;loginTimeout=30;",
5061
"ODBC": "Driver={ODBC Driver 18 for SQL Server};Server=tcp:%s,%d;Database=%s;Uid=%s;Pwd=%s;Encrypt=yes;TrustServerCertificate=%s;Connection Timeout=30;",
51-
"GO": "sqlserver://%s:%s@%s,%d?database=master;encrypt=true;trustServerCertificate=%s;dial+timeout=30",
62+
"GO": "sqlserver://%s:%s@%s,%d?database=%s;encrypt=true;trustServerCertificate=%s;dial+timeout=30",
5263
"SQLCMD": "sqlcmd -S %s,%d -U %s",
5364
}
5465

5566
endpoint, user := config.CurrentContext()
5667

68+
if c.database == "" {
69+
if endpoint.AssetDetails != nil && endpoint.AssetDetails.ContainerDetails != nil {
70+
controller := container.NewController()
71+
if controller.ContainerRunning(endpoint.AssetDetails.ContainerDetails.Id) {
72+
s := sql.New(sql.SqlOptions{})
73+
s.Connect(endpoint, user, sql.ConnectOptions{Interactive: false})
74+
c.database = s.ScalarString("PRINT DB_NAME()")
75+
} else {
76+
c.database = "master"
77+
}
78+
} else {
79+
c.database = "master"
80+
}
81+
}
82+
5783
if user != nil {
5884
for k, v := range connectionStringFormats {
5985
if k == "GO" {
@@ -63,23 +89,25 @@ func (c *ConnectionStrings) run() {
6389
secret.Decode(user.BasicAuth.Password, user.BasicAuth.PasswordEncrypted),
6490
endpoint.EndpointDetails.Address,
6591
endpoint.EndpointDetails.Port,
92+
c.database,
6693
c.stringForBoolean(c.trustServerCertificate(endpoint), k))
6794
} else if k == "SQLCMD" {
6895
format := pal.CmdLineWithEnvVars(
6996
[]string{"SQLCMDPASSWORD=%s"},
70-
"sqlcmd -S %s,%d -U %s",
97+
"sqlcmd -S %s,%d -U %s -d %s",
7198
)
7299

73100
connectionStringFormats[k] = fmt.Sprintf(format,
74101
secret.Decode(user.BasicAuth.Password, user.BasicAuth.PasswordEncrypted),
75102
endpoint.EndpointDetails.Address,
76103
endpoint.EndpointDetails.Port,
77-
user.BasicAuth.Username)
104+
user.BasicAuth.Username,
105+
c.database)
78106
} else {
79107
connectionStringFormats[k] = fmt.Sprintf(v,
80108
endpoint.EndpointDetails.Address,
81109
endpoint.EndpointDetails.Port,
82-
"master",
110+
c.database,
83111
user.BasicAuth.Username,
84112
secret.Decode(user.BasicAuth.Password, user.BasicAuth.PasswordEncrypted),
85113
c.stringForBoolean(c.trustServerCertificate(endpoint), k))

cmd/modern/root/config/connection-strings_test.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"github.com/microsoft/go-sqlcmd/internal"
88
"github.com/microsoft/go-sqlcmd/internal/cmdparser"
99
"github.com/microsoft/go-sqlcmd/internal/output"
10+
"github.com/stretchr/testify/assert"
1011
"os"
1112
"testing"
1213
)
@@ -48,5 +49,10 @@ func TestConnectionStrings(t *testing.T) {
4849
cmdparser.TestCmd[*AddUser]("--username user")
4950
cmdparser.TestCmd[*AddContext]("--endpoint endpoint2 --user user")
5051

51-
cmdparser.TestCmd[*ConnectionStrings]()
52+
result := cmdparser.TestCmd[*ConnectionStrings]()
53+
assert.Contains(t, result, "database=master")
54+
55+
result = cmdparser.TestCmd[*ConnectionStrings]("--database tempdb")
56+
assert.NotContains(t, result, "database=master")
57+
assert.Contains(t, result, "database=tempdb")
5258
}

internal/test/memory-buffer.go renamed to internal/buffer/memory-buffer.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// Copyright (c) Microsoft Corporation.
22
// Licensed under the MIT license.
33

4-
package test
4+
package buffer
55

66
import "bytes"
77

@@ -16,6 +16,8 @@ func (b *MemoryBuffer) Write(p []byte) (n int, err error) {
1616
}
1717

1818
func (b *MemoryBuffer) Close() error {
19+
b.buf = nil
20+
1921
return nil
2022
}
2123

internal/test/memory-buffer_test.go renamed to internal/buffer/memory-buffer_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// Copyright (c) Microsoft Corporation.
22
// Licensed under the MIT license.
33

4-
package test
4+
package buffer
55

66
import (
77
"testing"

internal/cmdparser/test.go

+15-9
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package cmdparser
55

66
import (
77
"github.com/microsoft/go-sqlcmd/internal"
8+
"github.com/microsoft/go-sqlcmd/internal/buffer"
89
"github.com/microsoft/go-sqlcmd/internal/cmdparser/dependency"
910
"github.com/microsoft/go-sqlcmd/internal/config"
1011
"github.com/microsoft/go-sqlcmd/internal/output"
@@ -44,21 +45,26 @@ func TestSetup(t *testing.T) {
4445
}
4546

4647
// Run a command expecing it to pass, passing in any supplied args (args are split on " " (space))
47-
func TestCmd[T PtrAsReceiverWrapper[pointerType], pointerType any](args ...string) {
48-
err := testCmd[T](args...)
48+
func TestCmd[T PtrAsReceiverWrapper[pointerType], pointerType any](args ...string) string {
49+
result, err := testCmd[T](args...)
4950

50-
// DEVNOTE: I don't think the code will ever get here (c.Command().Execute() will
51-
// always panic first. This is here to silence code checkers, that require the err return
52-
// variable be checked.
5351
if err != nil {
52+
53+
// DEVNOTE: I don't think the code will ever get here (c.Command().Execute() will
54+
// always panic first. This is here to silence code checkers, that require the err return
55+
// variable be used.
5456
panic(err)
5557
}
58+
return result
5659
}
5760

58-
func testCmd[T PtrAsReceiverWrapper[pointerType], pointerType any](args ...string) error {
61+
func testCmd[T PtrAsReceiverWrapper[pointerType], pointerType any](args ...string) (result string, err error) {
62+
buf := buffer.NewMemoryBuffer()
63+
defer func() { buf.Close() }()
5964
c := New[T](dependency.Options{
6065
Output: output.New(output.Options{
61-
LoggingLevel: verbosity.Trace}),
66+
StandardWriter: buf,
67+
LoggingLevel: verbosity.Trace}),
6268
})
6369
c.DefineCommand()
6470
if len(args) > 1 {
@@ -68,8 +74,8 @@ func testCmd[T PtrAsReceiverWrapper[pointerType], pointerType any](args ...strin
6874
} else {
6975
c.SetArgsForUnitTesting([]string{})
7076
}
71-
err := c.Command().Execute()
72-
return err
77+
err = c.Command().Execute()
78+
return buf.String(), err
7379
}
7480

7581
// splitStringIntoArgsSlice uses a regular expression that matches either a

internal/sql/interface.go

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
type Sql interface {
1111
Connect(endpoint Endpoint, user *User, options ConnectOptions)
1212
Query(text string)
13+
ScalarString(query string) string
1314
}
1415

1516
type ConnectOptions struct {

internal/sql/mock.go

+4
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,7 @@ func (m *mock) Connect(
1616
// Query is a mock implementation used to speed up unit testing of other units
1717
func (m *mock) Query(text string) {
1818
}
19+
20+
func (m *mock) ScalarString(query string) string {
21+
return ""
22+
}

internal/sql/mssql.go

+17
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ package sql
55

66
import (
77
"fmt"
8+
"github.com/microsoft/go-sqlcmd/internal/buffer"
89
"github.com/microsoft/go-sqlcmd/pkg/console"
910
"os"
11+
"strings"
1012

1113
"github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig"
1214
"github.com/microsoft/go-sqlcmd/pkg/sqlcmd"
@@ -77,3 +79,18 @@ func (m *mssql) Query(text string) {
7779
checkErr(err)
7880
}
7981
}
82+
83+
func (m *mssql) ScalarString(query string) string {
84+
buf := buffer.NewMemoryBuffer()
85+
defer func() { _ = buf.Close() }()
86+
87+
m.sqlcmd.Query = query
88+
m.sqlcmd.SetOutput(buf)
89+
m.sqlcmd.SetError(os.Stderr)
90+
91+
trace("Running query: %v", query)
92+
err := m.sqlcmd.Run(true, false)
93+
checkErr(err)
94+
95+
return strings.TrimRight(buf.String(), "\r\n")
96+
}

0 commit comments

Comments
 (0)