Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/crow/app.h
Original file line number Diff line number Diff line change
Expand Up @@ -627,12 +627,12 @@ namespace crow
}


void add_websocket(crow::websocket::connection* conn)
void add_websocket(std::shared_ptr<websocket::connection> conn)
{
websockets_.push_back(conn);
}

void remove_websocket(crow::websocket::connection* conn)
void remove_websocket(std::shared_ptr<websocket::connection> conn)
{
websockets_.erase(std::remove(websockets_.begin(), websockets_.end(), conn), websockets_.end());
}
Expand Down Expand Up @@ -846,7 +846,7 @@ namespace crow
bool server_started_{false};
std::condition_variable cv_started_;
std::mutex start_mutex_;
std::vector<crow::websocket::connection*> websockets_;
std::vector<std::shared_ptr<websocket::connection>> websockets_;
};

/// \brief Alias of Crow<Middlewares...>. Useful if you want
Expand Down
8 changes: 5 additions & 3 deletions include/crow/routing.h
Original file line number Diff line number Diff line change
Expand Up @@ -445,17 +445,19 @@ namespace crow // NOTE: Already documented in "crow/app.h"
void handle_upgrade(const request& req, response&, SocketAdaptor&& adaptor) override
{
max_payload_ = max_payload_override_ ? max_payload_ : app_->websocket_max_payload();
new crow::websocket::Connection<SocketAdaptor, App>(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_, mirror_protocols_);
crow::websocket::Connection<SocketAdaptor, App>::create(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_, mirror_protocols_);
}

void handle_upgrade(const request& req, response&, UnixSocketAdaptor&& adaptor) override
{
max_payload_ = max_payload_override_ ? max_payload_ : app_->websocket_max_payload();
new crow::websocket::Connection<UnixSocketAdaptor, App>(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_, mirror_protocols_);
crow::websocket::Connection<UnixSocketAdaptor, App>::create(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_, mirror_protocols_);
}

#ifdef CROW_ENABLE_SSL
void handle_upgrade(const request& req, response&, SSLAdaptor&& adaptor) override
{
new crow::websocket::Connection<SSLAdaptor, App>(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_, mirror_protocols_);
crow::websocket::Connection<SSLAdaptor, App>::create(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_, mirror_protocols_);
}
#endif

Expand Down
156 changes: 79 additions & 77 deletions include/crow/websocket.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include <array>
#include <memory>
#include "crow/logging.h"
#include "crow/socket_adaptors.h"
#include "crow/http_request.h"
Expand Down Expand Up @@ -102,36 +103,31 @@ namespace crow // NOTE: Already documented in "crow/app.h"
/// A websocket connection.

template<typename Adaptor, typename Handler>
class Connection : public connection
class Connection : public connection, public std::enable_shared_from_this<Connection<Adaptor, Handler>>
{
public:
/// Constructor for a connection.

/// Factory for a connection.
///
/// Requires a request with an "Upgrade: websocket" header.<br>
/// Automatically handles the handshake.
Connection(const crow::request& req, Adaptor&& adaptor, Handler* handler,
uint64_t max_payload, const std::vector<std::string>& subprotocols,
std::function<void(crow::websocket::connection&)> open_handler,
std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler,
std::function<void(crow::websocket::connection&, const std::string&, uint16_t)> close_handler,
std::function<void(crow::websocket::connection&, const std::string&)> error_handler,
std::function<bool(const crow::request&, void**)> accept_handler,
bool mirror_protocols):
adaptor_(std::move(adaptor)),
handler_(handler),
max_payload_bytes_(max_payload),
open_handler_(std::move(open_handler)),
message_handler_(std::move(message_handler)),
close_handler_(std::move(close_handler)),
error_handler_(std::move(error_handler)),
accept_handler_(std::move(accept_handler))
static void create(const crow::request& req, Adaptor adaptor, Handler* handler,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shoudn't std::make_shared(...) do the same as this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh yes this is a good question, I should have probably explained this design choice.

To correctly keep track of the open websockets::connections and destory them, we need to register them in App::websockets_ once they are created.
Inside the constructor the shared_ptr is not available yet so an option would be to add a method in the Connection class to register itself to the App::websockets_. This would have been an easier solution but would have left the responsibility to the user to actually register the websocket::connection after it is instantiated by calling the registration method.

I think that the registration should instead happen by design. To achieve this I made the constructor private so that it's not possible to directly instantiate a websocket::connection. The instantiation is instead possible through this factory method. This factory instantiate, validates and registers the connection.

Before, the constructor was taking care of validation as well while the registration was missing causing the memory leak.

Hope it was a clear explanation :)

uint64_t max_payload, const std::vector<std::string>& subprotocols,
std::function<void(crow::websocket::connection&)> open_handler,
std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler,
std::function<void(crow::websocket::connection&, const std::string&, uint16_t)> close_handler,
std::function<void(crow::websocket::connection&, const std::string&)> error_handler,
std::function<bool(const crow::request&, void**)> accept_handler,
bool mirror_protocols)
{
auto conn = std::shared_ptr<Connection>(
new Connection(std::move(adaptor), handler, max_payload,
std::move(open_handler), std::move(message_handler), std::move(close_handler),
std::move(error_handler), std::move(accept_handler)));

// Perform handshake validation
if (!utility::string_equals(req.get_header_value("upgrade"), "websocket"))
{
adaptor_.close();
handler_->remove_websocket(this);
delete this;
conn->adaptor_.close();
return;
}

Expand All @@ -142,26 +138,24 @@ namespace crow // NOTE: Already documented in "crow/app.h"
auto subprotocol = utility::find_first_of(subprotocols.begin(), subprotocols.end(), requested_subprotocols.begin(), requested_subprotocols.end());
if (subprotocol != subprotocols.end())
{
subprotocol_ = *subprotocol;
conn->subprotocol_ = *subprotocol;
}
}

