464 lines
16 KiB
Rust
464 lines
16 KiB
Rust
use crate::db::ConnectionPool;
|
|
use crate::message_broker::MessageBroker;
|
|
use crate::worker::Worker;
|
|
use futures_util::{FutureExt, SinkExt, StreamExt};
|
|
use serde::Deserialize;
|
|
use serde_json::Value as Json;
|
|
use std::collections::HashMap;
|
|
use std::fs::File;
|
|
use std::io::BufReader;
|
|
use std::net::SocketAddr;
|
|
use std::sync::atomic::{AtomicBool, Ordering};
|
|
use std::sync::Arc;
|
|
use tokio::io::{AsyncRead, AsyncWrite};
|
|
use tokio::net::TcpListener;
|
|
use tokio::runtime::{Builder, Runtime};
|
|
use tokio::sync::{broadcast, mpsc, Mutex};
|
|
use tokio_rustls::rustls::{self, ServerConfig};
|
|
use tokio_rustls::TlsAcceptor;
|
|
use tokio_tungstenite::tungstenite::Message;
|
|
use tokio_tungstenite::accept_async;
|
|
use rustls_pemfile::{certs, pkcs8_private_keys, rsa_private_keys};
|
|
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs1KeyDer, PrivatePkcs8KeyDer};
|
|
|
|
/// Einfacher WebSocket-Server auf Basis von Tokio + tokio-tungstenite.
|
|
///
|
|
/// Unterstützt:
|
|
/// - `setUserId`-Event vom Client (`{"event":"setUserId","data":{"userId":"..."}}`)
|
|
/// - Versenden von Broker-Nachrichten mit `user_id`-Feld an passende Verbindungen
|
|
/// - Broadcasting von Nachrichten ohne `user_id` an alle
|
|
pub struct WebSocketServer {
|
|
port: u16,
|
|
pool: ConnectionPool,
|
|
broker: MessageBroker,
|
|
use_ssl: bool,
|
|
cert_path: Option<String>,
|
|
key_path: Option<String>,
|
|
workers: Vec<*const dyn Worker>,
|
|
running: Arc<AtomicBool>,
|
|
runtime: Option<Runtime>,
|
|
}
|
|
|
|
/// Einfache Registry, um Verbindungsstatistiken für `getConnections` zu liefern.
|
|
#[derive(Default)]
|
|
struct ConnectionRegistry {
|
|
total: usize,
|
|
unauthenticated: usize,
|
|
by_user: HashMap<String, usize>,
|
|
}
|
|
|
|
fn create_tls_acceptor(
|
|
cert_path: Option<&str>,
|
|
key_path: Option<&str>,
|
|
) -> Result<TlsAcceptor, Box<dyn std::error::Error>> {
|
|
let cert_path = cert_path.ok_or("SSL aktiviert, aber kein Zertifikatspfad gesetzt")?;
|
|
let key_path = key_path.ok_or("SSL aktiviert, aber kein Key-Pfad gesetzt")?;
|
|
|
|
let cert_file = File::open(cert_path)?;
|
|
let mut cert_reader = BufReader::new(cert_file);
|
|
|
|
let mut cert_chain: Vec<CertificateDer<'static>> = Vec::new();
|
|
for cert_result in certs(&mut cert_reader) {
|
|
let cert: CertificateDer<'static> = cert_result?;
|
|
cert_chain.push(cert);
|
|
}
|
|
|
|
if cert_chain.is_empty() {
|
|
return Err("Zertifikatsdatei enthält keine Zertifikate".into());
|
|
}
|
|
|
|
let key_file = File::open(key_path)?;
|
|
let mut key_reader = BufReader::new(key_file);
|
|
|
|
// Versuche zuerst PKCS8, dann ggf. RSA-Key
|
|
let mut keys: Vec<PrivateKeyDer<'static>> = pkcs8_private_keys(&mut key_reader)
|
|
.map(|res: Result<PrivatePkcs8KeyDer<'static>, _>| res.map(PrivateKeyDer::Pkcs8))
|
|
.collect::<Result<_, _>>()?;
|
|
|
|
if keys.is_empty() {
|
|
// Leser zurücksetzen und RSA-Keys versuchen
|
|
let key_file = File::open(key_path)?;
|
|
let mut key_reader = BufReader::new(key_file);
|
|
keys = rsa_private_keys(&mut key_reader)
|
|
.map(|res: Result<PrivatePkcs1KeyDer<'static>, _>| res.map(PrivateKeyDer::Pkcs1))
|
|
.collect::<Result<_, _>>()?;
|
|
}
|
|
|
|
if keys.is_empty() {
|
|
return Err("Key-Datei enthält keinen privaten Schlüssel (PKCS8 oder RSA)".into());
|
|
}
|
|
|
|
let private_key = keys.remove(0);
|
|
|
|
let config = ServerConfig::builder()
|
|
.with_no_client_auth()
|
|
.with_single_cert(cert_chain, private_key)?;
|
|
|
|
Ok(TlsAcceptor::from(Arc::new(config)))
|
|
}
|
|
|
|
impl WebSocketServer {
|
|
pub fn new(
|
|
port: u16,
|
|
pool: ConnectionPool,
|
|
broker: MessageBroker,
|
|
use_ssl: bool,
|
|
cert_path: Option<String>,
|
|
key_path: Option<String>,
|
|
) -> Self {
|
|
Self {
|
|
port,
|
|
pool,
|
|
broker,
|
|
use_ssl,
|
|
cert_path,
|
|
key_path,
|
|
workers: Vec::new(),
|
|
running: Arc::new(AtomicBool::new(false)),
|
|
runtime: None,
|
|
}
|
|
}
|
|
|
|
pub fn set_workers(&mut self, workers: &[Box<dyn Worker>]) {
|
|
self.workers.clear();
|
|
for w in workers {
|
|
self.workers.push(&**w as *const dyn Worker);
|
|
}
|
|
}
|
|
|
|
pub fn run(&mut self) {
|
|
if self.running.swap(true, Ordering::SeqCst) {
|
|
eprintln!("[WebSocketServer] Läuft bereits.");
|
|
return;
|
|
}
|
|
|
|
if self.use_ssl {
|
|
println!(
|
|
"Starte WebSocket-Server auf Port {} mit SSL (cert: {:?}, key: {:?})",
|
|
self.port, self.cert_path, self.key_path
|
|
);
|
|
// Hinweis: SSL-Unterstützung ist noch nicht implementiert.
|
|
} else {
|
|
println!("Starte WebSocket-Server auf Port {} (ohne SSL)", self.port);
|
|
}
|
|
|
|
let addr = format!("0.0.0.0:{}", self.port);
|
|
let running_flag = self.running.clone();
|
|
let broker = self.broker.clone();
|
|
|
|
// Gemeinsame Registry für alle Verbindungen
|
|
let registry = Arc::new(Mutex::new(ConnectionRegistry::default()));
|
|
|
|
// Broadcast-Kanal für Broker-Nachrichten
|
|
let (tx, _) = broadcast::channel::<String>(1024);
|
|
let tx_clone = tx.clone();
|
|
|
|
// Broker-Subscription: jede gepublishte Nachricht geht in den Broadcast-Kanal
|
|
broker.subscribe(move |msg: String| {
|
|
let _ = tx_clone.send(msg);
|
|
});
|
|
|
|
// Optionalen TLS-Akzeptor laden, falls SSL aktiviert ist
|
|
let tls_acceptor = if self.use_ssl {
|
|
match create_tls_acceptor(
|
|
self.cert_path.as_deref(),
|
|
self.key_path.as_deref(),
|
|
) {
|
|
Ok(acc) => Some(acc),
|
|
Err(err) => {
|
|
eprintln!(
|
|
"[WebSocketServer] TLS-Initialisierung fehlgeschlagen, starte ohne SSL: {err}"
|
|
);
|
|
None
|
|
}
|
|
}
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let rt = Builder::new_multi_thread()
|
|
.enable_all()
|
|
.build()
|
|
.expect("Tokio Runtime konnte nicht erstellt werden");
|
|
|
|
rt.spawn(run_accept_loop(
|
|
addr,
|
|
running_flag,
|
|
tx,
|
|
self.pool.clone(),
|
|
registry,
|
|
tls_acceptor,
|
|
));
|
|
|
|
self.runtime = Some(rt);
|
|
}
|
|
|
|
pub fn stop(&mut self) {
|
|
if !self.running.swap(false, Ordering::SeqCst) {
|
|
return;
|
|
}
|
|
println!("WebSocket-Server wird gestoppt.");
|
|
if let Some(rt) = self.runtime.take() {
|
|
rt.shutdown_background();
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct IncomingMessage {
|
|
#[serde(default)]
|
|
event: String,
|
|
#[serde(default)]
|
|
data: Json,
|
|
}
|
|
|
|
async fn run_accept_loop(
|
|
addr: String,
|
|
running: Arc<AtomicBool>,
|
|
tx: broadcast::Sender<String>,
|
|
_pool: ConnectionPool,
|
|
registry: Arc<Mutex<ConnectionRegistry>>,
|
|
tls_acceptor: Option<TlsAcceptor>,
|
|
) {
|
|
let listener = match TcpListener::bind(&addr).await {
|
|
Ok(l) => l,
|
|
Err(e) => {
|
|
eprintln!("[WebSocketServer] Fehler beim Binden an {}: {}", addr, e);
|
|
running.store(false, Ordering::SeqCst);
|
|
return;
|
|
}
|
|
};
|
|
|
|
println!("[WebSocketServer] Lauscht auf {}", addr);
|
|
|
|
while running.load(Ordering::SeqCst) {
|
|
let (stream, peer) = match listener.accept().await {
|
|
Ok(v) => v,
|
|
Err(e) => {
|
|
eprintln!("[WebSocketServer] accept() fehlgeschlagen: {}", e);
|
|
continue;
|
|
}
|
|
};
|
|
|
|
let peer_addr = peer;
|
|
let rx = tx.subscribe();
|
|
let registry_clone = registry.clone();
|
|
let tls_acceptor_clone = tls_acceptor.clone();
|
|
|
|
tokio::spawn(async move {
|
|
if let Some(acc) = tls_acceptor_clone {
|
|
match acc.accept(stream).await {
|
|
Ok(tls_stream) => {
|
|
handle_connection(tls_stream, peer_addr, rx, registry_clone).await
|
|
}
|
|
Err(err) => {
|
|
eprintln!(
|
|
"[WebSocketServer] TLS-Handshake fehlgeschlagen ({peer_addr}): {err}"
|
|
);
|
|
}
|
|
}
|
|
} else {
|
|
handle_connection(stream, peer_addr, rx, registry_clone).await;
|
|
}
|
|
});
|
|
}
|
|
}
|
|
|
|
async fn handle_connection<S>(
|
|
stream: S,
|
|
peer_addr: SocketAddr,
|
|
mut broker_rx: broadcast::Receiver<String>,
|
|
registry: Arc<Mutex<ConnectionRegistry>>,
|
|
) where
|
|
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
|
{
|
|
let ws_stream = match accept_async(stream).await {
|
|
Ok(ws) => ws,
|
|
Err(e) => {
|
|
eprintln!("[WebSocketServer] WebSocket-Handshake fehlgeschlagen ({peer_addr}): {e}");
|
|
return;
|
|
}
|
|
};
|
|
|
|
println!("[WebSocketServer] Neue Verbindung von {}", peer_addr);
|
|
|
|
let (mut ws_sender, mut ws_receiver) = ws_stream.split();
|
|
|
|
// Kanal für Antworten direkt an diesen Client (z.B. getConnections)
|
|
let (client_tx, mut client_rx) = mpsc::channel::<String>(32);
|
|
|
|
// Neue Verbindung in der Registry zählen (zunächst als unauthentifiziert)
|
|
{
|
|
let mut reg = registry.lock().await;
|
|
reg.total += 1;
|
|
reg.unauthenticated += 1;
|
|
}
|
|
|
|
// user_id der Verbindung (nach setUserId)
|
|
let user_id = Arc::new(tokio::sync::Mutex::new(Option::<String>::None));
|
|
let user_id_for_incoming = user_id.clone();
|
|
let user_id_for_broker = user_id.clone();
|
|
let registry_for_incoming = registry.clone();
|
|
let client_tx_incoming = client_tx.clone();
|
|
|
|
// Eingehende Nachrichten vom Client
|
|
let incoming = async move {
|
|
while let Some(msg) = ws_receiver.next().await {
|
|
match msg {
|
|
Ok(Message::Text(txt)) => {
|
|
if let Ok(parsed) = serde_json::from_str::<IncomingMessage>(&txt) {
|
|
match parsed.event.as_str() {
|
|
"setUserId" => {
|
|
if let Some(uid) =
|
|
parsed.data.get("userId").and_then(|v| v.as_str())
|
|
{
|
|
{
|
|
// Registry aktualisieren: von unauthentifiziert -> Nutzer
|
|
let mut reg = registry_for_incoming.lock().await;
|
|
if reg.unauthenticated > 0 {
|
|
reg.unauthenticated -= 1;
|
|
}
|
|
*reg.by_user.entry(uid.to_string()).or_insert(0) += 1;
|
|
}
|
|
|
|
let mut guard = user_id_for_incoming.lock().await;
|
|
*guard = Some(uid.to_string());
|
|
println!(
|
|
"[WebSocketServer] User-ID gesetzt für {}: {}",
|
|
peer_addr, uid
|
|
);
|
|
}
|
|
}
|
|
"getConnections" => {
|
|
// Einfache Übersicht über aktuelle Verbindungen zurückgeben.
|
|
let snapshot = {
|
|
let reg = registry_for_incoming.lock().await;
|
|
serde_json::json!({
|
|
"event": "getConnectionsResponse",
|
|
"total": reg.total,
|
|
"unauthenticated": reg.unauthenticated,
|
|
"users": reg.by_user,
|
|
})
|
|
.to_string()
|
|
};
|
|
let _ = client_tx_incoming.send(snapshot).await;
|
|
}
|
|
_ => {
|
|
// Unbekannte Events ignorieren
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Ok(Message::Ping(_)) => {
|
|
// Ping wird aktuell nur geloggt/ignoriert; optional könnte man hier ein eigenes
|
|
// Ping/Pong-Handling ergänzen.
|
|
}
|
|
Ok(Message::Close(_)) => break,
|
|
Err(e) => {
|
|
eprintln!("[WebSocketServer] Fehler bei Nachricht von {peer_addr}: {e}");
|
|
break;
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
};
|
|
|
|
// Broker-Nachrichten an den Client
|
|
let outgoing = async move {
|
|
loop {
|
|
tokio::select! {
|
|
// Nachrichten aus dem MessageBroker
|
|
broker_msg = broker_rx.recv() => {
|
|
let msg = match broker_msg {
|
|
Ok(m) => m,
|
|
Err(_) => break,
|
|
};
|
|
|
|
// Filter nach user_id, falls gesetzt
|
|
let target_user = {
|
|
let guard = user_id_for_broker.lock().await;
|
|
guard.clone()
|
|
};
|
|
|
|
if let Some(uid) = target_user.clone() {
|
|
if let Ok(json) = serde_json::from_str::<Json>(&msg) {
|
|
let matches_user = json
|
|
.get("user_id")
|
|
.and_then(|v| {
|
|
if let Some(s) = v.as_str() {
|
|
Some(s.to_string())
|
|
} else if let Some(n) = v.as_i64() {
|
|
Some(n.to_string())
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
.map(|v| v == uid)
|
|
.unwrap_or(false);
|
|
|
|
if !matches_user {
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
|
|
if let Err(e) = ws_sender.send(Message::Text(msg)).await {
|
|
eprintln!(
|
|
"[WebSocketServer] Fehler beim Senden an {}: {}",
|
|
peer_addr, e
|
|
);
|
|
break;
|
|
}
|
|
}
|
|
// Antworten aus der Verbindung selbst (z.B. getConnections)
|
|
client_msg = client_rx.recv() => {
|
|
match client_msg {
|
|
Some(msg) => {
|
|
if let Err(e) = ws_sender.send(Message::Text(msg)).await {
|
|
eprintln!(
|
|
"[WebSocketServer] Fehler beim Senden an {}: {}",
|
|
peer_addr, e
|
|
);
|
|
break;
|
|
}
|
|
}
|
|
None => {
|
|
// Kanal wurde geschlossen
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
futures_util::future::select(incoming.boxed(), outgoing.boxed()).await;
|
|
|
|
// Verbindung aus der Registry entfernen
|
|
let final_uid = {
|
|
let guard = user_id.lock().await;
|
|
guard.clone()
|
|
};
|
|
{
|
|
let mut reg = registry.lock().await;
|
|
if reg.total > 0 {
|
|
reg.total -= 1;
|
|
}
|
|
if let Some(uid) = final_uid {
|
|
if let Some(count) = reg.by_user.get_mut(&uid) {
|
|
if *count > 0 {
|
|
*count -= 1;
|
|
}
|
|
if *count == 0 {
|
|
reg.by_user.remove(&uid);
|
|
}
|
|
}
|
|
} else if reg.unauthenticated > 0 {
|
|
reg.unauthenticated -= 1;
|
|
}
|
|
}
|
|
|
|
println!("[WebSocketServer] Verbindung geschlossen: {}", peer_addr);
|
|
}
|
|
|