Files
yourpart-daemon/src/websocket_server.rs

617 lines
23 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use crate::db::ConnectionPool;
use crate::message_broker::MessageBroker;
use crate::worker::Worker;
use futures_util::{FutureExt, SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use serde_json::Value as Json;
use std::collections::HashMap;
use std::fs::File;
use std::io::BufReader;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpListener;
use tokio::runtime::{Builder, Runtime};
use tokio::sync::{broadcast, mpsc, Mutex};
use tokio::time::{interval, Duration as TokioDuration};
use tokio_rustls::rustls::{self, ServerConfig};
use tokio_rustls::TlsAcceptor;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::accept_async;
use rustls_pemfile::{certs, pkcs8_private_keys, rsa_private_keys};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs1KeyDer, PrivatePkcs8KeyDer};
/// Einfacher WebSocket-Server auf Basis von Tokio + tokio-tungstenite.
///
/// Unterstützt:
/// - `setUserId`-Event vom Client (`{"event":"setUserId","data":{"userId":"..."}}`)
/// - Versenden von Broker-Nachrichten mit `user_id`-Feld an passende Verbindungen
/// - Broadcasting von Nachrichten ohne `user_id` an alle
pub struct WebSocketServer {
port: u16,
pool: ConnectionPool,
broker: MessageBroker,
use_ssl: bool,
cert_path: Option<String>,
key_path: Option<String>,
workers: Vec<*const dyn Worker>,
running: Arc<AtomicBool>,
runtime: Option<Runtime>,
}
/// Einfache Registry, um Verbindungsstatistiken für `getConnections` zu liefern.
#[derive(Default)]
struct ConnectionRegistry {
total: usize,
unauthenticated: usize,
by_user: HashMap<String, usize>,
}
/// Eintrag für das WebSocket-Log, abrufbar über ein eigenes Event.
#[derive(Debug, Clone, Serialize)]
struct WebSocketLogEntry {
timestamp: u64, // Sekunden seit UNIX_EPOCH
direction: String, // z.B. "broker->client"
peer: String, // "ip:port"
conn_user: Option<String>, // per setUserId gesetzte User-ID
target_user: Option<String>, // user_id aus der Nachricht (falls vorhanden)
event: Option<String>, // event-Feld aus der Nachricht (falls JSON)
}
#[derive(Default)]
struct WebSocketLog {
entries: Vec<WebSocketLogEntry>,
}
const WS_LOG_MAX_ENTRIES: usize = 50_000;
async fn append_ws_log(
log: &Arc<Mutex<WebSocketLog>>,
direction: &str,
peer_addr: &SocketAddr,
conn_user: &Option<String>,
msg: &str,
) {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let (event, target_user) = if let Ok(json) = serde_json::from_str::<Json>(msg) {
let event = json
.get("event")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let target_user = json
.get("user_id")
.and_then(|v| {
if let Some(s) = v.as_str() {
Some(s.to_string())
} else if let Some(n) = v.as_i64() {
Some(n.to_string())
} else {
None
}
});
(event, target_user)
} else {
(None, None)
};
let entry = WebSocketLogEntry {
timestamp: now,
direction: direction.to_string(),
peer: peer_addr.to_string(),
conn_user: conn_user.clone(),
target_user,
event,
};
let mut guard = log.lock().await;
guard.entries.push(entry);
if guard.entries.len() > WS_LOG_MAX_ENTRIES {
let overflow = guard.entries.len() - WS_LOG_MAX_ENTRIES;
guard.entries.drain(0..overflow);
}
}
fn create_tls_acceptor(
cert_path: Option<&str>,
key_path: Option<&str>,
) -> Result<TlsAcceptor, Box<dyn std::error::Error>> {
let cert_path = cert_path.ok_or("SSL aktiviert, aber kein Zertifikatspfad gesetzt")?;
let key_path = key_path.ok_or("SSL aktiviert, aber kein Key-Pfad gesetzt")?;
let cert_file = File::open(cert_path)?;
let mut cert_reader = BufReader::new(cert_file);
let mut cert_chain: Vec<CertificateDer<'static>> = Vec::new();
for cert_result in certs(&mut cert_reader) {
let cert: CertificateDer<'static> = cert_result?;
cert_chain.push(cert);
}
if cert_chain.is_empty() {
return Err("Zertifikatsdatei enthält keine Zertifikate".into());
}
let key_file = File::open(key_path)?;
let mut key_reader = BufReader::new(key_file);
// Versuche zuerst PKCS8, dann ggf. RSA-Key
let mut keys: Vec<PrivateKeyDer<'static>> = pkcs8_private_keys(&mut key_reader)
.map(|res: Result<PrivatePkcs8KeyDer<'static>, _>| res.map(PrivateKeyDer::Pkcs8))
.collect::<Result<_, _>>()?;
if keys.is_empty() {
// Leser zurücksetzen und RSA-Keys versuchen
let key_file = File::open(key_path)?;
let mut key_reader = BufReader::new(key_file);
keys = rsa_private_keys(&mut key_reader)
.map(|res: Result<PrivatePkcs1KeyDer<'static>, _>| res.map(PrivateKeyDer::Pkcs1))
.collect::<Result<_, _>>()?;
}
if keys.is_empty() {
return Err("Key-Datei enthält keinen privaten Schlüssel (PKCS8 oder RSA)".into());
}
let private_key = keys.remove(0);
let config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert_chain, private_key)?;
Ok(TlsAcceptor::from(Arc::new(config)))
}
impl WebSocketServer {
pub fn new(
port: u16,
pool: ConnectionPool,
broker: MessageBroker,
use_ssl: bool,
cert_path: Option<String>,
key_path: Option<String>,
) -> Self {
Self {
port,
pool,
broker,
use_ssl,
cert_path,
key_path,
workers: Vec::new(),
running: Arc::new(AtomicBool::new(false)),
runtime: None,
}
}
pub fn set_workers(&mut self, workers: &[Box<dyn Worker>]) {
self.workers.clear();
for w in workers {
self.workers.push(&**w as *const dyn Worker);
}
}
pub fn run(&mut self) {
if self.running.swap(true, Ordering::SeqCst) {
eprintln!("[WebSocketServer] Läuft bereits.");
return;
}
if self.use_ssl {
println!(
"Starte WebSocket-Server auf Port {} mit SSL (cert: {:?}, key: {:?})",
self.port, self.cert_path, self.key_path
);
// Hinweis: SSL-Unterstützung ist noch nicht implementiert.
} else {
println!("Starte WebSocket-Server auf Port {} (ohne SSL)", self.port);
}
let addr = format!("0.0.0.0:{}", self.port);
let running_flag = self.running.clone();
let broker = self.broker.clone();
// Gemeinsame Registry für alle Verbindungen
let registry = Arc::new(Mutex::new(ConnectionRegistry::default()));
// Gemeinsames WebSocket-Log (für max. 24h, begrenzt über WS_LOG_MAX_ENTRIES)
let ws_log = Arc::new(Mutex::new(WebSocketLog::default()));
// Broadcast-Kanal für Broker-Nachrichten
let (tx, _) = broadcast::channel::<String>(1024);
let tx_clone = tx.clone();
// Broker-Subscription: jede gepublishte Nachricht geht in den Broadcast-Kanal
broker.subscribe(move |msg: String| {
let _ = tx_clone.send(msg);
});
// Optionalen TLS-Akzeptor laden, falls SSL aktiviert ist
let tls_acceptor = if self.use_ssl {
match create_tls_acceptor(
self.cert_path.as_deref(),
self.key_path.as_deref(),
) {
Ok(acc) => Some(acc),
Err(err) => {
eprintln!(
"[WebSocketServer] TLS-Initialisierung fehlgeschlagen, starte ohne SSL: {err}"
);
None
}
}
} else {
None
};
let rt = Builder::new_multi_thread()
.enable_all()
.build()
.expect("Tokio Runtime konnte nicht erstellt werden");
rt.spawn(run_accept_loop(
addr,
running_flag,
tx,
self.pool.clone(),
registry,
ws_log,
tls_acceptor,
));
self.runtime = Some(rt);
}
pub fn stop(&mut self) {
if !self.running.swap(false, Ordering::SeqCst) {
return;
}
println!("WebSocket-Server wird gestoppt.");
if let Some(rt) = self.runtime.take() {
rt.shutdown_background();
}
}
}
#[derive(Debug, Deserialize)]
struct IncomingMessage {
#[serde(default)]
event: String,
#[serde(default)]
data: Json,
}
async fn run_accept_loop(
addr: String,
running: Arc<AtomicBool>,
tx: broadcast::Sender<String>,
_pool: ConnectionPool,
registry: Arc<Mutex<ConnectionRegistry>>,
ws_log: Arc<Mutex<WebSocketLog>>,
tls_acceptor: Option<TlsAcceptor>,
) {
let listener = match TcpListener::bind(&addr).await {
Ok(l) => l,
Err(e) => {
eprintln!("[WebSocketServer] Fehler beim Binden an {}: {}", addr, e);
running.store(false, Ordering::SeqCst);
return;
}
};
println!("[WebSocketServer] Lauscht auf {}", addr);
while running.load(Ordering::SeqCst) {
let (stream, peer) = match listener.accept().await {
Ok(v) => v,
Err(e) => {
eprintln!("[WebSocketServer] accept() fehlgeschlagen: {}", e);
continue;
}
};
let peer_addr = peer;
let rx = tx.subscribe();
let registry_clone = registry.clone();
let ws_log_clone = ws_log.clone();
let tls_acceptor_clone = tls_acceptor.clone();
tokio::spawn(async move {
if let Some(acc) = tls_acceptor_clone {
match acc.accept(stream).await {
Ok(tls_stream) => {
handle_connection(tls_stream, peer_addr, rx, registry_clone, ws_log_clone)
.await
}
Err(err) => {
eprintln!(
"[WebSocketServer] TLS-Handshake fehlgeschlagen ({peer_addr}): {err}"
);
}
}
} else {
handle_connection(stream, peer_addr, rx, registry_clone, ws_log_clone).await;
}
});
}
}
async fn handle_connection<S>(
stream: S,
peer_addr: SocketAddr,
mut broker_rx: broadcast::Receiver<String>,
registry: Arc<Mutex<ConnectionRegistry>>,
ws_log: Arc<Mutex<WebSocketLog>>,
) where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let ws_stream = match accept_async(stream).await {
Ok(ws) => ws,
Err(e) => {
eprintln!("[WebSocketServer] WebSocket-Handshake fehlgeschlagen ({peer_addr}): {e}");
return;
}
};
println!("[WebSocketServer] Neue Verbindung von {}", peer_addr);
let (mut ws_sender, mut ws_receiver) = ws_stream.split();
// Kanal für Antworten direkt an diesen Client (z.B. getConnections)
let (client_tx, mut client_rx) = mpsc::channel::<String>(32);
// Neue Verbindung in der Registry zählen (zunächst als unauthentifiziert)
{
let mut reg = registry.lock().await;
reg.total += 1;
reg.unauthenticated += 1;
}
// user_id der Verbindung (nach setUserId)
let user_id = Arc::new(tokio::sync::Mutex::new(Option::<String>::None));
let user_id_for_incoming = user_id.clone();
let user_id_for_broker = user_id.clone();
let registry_for_incoming = registry.clone();
let client_tx_incoming = client_tx.clone();
let ws_log_for_incoming = ws_log.clone();
let ws_log_for_outgoing = ws_log.clone();
// Eingehende Nachrichten vom Client
let incoming = async move {
while let Some(msg) = ws_receiver.next().await {
match msg {
Ok(Message::Text(txt)) => {
if let Ok(parsed) = serde_json::from_str::<IncomingMessage>(&txt) {
match parsed.event.as_str() {
"setUserId" => {
if let Some(uid) =
parsed.data.get("userId").and_then(|v| v.as_str())
{
{
// Registry aktualisieren: von unauthentifiziert -> Nutzer
let mut reg = registry_for_incoming.lock().await;
if reg.unauthenticated > 0 {
reg.unauthenticated -= 1;
}
*reg.by_user.entry(uid.to_string()).or_insert(0) += 1;
}
let mut guard = user_id_for_incoming.lock().await;
*guard = Some(uid.to_string());
println!(
"[WebSocketServer] User-ID gesetzt für {}: {}",
peer_addr, uid
);
}
}
"getConnections" => {
// Einfache Übersicht über aktuelle Verbindungen zurückgeben.
let snapshot = {
let reg = registry_for_incoming.lock().await;
serde_json::json!({
"event": "getConnectionsResponse",
"total": reg.total,
"unauthenticated": reg.unauthenticated,
"users": reg.by_user,
})
.to_string()
};
let _ = client_tx_incoming.send(snapshot).await;
}
"getWebsocketLog" => {
// Liefert die letzten 24h (oder weniger) aus dem In-Memory-Log.
let entries = {
let guard = ws_log_for_incoming.lock().await;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let cutoff = now.saturating_sub(24 * 3600);
let mut filtered: Vec<WebSocketLogEntry> = guard
.entries
.iter()
.filter(|e| e.timestamp >= cutoff)
.cloned()
.collect();
// Zur Sicherheit begrenzen wir die Antwortgröße
if filtered.len() > 1000 {
let len = filtered.len();
filtered = filtered[len - 1000..].to_vec();
}
filtered
};
let payload = serde_json::json!({
"event": "getWebsocketLogResponse",
"entries": entries,
})
.to_string();
let _ = client_tx_incoming.send(payload).await;
}
_ => {
// Unbekannte Events ignorieren
}
}
}
}
Ok(Message::Ping(_)) => {
// Ping wird aktuell nur geloggt/ignoriert; optional könnte man hier ein eigenes
// Ping/Pong-Handling ergänzen.
}
Ok(Message::Close(_)) => break,
Err(e) => {
eprintln!("[WebSocketServer] Fehler bei Nachricht von {peer_addr}: {e}");
break;
}
_ => {}
}
}
};
// Broker-Nachrichten an den Client + periodische Ping-Frames als Keepalive
let outgoing = async move {
// Regelmäßiges Ping, um inaktiven Verbindungen ein Lebenszeichen zu senden
// und Timeouts auf dem Weg (Proxy/Loadbalancer) zu vermeiden.
let mut ping_interval = interval(TokioDuration::from_secs(240));
loop {
tokio::select! {
// Nachrichten aus dem MessageBroker
broker_msg = broker_rx.recv() => {
let msg = match broker_msg {
Ok(m) => m,
Err(_) => break,
};
// Filter nach user_id, falls gesetzt und numerisch interpretierbar.
// Historisch wurde hier der Falukant-User (numerisch) verwendet.
// Wenn die gesetzte User-ID kein Integer ist (z.B. Benutzername),
// wird *nicht* gefiltert und alle Nachrichten durchgelassen
// das Frontend muss dann selbst selektieren.
let target_user = {
let guard = user_id_for_broker.lock().await;
guard.clone()
};
if let Some(uid) = target_user.clone() {
// Versuche, die user_id als numerisch zu interpretieren
match uid.parse::<i64>() {
Ok(numeric_uid) => {
// Numerische user_id: Filtere explizit nach dieser ID
if let Ok(json) = serde_json::from_str::<Json>(&msg) {
let matches_user = json
.get("user_id")
.and_then(|v| {
if let Some(s) = v.as_str() {
s.parse::<i64>().ok()
} else if let Some(n) = v.as_i64() {
Some(n)
} else {
None
}
})
.map(|v| v == numeric_uid)
.unwrap_or(false);
if !matches_user {
continue;
}
}
}
Err(_) => {
// Nicht-numerische user_id: Explizit alle Nachrichten durchlassen
// (keine Filterung, wie im Kommentar dokumentiert)
// Dies ermöglicht es dem Frontend, selbst zu filtern
}
}
}
// Logging für den 24h-Überblick, was an welchen User/Peer geht
let conn_user = {
let guard = user_id_for_broker.lock().await;
guard.clone()
};
append_ws_log(&ws_log_for_outgoing, "broker->client", &peer_addr, &conn_user, &msg).await;
if let Err(e) = ws_sender.send(Message::Text(msg)).await {
eprintln!(
"[WebSocketServer] Fehler beim Senden an {}: {}",
peer_addr, e
);
break;
}
}
// Antworten aus der Verbindung selbst (z.B. getConnections)
client_msg = client_rx.recv() => {
match client_msg {
Some(msg) => {
let conn_user = {
let guard = user_id_for_broker.lock().await;
guard.clone()
};
append_ws_log(&ws_log_for_outgoing, "local->client", &peer_addr, &conn_user, &msg).await;
if let Err(e) = ws_sender.send(Message::Text(msg)).await {
eprintln!(
"[WebSocketServer] Fehler beim Senden an {}: {}",
peer_addr, e
);
break;
}
}
None => {
// Kanal wurde geschlossen
break;
}
}
}
// Periodisches Ping an den Client
_ = ping_interval.tick() => {
if let Err(e) = ws_sender.send(Message::Ping(Vec::new())).await {
eprintln!(
"[WebSocketServer] Fehler beim Senden von Ping an {}: {}",
peer_addr, e
);
break;
}
}
}
}
};
futures_util::future::select(incoming.boxed(), outgoing.boxed()).await;
// Verbindung aus der Registry entfernen
let final_uid = {
let guard = user_id.lock().await;
guard.clone()
};
{
let mut reg = registry.lock().await;
if reg.total > 0 {
reg.total -= 1;
}
if let Some(uid) = final_uid {
if let Some(count) = reg.by_user.get_mut(&uid) {
if *count > 0 {
*count -= 1;
}
if *count == 0 {
reg.by_user.remove(&uid);
}
}
} else if reg.unauthenticated > 0 {
reg.unauthenticated -= 1;
}
}
println!("[WebSocketServer] Verbindung geschlossen: {}", peer_addr);
}