Skip to content

Commit

Permalink
RequestServer: Make WebSocket IPC APIs asynchronous
Browse files Browse the repository at this point in the history
This fixes deadlocking when interacting with WebSockets while
RequestServer is trying to stream downloaded data to WebContent.
  • Loading branch information
awesomekling committed Sep 19, 2024
1 parent 853a75c commit e205723
Show file tree
Hide file tree
Showing 10 changed files with 129 additions and 94 deletions.
46 changes: 31 additions & 15 deletions Userland/Libraries/LibRequests/RequestClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,45 +96,61 @@ void RequestClient::certificate_requested(i32 request_id)

RefPtr<WebSocket> RequestClient::websocket_connect(const URL::URL& url, ByteString const& origin, Vector<ByteString> const& protocols, Vector<ByteString> const& extensions, HTTP::HeaderMap const& request_headers)
{
auto connection_id = IPCProxy::websocket_connect(url, origin, protocols, extensions, request_headers);
if (connection_id < 0)
return nullptr;
auto connection = WebSocket::create_from_id({}, *this, connection_id);
m_websockets.set(connection_id, connection);
auto websocket_id = m_next_websocket_id++;
IPCProxy::async_websocket_connect(websocket_id, url, origin, protocols, extensions, request_headers);
auto connection = WebSocket::create_from_id({}, *this, websocket_id);
m_websockets.set(websocket_id, connection);
return connection;
}

void RequestClient::websocket_connected(i32 connection_id)
void RequestClient::websocket_connected(i64 websocket_id)
{
auto maybe_connection = m_websockets.get(connection_id);
auto maybe_connection = m_websockets.get(websocket_id);
if (maybe_connection.has_value())
maybe_connection.value()->did_open({});
}

void RequestClient::websocket_received(i32 connection_id, bool is_text, ByteBuffer const& data)
void RequestClient::websocket_received(i64 websocket_id, bool is_text, ByteBuffer const& data)
{
auto maybe_connection = m_websockets.get(connection_id);
auto maybe_connection = m_websockets.get(websocket_id);
if (maybe_connection.has_value())
maybe_connection.value()->did_receive({}, data, is_text);
}

void RequestClient::websocket_errored(i32 connection_id, i32 message)
void RequestClient::websocket_errored(i64 websocket_id, i32 message)
{
auto maybe_connection = m_websockets.get(connection_id);
auto maybe_connection = m_websockets.get(websocket_id);
if (maybe_connection.has_value())
maybe_connection.value()->did_error({}, message);
}

void RequestClient::websocket_closed(i32 connection_id, u16 code, ByteString const& reason, bool clean)
void RequestClient::websocket_closed(i64 websocket_id, u16 code, ByteString const& reason, bool clean)
{
auto maybe_connection = m_websockets.get(connection_id);
auto maybe_connection = m_websockets.get(websocket_id);
if (maybe_connection.has_value())
maybe_connection.value()->did_close({}, code, reason, clean);
}

void RequestClient::websocket_certificate_requested(i32 connection_id)
void RequestClient::websocket_ready_state_changed(i64 websocket_id, u32 ready_state)
{
auto maybe_connection = m_websockets.get(websocket_id);
if (maybe_connection.has_value()) {
VERIFY(ready_state <= static_cast<u32>(WebSocket::ReadyState::Closed));
maybe_connection.value()->set_ready_state(static_cast<WebSocket::ReadyState>(ready_state));
}
}

void RequestClient::websocket_subprotocol(i64 websocket_id, ByteString const& subprotocol)
{
auto maybe_connection = m_websockets.get(websocket_id);
if (maybe_connection.has_value()) {
maybe_connection.value()->set_subprotocol_in_use(subprotocol);
}
}

void RequestClient::websocket_certificate_requested(i64 websocket_id)
{
auto maybe_connection = m_websockets.get(connection_id);
auto maybe_connection = m_websockets.get(websocket_id);
if (maybe_connection.has_value())
maybe_connection.value()->did_request_certificates({});
}
Expand Down
16 changes: 10 additions & 6 deletions Userland/Libraries/LibRequests/RequestClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,18 @@ class RequestClient final
virtual void certificate_requested(i32) override;
virtual void headers_became_available(i32, HTTP::HeaderMap const&, Optional<u32> const&) override;

