Skip to content

Commit fede71f

Browse files
committed
support OPT_GET_SERVER_PUBLIC_KEY
1 parent 4eeba6a commit fede71f

File tree

3 files changed

+21
-6
lines changed

3 files changed

+21
-6
lines changed

lib/mysql.rb

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def initialize
9797
@result_exist = false
9898
@local_infile = nil
9999
@ssl_mode = SSL_MODE_PREFERRED
100+
@get_server_public_key = false
100101
end
101102

102103
# Connect to mysqld.
@@ -113,7 +114,7 @@ def connect(host=nil, user=nil, passwd=nil, db=nil, port=nil, socket=nil, flag=0
113114
warn 'unsupported flag: CLIENT_COMPRESS' if $VERBOSE
114115
flag &= ~CLIENT_COMPRESS
115116
end
116-
@protocol = Protocol.new host, port, socket, @connect_timeout, @read_timeout, @write_timeout, @local_infile, @ssl_mode
117+
@protocol = Protocol.new host, port, socket, @connect_timeout, @read_timeout, @write_timeout, @local_infile, @ssl_mode, @get_server_public_key
117118
@protocol.authenticate user, passwd, db, flag, @charset
118119
@charset ||= @protocol.charset
119120
@host_info = (host.nil? || host == "localhost") ? 'Localhost via UNIX socket' : "#{host} via TCP/IP"
@@ -145,7 +146,8 @@ def close!
145146
# Set option for connection.
146147
#
147148
# Available options:
148-
# Mysql::INIT_COMMAND, Mysql::OPT_CONNECT_TIMEOUT, Mysql::OPT_READ_TIMEOUT,
149+
# Mysql::INIT_COMMAND, Mysql::OPT_CONNECT_TIMEOUT, Mysql::OPT_GET_SERVER_PUBLIC_KEY,
150+
# Mysql::OPT_LOAD_DATA_LOCAL_DIR, Mysql::OPT_LOCAL_INFILE, Mysql::OPT_READ_TIMEOUT,
149151
# Mysql::OPT_SSL_MODE, Mysql::OPT_WRITE_TIMEOUT, Mysql::SET_CHARSET_NAME
150152
# @param [Integer] opt option
151153
# @param [Integer] value option value that is depend on opt
@@ -165,7 +167,8 @@ def options(opt, value=nil)
165167
# when Mysql::OPT_CONNECT_ATTR_RESET
166168
when Mysql::OPT_CONNECT_TIMEOUT
167169
@connect_timeout = value
168-
# when Mysql::OPT_GET_SERVER_PUBLIC_KEY
170+
when Mysql::OPT_GET_SERVER_PUBLIC_KEY
171+
@get_server_public_key = value
169172
when Mysql::OPT_LOAD_DATA_LOCAL_DIR
170173
@local_infile = value
171174
when Mysql::OPT_LOCAL_INFILE

lib/mysql/authenticator/caching_sha2_password.rb

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,19 @@ def authenticate(passwd, scramble)
2626
when "\x03" # fast_auth_success
2727
# OK
2828
when "\x04" # perform_full_authentication
29-
if @protocol.client_flags & CLIENT_SSL == 0
29+
if @protocol.client_flags & CLIENT_SSL != 0
30+
@protocol.write passwd+"\0"
31+
elsif !@protocol.get_server_public_key
3032
raise 'Authentication requires secure connection'
33+
else
34+
@protocol.write "\2" # request public key
35+
pkt = @protocol.read
36+
pkt.utiny # skip
37+
pubkey = pkt.to_s
38+
hash = (passwd+"\0").unpack("C*").zip(scramble.unpack("C*")).map{|a, b| a ^ b}.pack("C*")
39+
enc = OpenSSL::PKey::RSA.new(pubkey).public_encrypt(hash, OpenSSL::PKey::RSA::PKCS1_OAEP_PADDING)
40+
@protocol.write enc
3141
end
32-
@protocol.write passwd+"\0"
3342
else
3443
raise "invalid auth reply packet: #{data.inspect}"
3544
end

lib/mysql/protocol.rb

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def self.value2net(v)
120120
attr_reader :server_status
121121
attr_reader :warning_count
122122
attr_reader :message
123+
attr_reader :get_server_public_key
123124
attr_accessor :charset
124125

125126
# @state variable keep state for connection.
@@ -138,9 +139,10 @@ def self.value2net(v)
138139
# write_timeout :: [Integer] write timeout (sec).
139140
# local_infile :: [String] local infile path
140141
# ssl_mode :: [Integer]
142+
# get_server_public_key :: [Boolean]
141143
# === Exception
142144
# [ClientError] :: connection timeout
143-
def initialize(host, port, socket, conn_timeout, read_timeout, write_timeout, local_infile, ssl_mode)
145+
def initialize(host, port, socket, conn_timeout, read_timeout, write_timeout, local_infile, ssl_mode, get_server_public_key)
144146
@insert_id = 0
145147
@warning_count = 0
146148
@gc_stmt_queue = [] # stmt id list which GC destroy.
@@ -149,6 +151,7 @@ def initialize(host, port, socket, conn_timeout, read_timeout, write_timeout, lo
149151
@write_timeout = write_timeout
150152
@local_infile = local_infile
151153
@ssl_mode = ssl_mode
154+
@get_server_public_key = get_server_public_key
152155
begin
153156
Timeout.timeout conn_timeout do
154157
if host.nil? or host.empty? or host == "localhost"

0 commit comments

Comments
 (0)