Skip to content

Commit 307b729

Browse files
authored
Accept large data transfer over SSL (#1261)
* Add large data transfer test * Replace `SSL_read` and `SSL_write` with `ex` functions * Reflect review comment * Fix return value of `SSLSocketStream::read/write` * Fix return value in the case of `SSL_ERROR_ZERO_RETURN` * Disable `LargeDataTransfer` test due to OoM in CI
1 parent 696239d commit 307b729

File tree

2 files changed

+87
-42
lines changed

2 files changed

+87
-42
lines changed

httplib.h

Lines changed: 43 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7221,62 +7221,63 @@ inline bool SSLSocketStream::is_writable() const {
72217221
}
72227222

72237223
inline ssize_t SSLSocketStream::read(char *ptr, size_t size) {
7224+
size_t readbytes = 0;
72247225
if (SSL_pending(ssl_) > 0) {
7225-
return SSL_read(ssl_, ptr, static_cast<int>(size));
7226-
} else if (is_readable()) {
7227-
auto ret = SSL_read(ssl_, ptr, static_cast<int>(size));
7228-
if (ret < 0) {
7229-
auto err = SSL_get_error(ssl_, ret);
7230-
int n = 1000;
7226+
auto ret = SSL_read_ex(ssl_, ptr, size, &readbytes);
7227+
if (ret == 1) { return static_cast<ssize_t>(readbytes); }
7228+
if (SSL_get_error(ssl_, ret) == SSL_ERROR_ZERO_RETURN) { return 0; }
7229+
return -1;
7230+
}
7231+
if (!is_readable()) { return -1; }
7232+
7233+
auto ret = SSL_read_ex(ssl_, ptr, size, &readbytes);
7234+
if (ret == 1) { return static_cast<ssize_t>(readbytes); }
7235+
auto err = SSL_get_error(ssl_, ret);
7236+
int n = 1000;
72317237
#ifdef _WIN32
7232-
while (--n >= 0 && (err == SSL_ERROR_WANT_READ ||
7233-
(err == SSL_ERROR_SYSCALL &&
7234-
WSAGetLastError() == WSAETIMEDOUT))) {
7238+
while (--n >= 0 &&
7239+
(err == SSL_ERROR_WANT_READ ||
7240+
(err == SSL_ERROR_SYSCALL && WSAGetLastError() == WSAETIMEDOUT))) {
72357241
#else
7236-
while (--n >= 0 && err == SSL_ERROR_WANT_READ) {
7242+
while (--n >= 0 && err == SSL_ERROR_WANT_READ) {
72377243
#endif
7238-
if (SSL_pending(ssl_) > 0) {
7239-
return SSL_read(ssl_, ptr, static_cast<int>(size));
7240-
} else if (is_readable()) {
7241-
std::this_thread::sleep_for(std::chrono::milliseconds(1));
7242-
ret = SSL_read(ssl_, ptr, static_cast<int>(size));
7243-
if (ret >= 0) { return ret; }
7244-
err = SSL_get_error(ssl_, ret);
7245-
} else {
7246-
return -1;
7247-
}
7248-
}
7244+
if (SSL_pending(ssl_) > 0) {
7245+
ret = SSL_read_ex(ssl_, ptr, size, &readbytes);
7246+
if (ret == 1) { return static_cast<ssize_t>(readbytes); }
7247+
if (SSL_get_error(ssl_, ret) == SSL_ERROR_ZERO_RETURN) { return 0; }
7248+
return -1;
72497249
}
7250-
return ret;
7250+
if (!is_readable()) { return -1; }
7251+
std::this_thread::sleep_for(std::chrono::milliseconds(1));
7252+
ret = SSL_read_ex(ssl_, ptr, size, &readbytes);
7253+
if (ret == 1) { return static_cast<ssize_t>(readbytes); }
7254+
err = SSL_get_error(ssl_, ret);
72517255
}
7256+
if (err == SSL_ERROR_ZERO_RETURN) { return 0; }
72527257
return -1;
72537258
}
72547259

72557260
inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) {
7256-
if (is_writable()) {
7257-
auto ret = SSL_write(ssl_, ptr, static_cast<int>(size));
7258-
if (ret < 0) {
7259-
auto err = SSL_get_error(ssl_, ret);
7260-
int n = 1000;
7261+
if (!is_writable()) { return -1; }
7262+
size_t written = 0;
7263+
auto ret = SSL_write_ex(ssl_, ptr, size, &written);
7264+
if (ret == 1) { return static_cast<ssize_t>(written); }
7265+
auto err = SSL_get_error(ssl_, ret);
7266+
int n = 1000;
72617267
#ifdef _WIN32
7262-
while (--n >= 0 && (err == SSL_ERROR_WANT_WRITE ||
7263-
(err == SSL_ERROR_SYSCALL &&
7264-
WSAGetLastError() == WSAETIMEDOUT))) {
7268+
while (--n >= 0 &&
7269+
(err == SSL_ERROR_WANT_WRITE ||
7270+
(err == SSL_ERROR_SYSCALL && WSAGetLastError() == WSAETIMEDOUT))) {
72657271
#else
7266-
while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) {
7272+
while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) {
72677273
#endif
7268-
if (is_writable()) {
7269-
std::this_thread::sleep_for(std::chrono::milliseconds(1));
7270-
ret = SSL_write(ssl_, ptr, static_cast<int>(size));
7271-
if (ret >= 0) { return ret; }
7272-
err = SSL_get_error(ssl_, ret);
7273-
} else {
7274-
return -1;
7275-
}
7276-
}
7277-
}
7278-
return ret;
7274+
if (!is_writable()) { return -1; }
7275+
std::this_thread::sleep_for(std::chrono::milliseconds(1));
7276+
ret = SSL_write_ex(ssl_, ptr, size, &written);
7277+
if (ret == 1) { return static_cast<ssize_t>(written); }
7278+
err = SSL_get_error(ssl_, ret);
72797279
}
7280+
if (err == SSL_ERROR_ZERO_RETURN) { return 0; }
72807281
return -1;
72817282
}
72827283

test/test.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4660,6 +4660,50 @@ TEST(SSLClientServerTest, CustomizeServerSSLCtx) {
46604660

46614661
t.join();
46624662
}
4663+
4664+
// Disabled due to the out-of-memory problem on GitHub Actions Workflows
4665+
TEST(SSLClientServerTest, DISABLED_LargeDataTransfer) {
4666+
4667+
// prepare large data
4668+
std::random_device seed_gen;
4669+
std::mt19937 random(seed_gen());
4670+
constexpr auto large_size_byte = 2147483648UL + 1048576UL; // 2GiB + 1MiB
4671+
std::vector<std::uint32_t> binary(large_size_byte / sizeof(std::uint32_t));
4672+
std::generate(binary.begin(), binary.end(), [&random]() { return random(); });
4673+
4674+
// server
4675+
SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE);
4676+
ASSERT_TRUE(svr.is_valid());
4677+
4678+
svr.Post("/binary", [&](const Request &req, Response &res) {
4679+
EXPECT_EQ(large_size_byte, req.body.size());
4680+
EXPECT_EQ(0, std::memcmp(binary.data(), req.body.data(), large_size_byte));
4681+
res.set_content(req.body, "application/octet-stream");
4682+
});
4683+
4684+
auto listen_thread = std::thread([&svr]() { svr.listen("localhost", PORT); });
4685+
while (!svr.is_running()) {
4686+
std::this_thread::sleep_for(std::chrono::milliseconds(1));
4687+
}
4688+
4689+
// client POST
4690+
SSLClient cli("localhost", PORT);
4691+
cli.enable_server_certificate_verification(false);
4692+
cli.set_read_timeout(std::chrono::seconds(100));
4693+
cli.set_write_timeout(std::chrono::seconds(100));
4694+
auto res = cli.Post("/binary", reinterpret_cast<char *>(binary.data()),
4695+
large_size_byte, "application/octet-stream");
4696+
4697+
// compare
4698+
EXPECT_EQ(200, res->status);
4699+
EXPECT_EQ(large_size_byte, res->body.size());
4700+
EXPECT_EQ(0, std::memcmp(binary.data(), res->body.data(), large_size_byte));
4701+
4702+
// cleanup
4703+
svr.stop();
4704+
listen_thread.join();
4705+
ASSERT_FALSE(svr.is_running());
4706+
}
46634707
#endif
46644708

46654709
#ifdef _WIN32

0 commit comments

Comments
 (0)