if (mirror_protocols & !requested_subprotocols_header.empty())
{
subprotocol_ = requested_subprotocols_header;
conn->subprotocol_ = requested_subprotocols_header;
}

if (accept_handler_)
if (conn->accept_handler_)
{
void* ud = nullptr;
if (!accept_handler_(req, &ud))
if (!conn->accept_handler_(req, &ud))
{
adaptor_.close();
handler_->remove_websocket(this);
delete this;
conn->adaptor_.close();
return;
}
userdata(ud);
conn->userdata(ud);
}

// Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Expand All @@ -172,22 +166,11 @@ namespace crow // NOTE: Already documented in "crow/app.h"
uint8_t digest[20];
s.getDigestBytes(digest);

start(crow::utility::base64encode((unsigned char*)digest, 20));
conn->handler_->add_websocket(conn);
conn->start(crow::utility::base64encode((unsigned char*)digest, 20));
}

~Connection() noexcept override
{
// Do not modify anchor_ here since writing shared_ptr is not atomic.
auto watch = std::weak_ptr<void>{anchor_};

// Wait until all unhandled asynchronous operations to join.
// As the deletion occurs inside 'check_destroy()', which already locks
// anchor, use count can be 1 on valid deletion context.
while (watch.use_count() > 2) // 1 for 'check_destroy() routine', 1 for 'this->anchor_'
{
std::this_thread::yield();
}
}
~Connection() noexcept override = default;

template<typename Callable>
struct WeakWrappedMessage
Expand Down Expand Up @@ -717,38 +700,38 @@ namespace crow // NOTE: Already documented in "crow/app.h"
/// Also destroys the object if the Close flag is set.
void do_write()
{
if (sending_buffers_.empty())
if (write_buffers_.empty()) return;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was checking the wrong buffer, update to check write_buggers


sending_buffers_.swap(write_buffers_);
std::vector<asio::const_buffer> buffers;
buffers.reserve(sending_buffers_.size());
for (auto& s : sending_buffers_)
{
sending_buffers_.swap(write_buffers_);
std::vector<asio::const_buffer> buffers;
buffers.reserve(sending_buffers_.size());
for (auto& s : sending_buffers_)
{
buffers.emplace_back(asio::buffer(s));
}
auto watch = std::weak_ptr<void>{anchor_};
asio::async_write(
adaptor_.socket(), buffers,
[&, watch](const error_code& ec, std::size_t /*bytes_transferred*/) {
if (!ec && !close_connection_)
{
sending_buffers_.clear();
if (!write_buffers_.empty())
do_write();
if (has_sent_close_)
close_connection_ = true;
}
else
{
auto anchor = watch.lock();
if (anchor == nullptr) { return; }

sending_buffers_.clear();
close_connection_ = true;
check_destroy();
}
});
buffers.emplace_back(asio::buffer(s));
}
auto watch = std::weak_ptr<void>{anchor_};
asio::async_write(
adaptor_.socket(), buffers,
[this, watch](const error_code& ec, std::size_t /*bytes_transferred*/) {
auto anchor = watch.lock();
if (anchor == nullptr)
return;

if (!ec && !close_connection_)
{
sending_buffers_.clear();
if (!write_buffers_.empty())
do_write();
if (has_sent_close_)
close_connection_ = true;
}
else
{
sending_buffers_.clear();
close_connection_ = true;
check_destroy();
}
});
}

/// Destroy the Connection.
Expand All @@ -757,11 +740,14 @@ namespace crow // NOTE: Already documented in "crow/app.h"
// Note that if the close handler was not yet called at this point we did not receive a close packet (or send one)
// and thus we use ClosedAbnormally unless instructed otherwise
if (!is_close_handler_called_)
{
if (close_handler_)
{
close_handler_(*this, "uncleanly", code);
handler_->remove_websocket(this);
if (sending_buffers_.empty() && !is_reading)
delete this;
}
}

handler_->remove_websocket(this->shared_from_this());
}


Expand Down Expand Up @@ -796,6 +782,22 @@ namespace crow // NOTE: Already documented in "crow/app.h"
}

private:
Connection(Adaptor&& adaptor, Handler* handler, uint64_t max_payload,
std::function<void(crow::websocket::connection&)> open_handler,
std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler,
std::function<void(crow::websocket::connection&, const std::string&, uint16_t)> close_handler,
std::function<void(crow::websocket::connection&, const std::string&)> error_handler,
std::function<bool(const crow::request&, void**)> accept_handler):
adaptor_(std::move(adaptor)),
handler_(handler),
max_payload_bytes_(max_payload),
open_handler_(std::move(open_handler)),
message_handler_(std::move(message_handler)),
close_handler_(std::move(close_handler)),
error_handler_(std::move(error_handler)),
accept_handler_(std::move(accept_handler))
{}

Adaptor adaptor_;
Handler* handler_;

Expand Down