virtual void websocket_connected(i32) override;
virtual void websocket_received(i32, bool, ByteBuffer const&) override;
virtual void websocket_errored(i32, i32) override;
virtual void websocket_closed(i32, u16, ByteString const&, bool) override;
virtual void websocket_certificate_requested(i32) override;
virtual void websocket_connected(i64 websocket_id) override;
virtual void websocket_received(i64 websocket_id, bool, ByteBuffer const&) override;
virtual void websocket_errored(i64 websocket_id, i32) override;
virtual void websocket_closed(i64 websocket_id, u16, ByteString const&, bool) override;
virtual void websocket_ready_state_changed(i64 websocket_id, u32 ready_state) override;
virtual void websocket_subprotocol(i64 websocket_id, ByteString const& subprotocol) override;
virtual void websocket_certificate_requested(i64 websocket_id) override;

HashMap<i32, RefPtr<Request>> m_requests;
HashMap<i32, NonnullRefPtr<WebSocket>> m_websockets;
HashMap<i64, NonnullRefPtr<WebSocket>> m_websockets;

i64 m_next_websocket_id { 0 };
};

}
24 changes: 17 additions & 7 deletions Userland/Libraries/LibRequests/WebSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,35 @@

namespace Requests {

WebSocket::WebSocket(RequestClient& client, i32 connection_id)
WebSocket::WebSocket(RequestClient& client, i64 connection_id)
: m_client(client)
, m_connection_id(connection_id)
, m_websocket_id(connection_id)
{
}

WebSocket::ReadyState WebSocket::ready_state()
{
return static_cast<WebSocket::ReadyState>(m_client->websocket_ready_state(m_connection_id));
return m_ready_state;
}

void WebSocket::set_ready_state(ReadyState ready_state)
{
m_ready_state = ready_state;
}

ByteString WebSocket::subprotocol_in_use()
{
return m_client->websocket_subprotocol_in_use(m_connection_id);
return m_subprotocol;
}

void WebSocket::set_subprotocol_in_use(ByteString subprotocol)
{
m_subprotocol = move(subprotocol);
}

void WebSocket::send(ByteBuffer binary_or_text_message, bool is_text)
{
m_client->async_websocket_send(m_connection_id, is_text, move(binary_or_text_message));
m_client->async_websocket_send(m_websocket_id, is_text, move(binary_or_text_message));
}

void WebSocket::send(StringView text_message)
Expand All @@ -37,7 +47,7 @@ void WebSocket::send(StringView text_message)

void WebSocket::close(u16 code, ByteString reason)
{
m_client->async_websocket_close(m_connection_id, code, move(reason));
m_client->async_websocket_close(m_websocket_id, code, move(reason));
}

void WebSocket::did_open(Badge<RequestClient>)
Expand Down Expand Up @@ -68,7 +78,7 @@ void WebSocket::did_request_certificates(Badge<RequestClient>)
{
if (on_certificate_requested) {
auto result = on_certificate_requested();
if (!m_client->websocket_set_certificate(m_connection_id, result.certificate, result.key))
if (!m_client->websocket_set_certificate(m_websocket_id, result.certificate, result.key))
dbgln("WebSocket: set_certificate failed");
}
}
Expand Down
14 changes: 9 additions & 5 deletions Userland/Libraries/LibRequests/WebSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,18 @@ class WebSocket : public RefCounted<WebSocket> {
Closed = 3,
};

static NonnullRefPtr<WebSocket> create_from_id(Badge<RequestClient>, RequestClient& client, i32 connection_id)
static NonnullRefPtr<WebSocket> create_from_id(Badge<RequestClient>, RequestClient& client, i64 websocket_id)
{
return adopt_ref(*new WebSocket(client, connection_id));
return adopt_ref(*new WebSocket(client, websocket_id));
}

int id() const { return m_connection_id; }
i64 id() const { return m_websocket_id; }

ReadyState ready_state();
void set_ready_state(ReadyState);

ByteString subprotocol_in_use();
void set_subprotocol_in_use(ByteString);

void send(ByteBuffer binary_or_text_message, bool is_text);
void send(StringView text_message);
Expand All @@ -72,9 +74,11 @@ class WebSocket : public RefCounted<WebSocket> {
void did_request_certificates(Badge<RequestClient>);

private:
explicit WebSocket(RequestClient&, i32 connection_id);
explicit WebSocket(RequestClient&, i64 websocket_id);
WeakPtr<RequestClient> m_client;
int m_connection_id { -1 };
ReadyState m_ready_state { ReadyState::Connecting };
ByteString m_subprotocol;
i64 m_websocket_id { -1 };
};

}
33 changes: 23 additions & 10 deletions Userland/Libraries/LibWebSocket/WebSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ void WebSocket::start()
m_impl->on_connected = [this] {
if (m_state != WebSocket::InternalState::EstablishingProtocolConnection)
return;
m_state = WebSocket::InternalState::SendingClientHandshake;
set_state(WebSocket::InternalState::SendingClientHandshake);
send_client_handshake();
drain_read();
};
m_impl->on_ready_to_read = [this] {
drain_read();
};
m_state = WebSocket::InternalState::EstablishingProtocolConnection;
set_state(WebSocket::InternalState::EstablishingProtocolConnection);
m_impl->connect(m_connection);
}

