Skip to content

Add close/1 #50

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion c_src/ex_dtls/native.c
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ UNIFEX_TERM do_init(UnifexEnv *env, char *mode_str, int dtls_srtp,
state->x509 = NULL;
state->mode = 0;
state->hsk_finished = 0;
state->closed = 0;
state->env = unifex_alloc_env(env);

int mode;
Expand Down Expand Up @@ -244,6 +245,10 @@ UNIFEX_TERM get_cert_fingerprint(UnifexEnv *env, UnifexPayload *cert) {
}

UNIFEX_TERM do_handshake(UnifexEnv *env, State *state) {
if (state->closed == 1) {
return do_handshake_result_error_closed(env);
}

SSL_do_handshake(state->ssl);

UnifexPayload **gen_packets = NULL;
Expand All @@ -258,14 +263,19 @@ UNIFEX_TERM do_handshake(UnifexEnv *env, State *state) {
} else {
int timeout = get_timeout(state->ssl);
UNIFEX_TERM res_term =
do_handshake_result(env, gen_packets, gen_packets_size, timeout);
do_handshake_result_ok(env, gen_packets, gen_packets_size, timeout);
free_payload_array(gen_packets, gen_packets_size);

return res_term;
}
}

UNIFEX_TERM write_data(UnifexEnv *env, State *state, UnifexPayload *payload) {
if (state->closed == 1) {
DEBUG("Cannot write, connection closed");
return write_data_result_error_closed(env);
}

if (state->hsk_finished != 1) {
DEBUG("Cannot write, handshake not finished");
return write_data_result_error_handshake_not_finished(env);
Expand Down Expand Up @@ -303,6 +313,10 @@ UNIFEX_TERM write_data(UnifexEnv *env, State *state, UnifexPayload *payload) {
}

UNIFEX_TERM handle_data(UnifexEnv *env, State *state, UnifexPayload *payload) {
if (state->closed == 1) {
return handle_data_result_error_closed(env);
}

(void)env;

if (payload->size != 0) {
Expand Down Expand Up @@ -332,6 +346,32 @@ UNIFEX_TERM handle_data(UnifexEnv *env, State *state, UnifexPayload *payload) {
}
}

// prefix close with exd (ex_dtls) as close is defined in unistd.h
UNIFEX_TERM exd_close(UnifexEnv *env, State *state) {
if (state->closed == 1) {
return exd_close_result_ok(env, NULL, 0);
}

state->closed = 1;
if (SSL_shutdown(state->ssl) < 0) {
return exd_close_result_ok(env, NULL, 0);
} else {
UnifexPayload **gen_packets = NULL;
int gen_packets_size = 0;
read_pending_data(&gen_packets, &gen_packets_size, state);

if (gen_packets == NULL) {
return unifex_raise(state->env,
"Close failed: couldn't read pending data");
} else {
UNIFEX_TERM res_term =
exd_close_result_ok(env, gen_packets, gen_packets_size);
free_payload_array(gen_packets, gen_packets_size);
return res_term;
}
}
}

UNIFEX_TERM handle_regular_read(State *state, char data[], int ret) {
if (ret > 0) {
UnifexPayload packets;
Expand All @@ -351,6 +391,7 @@ UNIFEX_TERM handle_read_error(State *state, int ret) {
int error = SSL_get_error(state->ssl, ret);
switch (error) {
case SSL_ERROR_ZERO_RETURN:
state->closed = 1;
return handle_data_result_error_peer_closed_for_writing(state->env);
case SSL_ERROR_WANT_READ:
DEBUG("SSL WANT READ. This is workaround. Did we get retransmission?");
Expand Down Expand Up @@ -452,6 +493,10 @@ UNIFEX_TERM handle_handshake_in_progress(State *state, int ret) {
}

UNIFEX_TERM handle_timeout(UnifexEnv *env, State *state) {
if (state->closed == 1) {
return handle_timeout_result_error_closed(env);
}

long result = DTLSv1_handle_timeout(state->ssl);
if (result != 1)
return handle_timeout_result_ok(env);
Expand Down
1 change: 1 addition & 0 deletions c_src/ex_dtls/native.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ struct State {
X509 *x509;
int mode;
int hsk_finished;
int closed;
};

#include "_generated/native.h"
18 changes: 13 additions & 5 deletions c_src/ex_dtls/native.spec.exs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,17 @@ spec get_peer_cert(state) :: payload | (nil :: label)

spec get_cert_fingerprint(payload) :: payload

spec do_handshake(state) :: {packets :: [payload], timeout :: int}
spec do_handshake(state) :: {:ok :: label, packets :: [payload], timeout :: int} | {:error :: label, :closed :: label}

spec handle_timeout(state) :: (:ok :: label) | {:retransmit :: label, packets :: [payload], timeout :: int}
spec handle_timeout(state) ::
(:ok :: label)
| {:retransmit :: label, packets :: [payload], timeout :: int}
| {:error :: label, :closed :: label}

spec write_data(state, packets :: payload) :: {:ok :: label, packets :: [payload]} | {:error :: label, :handshake_not_finished :: label}
spec write_data(state, packets :: payload) ::
{:ok :: label, packets :: [payload]}
| {:error :: label, :handshake_not_finished :: label}
| {:error :: label, :closed :: label}

spec handle_data(state, packets :: payload) ::
{:ok :: label, packets :: payload}
Expand All @@ -34,5 +40,7 @@ spec handle_data(state, packets :: payload) ::
| {:handshake_finished :: label, client_keying_material :: payload,
server_keying_material :: payload, protection_profile :: int, packets :: [payload]}
| {:error :: label, :peer_closed_for_writing :: label}
| {:error :: label, :handshake_error :: label
}
| {:error :: label, :handshake_error :: label}
| {:error :: label, :closed :: label}

spec exd_close(state) :: {:ok :: label, packets :: [payload]}
19 changes: 15 additions & 4 deletions lib/ex_dtls.ex
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ defmodule ExDTLS do

`timeout` is a time in ms after which `handle_timeout/1` should be called.
"""
@spec do_handshake(dtls()) :: {packets :: [binary()], timeout :: integer()}
@spec do_handshake(dtls()) ::
{:ok, packets :: [binary()], timeout :: integer()} | {:error, :closed}
defdelegate do_handshake(dtls), to: Native

@doc """
Expand All @@ -148,7 +149,7 @@ defmodule ExDTLS do
Generates encrypted packets that need to be passed to the second host.
"""
@spec write_data(dtls(), data :: binary()) ::
{:ok, packets :: [binary()]} | {:error, :handshake_not_finished}
{:ok, packets :: [binary()]} | {:error, :handshake_not_finished | :closed}
defdelegate write_data(dtls, data), to: Native

@doc """
Expand All @@ -172,7 +173,7 @@ defmodule ExDTLS do
remote_keying_material :: binary(), protection_profile_t(), packets :: [binary()]}
| {:handshake_finished, local_keying_material :: binary(),
remote_keying_material :: binary(), protection_profile_t()}
| {:error, :handshake_error | :peer_closed_for_writing}
| {:error, :handshake_error | :peer_closed_for_writing | :closed}
def handle_data(dtls, packets) do
case Native.handle_data(dtls, packets) do
{:handshake_finished, lkm, rkm, protection_profile, []} ->
Expand All @@ -192,6 +193,16 @@ defmodule ExDTLS do

If there is no timeout to handle, simple `{:ok, dtls()}` tuple is returned.
"""
@spec handle_timeout(dtls()) :: :ok | {:retransmit, packets :: [binary()], timeout :: integer()}
@spec handle_timeout(dtls()) ::
:ok | {:retransmit, packets :: [binary()], timeout :: integer()} | {:error, :closed}
defdelegate handle_timeout(dtls), to: Native

@doc """
Irreversibly closes DTLS session.

If a handshake has been finished, this function will generate `close_notify` DTLS alert
that should be sent to the other side.
"""
@spec close(dtls()) :: {:ok, packets :: [binary()]}
defdelegate close(dtls), to: Native, as: :exd_close
end
50 changes: 46 additions & 4 deletions test/integration_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ defmodule ExDTLS.IntegrationTest do
rx_dtls = ExDTLS.init(mode: :server, dtls_srtp: true, verify_peer: true)
tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true, verify_peer: true)

{packets, _timeout} = ExDTLS.do_handshake(tx_dtls)
{:ok, packets, _timeout} = ExDTLS.do_handshake(tx_dtls)

assert :ok == loop({rx_dtls, false}, {tx_dtls, false}, packets)

Expand All @@ -17,7 +17,7 @@ defmodule ExDTLS.IntegrationTest do
rx_dtls = ExDTLS.init(mode: :server, dtls_srtp: true)
tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true)

{packets, _timeout} = ExDTLS.do_handshake(tx_dtls)
{:ok, packets, _timeout} = ExDTLS.do_handshake(tx_dtls)

assert :ok == loop({rx_dtls, false}, {tx_dtls, false}, packets)

Expand All @@ -34,7 +34,7 @@ defmodule ExDTLS.IntegrationTest do
assert {:error, :handshake_not_finished} = ExDTLS.write_data(sr_dtls, <<1, 2, 3>>)
assert {:error, :handshake_not_finished} = ExDTLS.write_data(cl_dtls, <<1, 2, 3>>)

{packets, _timeout} = ExDTLS.do_handshake(cl_dtls)
{:ok, packets, _timeout} = ExDTLS.do_handshake(cl_dtls)
assert :ok == loop({sr_dtls, false}, {cl_dtls, false}, packets)

msg = <<1, 3, 2, 5>>
Expand All @@ -55,11 +55,53 @@ defmodule ExDTLS.IntegrationTest do

tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true, verify_peer: true)

{packets, _timeout} = ExDTLS.do_handshake(tx_dtls)
{:ok, packets, _timeout} = ExDTLS.do_handshake(tx_dtls)
{:handshake_packets, packets, _timeout} = feed_packets(rx_dtls, packets)
assert {:error, :handshake_error} = feed_packets(tx_dtls, packets)
end

describe "close/1" do
test "before handshake has finished (client mode)" do
dtls = ExDTLS.init(mode: :client, dtls_srtp: true, verify_peer: true)
assert {:ok, []} = ExDTLS.close(dtls)
# assert that handshake can't be started
assert {:error, :closed} = ExDTLS.do_handshake(dtls)
end

test "before handshake has finished (server mode)" do
dtls = ExDTLS.init(mode: :server, dtls_srtp: true, verify_peer: true)
assert {:ok, []} = ExDTLS.close(dtls)
# assert that handshake can't be started
assert {:error, :closed} = ExDTLS.do_handshake(dtls)
end

test "after handshake has finished (client mode)" do
rx_dtls = ExDTLS.init(mode: :server, dtls_srtp: true, verify_peer: true)
tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true, verify_peer: true)

{:ok, packets, _timeout} = ExDTLS.do_handshake(tx_dtls)

assert :ok == loop({rx_dtls, false}, {tx_dtls, false}, packets)
assert {:ok, [packet]} = ExDTLS.close(tx_dtls)
assert {:error, :peer_closed_for_writing} = ExDTLS.handle_data(rx_dtls, packet)
assert {:error, :closed} = ExDTLS.handle_timeout(tx_dtls)
assert {:error, :closed} = ExDTLS.handle_timeout(rx_dtls)
end

test "after handshake has finished (server mode)" do
rx_dtls = ExDTLS.init(mode: :server, dtls_srtp: true, verify_peer: true)
tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true, verify_peer: true)

{:ok, packets, _timeout} = ExDTLS.do_handshake(tx_dtls)

assert :ok == loop({rx_dtls, false}, {tx_dtls, false}, packets)
assert {:ok, [packet]} = ExDTLS.close(rx_dtls)
assert {:error, :peer_closed_for_writing} = ExDTLS.handle_data(tx_dtls, packet)
assert {:error, :closed} = ExDTLS.handle_timeout(tx_dtls)
assert {:error, :closed} = ExDTLS.handle_timeout(rx_dtls)
end
end

defp loop({_dtls1, true}, {_dtls2, true}, _packets) do
:ok
end
Expand Down
2 changes: 1 addition & 1 deletion test/retransmission_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ defmodule ExDTLS.RetransmissionTest do
rx_dtls = ExDTLS.init(mode: :server, dtls_srtp: true)
tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true)

{_packets, timeout} = ExDTLS.do_handshake(tx_dtls)
{:ok, _packets, timeout} = ExDTLS.do_handshake(tx_dtls)
Process.send_after(self(), {:handle_timeout, :tx}, timeout)
{:retransmit, packets, timeout} = wait_for_timeout(tx_dtls, :tx)
Process.send_after(self(), {:handle_timeout, :tx}, timeout)
Expand Down
Loading