Skip to content

Commit

Permalink
cool#9833: Implement MaxConnection Limit and complete UnitTests
Browse files Browse the repository at this point in the history
Adding connection limit where registered via SocketPoll::setLimiter(),
i.e. DocumentBrokerPoll and WebServerPoll only.

SocketPoll::poll() will drop _new_ overhead connections exceeding MaxConnections.
This has been discussed, in favor of dropping the oldest connections.

Aligned net::Config::MaxConnections w/ pre-existing
MAX_CONNECTIONS and COOLWSD::MaxConnections.

COOLWSD / net::Config
- Aligned MAX_CONNECTIONS/COOLWSD::MaxConnections with
  config "net.maxconnections", net::Config::MaxConnections,
  having a minimum of 3 - defaults to 9999.

SocketPoll::setLimiter():
- Increments given non-zero connectionLimit by one
  for WS upgrade socket.

Added http::StatusCode::None(0), allowing to use for debugging

Unit Tests:
- Using base class UnitTimeoutBase for UnitTimeoutConnections + UnitTimeoutNone,
  - Testing http, ws-ping and wsd-chat-ping on multiple parallel connections
    w/ and w/o a connection limit

- UnitTimeoutSocket tests socket max-duration on http + WS sessions

- UnitTimeoutWSPing tests WS Ping (native frame) timeout limit on WS sessions

Signed-off-by: Sven Göthel <[email protected]>
Change-Id: I7e1a9329e0848c40a210f6250e29e26950da6fbc
  • Loading branch information
