Skip to content

Commit b367feb

Browse files
committed
refactor
1 parent fede71f commit b367feb

File tree

3 files changed

+65
-74
lines changed

3 files changed

+65
-74
lines changed

lib/mysql.rb

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,20 @@ def initialize
8686
@fields = nil
8787
@protocol = nil
8888
@charset = nil
89-
@connect_timeout = nil
90-
@read_timeout = nil
91-
@write_timeout = nil
9289
@init_command = nil
9390
@sqlstate = "00000"
9491
@query_with_result = true
9592
@host_info = nil
9693
@last_error = nil
9794
@result_exist = false
98-
@local_infile = nil
99-
@ssl_mode = SSL_MODE_PREFERRED
100-
@get_server_public_key = false
95+
@opts = {
96+
connect_timeout: nil,
97+
read_timeout: nil,
98+
write_timeout: nil,
99+
local_infile: nil,
100+
ssl_mode: SSL_MODE_PREFERRED,
101+
get_server_public_key: false,
102+
}
101103
end
102104

103105
# Connect to mysqld.
@@ -114,7 +116,7 @@ def connect(host=nil, user=nil, passwd=nil, db=nil, port=nil, socket=nil, flag=0
114116
warn 'unsupported flag: CLIENT_COMPRESS' if $VERBOSE
115117
flag &= ~CLIENT_COMPRESS
116118
end
117-
@protocol = Protocol.new host, port, socket, @connect_timeout, @read_timeout, @write_timeout, @local_infile, @ssl_mode, @get_server_public_key
119+
@protocol = Protocol.new(host, port, socket, @opts)
118120
@protocol.authenticate user, passwd, db, flag, @charset
119121
@charset ||= @protocol.charset
120122
@host_info = (host.nil? || host == "localhost") ? 'Localhost via UNIX socket' : "#{host} via TCP/IP"
@@ -166,20 +168,20 @@ def options(opt, value=nil)
166168
# when Mysql::OPT_CONNECT_ATTR_DELETE
167169
# when Mysql::OPT_CONNECT_ATTR_RESET
168170
when Mysql::OPT_CONNECT_TIMEOUT
169-
@connect_timeout = value
171+
@opts[:connect_timeout] = value
170172
when Mysql::OPT_GET_SERVER_PUBLIC_KEY
171-
@get_server_public_key = value
173+
@opts[:get_server_public_key] = value
172174
when Mysql::OPT_LOAD_DATA_LOCAL_DIR
173-
@local_infile = value
175+
@opts[:local_infile] = value
174176
when Mysql::OPT_LOCAL_INFILE
175-
@local_infile = value ? '' : nil
177+
@opts[:local_infile] = value ? '' : nil
176178
# when Mysql::OPT_MAX_ALLOWED_PACKET
177179
# when Mysql::OPT_NAMED_PIPE
178180
# when Mysql::OPT_NET_BUFFER_LENGTH
179181
# when Mysql::OPT_OPTIONAL_RESULTSET_METADATA
180182
# when Mysql::OPT_PROTOCOL
181183
when Mysql::OPT_READ_TIMEOUT
182-
@read_timeout = value.to_i
184+
@opts[:read_timeout] = value.to_i
183185
# when Mysql::OPT_RECONNECT
184186
# when Mysql::OPT_RETRY_COUNT
185187
# when Mysql::SET_CLIENT_IP
@@ -192,12 +194,12 @@ def options(opt, value=nil)
192194
# when Mysql::OPT_SSL_FIPS_MODE
193195
# when Mysql::OPT_SSL_KEY
194196
when Mysql::OPT_SSL_MODE
195-
@ssl_mode = value
197+
@opts[:ssl_mode] = value
196198
# when Mysql::OPT_TLS_CIPHERSUITES
197199
# when Mysql::OPT_TLS_VERSION
198200
# when Mysql::OPT_USE_RESULT
199201
when Mysql::OPT_WRITE_TIMEOUT
200-
@write_timeout = value.to_i
202+
@opts[:write_timeout] = value.to_i
201203
# when Mysql::OPT_ZSTD_COMPRESSION_LEVEL
202204
# when Mysql::PLUGIN_DIR
203205
# when Mysql::READ_DEFAULT_FILE

lib/mysql/protocol.rb

Lines changed: 47 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -130,45 +130,38 @@ def self.value2net(v)
130130
# :RESULT :: After retr_fields(), retr_all_records() or stmt_retr_all_records() is needed.
131131

132132
# make socket connection to server.
133-
# === Argument
134-
# host :: [String] if "localhost" or "" nil then use UNIXSocket. Otherwise use TCPSocket
135-
# port :: [Integer] port number using by TCPSocket
136-
# socket :: [String] socket file name using by UNIXSocket
137-
# conn_timeout :: [Integer] connect timeout (sec).
138-
# read_timeout :: [Integer] read timeout (sec).
139-
# write_timeout :: [Integer] write timeout (sec).
140-
# local_infile :: [String] local infile path
141-
# ssl_mode :: [Integer]
142-
# get_server_public_key :: [Boolean]
143-
# === Exception
144-
# [ClientError] :: connection timeout
145-
def initialize(host, port, socket, conn_timeout, read_timeout, write_timeout, local_infile, ssl_mode, get_server_public_key)
133+
# @param host [String] if "localhost" or "" or nil then use UNIX socket. Otherwise use TCP socket
134+
# @param port [Integer] port number using by TCP socket
135+
# @param socket [String] socket file name using by UNIX socket
136+
# @param [Hash] opts
137+
# @option opts :conn_timeout [Integer] connect timeout (sec).
138+
# @option opts :read_timeout [Integer] read timeout (sec).
139+
# @option opts :write_timeout [Integer] write timeout (sec).
140+
# @option opts :local_infile [String] local infile path
141+
# @option opts :get_server_public_key [Boolean]
142+
# @raise [ClientError] connection timeout
143+
def initialize(host, port, socket, opts)
144+
@opts = opts
146145
@insert_id = 0
147146
@warning_count = 0
148147
@gc_stmt_queue = [] # stmt id list which GC destroy.
149148
set_state :INIT
150-
@read_timeout = read_timeout
151-
@write_timeout = write_timeout
152-
@local_infile = local_infile
153-
@ssl_mode = ssl_mode
154-
@get_server_public_key = get_server_public_key
149+
@get_server_public_key = @opts[:get_server_public_key]
155150
begin
156-
Timeout.timeout conn_timeout do
157-
if host.nil? or host.empty? or host == "localhost"
158-
socket ||= ENV["MYSQL_UNIX_PORT"] || MYSQL_UNIX_PORT
159-
@sock = UNIXSocket.new socket
160-
else
161-
port ||= ENV["MYSQL_TCP_PORT"] || (Socket.getservbyname("mysql","tcp") rescue MYSQL_TCP_PORT)
162-
@sock = TCPSocket.new host, port
163-
end
151+
if host.nil? or host.empty? or host == "localhost"
152+
socket ||= ENV["MYSQL_UNIX_PORT"] || MYSQL_UNIX_PORT
153+
@socket = Socket.unix(socket)
154+
else
155+
port ||= ENV["MYSQL_TCP_PORT"] || (Socket.getservbyname("mysql","tcp") rescue MYSQL_TCP_PORT)
156+
@socket = Socket.tcp(host, port, connect_timeout: @opts[:connect_timeout])
164157
end
165-
rescue Timeout::Error
158+
rescue Errno::ETIMEDOUT
166159
raise ClientError, "connection timeout"
167160
end
168161
end
169162

170163
def close
171-
@sock.close
164+
@socket.close
172165
end
173166

174167
# initial negotiate and authenticate.
@@ -190,7 +183,7 @@ def authenticate(user, passwd, db, flag, charset)
190183
@server_capabilities = init_packet.server_capabilities
191184
@thread_id = init_packet.thread_id
192185
@client_flags = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_TRANSACTIONS | CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH
193-
@client_flags |= CLIENT_LOCAL_FILES if @local_infile
186+
@client_flags |= CLIENT_LOCAL_FILES if @opts[:local_infile]
194187
@client_flags |= CLIENT_CONNECT_WITH_DB if db
195188
@client_flags |= flag
196189
@charset = charset
@@ -204,28 +197,28 @@ def authenticate(user, passwd, db, flag, charset)
204197
end
205198

206199
def enable_ssl
207-
case @ssl_mode
200+
case @opts[:ssl_mode]
208201
when SSL_MODE_DISABLED
209202
return
210203
when SSL_MODE_PREFERRED
211-
return if @sock.is_a? UNIXSocket
204+
return if @socket.local_address.unix?
212205
return if @server_capabilities & CLIENT_SSL == 0
213206
when SSL_MODE_REQUIRED
214207
if @server_capabilities & CLIENT_SSL == 0
215208
raise ClientError::SslConnectionError, "SSL is required but the server doesn't support it"
216209
end
217210
else
218-
raise ClientError, "ssl_mode #{@ssl_mode} is not supported"
211+
raise ClientError, "ssl_mode #{@opts[:ssl_mode]} is not supported"
219212
end
220213
begin
221214
@client_flags |= CLIENT_SSL
222215
write Protocol::TlsAuthenticationPacket.serialize(@client_flags, 1024**3, @charset.number)
223-
@sock = OpenSSL::SSL::SSLSocket.new(@sock)
224-
@sock.sync_close = true
225-
@sock.connect
216+
@socket = OpenSSL::SSL::SSLSocket.new(@socket)
217+
@socket.sync_close = true
218+
@socket.connect
226219
rescue => e
227220
@client_flags &= ~CLIENT_SSL
228-
return if @ssl_mode == SSL_MODE_PREFERRED
221+
return if @opts[:ssl_mode] == SSL_MODE_PREFERRED
229222
raise e
230223
end
231224
end
@@ -282,7 +275,7 @@ def get_result
282275
# send local file to server
283276
def send_local_file(filename)
284277
filename = File.absolute_path(filename)
285-
if filename.start_with? @local_infile
278+
if filename.start_with? @opts[:local_infile]
286279
File.open(filename){|f| write f}
287280
else
288281
raise ClientError::LoadDataLocalInfileRejected, 'LOAD DATA LOCAL INFILE file request rejected due to restrictions on access.'
@@ -482,7 +475,7 @@ def check_state(st)
482475

483476
def set_state(st)
484477
@state = st
485-
if st == :READY
478+
if st == :READY && !@gc_stmt_queue.empty?
486479
gc_disabled = GC.disable
487480
begin
488481
while st = @gc_stmt_queue.shift
@@ -518,14 +511,14 @@ def read
518511
data = ''
519512
len = nil
520513
begin
521-
Timeout.timeout @read_timeout do
522-
header = @sock.read(4)
514+
Timeout.timeout @opts[:read_timeout] do
515+
header = @socket.read(4)
523516
raise EOFError unless header && header.length == 4
524517
len1, len2, seq = header.unpack("CvC")
525518
len = (len2 << 8) + len1
526519
raise ProtocolError, "invalid packet: sequence number mismatch(#{seq} != #{@seq}(expected))" if @seq != seq
527520
@seq = (@seq + 1) % 256
528-
ret = @sock.read(len)
521+
ret = @socket.read(len)
529522
raise EOFError unless ret && ret.length == len
530523
data.concat ret
531524
end
@@ -558,25 +551,21 @@ def read
558551
# data :: [String / IO] packet data. If data is nil, write empty packet.
559552
def write(data)
560553
begin
561-
@sock.sync = false
562-
if data.nil?
563-
Timeout.timeout @write_timeout do
564-
@sock.write [0, 0, @seq].pack("CvC")
565-
end
566-
@seq = (@seq + 1) % 256
567-
else
568-
data = StringIO.new data if data.is_a? String
569-
while d = data.read(MAX_PACKET_LENGTH)
570-
Timeout.timeout @write_timeout do
571-
@sock.write [d.length%256, d.length/256, @seq].pack("CvC")
572-
@sock.write d
573-
end
554+
Timeout.timeout @opts[:write_timeout] do
555+
@socket.sync = false
556+
if data.nil?
557+
@socket.write [0, 0, @seq].pack("CvC")
574558
@seq = (@seq + 1) % 256
559+
else
560+
data = StringIO.new data if data.is_a? String
561+
while d = data.read(MAX_PACKET_LENGTH)
562+
@socket.write [d.length%256, d.length/256, @seq].pack("CvC")
563+
@socket.write d
564+
@seq = (@seq + 1) % 256
565+
end
575566
end
576-
end
577-
@sock.sync = true
578-
Timeout.timeout @write_timeout do
579-
@sock.flush
567+
@socket.sync = true
568+
@socket.flush
580569
end
581570
rescue Errno::EPIPE
582571
raise ClientError::ServerGoneError, 'MySQL server has gone away'

test/test_mysql.rb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,8 @@ class TestMysql < Test::Unit::TestCase
140140
end
141141
test 'OPT_CONNECT_TIMEOUT: set timeout for connecting' do
142142
assert{ @m.options(Mysql::OPT_CONNECT_TIMEOUT, 0.1) == @m }
143-
stub(UNIXSocket).new{ sleep 1}
144-
stub(TCPSocket).new{ sleep 1}
143+
stub(Socket).tcp{ raise Errno::ETIMEDOUT }
144+
stub(Socket).unix{ raise Errno::ETIMEDOUT }
145145
assert_raise Mysql::ClientError, 'connection timeout' do
146146
@m.connect(MYSQL_SERVER, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DATABASE, MYSQL_PORT, MYSQL_SOCKET)
147147
end

0 commit comments

Comments
 (0)