Files
yourpart3/src/websocket_server.cpp

174 lines
5.6 KiB
C++

#include "websocket_server.h"
#include "connection_guard.h"
#include "worker.h"
#include <iostream>
#include <chrono>
#include <cstring>
using json = nlohmann::json;
// Protocols array definition
struct lws_protocols WebSocketServer::protocols[] = {
{
"yourpart-protocol",
WebSocketServer::wsCallback,
sizeof(WebSocketUserData),
4096
},
{ nullptr, nullptr, 0, 0 }
};
WebSocketServer::WebSocketServer(int port, ConnectionPool &pool, MessageBroker &broker,
bool useSSL, const std::string& certPath, const std::string& keyPath)
: port(port), pool(pool), broker(broker), useSSL(useSSL), certPath(certPath), keyPath(keyPath) {}
WebSocketServer::~WebSocketServer() {
stop();
}
void WebSocketServer::run() {
running = true;
broker.subscribe([this](const std::string &msg) {
{
std::lock_guard<std::mutex> lock(queueMutex);
messageQueue.push(msg);
}
queueCV.notify_one();
});
serverThread = std::thread([this](){ startServer(); });
messageThread = std::thread([this](){ processMessageQueue(); });
pingThread = std::thread([this](){ pingClients(); });
}
void WebSocketServer::stop() {
running = false;
if (context) lws_cancel_service(context);
if (serverThread.joinable()) serverThread.join();
if (messageThread.joinable()) messageThread.join();
if (pingThread.joinable()) pingThread.join();
if (context) {
lws_context_destroy(context);
context = nullptr;
}
}
void WebSocketServer::startServer() {
struct lws_context_creation_info info;
memset(&info, 0, sizeof(info));
info.port = port;
info.protocols = protocols;
// SSL/TLS Konfiguration
if (useSSL) {
if (certPath.empty() || keyPath.empty()) {
throw std::runtime_error("SSL enabled but certificate or key path not provided");
}
info.options = LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT;
info.ssl_cert_filepath = certPath.c_str();
info.ssl_private_key_filepath = keyPath.c_str();
std::cout << "WebSocket SSL Server starting on port " << port << " with certificates: "
<< certPath << " / " << keyPath << std::endl;
} else {
std::cout << "WebSocket Server starting on port " << port << " (no SSL)" << std::endl;
}
// Reduziere Log-Level um weniger Debug-Ausgaben zu haben
setenv("LWS_LOG_LEVEL", "0", 1); // 0 = nur Fehler
context = lws_create_context(&info);
if (!context) {
throw std::runtime_error("Failed to create LWS context");
}
while (running) {
lws_service(context, 50);
}
}
void WebSocketServer::processMessageQueue() {
while (running) {
std::unique_lock<std::mutex> lock(queueMutex);
queueCV.wait(lock, [this](){ return !messageQueue.empty() || !running; });
while (!messageQueue.empty()) {
std::string msg = std::move(messageQueue.front());
messageQueue.pop();
lock.unlock();
handleBrokerMessage(msg);
lock.lock();
}
}
}
void WebSocketServer::pingClients() {
while (running) {
std::this_thread::sleep_for(std::chrono::seconds(30));
lws_callback_on_writable_all_protocol(context, &protocols[0]);
}
}
int WebSocketServer::wsCallback(struct lws *wsi,
enum lws_callback_reasons reason,
void *user, void *in, size_t len) {
auto *ud = reinterpret_cast<WebSocketUserData*>(user);
switch (reason) {
case LWS_CALLBACK_ESTABLISHED:
ud->pongReceived = true;
break;
case LWS_CALLBACK_RECEIVE: {
std::string msg(reinterpret_cast<char*>(in), len);
// Here you would dispatch the received message to handleBrokerMessage or handleWebSocketMessage
break;
}
case LWS_CALLBACK_SERVER_WRITEABLE: {
unsigned char buf[LWS_PRE + 4];
memcpy(buf + LWS_PRE, "ping", 4);
lws_write(wsi, buf + LWS_PRE, 4, LWS_WRITE_TEXT);
break;
}
case LWS_CALLBACK_CLOSED:
// Remove closed connection if stored
break;
default:
break;
}
return 0;
}
void WebSocketServer::handleBrokerMessage(const std::string &message) {
try {
json parsed = json::parse(message);
if (parsed.contains("user_id")) {
int fid = parsed["user_id"].get<int>();
auto userId = getUserIdFromFalukantUserId(fid);
std::shared_lock<std::shared_mutex> lock(connectionsMutex);
auto it = connections.find(userId);
if (it != connections.end()) {
lws_callback_on_writable(it->second);
}
}
} catch (const std::exception &e) {
std::cerr << "Error processing broker message: " << e.what() << std::endl;
}
}
std::string WebSocketServer::getUserIdFromFalukantUserId(int userId) {
ConnectionGuard guard(pool);
auto &db = guard.get();
std::string sql = R"(
SELECT u.hashed_id
FROM community.user u
JOIN falukant_data.falukant_user fu ON u.id = fu.user_id
WHERE fu.id = $1
)";
db.prepare("get_user_id", sql);
auto res = db.execute("get_user_id", {std::to_string(userId)});
return (!res.empty()) ? res[0]["hashed_id"] : std::string();
}
void WebSocketServer::setWorkers(const std::vector<std::unique_ptr<Worker>> &workerList) {
workers.clear();
workers.reserve(workerList.size());
for (const auto &wptr : workerList) {
workers.push_back(wptr.get());
}
}