Sven Göthel committed Sep 11, 2024
1 parent eaf2439 commit b8ec82a
Show file tree
Hide file tree
Showing 12 changed files with 666 additions and 268 deletions.
2 changes: 2 additions & 0 deletions net/HttpRequest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ STATE_ENUM(FieldParseState,
/// See https://en.wikipedia.org/wiki/List_of_HTTP_status_codes
enum class StatusCode : unsigned
{
None = 0, // Undefined status (unknown)

// Informational
Continue = 100,
SwitchingProtocols = 101,
Expand Down
6 changes: 3 additions & 3 deletions net/NetUtil.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class Config
/// http::Session timeout in us (30s default). Zero disables metric.
std::chrono::microseconds HTTPTimeout;

/// Socket maximum connections (100000). Zero disables metric.
size_t MaxConnectionCount;
/// Maximum total connections (9999 or MAX_CONNECTIONS). Zero disables metric.
size_t MaxConnections;
/// Socket maximum duration in seconds (12h). Zero disables metric.
std::chrono::seconds MaxDuration;
/// Socket minimum bits per seconds throughput (0). Zero disables metric.
Expand All @@ -54,7 +54,7 @@ class Config
: WSPingTimeout(std::chrono::milliseconds(2000))
, WSPingPeriod(std::chrono::milliseconds(3000))
, HTTPTimeout(std::chrono::milliseconds(30000))
, MaxConnectionCount(100000)
, MaxConnections(9999)
, MaxDuration(std::chrono::seconds(43200))
, MinBytesPerSec(0.0)
, SocketPollTimeout(std::chrono::seconds(64))
Expand Down
145 changes: 131 additions & 14 deletions net/Socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "TraceEvent.hpp"
#include "Util.hpp"

#include <atomic>
#include <chrono>
#include <cstring>
#include <cctype>
Expand Down Expand Up @@ -62,6 +63,38 @@ std::atomic<bool> Socket::InhibitThreadChecks(false);

std::unique_ptr<Watchdog> SocketPoll::PollWatchdog;

std::mutex SocketPoll::StatsMutex;
std::atomic<size_t> SocketPoll::StatsConnectionCount(0);

size_t SocketPoll::StatsConnectionMod(size_t added, size_t removed) {
if( added == 0 && removed == 0 ) {
return GetStatsConnectionCount();
}
size_t res, pre;
{
std::lock_guard<std::mutex> lock(StatsMutex);
pre = GetStatsConnectionCount();
res = pre;
if( added <= std::numeric_limits<size_t>::max() - res ) {
res += added;
} else {
// overflow
LOG_WRN("SocketPoll::ConnectionCount: Overflow " << res << " + " << added);
res = std::numeric_limits<size_t>::max();
}
if( removed <= res ) {
res -= removed;
} else {
// underflow
LOG_WRN("SocketPoll::ConnectionCount: Underflow " << res << " - " << removed);
res = 0;
}
StatsConnectionCount.store(res, std::memory_order_relaxed);
}
LOG_DBG("SocketPoll::ConnectionCount: " << pre << " +" << added << " -" << removed << " = " << res);
return res;
}

#define SOCKET_ABSTRACT_UNIX_NAME "0coolwsd-"

const char* Socket::toString(Type t) noexcept
Expand Down Expand Up @@ -292,6 +325,8 @@ namespace {
SocketPoll::SocketPoll(std::string threadName)
: _name(std::move(threadName)),
_pollTimeout( net::Config::get().SocketPollTimeout ),
_limitedConnections( false ),
_connectionLimit( 0 ),
_pollStartIndex(0),
_stop(false),
_threadStarted(0),
Expand Down Expand Up @@ -489,6 +524,8 @@ int SocketPoll::poll(int64_t timeoutMaxMicroS)
// The events to poll on change each spin of the loop.
setupPollFds(now, timeoutMaxMicroS);
const size_t size = _pollSockets.size();
size_t itemsAdded = 0;
size_t itemsErased = 0;

// disable watchdog - it's good to sleep
disableWatchdog();
Expand Down Expand Up @@ -541,7 +578,29 @@ int SocketPoll::poll(int64_t timeoutMaxMicroS)
std::vector<CallbackFn> invoke;
{
std::lock_guard<std::mutex> lock(_mutex);
const size_t newConnCount = _newSockets.size();
const size_t globCount = GetStatsConnectionCount();

if( _limitedConnections &&
_connectionLimit > 0 &&
globCount + newConnCount > _connectionLimit)
{
// For now we simply drop new connections
const size_t overhead = globCount + newConnCount - _connectionLimit;
for(size_t i=0; i<overhead; ++i)
{
std::shared_ptr<Socket>& socket = _newSockets.back(); // oldest
assert(socket);

LOG_WRN("Limiter: #" << socket->fd() << ": Removing "
<< (i+1) << " / " << overhead << " new socket of (pre "
<< globCount << " + new " << newConnCount << ") / max " << _connectionLimit
<< " from " << _name);

socket->resetThreadOwner();
_newSockets.pop_back();
}
}
if (!_newSockets.empty())
{
LOGA_TRC(Socket, "Inserting " << _newSockets.size() << " new sockets after the existing "
Expand All @@ -554,6 +613,8 @@ int SocketPoll::poll(int64_t timeoutMaxMicroS)
// Copy the new sockets over and clear.
_pollSockets.insert(_pollSockets.end(), _newSockets.begin(), _newSockets.end());

itemsAdded += _newSockets.size();

_newSockets.clear();
}

Expand Down Expand Up @@ -587,10 +648,15 @@ int SocketPoll::poll(int64_t timeoutMaxMicroS)
}
}

if (_pollSockets.size() != size)
if (itemsAdded != _pollSockets.size() - size)
{
LOG_TRC("PollSocket container size has changed from " << size << " to "
<< _pollSockets.size());
// unexpected
LOG_WRN("PollSocket container size has changed from " << size
<< " + " << itemsAdded << " to " << _pollSockets.size());
} else if (itemsAdded > 0)
{
LOG_TRC("PollSocket container size increased from " << size
<< " + " << itemsAdded << " to " << _pollSockets.size());
}

// If we had sockets to process.
Expand All @@ -607,7 +673,6 @@ int SocketPoll::poll(int64_t timeoutMaxMicroS)
if (_pollStartIndex > size - 1)
_pollStartIndex = 0;

size_t itemsErased = 0;
size_t i = _pollStartIndex;
for (std::size_t j = 0; j < size; ++j)
{
Expand Down Expand Up @@ -686,15 +751,62 @@ int SocketPoll::poll(int64_t timeoutMaxMicroS)

if (itemsErased)
{
LOG_TRC("Scanning to removing " << itemsErased << " defunct sockets from "
<< _pollSockets.size() << " sockets");
const size_t itemsErasedPre = itemsErased;
itemsErased = 0; // correcting itemsErased

_pollSockets.erase(
std::remove_if(_pollSockets.begin(), _pollSockets.end(),
[](const std::shared_ptr<Socket>& s)->bool
{ return !s; }),
[&itemsErased](const std::shared_ptr<Socket>& s) -> bool
{
if (!s)
{
++itemsErased;
return true;
}
else
{
return false;
}
}),
_pollSockets.end());

LOG_TRC("Removed " << itemsErased
<< "(" << itemsErasedPre << ") defunct sockets from "
<< _pollSockets.size() << " sockets");
}
}
if( _limitedConnections )
{
#if 1
// For now we simply drop new connections (see _newSockets above)
// Simply perform bookkeeping if required
StatsConnectionMod(itemsAdded, itemsErased);
#else
// Drop oldest connections
const size_t globCount = StatsConnectionMod(itemsAdded, itemsErased);
if( _connectionLimit > 0 && globCount > _connectionLimit )
{
const size_t localCount = _pollSockets.size();
const size_t globOverhead = globCount - _connectionLimit;
const double pct = std::max(1.0, (double)localCount / (double)globCount);
// clip double-of-extending-pct to [1 .. localCount/4], i.e. be nice.
const size_t localOverhead = std::max<size_t>(1, std::min<size_t>(localCount/4, static_cast<size_t>( pct * globOverhead + 0.5 ) * 2));
for(size_t i=0; i<localOverhead; ++i)
{
std::shared_ptr<Socket>& socket = _pollSockets[i]; // oldest
assert(socket);
LOG_WRN("Limiter: #" << socket->fd() << ": Removing local "
<< localOverhead << " / global " << globOverhead << " socket (at " << i
<< " of " << localCount << ") from " << _name);
socket->resetThreadOwner();
_pollSockets[i] = nullptr; // close via dtor
}
_pollSockets.erase(_pollSockets.begin(), _pollSockets.begin()+localOverhead);
StatsConnectionMod(0, localOverhead);
}
#endif
}

return rc;
Expand Down Expand Up @@ -802,9 +914,11 @@ void SocketPoll::createWakeups()
void SocketPoll::removeSockets()
{
LOG_DBG("Removing all " << _pollSockets.size() + _newSockets.size()
<< " sockets from SocketPoll thread " << _name);
<< " sockets from SocketPoll thread " << _name
<< " of " << GetStatsConnectionCount() << " total poll sockets");
ASSERT_CORRECT_SOCKET_THREAD(this);

size_t removedPollSockets = 0;
while (!_pollSockets.empty())
{
const std::shared_ptr<Socket>& socket = _pollSockets.back();
Expand All @@ -815,6 +929,10 @@ void SocketPoll::removeSockets()
socket->resetThreadOwner();

_pollSockets.pop_back();
++removedPollSockets;
}
if( _limitedConnections ) {
StatsConnectionMod(0, removedPollSockets);
}

while (!_newSockets.empty())
Expand Down Expand Up @@ -1083,7 +1201,9 @@ void SocketPoll::dumpState(std::ostream& os) const

os << "\n SocketPoll:";
os << "\n Poll [" << name() << "] with " << pollSockets.size() << " socket"
<< (pollSockets.size() == 1 ? "" : "s") << " - wakeup rfd: " << _wakeup[0]
<< (pollSockets.size() == 1 ? "" : "s")
<< " of " << GetStatsConnectionCount()
<< " total - wakeup rfd: " << _wakeup[0]
<< " wfd: " << _wakeup[1] << '\n';
const auto callbacks = _newCallbacks.size();
if (callbacks > 0)
Expand Down Expand Up @@ -1451,15 +1571,14 @@ bool StreamSocket::checkRemoval(std::chrono::steady_clock::time_point now) noexc
LOG_WRN("Socket still open post onDisconnect(), forced shutdown.");
shutdown(); // signal
closeConnection(); // real -> setClosed()
assert(isOpen() == false); // should have issued shutdown
}
}
else
{
shutdown(); // signal
closeConnection(); // real -> setClosed()
assert(isOpen() == false); // should have issued shutdown
}
assert(isOpen() == false); // should have issued shutdown
return true;
}
else if (_socketHandler && _socketHandler->checkTimeout(now))
Expand All @@ -1468,8 +1587,6 @@ bool StreamSocket::checkRemoval(std::chrono::steady_clock::time_point now) noexc
setClosed();
LOG_WRN("CheckRemoval: Timeout: " << getStatsString(now) << ", " << *this);
return true;
} else {
LOG_DBG("CheckRemoval: Test " << getStatsString(now) << ", " << *this);
}
return false;
}
Expand Down
20 changes: 18 additions & 2 deletions net/Socket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,15 @@ class SocketPoll
/// Global wakeup - signal safe: wakeup all socket polls.
static void wakeupWorld();

