use crate::db::ConnectionPool; use crate::message_broker::MessageBroker; use crate::worker::Worker; use futures_util::{FutureExt, SinkExt, StreamExt}; use serde::Deserialize; use serde_json::Value as Json; use std::collections::HashMap; use std::fs::File; use std::io::BufReader; use std::net::SocketAddr; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpListener; use tokio::runtime::{Builder, Runtime}; use tokio::sync::{broadcast, mpsc, Mutex}; use tokio_rustls::rustls::{self, ServerConfig}; use tokio_rustls::TlsAcceptor; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::accept_async; use rustls_pemfile::{certs, pkcs8_private_keys, rsa_private_keys}; use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs1KeyDer, PrivatePkcs8KeyDer}; /// Einfacher WebSocket-Server auf Basis von Tokio + tokio-tungstenite. /// /// Unterstützt: /// - `setUserId`-Event vom Client (`{"event":"setUserId","data":{"userId":"..."}}`) /// - Versenden von Broker-Nachrichten mit `user_id`-Feld an passende Verbindungen /// - Broadcasting von Nachrichten ohne `user_id` an alle pub struct WebSocketServer { port: u16, pool: ConnectionPool, broker: MessageBroker, use_ssl: bool, cert_path: Option, key_path: Option, workers: Vec<*const dyn Worker>, running: Arc, runtime: Option, } /// Einfache Registry, um Verbindungsstatistiken für `getConnections` zu liefern. #[derive(Default)] struct ConnectionRegistry { total: usize, unauthenticated: usize, by_user: HashMap, } fn create_tls_acceptor( cert_path: Option<&str>, key_path: Option<&str>, ) -> Result> { let cert_path = cert_path.ok_or("SSL aktiviert, aber kein Zertifikatspfad gesetzt")?; let key_path = key_path.ok_or("SSL aktiviert, aber kein Key-Pfad gesetzt")?; let cert_file = File::open(cert_path)?; let mut cert_reader = BufReader::new(cert_file); let mut cert_chain: Vec> = Vec::new(); for cert_result in certs(&mut cert_reader) { let cert: CertificateDer<'static> = cert_result?; cert_chain.push(cert); } if cert_chain.is_empty() { return Err("Zertifikatsdatei enthält keine Zertifikate".into()); } let key_file = File::open(key_path)?; let mut key_reader = BufReader::new(key_file); // Versuche zuerst PKCS8, dann ggf. RSA-Key let mut keys: Vec> = pkcs8_private_keys(&mut key_reader) .map(|res: Result, _>| res.map(PrivateKeyDer::Pkcs8)) .collect::>()?; if keys.is_empty() { // Leser zurücksetzen und RSA-Keys versuchen let key_file = File::open(key_path)?; let mut key_reader = BufReader::new(key_file); keys = rsa_private_keys(&mut key_reader) .map(|res: Result, _>| res.map(PrivateKeyDer::Pkcs1)) .collect::>()?; } if keys.is_empty() { return Err("Key-Datei enthält keinen privaten Schlüssel (PKCS8 oder RSA)".into()); } let private_key = keys.remove(0); let config = ServerConfig::builder() .with_no_client_auth() .with_single_cert(cert_chain, private_key)?; Ok(TlsAcceptor::from(Arc::new(config))) } impl WebSocketServer { pub fn new( port: u16, pool: ConnectionPool, broker: MessageBroker, use_ssl: bool, cert_path: Option, key_path: Option, ) -> Self { Self { port, pool, broker, use_ssl, cert_path, key_path, workers: Vec::new(), running: Arc::new(AtomicBool::new(false)), runtime: None, } } pub fn set_workers(&mut self, workers: &[Box]) { self.workers.clear(); for w in workers { self.workers.push(&**w as *const dyn Worker); } } pub fn run(&mut self) { if self.running.swap(true, Ordering::SeqCst) { eprintln!("[WebSocketServer] Läuft bereits."); return; } if self.use_ssl { println!( "Starte WebSocket-Server auf Port {} mit SSL (cert: {:?}, key: {:?})", self.port, self.cert_path, self.key_path ); // Hinweis: SSL-Unterstützung ist noch nicht implementiert. } else { println!("Starte WebSocket-Server auf Port {} (ohne SSL)", self.port); } let addr = format!("0.0.0.0:{}", self.port); let running_flag = self.running.clone(); let broker = self.broker.clone(); // Gemeinsame Registry für alle Verbindungen let registry = Arc::new(Mutex::new(ConnectionRegistry::default())); // Broadcast-Kanal für Broker-Nachrichten let (tx, _) = broadcast::channel::(1024); let tx_clone = tx.clone(); // Broker-Subscription: jede gepublishte Nachricht geht in den Broadcast-Kanal broker.subscribe(move |msg: String| { let _ = tx_clone.send(msg); }); // Optionalen TLS-Akzeptor laden, falls SSL aktiviert ist let tls_acceptor = if self.use_ssl { match create_tls_acceptor( self.cert_path.as_deref(), self.key_path.as_deref(), ) { Ok(acc) => Some(acc), Err(err) => { eprintln!( "[WebSocketServer] TLS-Initialisierung fehlgeschlagen, starte ohne SSL: {err}" ); None } } } else { None }; let rt = Builder::new_multi_thread() .enable_all() .build() .expect("Tokio Runtime konnte nicht erstellt werden"); rt.spawn(run_accept_loop( addr, running_flag, tx, self.pool.clone(), registry, tls_acceptor, )); self.runtime = Some(rt); } pub fn stop(&mut self) { if !self.running.swap(false, Ordering::SeqCst) { return; } println!("WebSocket-Server wird gestoppt."); if let Some(rt) = self.runtime.take() { rt.shutdown_background(); } } } #[derive(Debug, Deserialize)] struct IncomingMessage { #[serde(default)] event: String, #[serde(default)] data: Json, } async fn run_accept_loop( addr: String, running: Arc, tx: broadcast::Sender, _pool: ConnectionPool, registry: Arc>, tls_acceptor: Option, ) { let listener = match TcpListener::bind(&addr).await { Ok(l) => l, Err(e) => { eprintln!("[WebSocketServer] Fehler beim Binden an {}: {}", addr, e); running.store(false, Ordering::SeqCst); return; } }; println!("[WebSocketServer] Lauscht auf {}", addr); while running.load(Ordering::SeqCst) { let (stream, peer) = match listener.accept().await { Ok(v) => v, Err(e) => { eprintln!("[WebSocketServer] accept() fehlgeschlagen: {}", e); continue; } }; let peer_addr = peer; let rx = tx.subscribe(); let registry_clone = registry.clone(); let tls_acceptor_clone = tls_acceptor.clone(); tokio::spawn(async move { if let Some(acc) = tls_acceptor_clone { match acc.accept(stream).await { Ok(tls_stream) => { handle_connection(tls_stream, peer_addr, rx, registry_clone).await } Err(err) => { eprintln!( "[WebSocketServer] TLS-Handshake fehlgeschlagen ({peer_addr}): {err}" ); } } } else { handle_connection(stream, peer_addr, rx, registry_clone).await; } }); } } async fn handle_connection( stream: S, peer_addr: SocketAddr, mut broker_rx: broadcast::Receiver, registry: Arc>, ) where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let ws_stream = match accept_async(stream).await { Ok(ws) => ws, Err(e) => { eprintln!("[WebSocketServer] WebSocket-Handshake fehlgeschlagen ({peer_addr}): {e}"); return; } }; println!("[WebSocketServer] Neue Verbindung von {}", peer_addr); let (mut ws_sender, mut ws_receiver) = ws_stream.split(); // Kanal für Antworten direkt an diesen Client (z.B. getConnections) let (client_tx, mut client_rx) = mpsc::channel::(32); // Neue Verbindung in der Registry zählen (zunächst als unauthentifiziert) { let mut reg = registry.lock().await; reg.total += 1; reg.unauthenticated += 1; } // user_id der Verbindung (nach setUserId) let user_id = Arc::new(tokio::sync::Mutex::new(Option::::None)); let user_id_for_incoming = user_id.clone(); let user_id_for_broker = user_id.clone(); let registry_for_incoming = registry.clone(); let client_tx_incoming = client_tx.clone(); // Eingehende Nachrichten vom Client let incoming = async move { while let Some(msg) = ws_receiver.next().await { match msg { Ok(Message::Text(txt)) => { if let Ok(parsed) = serde_json::from_str::(&txt) { match parsed.event.as_str() { "setUserId" => { if let Some(uid) = parsed.data.get("userId").and_then(|v| v.as_str()) { { // Registry aktualisieren: von unauthentifiziert -> Nutzer let mut reg = registry_for_incoming.lock().await; if reg.unauthenticated > 0 { reg.unauthenticated -= 1; } *reg.by_user.entry(uid.to_string()).or_insert(0) += 1; } let mut guard = user_id_for_incoming.lock().await; *guard = Some(uid.to_string()); println!( "[WebSocketServer] User-ID gesetzt für {}: {}", peer_addr, uid ); } } "getConnections" => { // Einfache Übersicht über aktuelle Verbindungen zurückgeben. let snapshot = { let reg = registry_for_incoming.lock().await; serde_json::json!({ "event": "getConnectionsResponse", "total": reg.total, "unauthenticated": reg.unauthenticated, "users": reg.by_user, }) .to_string() }; let _ = client_tx_incoming.send(snapshot).await; } _ => { // Unbekannte Events ignorieren } } } } Ok(Message::Ping(_)) => { // Ping wird aktuell nur geloggt/ignoriert; optional könnte man hier ein eigenes // Ping/Pong-Handling ergänzen. } Ok(Message::Close(_)) => break, Err(e) => { eprintln!("[WebSocketServer] Fehler bei Nachricht von {peer_addr}: {e}"); break; } _ => {} } } }; // Broker-Nachrichten an den Client let outgoing = async move { loop { tokio::select! { // Nachrichten aus dem MessageBroker broker_msg = broker_rx.recv() => { let msg = match broker_msg { Ok(m) => m, Err(_) => break, }; // Filter nach user_id, falls gesetzt let target_user = { let guard = user_id_for_broker.lock().await; guard.clone() }; if let Some(uid) = target_user.clone() { if let Ok(json) = serde_json::from_str::(&msg) { let matches_user = json .get("user_id") .and_then(|v| { if let Some(s) = v.as_str() { Some(s.to_string()) } else if let Some(n) = v.as_i64() { Some(n.to_string()) } else { None } }) .map(|v| v == uid) .unwrap_or(false); if !matches_user { continue; } } } if let Err(e) = ws_sender.send(Message::Text(msg)).await { eprintln!( "[WebSocketServer] Fehler beim Senden an {}: {}", peer_addr, e ); break; } } // Antworten aus der Verbindung selbst (z.B. getConnections) client_msg = client_rx.recv() => { match client_msg { Some(msg) => { if let Err(e) = ws_sender.send(Message::Text(msg)).await { eprintln!( "[WebSocketServer] Fehler beim Senden an {}: {}", peer_addr, e ); break; } } None => { // Kanal wurde geschlossen break; } } } } } }; futures_util::future::select(incoming.boxed(), outgoing.boxed()).await; // Verbindung aus der Registry entfernen let final_uid = { let guard = user_id.lock().await; guard.clone() }; { let mut reg = registry.lock().await; if reg.total > 0 { reg.total -= 1; } if let Some(uid) = final_uid { if let Some(count) = reg.by_user.get_mut(&uid) { if *count > 0 { *count -= 1; } if *count == 0 { reg.by_user.remove(&uid); } } } else if reg.unauthenticated > 0 { reg.unauthenticated -= 1; } } println!("[WebSocketServer] Verbindung geschlossen: {}", peer_addr); }