@@ -6,6 +6,8 @@ package config
6
6
import (
7
7
"fmt"
8
8
"github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig"
9
+ "github.com/microsoft/go-sqlcmd/internal/container"
10
+ "github.com/microsoft/go-sqlcmd/internal/sql"
9
11
"strings"
10
12
11
13
"github.com/microsoft/go-sqlcmd/internal/cmdparser"
@@ -17,6 +19,8 @@ import (
17
19
// ConnectionStrings implements the `sqlcmd config connection-strings` command
18
20
type ConnectionStrings struct {
19
21
cmdparser.Cmd
22
+
23
+ database string
20
24
}
21
25
22
26
func (c * ConnectionStrings ) DefineCommand (... cmdparser.CommandOptions ) {
@@ -36,6 +40,13 @@ func (c *ConnectionStrings) DefineCommand(...cmdparser.CommandOptions) {
36
40
}
37
41
38
42
c .Cmd .DefineCommand (options )
43
+
44
+ c .AddFlag (cmdparser.FlagOptions {
45
+ String : & c .database ,
46
+ Name : "database" ,
47
+ DefaultString : "" ,
48
+ Shorthand : "d" ,
49
+ Usage : "Database for the connection string (default is taken from the T/SQL login)" })
39
50
}
40
51
41
52
// run generates connection strings for the current context in multiple formats.
@@ -48,12 +59,27 @@ func (c *ConnectionStrings) run() {
48
59
"ADO.NET" : "Server=tcp:%s,%d;Initial Catalog=%s;Persist Security Info=False;User ID=%s;Password=%s;MultipleActiveResultSets=False;Encrypt=True;TrustServerCertificate=%s;Connection Timeout=30;" ,
49
60
"JDBC" : "jdbc:sqlserver://%s:%d;database=%s;user=%s;password=%s;encrypt=true;trustServerCertificate=%s;loginTimeout=30;" ,
50
61
"ODBC" : "Driver={ODBC Driver 18 for SQL Server};Server=tcp:%s,%d;Database=%s;Uid=%s;Pwd=%s;Encrypt=yes;TrustServerCertificate=%s;Connection Timeout=30;" ,
51
- "GO" : "sqlserver://%s:%s@%s,%d?database=master ;encrypt=true;trustServerCertificate=%s;dial+timeout=30" ,
62
+ "GO" : "sqlserver://%s:%s@%s,%d?database=%s ;encrypt=true;trustServerCertificate=%s;dial+timeout=30" ,
52
63
"SQLCMD" : "sqlcmd -S %s,%d -U %s" ,
53
64
}
54
65
55
66
endpoint , user := config .CurrentContext ()
56
67
68
+ if c .database == "" {
69
+ if endpoint .AssetDetails != nil && endpoint .AssetDetails .ContainerDetails != nil {
70
+ controller := container .NewController ()
71
+ if controller .ContainerRunning (endpoint .AssetDetails .ContainerDetails .Id ) {
72
+ s := sql .New (sql.SqlOptions {})
73
+ s .Connect (endpoint , user , sql.ConnectOptions {Interactive : false })
74
+ c .database = s .ScalarString ("PRINT DB_NAME()" )
75
+ } else {
76
+ c .database = "master"
77
+ }
78
+ } else {
79
+ c .database = "master"
80
+ }
81
+ }
82
+
57
83
if user != nil {
58
84
for k , v := range connectionStringFormats {
59
85
if k == "GO" {
@@ -63,23 +89,25 @@ func (c *ConnectionStrings) run() {
63
89
secret .Decode (user .BasicAuth .Password , user .BasicAuth .PasswordEncrypted ),
64
90
endpoint .EndpointDetails .Address ,
65
91
endpoint .EndpointDetails .Port ,
92
+ c .database ,
66
93
c .stringForBoolean (c .trustServerCertificate (endpoint ), k ))
67
94
} else if k == "SQLCMD" {
68
95
format := pal .CmdLineWithEnvVars (
69
96
[]string {"SQLCMDPASSWORD=%s" },
70
- "sqlcmd -S %s,%d -U %s" ,
97
+ "sqlcmd -S %s,%d -U %s -d %s " ,
71
98
)
72
99
73
100
connectionStringFormats [k ] = fmt .Sprintf (format ,
74
101
secret .Decode (user .BasicAuth .Password , user .BasicAuth .PasswordEncrypted ),
75
102
endpoint .EndpointDetails .Address ,
76
103
endpoint .EndpointDetails .Port ,
77
- user .BasicAuth .Username )
104
+ user .BasicAuth .Username ,
105
+ c .database )
78
106
} else {
79
107
connectionStringFormats [k ] = fmt .Sprintf (v ,
80
108
endpoint .EndpointDetails .Address ,
81
109
endpoint .EndpointDetails .Port ,
82
- "master" ,
110
+ c . database ,
83
111
user .BasicAuth .Username ,
84
112
secret .Decode (user .BasicAuth .Password , user .BasicAuth .PasswordEncrypted ),
85
113
c .stringForBoolean (c .trustServerCertificate (endpoint ), k ))
0 commit comments