/// Enable connection accounting and limiter
/// Internally we allow one extra connection for the WS upgrade
/// @param connectionLimit socket connection limit
void setLimiter(size_t connectionLimit)
{
_limitedConnections = true;
_connectionLimit = connectionLimit > 0 ? connectionLimit + 1 : 0;
}

/// Insert a new socket to be polled.
/// A socket is removed when it is closed, readIncomingData
/// returns false, or when removeSockets is called (which is
Expand Down Expand Up @@ -976,6 +985,8 @@ class SocketPoll
/// Debug name used for logging.
const std::string _name;
const std::chrono::microseconds _pollTimeout;
bool _limitedConnections;
size_t _connectionLimit;

/// main-loop wakeup pipe
int _wakeup[2];
Expand All @@ -1002,6 +1013,13 @@ class SocketPoll
/// Time-stamp for profiling
int _ownerThreadId;
std::atomic<uint64_t> _watchdogTime;

static std::mutex StatsMutex;
static std::atomic<size_t> StatsConnectionCount; // total of all _pollSockets (excluding _newSockets)
static size_t StatsConnectionMod(size_t added, size_t removed); // safe add-sub of StatsConnectionCount

public:
static int64_t GetStatsConnectionCount() { return StatsConnectionCount.load(std::memory_order_relaxed); }
};

