diff --git a/lib/myxql.ex b/lib/myxql.ex index 0b790d8..f171c13 100644 --- a/lib/myxql.ex +++ b/lib/myxql.ex @@ -23,6 +23,7 @@ defmodule MyXQL do | {:prepare, :force_named | :named | :unnamed} | {:disconnect_on_error_codes, [atom()]} | {:enable_cleartext_plugin, boolean()} + | {:local_infile, boolean()} | DBConnection.start_option() @type option() :: DBConnection.option() @@ -103,6 +104,8 @@ defmodule MyXQL do * `:enable_cleartext_plugin` - Set to `true` to send password as cleartext (default: `false`) + * `:local_infile` - Set to `true` to enable LOCAL INFILE capability (default: `false`) + The given options are passed down to DBConnection, some of the most commonly used ones are documented below: diff --git a/lib/myxql/client.ex b/lib/myxql/client.ex index c2e057a..0aa569c 100644 --- a/lib/myxql/client.ex +++ b/lib/myxql/client.ex @@ -27,7 +27,8 @@ defmodule MyXQL.Client do :max_packet_size, :charset, :collation, - :enable_cleartext_plugin + :enable_cleartext_plugin, + :local_infile ] @sock_opts [mode: :binary, packet: :raw, active: false] @@ -67,7 +68,8 @@ defmodule MyXQL.Client do socket_options: (opts[:socket_options] || []) ++ @sock_opts, charset: Keyword.get(opts, :charset), collation: Keyword.get(opts, :collation), - enable_cleartext_plugin: Keyword.get(opts, :enable_cleartext_plugin, false) + enable_cleartext_plugin: Keyword.get(opts, :enable_cleartext_plugin, false), + local_infile: Keyword.get(opts, :local_infile, false) } end @@ -135,7 +137,32 @@ defmodule MyXQL.Client do def com_query(client, statement, result_state \\ :single) do with :ok <- send_com(client, {:com_query, statement}) do - recv_packets(client, &decode_com_query_response/3, :initial, result_state) + case recv_packets(client, &decode_com_query_response/3, :initial, result_state) do + {:ok, {:local_infile, filename}} -> + case handle_local_infile(client, filename) do + :ok -> + recv_packets(client, &decode_com_query_response/3, :initial, result_state) + + error -> + error + end + + other -> + other + end + end + end + + def handle_local_infile(client, filename) do + case File.read(filename) do + {:ok, content} -> + with :ok <- send_packet(client, content, 2), + :ok <- send_packet(client, <<>>, 3) do + :ok + end + + {:error, _reason} -> + send_packet(client, <<>>, 2) end end @@ -265,6 +292,15 @@ defmodule MyXQL.Client do {:many, results} -> {:ok, [result | results]} end + {:local_infile, filename} -> + case handle_local_infile(client, filename) do + :ok -> + recv_packets(rest, decoder, decoder_state, result_state, timeout, client) + + {:error, reason} -> + {:error, reason} + end + {:error, _} = error -> error end diff --git a/lib/myxql/protocol.ex b/lib/myxql/protocol.ex index 74dd083..270c220 100644 --- a/lib/myxql/protocol.ex +++ b/lib/myxql/protocol.ex @@ -180,6 +180,7 @@ defmodule MyXQL.Protocol do ]) |> maybe_put_capability_flag(:client_connect_with_db, !is_nil(config.database)) |> maybe_put_capability_flag(:client_ssl, is_list(config.ssl_opts)) + |> maybe_put_capability_flag(:client_local_files, config.local_infile) if config.ssl_opts && !has_capability_flag?(server_capability_flags, :client_ssl) do {:error, :server_does_not_support_ssl} @@ -331,6 +332,12 @@ defmodule MyXQL.Protocol do {:halt, decode_ok_packet_body(rest)} end + def decode_com_query_response(<<0xFB, rest::binary>>, "", :initial) do + {filename, ""} = take_string_nul(rest) + + {:local_infile, filename} + end + def decode_com_query_response(<<0xFF, rest::binary>>, "", :initial) do {:halt, decode_err_packet_body(rest)} end @@ -513,6 +520,11 @@ defmodule MyXQL.Protocol do {:cont, :initial, {:many, [resultset | results]}} end + defp decode_resultset(<<0xFB, rest::binary>>, _next_data, :initial, _row_decoder) do + {filename, ""} = take_string_nul(rest) + {:local_infile, filename} + end + defp decode_resultset(payload, _next_data, :initial, _row_decoder) do {:cont, {:column_defs, decode_int_lenenc(payload), []}} end diff --git a/lib/myxql/protocol/types.ex b/lib/myxql/protocol/types.ex index 572331d..71867c9 100644 --- a/lib/myxql/protocol/types.ex +++ b/lib/myxql/protocol/types.ex @@ -27,14 +27,16 @@ defmodule MyXQL.Protocol.Types do def encode_int_lenenc(int) when int < 0xFFFFFFFFFFFFFFFF, do: <<0xFE, int::uint8()>> def decode_int_lenenc(binary) do - {integer, ""} = take_int_lenenc(binary) + {integer, _rest} = take_int_lenenc(binary) integer end def take_int_lenenc(<>) when int < 251, do: {int, rest} + def take_int_lenenc(<<0xFB, rest::binary>>), do: {nil, rest} def take_int_lenenc(<<0xFC, int::uint2(), rest::binary>>), do: {int, rest} def take_int_lenenc(<<0xFD, int::uint3(), rest::binary>>), do: {int, rest} def take_int_lenenc(<<0xFE, int::uint8(), rest::binary>>), do: {int, rest} + def take_int_lenenc(<<0xFF, rest::binary>>), do: {:error, rest} # https://dev.mysql.com/doc/internals/en/string.html#packet-Protocol::FixedLengthString defmacro string(size) do @@ -68,8 +70,13 @@ defmodule MyXQL.Protocol.Types do def take_string_nul(""), do: {nil, ""} - def take_string_nul(binary) do - [string, rest] = :binary.split(binary, <<0>>) - {string, rest} + def take_string_nul(binary) when is_binary(binary) do + case :binary.split(binary, <<0>>) do + [string] -> + {string, ""} + + [string, rest] -> + {string, rest} + end end end