Files
yourpart3/src/websocket_server.cpp
Torsten Schulz (local) 00a5f47cae Refactor WebSocket server connection management and message handling
- Update WebSocketUserData to use a message queue for handling outgoing messages, improving concurrency and message delivery.
- Modify pingClients method to handle multiple connections per user and implement timeout logic for ping responses.
- Enhance addConnection and removeConnection methods to manage multiple connections for each user, including detailed logging of connection states.
- Update handleBrokerMessage to send messages to all active connections for a user, ensuring proper queue management and callback invocation.
2026-01-14 14:38:42 +01:00

443 lines
18 KiB
C++

#include "websocket_server.h"
#include "connection_guard.h"
#include "worker.h"
#include <iostream>
#include <chrono>
#include <cstring>
#include <future>
#include <algorithm>
using json = nlohmann::json;
// Protocols array definition
struct lws_protocols WebSocketServer::protocols[] = {
{
"", // Leeres Protokoll für Standard-WebSocket-Verbindungen
WebSocketServer::wsCallback,
sizeof(WebSocketUserData),
4096
},
{
"yourpart-protocol",
WebSocketServer::wsCallback,
sizeof(WebSocketUserData),
4096
},
{ nullptr, nullptr, 0, 0 }
};
// Static instance pointer
WebSocketServer* WebSocketServer::instance = nullptr;
WebSocketServer::WebSocketServer(int port, ConnectionPool &pool, MessageBroker &broker,
bool useSSL, const std::string& certPath, const std::string& keyPath)
: port(port), pool(pool), broker(broker), useSSL(useSSL), certPath(certPath), keyPath(keyPath) {
instance = this;
}
WebSocketServer::~WebSocketServer() {
stop();
instance = nullptr;
}
void WebSocketServer::run() {
running = true;
broker.subscribe([this](const std::string &msg) {
{
std::lock_guard<std::mutex> lock(queueMutex);
messageQueue.push(msg);
}
queueCV.notify_one();
});
serverThread = std::thread([this](){ startServer(); });
messageThread = std::thread([this](){ processMessageQueue(); });
pingThread = std::thread([this](){ pingClients(); });
// Warte kurz bis alle Threads gestartet sind
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
void WebSocketServer::stop() {
running = false;
if (context) lws_cancel_service(context);
// Stoppe Threads mit Timeout
std::vector<std::future<void>> futures;
if (serverThread.joinable()) {
futures.push_back(std::async(std::launch::async, [this]() { serverThread.join(); }));
}
if (messageThread.joinable()) {
futures.push_back(std::async(std::launch::async, [this]() { messageThread.join(); }));
}
if (pingThread.joinable()) {
futures.push_back(std::async(std::launch::async, [this]() { pingThread.join(); }));
}
// Warte auf alle Threads mit Timeout
for (auto& future : futures) {
if (future.wait_for(std::chrono::milliseconds(1000)) == std::future_status::timeout) {
std::cerr << "WebSocket-Thread beendet sich nicht, erzwinge Beendigung..." << std::endl;
}
}
// Force detach alle Threads
if (serverThread.joinable()) serverThread.detach();
if (messageThread.joinable()) messageThread.detach();
if (pingThread.joinable()) pingThread.detach();
if (context) {
lws_context_destroy(context);
context = nullptr;
}
}
void WebSocketServer::startServer() {
struct lws_context_creation_info info;
memset(&info, 0, sizeof(info));
info.port = port;
info.protocols = protocols;
// Vereinfachte Server-Optionen für bessere Kompatibilität
info.options = LWS_SERVER_OPTION_VALIDATE_UTF8;
// SSL/TLS Konfiguration
if (useSSL) {
if (certPath.empty() || keyPath.empty()) {
throw std::runtime_error("SSL enabled but certificate or key path not provided");
}
info.options |= LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT;
info.ssl_cert_filepath = certPath.c_str();
info.ssl_private_key_filepath = keyPath.c_str();
std::cout << "WebSocket SSL Server starting on port " << port << " with certificates: "
<< certPath << " / " << keyPath << std::endl;
} else {
std::cout << "WebSocket Server starting on port " << port << " (no SSL)" << std::endl;
}
// Erhöhe Log-Level für besseres Debugging
setenv("LWS_LOG_LEVEL", "7", 1); // 7 = alle Logs
context = lws_create_context(&info);
if (!context) {
throw std::runtime_error("Failed to create LWS context");
}
std::cout << "WebSocket-Server erfolgreich gestartet auf Port " << port << std::endl;
while (running) {
int ret = lws_service(context, 50);
if (ret < 0) {
std::cerr << "WebSocket-Server Fehler: lws_service returned " << ret << std::endl;
break;
}
// Kurze Pause für bessere Shutdown-Responsivität
if (running) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
std::cout << "WebSocket-Server wird beendet..." << std::endl;
}
void WebSocketServer::processMessageQueue() {
while (running) {
std::unique_lock<std::mutex> lock(queueMutex);
queueCV.wait_for(lock, std::chrono::milliseconds(100), [this](){ return !messageQueue.empty() || !running; });
while (!messageQueue.empty() && running) {
std::string msg = std::move(messageQueue.front());
messageQueue.pop();
lock.unlock();
handleBrokerMessage(msg);
lock.lock();
}
}
}
void WebSocketServer::pingClients() {
while (running) {
// Kürzere Sleep-Intervalle für bessere Shutdown-Responsivität
for (int i = 0; i < WebSocketUserData::PING_INTERVAL_SECONDS * 10 && running; ++i) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
if (!running || !context) continue;
auto now = std::chrono::steady_clock::now();
std::vector<struct lws*> toDisconnect;
// Prüfe alle Verbindungen auf Timeouts
{
std::shared_lock<std::shared_mutex> lock(connectionsMutex);
for (auto& pair : connections) {
for (auto* wsi : pair.second) {
auto* ud = reinterpret_cast<WebSocketUserData*>(lws_wsi_user(wsi));
if (!ud) continue;
// Prüfe ob Pong-Timeout erreicht wurde
auto timeSincePing = std::chrono::duration_cast<std::chrono::seconds>(now - ud->lastPingTime).count();
auto timeSincePong = std::chrono::duration_cast<std::chrono::seconds>(now - ud->lastPongTime).count();
if (!ud->pongReceived && timeSincePing > WebSocketUserData::PONG_TIMEOUT_SECONDS) {
ud->pingTimeoutCount++;
std::cout << "Ping-Timeout für User " << ud->userId << " (Versuch " << ud->pingTimeoutCount << "/" << WebSocketUserData::MAX_PING_TIMEOUTS << ")" << std::endl;
if (ud->pingTimeoutCount >= WebSocketUserData::MAX_PING_TIMEOUTS) {
std::cout << "Verbindung wird getrennt: Zu viele Ping-Timeouts für User " << ud->userId << std::endl;
toDisconnect.push_back(wsi);
} else {
// Reset für nächsten Versuch
ud->pongReceived = true;
ud->lastPongTime = now;
}
}
}
}
}
// Trenne problematische Verbindungen
for (auto* wsi : toDisconnect) {
lws_close_reason(wsi, LWS_CLOSE_STATUS_POLICY_VIOLATION, (unsigned char*)"Ping timeout", 12);
}
// Sende Pings an alle aktiven Verbindungen
if (running) {
lws_callback_on_writable_all_protocol(context, &protocols[0]);
}
}
}
int WebSocketServer::wsCallback(struct lws *wsi,
enum lws_callback_reasons reason,
void *user, void *in, size_t len) {
if (!instance) return 0;
auto *ud = reinterpret_cast<WebSocketUserData*>(user);
switch (reason) {
case LWS_CALLBACK_ESTABLISHED: {
ud->pongReceived = true;
ud->lastPingTime = std::chrono::steady_clock::now();
ud->lastPongTime = std::chrono::steady_clock::now();
ud->pingTimeoutCount = 0;
const char* protocolName = lws_get_protocol(wsi)->name;
std::cout << "WebSocket-Verbindung hergestellt (Protokoll: " << (protocolName ? protocolName : "Standard") << ")" << std::endl;
char client_addr[128];
lws_get_peer_simple(wsi, client_addr, sizeof(client_addr));
std::cout << "Client-Adresse: " << client_addr << std::endl;
break;
}
case LWS_CALLBACK_RECEIVE: {
std::string msg(reinterpret_cast<char*>(in), len);
std::cout << "WebSocket-Nachricht empfangen: " << msg << std::endl;
// Pong-Antwort behandeln
if (msg == "pong") {
ud->pongReceived = true;
ud->lastPongTime = std::chrono::steady_clock::now();
ud->pingTimeoutCount = 0;
std::cout << "Pong von Client empfangen" << std::endl;
break;
}
try {
json parsed = json::parse(msg);
std::cout << "[RECEIVE] Nachricht empfangen: " << msg << std::endl;
if (parsed.contains("event") && parsed["event"] == "setUserId") {
if (parsed.contains("data") && parsed["data"].contains("userId")) {
ud->userId = parsed["data"]["userId"].get<std::string>();
std::cout << "[RECEIVE] User-ID gesetzt: " << ud->userId << std::endl;
// Verbindung in der Map speichern
instance->addConnection(ud->userId, wsi);
std::cout << "[RECEIVE] Verbindung gespeichert" << std::endl;
} else {
std::cerr << "[RECEIVE] setUserId-Event ohne data.userId-Feld" << std::endl;
}
} else {
std::cout << "[RECEIVE] Ignoriere Nachricht (kein setUserId-Event)" << std::endl;
}
} catch (const std::exception &e) {
std::cerr << "[RECEIVE] Fehler beim Parsen der WebSocket-Nachricht: " << e.what() << std::endl;
}
break;
}
case LWS_CALLBACK_SERVER_WRITEABLE: {
// Prüfe ob es eine Nachricht zum Senden gibt
std::string messageToSend;
{
std::lock_guard<std::mutex> lock(ud->messageQueueMutex);
if (!ud->messageQueue.empty()) {
messageToSend = std::move(ud->messageQueue.front());
ud->messageQueue.pop();
}
}
if (!messageToSend.empty()) {
// Nachricht senden
std::cout << "[WRITEABLE] Sende Nachricht: " << messageToSend << std::endl;
unsigned char buf[LWS_PRE + messageToSend.length()];
memcpy(buf + LWS_PRE, messageToSend.c_str(), messageToSend.length());
lws_write(wsi, buf + LWS_PRE, messageToSend.length(), LWS_WRITE_TEXT);
std::cout << "[WRITEABLE] Nachricht erfolgreich gesendet" << std::endl;
// Wenn noch weitere Nachrichten in der Queue sind, wieder schreibbereit machen
{
std::lock_guard<std::mutex> lock(ud->messageQueueMutex);
if (!ud->messageQueue.empty()) {
lws_callback_on_writable(wsi);
}
}
} else {
// Ping senden
ud->lastPingTime = std::chrono::steady_clock::now();
ud->pongReceived = false;
unsigned char buf[LWS_PRE + 4];
memcpy(buf + LWS_PRE, "ping", 4);
lws_write(wsi, buf + LWS_PRE, 4, LWS_WRITE_TEXT);
// std::cout << "Ping an Client gesendet" << std::endl;
}
break;
}
case LWS_CALLBACK_CLOSED:
// Verbindung aus der Map entfernen
if (!ud->userId.empty()) {
instance->removeConnection(ud->userId, wsi);
std::cout << "WebSocket-Verbindung geschlossen für User: " << ud->userId << std::endl;
}
break;
case LWS_CALLBACK_HTTP:
// HTTP-Anfragen ablehnen (nur WebSocket erlaubt)
return -1;
case LWS_CALLBACK_FILTER_PROTOCOL_CONNECTION:
// Protokoll-Filter für bessere Kompatibilität
return 0;
case LWS_CALLBACK_RAW_CONNECTED:
// Raw-Verbindungen behandeln
return 0;
default:
break;
}
return 0;
}
void WebSocketServer::handleBrokerMessage(const std::string &message) {
try {
std::cout << "[handleBrokerMessage] Nachricht empfangen: " << message << std::endl;
json parsed = json::parse(message);
if (parsed.contains("user_id")) {
int fid;
if (parsed["user_id"].is_string()) {
fid = std::stoi(parsed["user_id"].get<std::string>());
} else {
fid = parsed["user_id"].get<int>();
}
auto userId = getUserIdFromFalukantUserId(fid);
std::cout << "[handleBrokerMessage] Broker-Nachricht für Falukant-User " << fid << " -> User-ID " << userId << std::endl;
// Prüfe ob User-ID gefunden wurde
if (userId.empty()) {
std::cerr << "[handleBrokerMessage] WARNUNG: User-ID für Falukant-User " << fid << " nicht gefunden! Nachricht wird nicht gesendet." << std::endl;
return;
}
std::shared_lock<std::shared_mutex> lock(connectionsMutex);
std::cout << "[handleBrokerMessage] Aktive User-Verbindungen: " << connections.size() << std::endl;
auto it = connections.find(userId);
if (it != connections.end() && !it->second.empty()) {
std::cout << "[handleBrokerMessage] Sende Nachricht an User " << userId << " (" << it->second.size() << " Verbindungen): " << message << std::endl;
// Nachricht an alle Verbindungen des Users senden
for (auto* wsi : it->second) {
auto *ud = reinterpret_cast<WebSocketUserData*>(lws_wsi_user(wsi));
if (ud) {
bool wasEmpty = false;
{
std::lock_guard<std::mutex> lock(ud->messageQueueMutex);
wasEmpty = ud->messageQueue.empty();
ud->messageQueue.push(message);
std::cout << "[handleBrokerMessage] Nachricht zur Queue hinzugefügt (Queue-Größe: " << ud->messageQueue.size() << ")" << std::endl;
}
// Nur wenn die Queue leer war, den Callback aufrufen
// (sonst wird er bereits durch den WRITEABLE-Handler aufgerufen)
if (wasEmpty) {
lws_callback_on_writable(wsi);
}
} else {
std::cerr << "[handleBrokerMessage] FEHLER: ud ist nullptr für eine Verbindung!" << std::endl;
}
}
} else {
std::cout << "[handleBrokerMessage] Keine aktive Verbindung für User " << userId << " gefunden" << std::endl;
std::cout << "[handleBrokerMessage] Verfügbare User-IDs in connections:" << std::endl;
for (const auto& pair : connections) {
std::cout << " - " << pair.first << " (" << pair.second.size() << " Verbindungen)" << std::endl;
}
}
} else {
std::cout << "[handleBrokerMessage] Nachricht enthält kein user_id-Feld!" << std::endl;
}
} catch (const std::exception &e) {
std::cerr << "[handleBrokerMessage] Error processing broker message: " << e.what() << std::endl;
}
}
std::string WebSocketServer::getUserIdFromFalukantUserId(int userId) {
ConnectionGuard guard(pool);
auto &db = guard.get();
std::string sql = R"(
SELECT u.hashed_id
FROM community.user u
JOIN falukant_data.falukant_user fu ON u.id = fu.user_id
WHERE fu.id = $1
)";
db.prepare("get_user_id", sql);
auto res = db.execute("get_user_id", {std::to_string(userId)});
return (!res.empty()) ? res[0]["hashed_id"] : std::string();
}
void WebSocketServer::setWorkers(const std::vector<std::unique_ptr<Worker>> &workerList) {
workers.clear();
workers.reserve(workerList.size());
for (const auto &wptr : workerList) {
workers.push_back(wptr.get());
}
}
void WebSocketServer::addConnection(const std::string &userId, struct lws *wsi) {
std::unique_lock<std::shared_mutex> lock(connectionsMutex);
connections[userId].push_back(wsi);
size_t totalConnections = 0;
for (const auto& pair : connections) {
totalConnections += pair.second.size();
}
std::cout << "[addConnection] Verbindung für User " << userId << " gespeichert (User hat " << connections[userId].size() << " Verbindung(en), insgesamt: " << totalConnections << " Verbindungen)" << std::endl;
}
void WebSocketServer::removeConnection(const std::string &userId, struct lws *wsi) {
std::unique_lock<std::shared_mutex> lock(connectionsMutex);
auto it = connections.find(userId);
if (it != connections.end()) {
// Entferne die spezifische Verbindung aus dem Vektor
auto& connList = it->second;
connList.erase(std::remove(connList.begin(), connList.end(), wsi), connList.end());
// Speichere die verbleibende Anzahl vor dem möglichen Löschen
size_t remainingConnections = connList.size();
// Wenn keine Verbindungen mehr vorhanden sind, entferne den Eintrag
if (connList.empty()) {
connections.erase(it);
}
size_t totalConnections = 0;
for (const auto& pair : connections) {
totalConnections += pair.second.size();
}
std::cout << "[removeConnection] Verbindung für User " << userId << " entfernt (User hat noch " << remainingConnections << " Verbindung(en), insgesamt: " << totalConnections << " Verbindungen)" << std::endl;
} else {
std::cout << "[removeConnection] Warnung: Keine Verbindungen für User " << userId << " gefunden" << std::endl;
}
}