Expand Down Expand Up @@ -100,15 +100,15 @@ void WebSocket::close(u16 code, ByteString const& message)
case InternalState::SendingClientHandshake:
case InternalState::WaitingForServerHandshake:
// FIXME: Fail the connection.
m_state = InternalState::Closing;
set_state(InternalState::Closing);
break;
case InternalState::Open: {
auto message_bytes = message.bytes();
auto close_payload = ByteBuffer::create_uninitialized(message_bytes.size() + 2).release_value_but_fixme_should_propagate_errors(); // FIXME: Handle possible OOM situation.
close_payload.overwrite(0, (u8*)&code, 2);
close_payload.overwrite(2, message_bytes.data(), message_bytes.size());
send_frame(WebSocket::OpCode::ConnectionClose, close_payload, true);
m_state = InternalState::Closing;
set_state(InternalState::Closing);
break;
}
default:
Expand All @@ -120,7 +120,7 @@ void WebSocket::drain_read()
{
if (m_impl->eof()) {
// The connection got closed by the server
m_state = WebSocket::InternalState::Closed;
set_state(WebSocket::InternalState::Closed);
notify_close(m_last_close_code, m_last_close_message, true);
discard_connection();
return;
Expand Down Expand Up @@ -218,7 +218,7 @@ void WebSocket::send_client_handshake()

builder.append("\r\n"sv);

m_state = WebSocket::InternalState::WaitingForServerHandshake;
set_state(WebSocket::InternalState::WaitingForServerHandshake);
auto success = m_impl->send(builder.string_view().bytes());
VERIFY(success);
}
Expand Down Expand Up @@ -282,7 +282,7 @@ void WebSocket::read_server_handshake()
return;
}

m_state = WebSocket::InternalState::Open;
set_state(WebSocket::InternalState::Open);
notify_open();
return;
}
Expand Down Expand Up @@ -400,7 +400,7 @@ void WebSocket::read_frame()
auto head_bytes = get_buffered_bytes(2);
if (head_bytes.is_null() || head_bytes.is_empty()) {
// The connection got closed.
m_state = WebSocket::InternalState::Closed;
set_state(WebSocket::InternalState::Closed);
notify_close(m_last_close_code, m_last_close_message, true);
discard_connection();
return;
Expand Down Expand Up @@ -487,7 +487,7 @@ void WebSocket::read_frame()
m_last_close_code = (((u16)(payload[0] & 0xff) << 8) | ((u16)(payload[1] & 0xff)));
m_last_close_message = ByteString(ReadonlyBytes(payload.offset_pointer(2), payload.size() - 2));
}
m_state = WebSocket::InternalState::Closing;
set_state(WebSocket::InternalState::Closing);
return;
}
if (op_code == WebSocket::OpCode::Ping) {
Expand Down Expand Up @@ -608,7 +608,7 @@ void WebSocket::send_frame(WebSocket::OpCode op_code, ReadonlyBytes payload, boo

void WebSocket::fatal_error(WebSocket::Error error)
{
m_state = WebSocket::InternalState::Errored;
set_state(WebSocket::InternalState::Errored);
notify_error(error);
discard_connection();
}
Expand Down Expand Up @@ -653,4 +653,17 @@ void WebSocket::notify_message(Message message)
on_message(move(message));
}

void WebSocket::set_state(InternalState state)
{
if (m_state == state)
return;
auto old_ready_state = ready_state();
m_state = state;
auto new_ready_state = ready_state();
if (old_ready_state != new_ready_state) {
if (on_ready_state_change)
on_ready_state_change(ready_state());
}
}

}
4 changes: 4 additions & 0 deletions Userland/Libraries/LibWebSocket/WebSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class WebSocket final : public Core::EventReceiver {
Function<void()> on_open;
Function<void(u16 code, ByteString reason, bool was_clean)> on_close;
Function<void(Message message)> on_message;
Function<void(ReadyState)> on_ready_state_change;
Function<void(ByteString)> on_subprotocol;

enum class Error {
CouldNotEstablishConnection,
Expand Down Expand Up @@ -97,6 +99,8 @@ class WebSocket final : public Core::EventReceiver {

InternalState m_state { InternalState::NotStarted };

void set_state(InternalState);

ByteString m_subprotocol_in_use { ByteString::empty() };

ByteString m_websocket_key;
Expand Down
Loading

0 comments on commit e205723

Please sign in to comment.