Files
yourpart-daemon/src/websocket_server.rs
Torsten Schulz (local) d0ec363f09 Initial commit: Rust YpDaemon
2025-11-21 23:05:34 +01:00

464 lines
16 KiB
Rust

use crate::db::ConnectionPool;
use crate::message_broker::MessageBroker;
use crate::worker::Worker;
use futures_util::{FutureExt, SinkExt, StreamExt};
use serde::Deserialize;
use serde_json::Value as Json;
use std::collections::HashMap;
use std::fs::File;
use std::io::BufReader;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpListener;
use tokio::runtime::{Builder, Runtime};
use tokio::sync::{broadcast, mpsc, Mutex};
use tokio_rustls::rustls::{self, ServerConfig};
use tokio_rustls::TlsAcceptor;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::accept_async;
use rustls_pemfile::{certs, pkcs8_private_keys, rsa_private_keys};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs1KeyDer, PrivatePkcs8KeyDer};
/// Einfacher WebSocket-Server auf Basis von Tokio + tokio-tungstenite.
///
/// Unterstützt:
/// - `setUserId`-Event vom Client (`{"event":"setUserId","data":{"userId":"..."}}`)
/// - Versenden von Broker-Nachrichten mit `user_id`-Feld an passende Verbindungen
/// - Broadcasting von Nachrichten ohne `user_id` an alle
pub struct WebSocketServer {
port: u16,
pool: ConnectionPool,
broker: MessageBroker,
use_ssl: bool,
cert_path: Option<String>,
key_path: Option<String>,
workers: Vec<*const dyn Worker>,
running: Arc<AtomicBool>,
runtime: Option<Runtime>,
}
/// Einfache Registry, um Verbindungsstatistiken für `getConnections` zu liefern.
#[derive(Default)]
struct ConnectionRegistry {
total: usize,
unauthenticated: usize,
by_user: HashMap<String, usize>,
}
fn create_tls_acceptor(
cert_path: Option<&str>,
key_path: Option<&str>,
) -> Result<TlsAcceptor, Box<dyn std::error::Error>> {
let cert_path = cert_path.ok_or("SSL aktiviert, aber kein Zertifikatspfad gesetzt")?;
let key_path = key_path.ok_or("SSL aktiviert, aber kein Key-Pfad gesetzt")?;
let cert_file = File::open(cert_path)?;
let mut cert_reader = BufReader::new(cert_file);
let mut cert_chain: Vec<CertificateDer<'static>> = Vec::new();
for cert_result in certs(&mut cert_reader) {
let cert: CertificateDer<'static> = cert_result?;
cert_chain.push(cert);
}
if cert_chain.is_empty() {
return Err("Zertifikatsdatei enthält keine Zertifikate".into());
}
let key_file = File::open(key_path)?;
let mut key_reader = BufReader::new(key_file);
// Versuche zuerst PKCS8, dann ggf. RSA-Key
let mut keys: Vec<PrivateKeyDer<'static>> = pkcs8_private_keys(&mut key_reader)
.map(|res: Result<PrivatePkcs8KeyDer<'static>, _>| res.map(PrivateKeyDer::Pkcs8))
.collect::<Result<_, _>>()?;
if keys.is_empty() {
// Leser zurücksetzen und RSA-Keys versuchen
let key_file = File::open(key_path)?;
let mut key_reader = BufReader::new(key_file);
keys = rsa_private_keys(&mut key_reader)
.map(|res: Result<PrivatePkcs1KeyDer<'static>, _>| res.map(PrivateKeyDer::Pkcs1))
.collect::<Result<_, _>>()?;
}
if keys.is_empty() {
return Err("Key-Datei enthält keinen privaten Schlüssel (PKCS8 oder RSA)".into());
}
let private_key = keys.remove(0);
let config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert_chain, private_key)?;
Ok(TlsAcceptor::from(Arc::new(config)))
}
impl WebSocketServer {
pub fn new(
port: u16,
pool: ConnectionPool,
broker: MessageBroker,
use_ssl: bool,
cert_path: Option<String>,
key_path: Option<String>,
) -> Self {
Self {
port,
pool,
broker,
use_ssl,
cert_path,
key_path,
workers: Vec::new(),
running: Arc::new(AtomicBool::new(false)),
runtime: None,
}
}
pub fn set_workers(&mut self, workers: &[Box<dyn Worker>]) {
self.workers.clear();
for w in workers {
self.workers.push(&**w as *const dyn Worker);
}
}
pub fn run(&mut self) {
if self.running.swap(true, Ordering::SeqCst) {
eprintln!("[WebSocketServer] Läuft bereits.");
return;
}
if self.use_ssl {
println!(
"Starte WebSocket-Server auf Port {} mit SSL (cert: {:?}, key: {:?})",
self.port, self.cert_path, self.key_path
);
// Hinweis: SSL-Unterstützung ist noch nicht implementiert.
} else {
println!("Starte WebSocket-Server auf Port {} (ohne SSL)", self.port);
}
let addr = format!("0.0.0.0:{}", self.port);
let running_flag = self.running.clone();
let broker = self.broker.clone();
// Gemeinsame Registry für alle Verbindungen
let registry = Arc::new(Mutex::new(ConnectionRegistry::default()));
// Broadcast-Kanal für Broker-Nachrichten
let (tx, _) = broadcast::channel::<String>(1024);
let tx_clone = tx.clone();
// Broker-Subscription: jede gepublishte Nachricht geht in den Broadcast-Kanal
broker.subscribe(move |msg: String| {
let _ = tx_clone.send(msg);
});
// Optionalen TLS-Akzeptor laden, falls SSL aktiviert ist
let tls_acceptor = if self.use_ssl {
match create_tls_acceptor(
self.cert_path.as_deref(),
self.key_path.as_deref(),
) {
Ok(acc) => Some(acc),
Err(err) => {
eprintln!(
"[WebSocketServer] TLS-Initialisierung fehlgeschlagen, starte ohne SSL: {err}"
);
None
}
}
} else {
None
};
let rt = Builder::new_multi_thread()
.enable_all()
.build()
.expect("Tokio Runtime konnte nicht erstellt werden");
rt.spawn(run_accept_loop(
addr,
running_flag,
tx,
self.pool.clone(),
registry,
tls_acceptor,
));
self.runtime = Some(rt);
}
pub fn stop(&mut self) {
if !self.running.swap(false, Ordering::SeqCst) {
return;
}
println!("WebSocket-Server wird gestoppt.");
if let Some(rt) = self.runtime.take() {
rt.shutdown_background();
}
}
}
#[derive(Debug, Deserialize)]
struct IncomingMessage {
#[serde(default)]
event: String,
#[serde(default)]
data: Json,
}
async fn run_accept_loop(
addr: String,
running: Arc<AtomicBool>,
tx: broadcast::Sender<String>,
_pool: ConnectionPool,
registry: Arc<Mutex<ConnectionRegistry>>,
tls_acceptor: Option<TlsAcceptor>,
) {
let listener = match TcpListener::bind(&addr).await {
Ok(l) => l,
Err(e) => {
eprintln!("[WebSocketServer] Fehler beim Binden an {}: {}", addr, e);
running.store(false, Ordering::SeqCst);
return;
}
};
println!("[WebSocketServer] Lauscht auf {}", addr);
while running.load(Ordering::SeqCst) {
let (stream, peer) = match listener.accept().await {
Ok(v) => v,
Err(e) => {
eprintln!("[WebSocketServer] accept() fehlgeschlagen: {}", e);
continue;
}
};
let peer_addr = peer;
let rx = tx.subscribe();
let registry_clone = registry.clone();
let tls_acceptor_clone = tls_acceptor.clone();
tokio::spawn(async move {
if let Some(acc) = tls_acceptor_clone {
match acc.accept(stream).await {
Ok(tls_stream) => {
handle_connection(tls_stream, peer_addr, rx, registry_clone).await
}
Err(err) => {
eprintln!(
"[WebSocketServer] TLS-Handshake fehlgeschlagen ({peer_addr}): {err}"
);
}
}
} else {
handle_connection(stream, peer_addr, rx, registry_clone).await;
}
});
}
}
async fn handle_connection<S>(
stream: S,
peer_addr: SocketAddr,
mut broker_rx: broadcast::Receiver<String>,
registry: Arc<Mutex<ConnectionRegistry>>,
) where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let ws_stream = match accept_async(stream).await {
Ok(ws) => ws,
Err(e) => {
eprintln!("[WebSocketServer] WebSocket-Handshake fehlgeschlagen ({peer_addr}): {e}");
return;
}
};
println!("[WebSocketServer] Neue Verbindung von {}", peer_addr);
let (mut ws_sender, mut ws_receiver) = ws_stream.split();
// Kanal für Antworten direkt an diesen Client (z.B. getConnections)
let (client_tx, mut client_rx) = mpsc::channel::<String>(32);
// Neue Verbindung in der Registry zählen (zunächst als unauthentifiziert)
{
let mut reg = registry.lock().await;
reg.total += 1;
reg.unauthenticated += 1;
}
// user_id der Verbindung (nach setUserId)
let user_id = Arc::new(tokio::sync::Mutex::new(Option::<String>::None));
let user_id_for_incoming = user_id.clone();
let user_id_for_broker = user_id.clone();
let registry_for_incoming = registry.clone();
let client_tx_incoming = client_tx.clone();
// Eingehende Nachrichten vom Client
let incoming = async move {
while let Some(msg) = ws_receiver.next().await {
match msg {
Ok(Message::Text(txt)) => {
if let Ok(parsed) = serde_json::from_str::<IncomingMessage>(&txt) {
match parsed.event.as_str() {
"setUserId" => {
if let Some(uid) =
parsed.data.get("userId").and_then(|v| v.as_str())
{
{
// Registry aktualisieren: von unauthentifiziert -> Nutzer
let mut reg = registry_for_incoming.lock().await;
if reg.unauthenticated > 0 {
reg.unauthenticated -= 1;
}
*reg.by_user.entry(uid.to_string()).or_insert(0) += 1;
}
let mut guard = user_id_for_incoming.lock().await;
*guard = Some(uid.to_string());
println!(
"[WebSocketServer] User-ID gesetzt für {}: {}",
peer_addr, uid
);
}
}
"getConnections" => {
// Einfache Übersicht über aktuelle Verbindungen zurückgeben.
let snapshot = {
let reg = registry_for_incoming.lock().await;
serde_json::json!({
"event": "getConnectionsResponse",
"total": reg.total,
"unauthenticated": reg.unauthenticated,
"users": reg.by_user,
})
.to_string()
};
let _ = client_tx_incoming.send(snapshot).await;
}
_ => {
// Unbekannte Events ignorieren
}
}
}
}
Ok(Message::Ping(_)) => {
// Ping wird aktuell nur geloggt/ignoriert; optional könnte man hier ein eigenes
// Ping/Pong-Handling ergänzen.
}
Ok(Message::Close(_)) => break,
Err(e) => {
eprintln!("[WebSocketServer] Fehler bei Nachricht von {peer_addr}: {e}");
break;
}
_ => {}
}
}
};
// Broker-Nachrichten an den Client
let outgoing = async move {
loop {
tokio::select! {
// Nachrichten aus dem MessageBroker
broker_msg = broker_rx.recv() => {
let msg = match broker_msg {
Ok(m) => m,
Err(_) => break,
};
// Filter nach user_id, falls gesetzt
let target_user = {
let guard = user_id_for_broker.lock().await;
guard.clone()
};
if let Some(uid) = target_user.clone() {
if let Ok(json) = serde_json::from_str::<Json>(&msg) {
let matches_user = json
.get("user_id")
.and_then(|v| {
if let Some(s) = v.as_str() {
Some(s.to_string())
} else if let Some(n) = v.as_i64() {
Some(n.to_string())
} else {
None
}
})
.map(|v| v == uid)
.unwrap_or(false);
if !matches_user {
continue;
}
}
}
if let Err(e) = ws_sender.send(Message::Text(msg)).await {
eprintln!(
"[WebSocketServer] Fehler beim Senden an {}: {}",
peer_addr, e
);
break;
}
}
// Antworten aus der Verbindung selbst (z.B. getConnections)
client_msg = client_rx.recv() => {
match client_msg {
Some(msg) => {
if let Err(e) = ws_sender.send(Message::Text(msg)).await {
eprintln!(
"[WebSocketServer] Fehler beim Senden an {}: {}",
peer_addr, e
);
break;
}
}
None => {
// Kanal wurde geschlossen
break;
}
}
}
}
}
};
futures_util::future::select(incoming.boxed(), outgoing.boxed()).await;
// Verbindung aus der Registry entfernen
let final_uid = {
let guard = user_id.lock().await;
guard.clone()
};
{
let mut reg = registry.lock().await;
if reg.total > 0 {
reg.total -= 1;
}
if let Some(uid) = final_uid {
if let Some(count) = reg.by_user.get_mut(&uid) {
if *count > 0 {
*count -= 1;
}
if *count == 0 {
reg.by_user.remove(&uid);
}
}
} else if reg.unauthenticated > 0 {
reg.unauthenticated -= 1;
}
}
println!("[WebSocketServer] Verbindung geschlossen: {}", peer_addr);
}