Skip to content

Commit

Permalink
cool#9833: Mitigate connection count limitation, fix timeout-checking…
Browse files Browse the repository at this point in the history
… code (*WIP*)

This commit is a WIP, pushed for discussion and to be amended.
(Cleaned up version with general code cleanups and Util artefacts removed for clarity)

- ProtocolHandlerInterface::checkTimeout(..)
  - Add bool return value: true -> shutdown connection, caller shall stop processing
  - Implemented for http::Session
    - Timeout (30s) with missing response
  - Implemented for WebSocketHandler
    - Timeout (64s = SocketPoll::DefaultPollTimeoutMicroS)
      after missing pong (server only)

- StreamSocket -> Socket (properties moved)
  - bytes sent/received
  - closed state

- Socket (added properties)
  - creation- and last-seen -time
  - socket type and port
  - checkForcedRemoval(..) *WIP*
    - called directly from SocketPoll::poll()
    - only for IPv4/v6 network connections
    - similar to ProtocolHandlerInterface::checkTimeout(..)
    - added further criteria (age, throughput, ..)
      - Timeout (64s = SocketPoll::DefaultPollTimeoutMicroS)
        if (now - lastSeen) > timeout
      - Timeout (12 hours)
        if (now - creationTime) > timeout
      - TODO: Throughput/bandwitdh disabled, find proper metrics/timing
      - TODO: Add maximimal IPv4/IPv6 socket-count criteria, drop oldest.

- SocketPoll::poll()
  - Additionally erases if !socket->isOpen() || socket->checkForcedRemoval()

- TODO
  - Facility to configure timeouts, at least for testing!
    - Marked code with `TODO Timeout: Config`
  - More elaborated tests
    - WebSocket
    - ..

Signed-off-by: Sven Göthel <[email protected]>
Change-Id: I7e1a9329e0848c40a210f6250e29e26950da6fbc
  • Loading branch information
Sven Göthel committed Sep 2, 2024
1 parent b490637 commit cb801db
Show file tree
Hide file tree
Showing 9 changed files with 730 additions and 130 deletions.
22 changes: 18 additions & 4 deletions net/HttpRequest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1432,7 +1432,10 @@ class Session final : public ProtocolHandlerInterface
while (!_response->done())
{
const auto now = std::chrono::steady_clock::now();
checkTimeout(now);
if (checkTimeout(now))
{
return false;
}

const auto remaining =
std::chrono::duration_cast<std::chrono::microseconds>(deadline - now);
Expand Down Expand Up @@ -1692,14 +1695,18 @@ class Session final : public ProtocolHandlerInterface
net::asyncConnect(_host, _port, isSecure(), shared_from_this(), pushConnectCompleteToPoll);
}

void checkTimeout(std::chrono::steady_clock::time_point now) override
bool checkTimeout(std::chrono::steady_clock::time_point now) override
{
if (!_response || _response->done())
return;
{
return false;
}

const std::chrono::microseconds timeout = getTimeout();
const auto duration =
std::chrono::duration_cast<std::chrono::milliseconds>(now - _startTime);
if (now < _startTime || duration > getTimeout() || SigUtil::getTerminationFlag())

if (now < _startTime || duration > timeout || SigUtil::getTerminationFlag())
{
LOG_WRN("Timed out while requesting [" << _request.getVerb() << ' ' << _host
<< _request.getUrl() << "] after " << duration);
Expand All @@ -1712,7 +1719,14 @@ class Session final : public ProtocolHandlerInterface
// no good maintaining a poor connection (if that's the issue).
onDisconnect(); // Trigger manually (why wait for poll to do it?).
assert(isConnected() == false);
return true;
} else {
// FIXME: Remove!

Check notice

Code scanning / CodeQL

FIXME comment Note

FIXME comment: Remove!
LOG_DBG("Timeout check while requesting [" << _request.getVerb() << ' ' << _host
<< _request.getUrl() << "] after "
<< duration << " <= " << timeout);
}
return false;
}

int sendTextMessage(const char*, const size_t, bool) const override { return 0; }
Expand Down
2 changes: 1 addition & 1 deletion net/ServerSocket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class LocalServerSocket : public ServerSocket
ServerSocket(Socket::Type::Unix, clientPoller, std::move(sockFactory))
{
}
~LocalServerSocket();
~LocalServerSocket() override;

virtual bool bind(Type, int) override { assert(false); return false; }
virtual std::shared_ptr<Socket> accept() override;
Expand Down
192 changes: 171 additions & 21 deletions net/Socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

