617 lines
23 KiB
Rust
617 lines
23 KiB
Rust
use crate::db::ConnectionPool;
|
||
use crate::message_broker::MessageBroker;
|
||
use crate::worker::Worker;
|
||
use futures_util::{FutureExt, SinkExt, StreamExt};
|
||
use serde::{Deserialize, Serialize};
|
||
use serde_json::Value as Json;
|
||
use std::collections::HashMap;
|
||
use std::fs::File;
|
||
use std::io::BufReader;
|
||
use std::net::SocketAddr;
|
||
use std::sync::atomic::{AtomicBool, Ordering};
|
||
use std::sync::Arc;
|
||
use std::time::{SystemTime, UNIX_EPOCH};
|
||
use tokio::io::{AsyncRead, AsyncWrite};
|
||
use tokio::net::TcpListener;
|
||
use tokio::runtime::{Builder, Runtime};
|
||
use tokio::sync::{broadcast, mpsc, Mutex};
|
||
use tokio::time::{interval, Duration as TokioDuration};
|
||
use tokio_rustls::rustls::{self, ServerConfig};
|
||
use tokio_rustls::TlsAcceptor;
|
||
use tokio_tungstenite::tungstenite::Message;
|
||
use tokio_tungstenite::accept_async;
|
||
use rustls_pemfile::{certs, pkcs8_private_keys, rsa_private_keys};
|
||
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs1KeyDer, PrivatePkcs8KeyDer};
|
||
|
||
/// Einfacher WebSocket-Server auf Basis von Tokio + tokio-tungstenite.
|
||
///
|
||
/// Unterstützt:
|
||
/// - `setUserId`-Event vom Client (`{"event":"setUserId","data":{"userId":"..."}}`)
|
||
/// - Versenden von Broker-Nachrichten mit `user_id`-Feld an passende Verbindungen
|
||
/// - Broadcasting von Nachrichten ohne `user_id` an alle
|
||
pub struct WebSocketServer {
|
||
port: u16,
|
||
pool: ConnectionPool,
|
||
broker: MessageBroker,
|
||
use_ssl: bool,
|
||
cert_path: Option<String>,
|
||
key_path: Option<String>,
|
||
workers: Vec<*const dyn Worker>,
|
||
running: Arc<AtomicBool>,
|
||
runtime: Option<Runtime>,
|
||
}
|
||
|
||
/// Einfache Registry, um Verbindungsstatistiken für `getConnections` zu liefern.
|
||
#[derive(Default)]
|
||
struct ConnectionRegistry {
|
||
total: usize,
|
||
unauthenticated: usize,
|
||
by_user: HashMap<String, usize>,
|
||
}
|
||
|
||
/// Eintrag für das WebSocket-Log, abrufbar über ein eigenes Event.
|
||
#[derive(Debug, Clone, Serialize)]
|
||
struct WebSocketLogEntry {
|
||
timestamp: u64, // Sekunden seit UNIX_EPOCH
|
||
direction: String, // z.B. "broker->client"
|
||
peer: String, // "ip:port"
|
||
conn_user: Option<String>, // per setUserId gesetzte User-ID
|
||
target_user: Option<String>, // user_id aus der Nachricht (falls vorhanden)
|
||
event: Option<String>, // event-Feld aus der Nachricht (falls JSON)
|
||
}
|
||
|
||
#[derive(Default)]
|
||
struct WebSocketLog {
|
||
entries: Vec<WebSocketLogEntry>,
|
||
}
|
||
|
||
const WS_LOG_MAX_ENTRIES: usize = 50_000;
|
||
|
||
async fn append_ws_log(
|
||
log: &Arc<Mutex<WebSocketLog>>,
|
||
direction: &str,
|
||
peer_addr: &SocketAddr,
|
||
conn_user: &Option<String>,
|
||
msg: &str,
|
||
) {
|
||
let now = SystemTime::now()
|
||
.duration_since(UNIX_EPOCH)
|
||
.unwrap_or_default()
|
||
.as_secs();
|
||
|
||
let (event, target_user) = if let Ok(json) = serde_json::from_str::<Json>(msg) {
|
||
let event = json
|
||
.get("event")
|
||
.and_then(|v| v.as_str())
|
||
.map(|s| s.to_string());
|
||
let target_user = json
|
||
.get("user_id")
|
||
.and_then(|v| {
|
||
if let Some(s) = v.as_str() {
|
||
Some(s.to_string())
|
||
} else if let Some(n) = v.as_i64() {
|
||
Some(n.to_string())
|
||
} else {
|
||
None
|
||
}
|
||
});
|
||
(event, target_user)
|
||
} else {
|
||
(None, None)
|
||
};
|
||
|
||
let entry = WebSocketLogEntry {
|
||
timestamp: now,
|
||
direction: direction.to_string(),
|
||
peer: peer_addr.to_string(),
|
||
conn_user: conn_user.clone(),
|
||
target_user,
|
||
event,
|
||
};
|
||
|
||
let mut guard = log.lock().await;
|
||
guard.entries.push(entry);
|
||
if guard.entries.len() > WS_LOG_MAX_ENTRIES {
|
||
let overflow = guard.entries.len() - WS_LOG_MAX_ENTRIES;
|
||
guard.entries.drain(0..overflow);
|
||
}
|
||
}
|
||
|
||
fn create_tls_acceptor(
|
||
cert_path: Option<&str>,
|
||
key_path: Option<&str>,
|
||
) -> Result<TlsAcceptor, Box<dyn std::error::Error>> {
|
||
let cert_path = cert_path.ok_or("SSL aktiviert, aber kein Zertifikatspfad gesetzt")?;
|
||
let key_path = key_path.ok_or("SSL aktiviert, aber kein Key-Pfad gesetzt")?;
|
||
|
||
let cert_file = File::open(cert_path)?;
|
||
let mut cert_reader = BufReader::new(cert_file);
|
||
|
||
let mut cert_chain: Vec<CertificateDer<'static>> = Vec::new();
|
||
for cert_result in certs(&mut cert_reader) {
|
||
let cert: CertificateDer<'static> = cert_result?;
|
||
cert_chain.push(cert);
|
||
}
|
||
|
||
if cert_chain.is_empty() {
|
||
return Err("Zertifikatsdatei enthält keine Zertifikate".into());
|
||
}
|
||
|
||
let key_file = File::open(key_path)?;
|
||
let mut key_reader = BufReader::new(key_file);
|
||
|
||
// Versuche zuerst PKCS8, dann ggf. RSA-Key
|
||
let mut keys: Vec<PrivateKeyDer<'static>> = pkcs8_private_keys(&mut key_reader)
|
||
.map(|res: Result<PrivatePkcs8KeyDer<'static>, _>| res.map(PrivateKeyDer::Pkcs8))
|
||
.collect::<Result<_, _>>()?;
|
||
|
||
if keys.is_empty() {
|
||
// Leser zurücksetzen und RSA-Keys versuchen
|
||
let key_file = File::open(key_path)?;
|
||
let mut key_reader = BufReader::new(key_file);
|
||
keys = rsa_private_keys(&mut key_reader)
|
||
.map(|res: Result<PrivatePkcs1KeyDer<'static>, _>| res.map(PrivateKeyDer::Pkcs1))
|
||
.collect::<Result<_, _>>()?;
|
||
}
|
||
|
||
if keys.is_empty() {
|
||
return Err("Key-Datei enthält keinen privaten Schlüssel (PKCS8 oder RSA)".into());
|
||
}
|
||
|
||
let private_key = keys.remove(0);
|
||
|
||
let config = ServerConfig::builder()
|
||
.with_no_client_auth()
|
||
.with_single_cert(cert_chain, private_key)?;
|
||
|
||
Ok(TlsAcceptor::from(Arc::new(config)))
|
||
}
|
||
|
||
impl WebSocketServer {
|
||
pub fn new(
|
||
port: u16,
|
||
pool: ConnectionPool,
|
||
broker: MessageBroker,
|
||
use_ssl: bool,
|
||
cert_path: Option<String>,
|
||
key_path: Option<String>,
|
||
) -> Self {
|
||
Self {
|
||
port,
|
||
pool,
|
||
broker,
|
||
use_ssl,
|
||
cert_path,
|
||
key_path,
|
||
workers: Vec::new(),
|
||
running: Arc::new(AtomicBool::new(false)),
|
||
runtime: None,
|
||
}
|
||
}
|
||
|
||
pub fn set_workers(&mut self, workers: &[Box<dyn Worker>]) {
|
||
self.workers.clear();
|
||
for w in workers {
|
||
self.workers.push(&**w as *const dyn Worker);
|
||
}
|
||
}
|
||
|
||
pub fn run(&mut self) {
|
||
if self.running.swap(true, Ordering::SeqCst) {
|
||
eprintln!("[WebSocketServer] Läuft bereits.");
|
||
return;
|
||
}
|
||
|
||
if self.use_ssl {
|
||
println!(
|
||
"Starte WebSocket-Server auf Port {} mit SSL (cert: {:?}, key: {:?})",
|
||
self.port, self.cert_path, self.key_path
|
||
);
|
||
// Hinweis: SSL-Unterstützung ist noch nicht implementiert.
|
||
} else {
|
||
println!("Starte WebSocket-Server auf Port {} (ohne SSL)", self.port);
|
||
}
|
||
|
||
let addr = format!("0.0.0.0:{}", self.port);
|
||
let running_flag = self.running.clone();
|
||
let broker = self.broker.clone();
|
||
|
||
// Gemeinsame Registry für alle Verbindungen
|
||
let registry = Arc::new(Mutex::new(ConnectionRegistry::default()));
|
||
// Gemeinsames WebSocket-Log (für max. 24h, begrenzt über WS_LOG_MAX_ENTRIES)
|
||
let ws_log = Arc::new(Mutex::new(WebSocketLog::default()));
|
||
|
||
// Broadcast-Kanal für Broker-Nachrichten
|
||
let (tx, _) = broadcast::channel::<String>(1024);
|
||
let tx_clone = tx.clone();
|
||
|
||
// Broker-Subscription: jede gepublishte Nachricht geht in den Broadcast-Kanal
|
||
broker.subscribe(move |msg: String| {
|
||
let _ = tx_clone.send(msg);
|
||
});
|
||
|
||
// Optionalen TLS-Akzeptor laden, falls SSL aktiviert ist
|
||
let tls_acceptor = if self.use_ssl {
|
||
match create_tls_acceptor(
|
||
self.cert_path.as_deref(),
|
||
self.key_path.as_deref(),
|
||
) {
|
||
Ok(acc) => Some(acc),
|
||
Err(err) => {
|
||
eprintln!(
|
||
"[WebSocketServer] TLS-Initialisierung fehlgeschlagen, starte ohne SSL: {err}"
|
||
);
|
||
None
|
||
}
|
||
}
|
||
} else {
|
||
None
|
||
};
|
||
|
||
let rt = Builder::new_multi_thread()
|
||
.enable_all()
|
||
.build()
|
||
.expect("Tokio Runtime konnte nicht erstellt werden");
|
||
|
||
rt.spawn(run_accept_loop(
|
||
addr,
|
||
running_flag,
|
||
tx,
|
||
self.pool.clone(),
|
||
registry,
|
||
ws_log,
|
||
tls_acceptor,
|
||
));
|
||
|
||
self.runtime = Some(rt);
|
||
}
|
||
|
||
pub fn stop(&mut self) {
|
||
if !self.running.swap(false, Ordering::SeqCst) {
|
||
return;
|
||
}
|
||
println!("WebSocket-Server wird gestoppt.");
|
||
if let Some(rt) = self.runtime.take() {
|
||
rt.shutdown_background();
|
||
}
|
||
}
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct IncomingMessage {
|
||
#[serde(default)]
|
||
event: String,
|
||
#[serde(default)]
|
||
data: Json,
|
||
}
|
||
|
||
async fn run_accept_loop(
|
||
addr: String,
|
||
running: Arc<AtomicBool>,
|
||
tx: broadcast::Sender<String>,
|
||
_pool: ConnectionPool,
|
||
registry: Arc<Mutex<ConnectionRegistry>>,
|
||
ws_log: Arc<Mutex<WebSocketLog>>,
|
||
tls_acceptor: Option<TlsAcceptor>,
|
||
) {
|
||
let listener = match TcpListener::bind(&addr).await {
|
||
Ok(l) => l,
|
||
Err(e) => {
|
||
eprintln!("[WebSocketServer] Fehler beim Binden an {}: {}", addr, e);
|
||
running.store(false, Ordering::SeqCst);
|
||
return;
|
||
}
|
||
};
|
||
|
||
println!("[WebSocketServer] Lauscht auf {}", addr);
|
||
|
||
while running.load(Ordering::SeqCst) {
|
||
let (stream, peer) = match listener.accept().await {
|
||
Ok(v) => v,
|
||
Err(e) => {
|
||
eprintln!("[WebSocketServer] accept() fehlgeschlagen: {}", e);
|
||
continue;
|
||
}
|
||
};
|
||
|
||
let peer_addr = peer;
|
||
let rx = tx.subscribe();
|
||
let registry_clone = registry.clone();
|
||
let ws_log_clone = ws_log.clone();
|
||
let tls_acceptor_clone = tls_acceptor.clone();
|
||
|
||
tokio::spawn(async move {
|
||
if let Some(acc) = tls_acceptor_clone {
|
||
match acc.accept(stream).await {
|
||
Ok(tls_stream) => {
|
||
handle_connection(tls_stream, peer_addr, rx, registry_clone, ws_log_clone)
|
||
.await
|
||
}
|
||
Err(err) => {
|
||
eprintln!(
|
||
"[WebSocketServer] TLS-Handshake fehlgeschlagen ({peer_addr}): {err}"
|
||
);
|
||
}
|
||
}
|
||
} else {
|
||
handle_connection(stream, peer_addr, rx, registry_clone, ws_log_clone).await;
|
||
}
|
||
});
|
||
}
|
||
}
|
||
|
||
async fn handle_connection<S>(
|
||
stream: S,
|
||
peer_addr: SocketAddr,
|
||
mut broker_rx: broadcast::Receiver<String>,
|
||
registry: Arc<Mutex<ConnectionRegistry>>,
|
||
ws_log: Arc<Mutex<WebSocketLog>>,
|
||
) where
|
||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||
{
|
||
let ws_stream = match accept_async(stream).await {
|
||
Ok(ws) => ws,
|
||
Err(e) => {
|
||
eprintln!("[WebSocketServer] WebSocket-Handshake fehlgeschlagen ({peer_addr}): {e}");
|
||
return;
|
||
}
|
||
};
|
||
|
||
println!("[WebSocketServer] Neue Verbindung von {}", peer_addr);
|
||
|
||
let (mut ws_sender, mut ws_receiver) = ws_stream.split();
|
||
|
||
// Kanal für Antworten direkt an diesen Client (z.B. getConnections)
|
||
let (client_tx, mut client_rx) = mpsc::channel::<String>(32);
|
||
|
||
// Neue Verbindung in der Registry zählen (zunächst als unauthentifiziert)
|
||
{
|
||
let mut reg = registry.lock().await;
|
||
reg.total += 1;
|
||
reg.unauthenticated += 1;
|
||
}
|
||
|
||
// user_id der Verbindung (nach setUserId)
|
||
let user_id = Arc::new(tokio::sync::Mutex::new(Option::<String>::None));
|
||
let user_id_for_incoming = user_id.clone();
|
||
let user_id_for_broker = user_id.clone();
|
||
let registry_for_incoming = registry.clone();
|
||
let client_tx_incoming = client_tx.clone();
|
||
let ws_log_for_incoming = ws_log.clone();
|
||
let ws_log_for_outgoing = ws_log.clone();
|
||
|
||
// Eingehende Nachrichten vom Client
|
||
let incoming = async move {
|
||
while let Some(msg) = ws_receiver.next().await {
|
||
match msg {
|
||
Ok(Message::Text(txt)) => {
|
||
if let Ok(parsed) = serde_json::from_str::<IncomingMessage>(&txt) {
|
||
match parsed.event.as_str() {
|
||
"setUserId" => {
|
||
if let Some(uid) =
|
||
parsed.data.get("userId").and_then(|v| v.as_str())
|
||
{
|
||
{
|
||
// Registry aktualisieren: von unauthentifiziert -> Nutzer
|
||
let mut reg = registry_for_incoming.lock().await;
|
||
if reg.unauthenticated > 0 {
|
||
reg.unauthenticated -= 1;
|
||
}
|
||
*reg.by_user.entry(uid.to_string()).or_insert(0) += 1;
|
||
}
|
||
|
||
let mut guard = user_id_for_incoming.lock().await;
|
||
*guard = Some(uid.to_string());
|
||
println!(
|
||
"[WebSocketServer] User-ID gesetzt für {}: {}",
|
||
peer_addr, uid
|
||
);
|
||
}
|
||
}
|
||
"getConnections" => {
|
||
// Einfache Übersicht über aktuelle Verbindungen zurückgeben.
|
||
let snapshot = {
|
||
let reg = registry_for_incoming.lock().await;
|
||
serde_json::json!({
|
||
"event": "getConnectionsResponse",
|
||
"total": reg.total,
|
||
"unauthenticated": reg.unauthenticated,
|
||
"users": reg.by_user,
|
||
})
|
||
.to_string()
|
||
};
|
||
let _ = client_tx_incoming.send(snapshot).await;
|
||
}
|
||
"getWebsocketLog" => {
|
||
// Liefert die letzten 24h (oder weniger) aus dem In-Memory-Log.
|
||
let entries = {
|
||
let guard = ws_log_for_incoming.lock().await;
|
||
let now = SystemTime::now()
|
||
.duration_since(UNIX_EPOCH)
|
||
.unwrap_or_default()
|
||
.as_secs();
|
||
let cutoff = now.saturating_sub(24 * 3600);
|
||
|
||
let mut filtered: Vec<WebSocketLogEntry> = guard
|
||
.entries
|
||
.iter()
|
||
.filter(|e| e.timestamp >= cutoff)
|
||
.cloned()
|
||
.collect();
|
||
|
||
// Zur Sicherheit begrenzen wir die Antwortgröße
|
||
if filtered.len() > 1000 {
|
||
let len = filtered.len();
|
||
filtered = filtered[len - 1000..].to_vec();
|
||
}
|
||
filtered
|
||
};
|
||
|
||
let payload = serde_json::json!({
|
||
"event": "getWebsocketLogResponse",
|
||
"entries": entries,
|
||
})
|
||
.to_string();
|
||
let _ = client_tx_incoming.send(payload).await;
|
||
}
|
||
_ => {
|
||
// Unbekannte Events ignorieren
|
||
}
|
||
}
|
||
}
|
||
}
|
||
Ok(Message::Ping(_)) => {
|
||
// Ping wird aktuell nur geloggt/ignoriert; optional könnte man hier ein eigenes
|
||
// Ping/Pong-Handling ergänzen.
|
||
}
|
||
Ok(Message::Close(_)) => break,
|
||
Err(e) => {
|
||
eprintln!("[WebSocketServer] Fehler bei Nachricht von {peer_addr}: {e}");
|
||
break;
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
};
|
||
|
||
// Broker-Nachrichten an den Client + periodische Ping-Frames als Keepalive
|
||
let outgoing = async move {
|
||
// Regelmäßiges Ping, um inaktiven Verbindungen ein Lebenszeichen zu senden
|
||
// und Timeouts auf dem Weg (Proxy/Loadbalancer) zu vermeiden.
|
||
let mut ping_interval = interval(TokioDuration::from_secs(240));
|
||
|
||
loop {
|
||
tokio::select! {
|
||
// Nachrichten aus dem MessageBroker
|
||
broker_msg = broker_rx.recv() => {
|
||
let msg = match broker_msg {
|
||
Ok(m) => m,
|
||
Err(_) => break,
|
||
};
|
||
|
||
// Filter nach user_id, falls gesetzt und numerisch interpretierbar.
|
||
// Historisch wurde hier der Falukant-User (numerisch) verwendet.
|
||
// Wenn die gesetzte User-ID kein Integer ist (z.B. Benutzername),
|
||
// wird *nicht* gefiltert und alle Nachrichten durchgelassen –
|
||
// das Frontend muss dann selbst selektieren.
|
||
let target_user = {
|
||
let guard = user_id_for_broker.lock().await;
|
||
guard.clone()
|
||
};
|
||
|
||
if let Some(uid) = target_user.clone() {
|
||
// Versuche, die user_id als numerisch zu interpretieren
|
||
match uid.parse::<i64>() {
|
||
Ok(numeric_uid) => {
|
||
// Numerische user_id: Filtere explizit nach dieser ID
|
||
if let Ok(json) = serde_json::from_str::<Json>(&msg) {
|
||
let matches_user = json
|
||
.get("user_id")
|
||
.and_then(|v| {
|
||
if let Some(s) = v.as_str() {
|
||
s.parse::<i64>().ok()
|
||
} else if let Some(n) = v.as_i64() {
|
||
Some(n)
|
||
} else {
|
||
None
|
||
}
|
||
})
|
||
.map(|v| v == numeric_uid)
|
||
.unwrap_or(false);
|
||
|
||
if !matches_user {
|
||
continue;
|
||
}
|
||
}
|
||
}
|
||
Err(_) => {
|
||
// Nicht-numerische user_id: Explizit alle Nachrichten durchlassen
|
||
// (keine Filterung, wie im Kommentar dokumentiert)
|
||
// Dies ermöglicht es dem Frontend, selbst zu filtern
|
||
}
|
||
}
|
||
}
|
||
|
||
// Logging für den 24h-Überblick, was an welchen User/Peer geht
|
||
let conn_user = {
|
||
let guard = user_id_for_broker.lock().await;
|
||
guard.clone()
|
||
};
|
||
append_ws_log(&ws_log_for_outgoing, "broker->client", &peer_addr, &conn_user, &msg).await;
|
||
|
||
if let Err(e) = ws_sender.send(Message::Text(msg)).await {
|
||
eprintln!(
|
||
"[WebSocketServer] Fehler beim Senden an {}: {}",
|
||
peer_addr, e
|
||
);
|
||
break;
|
||
}
|
||
}
|
||
// Antworten aus der Verbindung selbst (z.B. getConnections)
|
||
client_msg = client_rx.recv() => {
|
||
match client_msg {
|
||
Some(msg) => {
|
||
let conn_user = {
|
||
let guard = user_id_for_broker.lock().await;
|
||
guard.clone()
|
||
};
|
||
append_ws_log(&ws_log_for_outgoing, "local->client", &peer_addr, &conn_user, &msg).await;
|
||
|
||
if let Err(e) = ws_sender.send(Message::Text(msg)).await {
|
||
eprintln!(
|
||
"[WebSocketServer] Fehler beim Senden an {}: {}",
|
||
peer_addr, e
|
||
);
|
||
break;
|
||
}
|
||
}
|
||
None => {
|
||
// Kanal wurde geschlossen
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
// Periodisches Ping an den Client
|
||
_ = ping_interval.tick() => {
|
||
if let Err(e) = ws_sender.send(Message::Ping(Vec::new())).await {
|
||
eprintln!(
|
||
"[WebSocketServer] Fehler beim Senden von Ping an {}: {}",
|
||
peer_addr, e
|
||
);
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
};
|
||
|
||
futures_util::future::select(incoming.boxed(), outgoing.boxed()).await;
|
||
|
||
// Verbindung aus der Registry entfernen
|
||
let final_uid = {
|
||
let guard = user_id.lock().await;
|
||
guard.clone()
|
||
};
|
||
{
|
||
let mut reg = registry.lock().await;
|
||
if reg.total > 0 {
|
||
reg.total -= 1;
|
||
}
|
||
if let Some(uid) = final_uid {
|
||
if let Some(count) = reg.by_user.get_mut(&uid) {
|
||
if *count > 0 {
|
||
*count -= 1;
|
||
}
|
||
if *count == 0 {
|
||
reg.by_user.remove(&uid);
|
||
}
|
||
}
|
||
} else if reg.unauthenticated > 0 {
|
||
reg.unauthenticated -= 1;
|
||
}
|
||
}
|
||
|
||
println!("[WebSocketServer] Verbindung geschlossen: {}", peer_addr);
|
||
}
|
||
|