Skip to content

Commit

Permalink
Deduplicate HTTPSession implementation (#6823)
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyashton authored Feb 12, 2025
1 parent 6866055 commit 1649f0c
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 202 deletions.
108 changes: 108 additions & 0 deletions src/enclave/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "ccf/node/session.h"
#include "ds/thread_messaging.h"
#include "enclave/tls_session.h"
#include "tcp/msg_types.h"

#include <span>
Expand Down Expand Up @@ -53,4 +54,111 @@ namespace ccf

virtual void handle_incoming_data_thread(std::vector<uint8_t>&& data) = 0;
};

class EncryptedSession : public ThreadedSession
{
public:
virtual bool parse(std::span<const uint8_t> data) = 0;

protected:
std::shared_ptr<ccf::TLSSession> tls_io;
::tcp::ConnID session_id;

EncryptedSession(
::tcp::ConnID session_id_,
ringbuffer::AbstractWriterFactory& writer_factory,
std::unique_ptr<ccf::tls::Context> ctx) :
ThreadedSession(session_id_),
tls_io(std::make_shared<ccf::TLSSession>(
session_id_, writer_factory, std::move(ctx))),
session_id(session_id_)
{}

public:
void send_data(std::span<const uint8_t> data) override
{
tls_io->send_raw(data.data(), data.size());
}

void close_session() override
{
tls_io->close();
}

void handle_incoming_data_thread(std::vector<uint8_t>&& data) override
{
tls_io->recv_buffered(data.data(), data.size());

LOG_TRACE_FMT("recv called with {} bytes", data.size());

// Try to parse all incoming data, reusing the vector we were just passed
// for storage. Increase the size if the received vector was too small
// (for the case where this chunk is very small, but we had some previous
// data to continue reading).
constexpr auto min_read_block_size = 4096;
if (data.size() < min_read_block_size)
{
data.resize(min_read_block_size);
}

auto n_read = tls_io->read(data.data(), data.size(), false);

while (true)
{
if (n_read == 0)
{
return;
}

LOG_TRACE_FMT("Going to parse {} bytes", n_read);

bool cont = parse({data.data(), n_read});
if (!cont)
{
return;
}

// Used all provided bytes - check if more are available
n_read = tls_io->read(data.data(), data.size(), false);
}
}
};

class UnencryptedSession : public ccf::ThreadedSession
{
public:
virtual bool parse(std::span<const uint8_t> data) = 0;

protected:
::tcp::ConnID session_id;
ringbuffer::WriterPtr to_host;

UnencryptedSession(
::tcp::ConnID session_id_,
ringbuffer::AbstractWriterFactory& writer_factory_) :
ccf::ThreadedSession(session_id_),
session_id(session_id_),
to_host(writer_factory_.create_writer_to_outside())
{}

void send_data(std::span<const uint8_t> data) override
{
RINGBUFFER_WRITE_MESSAGE(
::tcp::tcp_outbound,
to_host,
session_id,
serializer::ByteRange{data.data(), data.size()});
}

void close_session() override
{
RINGBUFFER_WRITE_MESSAGE(
::tcp::tcp_stop, to_host, session_id, std::string("Session closed"));
}

void handle_incoming_data_thread(std::vector<uint8_t>&& data) override
{
parse(data);
}
};
}
1 change: 0 additions & 1 deletion src/enclave/tls_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include "ds/messaging.h"
#include "ds/ring_buffer.h"
#include "ds/thread_messaging.h"
#include "enclave/session.h"
#include "tcp/msg_types.h"
#include "tls/context.h"
#include "tls/tls.h"
Expand Down
75 changes: 4 additions & 71 deletions src/http/http2_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,76 +13,7 @@