#include <chrono>
#include <cstring>
#include <ctype.h>
#include <cctype>
#include <iomanip>
#include <memory>
#include <ratio>
#include <sstream>
#include <stdio.h>
#include <cstdio>
#include <string>
#include <unistd.h>
#include <sys/stat.h>
Expand Down Expand Up @@ -67,6 +67,22 @@ std::unique_ptr<Watchdog> SocketPoll::PollWatchdog;

#define SOCKET_ABSTRACT_UNIX_NAME "0coolwsd-"

std::string Socket::toString(Type t) noexcept
{
switch (t)
{
case Type::IPv4:
return "IPv4";
case Type::IPv6:
return "IPv6";
case Type::All:
return "All";
case Type::Unix:
return "Unix";
}
return "Unknown";
}

int Socket::createSocket([[maybe_unused]] Socket::Type type)
{
#if !MOBILEAPP
Expand All @@ -86,9 +102,69 @@ int Socket::createSocket([[maybe_unused]] Socket::Type type)
#endif
}

const std::string Socket::getClientAddressAndPort() const noexcept
{
std::string s;
if (Type::IPv6 == type())
{
s.append("[").append(clientAddress()).append("]:").append(std::to_string(clientPort()));
}
else
{
s.append(clientAddress()).append(":").append(std::to_string(clientPort()));
}
return s;
}

std::string Socket::getStatsString(std::chrono::steady_clock::time_point now) const noexcept
{
const auto durTotal = std::chrono::duration_cast<std::chrono::milliseconds>(now - _creationTime);
const auto durLast = std::chrono::duration_cast<std::chrono::milliseconds>(now - _lastSeenTime);

float kBpsIn, kBpsOut;
if (durTotal.count() > 0)
{
kBpsIn = (float)_bytesRcvd / (float)durTotal.count();
kBpsOut = (float)_bytesSent / (float)durTotal.count();
}
else
{
kBpsIn = (float)_bytesRcvd / 1000.0f;
kBpsOut = (float)_bytesSent / 1000.0f;
}
std::ostringstream oss;
oss.precision(1);
oss << "Socket[#" << getFD() << ", dur[total"
<< durTotal.count() << "ms, last "
<< durLast.count() << ",ms], kBps[in "
<< kBpsIn << ", out " << kBpsOut
<< "], " << toString(type()) << " @ ";
if (Type::IPv6 == type())
{
oss << "[" << clientAddress() << "]:" << clientPort();
}
else
{
oss << clientAddress() << ":" << clientPort();
}
oss << "]";
return oss.str();
}

std::string Socket::toString() const noexcept
{
std::string s("Socket[#");
s.append(std::to_string(getFD()))
.append(", ")
.append(toString(type()))
.append(" @ ")
.append(getClientAddressAndPort())
.append("]");
return s;
}