/// A SocketPoll that will stop polling and
Expand Down Expand Up @@ -1040,7 +1058,6 @@ class StreamSocket : public Socket,
_pollTimeout( net::Config::get().SocketPollTimeout ),
_httpTimeout( net::Config::get().HTTPTimeout ),
_minBytesPerSec( net::Config::get().MinBytesPerSec ),
_maxConnectionCount( net::Config::get().MaxConnectionCount ),
_hostname(std::move(host)),
_wsState(WSState::HTTP),
_sentHTTPContinue(false),
Expand Down Expand Up @@ -1704,7 +1721,6 @@ class StreamSocket : public Socket,
const std::chrono::microseconds _pollTimeout;
const std::chrono::microseconds _httpTimeout;
const double _minBytesPerSec;
const size_t _maxConnectionCount;

/// The hostname (or IP) of the peer we are connecting to.
const std::string _hostname;
Expand Down
3 changes: 3 additions & 0 deletions test/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ all_la_unit_tests = \
unit-timeout.la \
unit-timeout_socket.la \
unit-timeout_wsping.la \
unit-timeout_conn.la \
unit-timeout_none.la \
unit-base.la
# unit-admin.la
Expand Down Expand Up @@ -230,6 +231,8 @@ unit_timeout_socket_la_SOURCES = UnitTimeoutSocket.cpp
unit_timeout_socket_la_LIBADD = $(CPPUNIT_LIBS)
unit_timeout_wsping_la_SOURCES = UnitTimeoutWSPing.cpp
unit_timeout_wsping_la_LIBADD = $(CPPUNIT_LIBS)
unit_timeout_conn_la_SOURCES = UnitTimeoutConnections.cpp
unit_timeout_conn_la_LIBADD = $(CPPUNIT_LIBS)
unit_timeout_none_la_SOURCES = UnitTimeoutNone.cpp
unit_timeout_none_la_LIBADD = $(CPPUNIT_LIBS)
unit_prefork_la_SOURCES = UnitPrefork.cpp
Expand Down
Loading

0 comments on commit b8ec82a

Please sign in to comment.