namespace http
{
class HTTP2Session : public ccf::ThreadedSession
{
protected:
std::shared_ptr<ccf::TLSSession> tls_io;
std::shared_ptr<ErrorReporter> error_reporter;
::tcp::ConnID session_id;

HTTP2Session(
::tcp::ConnID session_id_,
ringbuffer::AbstractWriterFactory& writer_factory,
std::unique_ptr<ccf::tls::Context> ctx,
const std::shared_ptr<ErrorReporter>& error_reporter = nullptr) :
ccf::ThreadedSession(session_id_),
tls_io(std::make_shared<ccf::TLSSession>(
session_id_, writer_factory, std::move(ctx))),
error_reporter(error_reporter),
session_id(session_id_)
{}

public:
virtual bool parse(std::span<const uint8_t> data) = 0;

void send_data(std::span<const uint8_t> data) override
{
tls_io->send_raw(data.data(), data.size());
}

void close_session() override
{
tls_io->close();
}

void handle_incoming_data_thread(std::vector<uint8_t>&& data) override
{
tls_io->recv_buffered(data.data(), data.size());

LOG_TRACE_FMT("recv called with {} bytes", data.size());

// Try to parse all incoming data, reusing the vector we were just passed
// for storage. Increase the size if the received vector was too small
// (for the case where this chunk is very small, but we had some previous
// data to continue reading).
constexpr auto min_read_block_size = 4096;
if (data.size() < min_read_block_size)
{
data.resize(min_read_block_size);
}

auto n_read = tls_io->read(data.data(), data.size(), false);

while (true)
{
if (n_read == 0)
{
return;
}

LOG_TRACE_FMT("Going to parse {} bytes", n_read);

bool cont = parse({data.data(), n_read});
if (!cont)
{
return;
}

// Used all provided bytes - check if more are available
n_read = tls_io->read(data.data(), data.size(), false);
}
}
};
using HTTP2Session = ccf::EncryptedSession;

struct HTTP2SessionContext : public ccf::SessionContext
{
Expand Down Expand Up @@ -251,6 +182,7 @@ namespace http

std::shared_ptr<ccf::RPCMap> rpc_map;
std::shared_ptr<ccf::RpcHandler> handler;
std::shared_ptr<ErrorReporter> error_reporter;
ccf::ListenInterfaceID interface_id;

http::ResponderLookup& responder_lookup;
Expand Down Expand Up @@ -317,10 +249,11 @@ namespace http
const ccf::http::ParserConfiguration& configuration,
const std::shared_ptr<ErrorReporter>& error_reporter,
http::ResponderLookup& responder_lookup_) :
HTTP2Session(session_id_, writer_factory, std::move(ctx), error_reporter),
HTTP2Session(session_id_, writer_factory, std::move(ctx)),
server_parser(
std::make_shared<http2::ServerParser>(*this, configuration)),
rpc_map(rpc_map),
error_reporter(error_reporter),
interface_id(interface_id),
responder_lookup(responder_lookup_)
{
Expand Down
135 changes: 5 additions & 130 deletions src/http/http_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,76 +13,7 @@

namespace http
{
class HTTPSession : public ccf::ThreadedSession
{
protected:
std::shared_ptr<ccf::TLSSession> tls_io;
std::shared_ptr<ErrorReporter> error_reporter;
::tcp::ConnID session_id;

HTTPSession(
::tcp::ConnID session_id_,
ringbuffer::AbstractWriterFactory& writer_factory,
std::unique_ptr<ccf::tls::Context> ctx,
const std::shared_ptr<ErrorReporter>& error_reporter = nullptr) :
ccf::ThreadedSession(session_id_),
tls_io(std::make_shared<ccf::TLSSession>(
session_id_, writer_factory, std::move(ctx))),
error_reporter(error_reporter),
session_id(session_id_)
{}

public:
virtual bool parse(std::span<const uint8_t> data) = 0;

void send_data(std::span<const uint8_t> data) override
{
tls_io->send_raw(data.data(), data.size());
}

void close_session() override
{
tls_io->close();
}

void handle_incoming_data_thread(std::vector<uint8_t>&& data) override
{
tls_io->recv_buffered(data.data(), data.size());

LOG_TRACE_FMT("recv called with {} bytes", data.size());

// Try to parse all incoming data, reusing the vector we were just passed
// for storage. Increase the size if the received vector was too small
// (for the case where this chunk is very small, but we had some previous
// data to continue reading).
constexpr auto min_read_block_size = 4096;
if (data.size() < min_read_block_size)
{
data.resize(min_read_block_size);
}

auto n_read = tls_io->read(data.data(), data.size(), false);

while (true)
{
if (n_read == 0)
{
return;
}

LOG_TRACE_FMT("Going to parse {} bytes", n_read);

bool cont = parse({data.data(), n_read});
if (!cont)
{
return;
}

// Used all provided bytes - check if more are available
n_read = tls_io->read(data.data(), data.size(), false);
}
}
};
using HTTPSession = ccf::EncryptedSession;

class HTTPServerSession : public HTTPSession,
public http::RequestProcessor,
Expand All @@ -94,6 +25,7 @@ namespace http
std::shared_ptr<ccf::RPCMap> rpc_map;
std::shared_ptr<ccf::RpcHandler> handler;
std::shared_ptr<ccf::SessionContext> session_ctx;
std::shared_ptr<ErrorReporter> error_reporter;
ccf::ListenInterfaceID interface_id;

public:
Expand All @@ -105,9 +37,10 @@ namespace http
std::unique_ptr<ccf::tls::Context> ctx,
const ccf::http::ParserConfiguration& configuration,
const std::shared_ptr<ErrorReporter>& error_reporter = nullptr) :
HTTPSession(session_id_, writer_factory, std::move(ctx), error_reporter),
HTTPSession(session_id_, writer_factory, std::move(ctx)),
request_parser(*this, configuration),
rpc_map(rpc_map),
error_reporter(error_reporter),
interface_id(interface_id)
{}

Expand Down Expand Up @@ -396,65 +329,7 @@ namespace http
}
};