bool StreamSocket::socketpair(std::shared_ptr<StreamSocket> &parent,
std::shared_ptr<StreamSocket> &child)
bool StreamSocket::socketpair(std::shared_ptr<StreamSocket>& parent,
std::shared_ptr<StreamSocket>& child)
{
#if MOBILEAPP
return false;
Expand All @@ -98,10 +174,10 @@ bool StreamSocket::socketpair(std::shared_ptr<StreamSocket> &parent,
if (rc != 0)
return false;

child = std::shared_ptr<StreamSocket>(new StreamSocket("save-child", pair[0], Socket::Type::Unix, true));
child = std::make_shared<StreamSocket>("save-child", pair[0], Socket::Type::Unix, true);
child->setNoShutdown();
child->setClientAddress("save-child");
parent = std::shared_ptr<StreamSocket>(new StreamSocket("save-kit-parent", pair[1], Socket::Type::Unix, true));
parent = std::make_shared<StreamSocket>("save-kit-parent", pair[1], Socket::Type::Unix, true);
parent->setNoShutdown();
parent->setClientAddress("save-parent");

Expand Down Expand Up @@ -234,7 +310,7 @@ SocketPoll::SocketPoll(std::string threadName)

static bool watchDogProfile = !!getenv("COOL_WATCHDOG");
if (watchDogProfile && !PollWatchdog)
PollWatchdog.reset(new Watchdog());
PollWatchdog = std::make_unique<Watchdog>();

_wakeup[0] = -1;
_wakeup[1] = -1;
Expand Down Expand Up @@ -412,7 +488,7 @@ int SocketPoll::poll(int64_t timeoutMaxMicroS)
socketErrorCount++;
#endif

std::chrono::steady_clock::time_point now =
const std::chrono::steady_clock::time_point now =
std::chrono::steady_clock::now();

// The events to poll on change each spin of the loop.
Expand Down Expand Up @@ -528,7 +604,7 @@ int SocketPoll::poll(int64_t timeoutMaxMicroS)
assert(!_pollSockets.empty() && "All existing sockets disappeared from the SocketPoll");

// Fire the poll callbacks and remove dead fds.
std::chrono::steady_clock::time_point newNow = std::chrono::steady_clock::now();
const std::chrono::steady_clock::time_point newNow = std::chrono::steady_clock::now();

// We use the _pollStartIndex to start the polling at a different index each time. Do some
// sanity check first to handle the case where we removed one or several sockets last time.
Expand All @@ -548,7 +624,23 @@ int SocketPoll::poll(int64_t timeoutMaxMicroS)
else if (!_pollSockets[i])
{
// removed in a callback
itemsErased++;
++itemsErased;
}
else if (!_pollSockets[i]->isOpen())
{
// closed socket ..
++itemsErased;
LOGA_TRC(Socket, '#' << _pollFds[i].fd << ": Removing socket (at " << i
<< " of " << _pollSockets.size() << ") from " << _name);
_pollSockets[i] = nullptr;
}
else if( _pollSockets[i]->checkForcedRemoval(newNow) )
{
// timed out socket ..
++itemsErased;
LOGA_TRC(Socket, '#' << _pollFds[i].fd << ": Removing socket (at " << i
<< " of " << _pollSockets.size() << ") from " << _name);
_pollSockets[i] = nullptr;
}
else if (_pollFds[i].fd == _pollSockets[i]->getFD())
{
Expand All @@ -569,9 +661,16 @@ int SocketPoll::poll(int64_t timeoutMaxMicroS)
rc = -1;
}

if (!disposition.isContinue())
if (!_pollSockets[i]->isOpen() || !disposition.isContinue())
{
itemsErased++;
// Potentially via ProtocolHandlerInterface::handlePoll()'s
// - ProtocolHandlerInterface::checkTimeout()
// - ProtocolHandlerInterface::onDisconnect())
// - disposition.setClosed()
// ProtocolHandlerInterface
// - http::Session::handlePoll() OK
// - WebSocketHandler::handlePoll() OK
++itemsErased;
LOGA_TRC(Socket, '#' << _pollFds[i].fd << ": Removing socket (at " << i
<< " of " << _pollSockets.size() << ") from " << _name);
_pollSockets[i] = nullptr;
Expand Down Expand Up @@ -625,10 +724,10 @@ void SocketPoll::closeAllSockets()
checkAndReThread();

removeFromWakeupArray();
for (auto &it : _pollSockets)
for (std::shared_ptr<Socket> &it : _pollSockets)
{
// first close the underlying socket
close(it->getFD());
::close(it->getFD());

// avoid the socketHandler' getting an onDisconnect
auto stream = dynamic_cast<StreamSocket *>(it.get());
Expand All @@ -651,7 +750,7 @@ void SocketPoll::takeSocket(const std::shared_ptr<SocketPoll> &fromPoll,
ASSERT_CORRECT_THREAD();

// hold a reference during transfer
std::shared_ptr<Socket> socket = inSocket;
std::shared_ptr<Socket> socket = inSocket; // NOLINT

SocketPoll *toPoll = this;
fromPoll->addCallback([fromPoll,socket,&mut,&cond,&transferred,toPoll](){
Expand Down Expand Up @@ -933,8 +1032,8 @@ void StreamSocket::dumpState(std::ostream& os)
const int events = getPollEvents(std::chrono::steady_clock::now(), timeoutMaxMicroS);
os << '\t' << std::setw(6) << getFD() << "\t0x" << std::hex << events << std::dec << '\t'
<< (ignoringInput() ? "ignore\t" : "process\t") << std::setw(6) << _inBuffer.size() << '\t'
<< std::setw(6) << _outBuffer.size() << '\t' << " r: " << std::setw(6) << _bytesRecvd
<< "\t w: " << std::setw(6) << _bytesSent << '\t' << clientAddress() << '\t';
<< std::setw(6) << _outBuffer.size() << '\t' << " r: " << std::setw(6) << getBytesRcvd()
<< "\t w: " << std::setw(6) << getBytesSent() << '\t' << clientAddress() << '\t';
_socketHandler->dumpState(os);
if (_inBuffer.size() > 0)
Util::dumpHex(os, _inBuffer, "\t\tinBuffer:\n", "\t\t");
Expand Down Expand Up @@ -1097,7 +1196,7 @@ std::shared_ptr<Socket> ServerSocket::accept()
std::shared_ptr<Socket> _socket = createSocketFromAccept(rc, type);

inet_ntop(clientInfo.sin6_family, inAddr, addrstr, sizeof(addrstr));
_socket->setClientAddress(addrstr);
_socket->setClientAddress(addrstr, clientInfo.sin6_port);

LOG_TRC("Accepted socket #" << _socket->getFD() << " has family "
<< clientInfo.sin6_family << " address "
Expand Down Expand Up @@ -1298,6 +1397,57 @@ LocalServerSocket::~LocalServerSocket()
# define LOG_CHUNK(X)
#endif

#endif // !MOBILEAPP

bool StreamSocket::checkForcedRemoval(std::chrono::steady_clock::time_point now) noexcept
{
if ( !isIPType() ) // forced removal on IPv[46] network connections only
{
return false;
}
const std::chrono::microseconds timeoutMax = DefaultMaxConnectionTimMicroS; // TODO Timeout: Config
const std::chrono::microseconds timeoutLast = SocketPoll::DefaultPollTimeoutMicroS; // TODO Timeout: Config
const float minBytesPerSec = DefaultMinBytesPerSec; // TODO Timeout: Config

const auto durTotal =
std::chrono::duration_cast<std::chrono::milliseconds>(now - getCreationTime());
const auto durLast =
std::chrono::duration_cast<std::chrono::milliseconds>(now - getLastSeenTime());
const float bytesPerSecIn = durTotal.count() > 0 ? (float)getBytesRcvd() / ((float)durTotal.count() / 1000.0f) : (float)getBytesRcvd();
if (now < getCreationTime() || durTotal > timeoutMax || durLast > timeoutLast ||
(bytesPerSecIn > 0.0f && minBytesPerSec > 1.0f && bytesPerSecIn < minBytesPerSec) ||
SigUtil::getTerminationFlag())
{
LOG_WRN("Timed out socket after " << durTotal << ", " << getStatsString(now));

if (_socketHandler)
{
_socketHandler->onDisconnect();
if( isOpen() ) {
// FIXME: Ensure proper semantics of onDisconnect()

Check notice

Code scanning / CodeQL

FIXME comment Note

FIXME comment: Ensure proper semantics of onDisconnect()
LOG_WRN("Socket still open post onDisconnect(), forced shutdown.");
shutdown(); // signal
closeConnection(); // real -> setClosed()
assert(isOpen() == false); // should have issued shutdown
}
}
else
{
shutdown(); // signal
closeConnection(); // real -> setClosed()
assert(isOpen() == false); // should have issued shutdown
}
return true;
}
else
{
LOG_WRN("Timeout check socket after " << durTotal << ", " << getStatsString(now));
}
return false;
}

#if !MOBILEAPP

bool StreamSocket::parseHeader(const char *clientName,
Poco::MemoryInputStream &message,
Poco::Net::HTTPRequest &request,
Expand Down Expand Up @@ -1366,7 +1516,7 @@ bool StreamSocket::parseHeader(const char *clientName,
if (request.getChunkedTransferEncoding())
{
// keep the header
map._spans.push_back(std::pair<size_t, size_t>(0, itBody - _inBuffer.begin()));
map._spans.emplace_back(0, itBody - _inBuffer.begin());

int chunk = 0;
while (itBody != _inBuffer.end())
Expand Down Expand Up @@ -1416,7 +1566,7 @@ bool StreamSocket::parseHeader(const char *clientName,
}
itBody += chunkLen;

map._spans.push_back(std::pair<size_t,size_t>(chunkOffset, chunkLen));
map._spans.emplace_back(chunkOffset, chunkLen);

if (*itBody != '\r' || *(itBody + 1) != '\n')
{
Expand Down Expand Up @@ -1539,7 +1689,7 @@ bool StreamSocket::compactChunks(MessageMap& map)
bool StreamSocket::sniffSSL() const
{
// Only sniffing the first bytes of a socket.
if (_bytesSent > 0 || _bytesRecvd != _inBuffer.size() || _bytesRecvd < 6)
if (getBytesSent() > 0 || getBytesRcvd() != _inBuffer.size() || getBytesRcvd() < 6)
return false;

// 0x0000 16 03 01 02 00 01 00 01
Expand Down
Loading

0 comments on commit cb801db

Please sign in to comment.