class UnencryptedHTTPSession : public ccf::ThreadedSession
{
protected:
std::shared_ptr<ErrorReporter> error_reporter;
::tcp::ConnID session_id;
ringbuffer::AbstractWriterFactory& writer_factory;
ringbuffer::WriterPtr to_host;
size_t execution_thread;

UnencryptedHTTPSession(
::tcp::ConnID session_id_,
ringbuffer::AbstractWriterFactory& writer_factory_,
const std::shared_ptr<ErrorReporter>& error_reporter = nullptr) :
ccf::ThreadedSession(session_id_),
error_reporter(error_reporter),
session_id(session_id_),
writer_factory(writer_factory_),
to_host(writer_factory.create_writer_to_outside())
{
execution_thread =
threading::ThreadMessaging::instance().get_execution_thread(
session_id_);
}

public:
virtual bool parse(std::span<const uint8_t> data) = 0;

void send_data(std::span<const uint8_t> data) override
{
if (ccf::threading::get_current_thread_id() != execution_thread)
{
throw std::logic_error(
"Called UnencryptedHTTPSession::send_data "
"from wrong thread");
}
RINGBUFFER_WRITE_MESSAGE(
::tcp::tcp_outbound,
to_host,
session_id,
serializer::ByteRange{data.data(), data.size()});
}

void close_session() override
{
if (ccf::threading::get_current_thread_id() != execution_thread)
{
throw std::logic_error(
"Called UnencryptedHTTPSession::close_session "
"from wrong thread");
}
RINGBUFFER_WRITE_MESSAGE(
::tcp::tcp_stop, to_host, session_id, std::string("Session closed"));
}

void handle_incoming_data_thread(std::vector<uint8_t>&& data) override
{
parse(data);
}
};
using UnencryptedHTTPSession = ccf::UnencryptedSession;

class UnencryptedHTTPClientSession : public UnencryptedHTTPSession,
public ccf::ClientSession,
Expand Down

0 comments on commit 1649f0c

Please sign in to comment.