Initial commit: Rust YpDaemon
This commit is contained in:
463
src/websocket_server.rs
Normal file
463
src/websocket_server.rs
Normal file
@@ -0,0 +1,463 @@
|
||||
use crate::db::ConnectionPool;
|
||||
use crate::message_broker::MessageBroker;
|
||||
use crate::worker::Worker;
|
||||
use futures_util::{FutureExt, SinkExt, StreamExt};
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value as Json;
|
||||
use std::collections::HashMap;
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::runtime::{Builder, Runtime};
|
||||
use tokio::sync::{broadcast, mpsc, Mutex};
|
||||
use tokio_rustls::rustls::{self, ServerConfig};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tokio_tungstenite::accept_async;
|
||||
use rustls_pemfile::{certs, pkcs8_private_keys, rsa_private_keys};
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs1KeyDer, PrivatePkcs8KeyDer};
|
||||
|
||||
/// Einfacher WebSocket-Server auf Basis von Tokio + tokio-tungstenite.
|
||||
///
|
||||
/// Unterstützt:
|
||||
/// - `setUserId`-Event vom Client (`{"event":"setUserId","data":{"userId":"..."}}`)
|
||||
/// - Versenden von Broker-Nachrichten mit `user_id`-Feld an passende Verbindungen
|
||||
/// - Broadcasting von Nachrichten ohne `user_id` an alle
|
||||
pub struct WebSocketServer {
|
||||
port: u16,
|
||||
pool: ConnectionPool,
|
||||
broker: MessageBroker,
|
||||
use_ssl: bool,
|
||||
cert_path: Option<String>,
|
||||
key_path: Option<String>,
|
||||
workers: Vec<*const dyn Worker>,
|
||||
running: Arc<AtomicBool>,
|
||||
runtime: Option<Runtime>,
|
||||
}
|
||||
|
||||
/// Einfache Registry, um Verbindungsstatistiken für `getConnections` zu liefern.
|
||||
#[derive(Default)]
|
||||
struct ConnectionRegistry {
|
||||
total: usize,
|
||||
unauthenticated: usize,
|
||||
by_user: HashMap<String, usize>,
|
||||
}
|
||||
|
||||
fn create_tls_acceptor(
|
||||
cert_path: Option<&str>,
|
||||
key_path: Option<&str>,
|
||||
) -> Result<TlsAcceptor, Box<dyn std::error::Error>> {
|
||||
let cert_path = cert_path.ok_or("SSL aktiviert, aber kein Zertifikatspfad gesetzt")?;
|
||||
let key_path = key_path.ok_or("SSL aktiviert, aber kein Key-Pfad gesetzt")?;
|
||||
|
||||
let cert_file = File::open(cert_path)?;
|
||||
let mut cert_reader = BufReader::new(cert_file);
|
||||
|
||||
let mut cert_chain: Vec<CertificateDer<'static>> = Vec::new();
|
||||
for cert_result in certs(&mut cert_reader) {
|
||||
let cert: CertificateDer<'static> = cert_result?;
|
||||
cert_chain.push(cert);
|
||||
}
|
||||
|
||||
if cert_chain.is_empty() {
|
||||
return Err("Zertifikatsdatei enthält keine Zertifikate".into());
|
||||
}
|
||||
|
||||
let key_file = File::open(key_path)?;
|
||||
let mut key_reader = BufReader::new(key_file);
|
||||
|
||||
// Versuche zuerst PKCS8, dann ggf. RSA-Key
|
||||
let mut keys: Vec<PrivateKeyDer<'static>> = pkcs8_private_keys(&mut key_reader)
|
||||
.map(|res: Result<PrivatePkcs8KeyDer<'static>, _>| res.map(PrivateKeyDer::Pkcs8))
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
if keys.is_empty() {
|
||||
// Leser zurücksetzen und RSA-Keys versuchen
|
||||
let key_file = File::open(key_path)?;
|
||||
let mut key_reader = BufReader::new(key_file);
|
||||
keys = rsa_private_keys(&mut key_reader)
|
||||
.map(|res: Result<PrivatePkcs1KeyDer<'static>, _>| res.map(PrivateKeyDer::Pkcs1))
|
||||
.collect::<Result<_, _>>()?;
|
||||
}
|
||||
|
||||
if keys.is_empty() {
|
||||
return Err("Key-Datei enthält keinen privaten Schlüssel (PKCS8 oder RSA)".into());
|
||||
}
|
||||
|
||||
let private_key = keys.remove(0);
|
||||
|
||||
let config = ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(cert_chain, private_key)?;
|
||||
|
||||
Ok(TlsAcceptor::from(Arc::new(config)))
|
||||
}
|
||||
|
||||
impl WebSocketServer {
|
||||
pub fn new(
|
||||
port: u16,
|
||||
pool: ConnectionPool,
|
||||
broker: MessageBroker,
|
||||
use_ssl: bool,
|
||||
cert_path: Option<String>,
|
||||
key_path: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
port,
|
||||
pool,
|
||||
broker,
|
||||
use_ssl,
|
||||
cert_path,
|
||||
key_path,
|
||||
workers: Vec::new(),
|
||||
running: Arc::new(AtomicBool::new(false)),
|
||||
runtime: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_workers(&mut self, workers: &[Box<dyn Worker>]) {
|
||||
self.workers.clear();
|
||||
for w in workers {
|
||||
self.workers.push(&**w as *const dyn Worker);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(&mut self) {
|
||||
if self.running.swap(true, Ordering::SeqCst) {
|
||||
eprintln!("[WebSocketServer] Läuft bereits.");
|
||||
return;
|
||||
}
|
||||
|
||||
if self.use_ssl {
|
||||
println!(
|
||||
"Starte WebSocket-Server auf Port {} mit SSL (cert: {:?}, key: {:?})",
|
||||
self.port, self.cert_path, self.key_path
|
||||
);
|
||||
// Hinweis: SSL-Unterstützung ist noch nicht implementiert.
|
||||
} else {
|
||||
println!("Starte WebSocket-Server auf Port {} (ohne SSL)", self.port);
|
||||
}
|
||||
|
||||
let addr = format!("0.0.0.0:{}", self.port);
|
||||
let running_flag = self.running.clone();
|
||||
let broker = self.broker.clone();
|
||||
|
||||
// Gemeinsame Registry für alle Verbindungen
|
||||
let registry = Arc::new(Mutex::new(ConnectionRegistry::default()));
|
||||
|
||||
// Broadcast-Kanal für Broker-Nachrichten
|
||||
let (tx, _) = broadcast::channel::<String>(1024);
|
||||
let tx_clone = tx.clone();
|
||||
|
||||
// Broker-Subscription: jede gepublishte Nachricht geht in den Broadcast-Kanal
|
||||
broker.subscribe(move |msg: String| {
|
||||
let _ = tx_clone.send(msg);
|
||||
});
|
||||
|
||||
// Optionalen TLS-Akzeptor laden, falls SSL aktiviert ist
|
||||
let tls_acceptor = if self.use_ssl {
|
||||
match create_tls_acceptor(
|
||||
self.cert_path.as_deref(),
|
||||
self.key_path.as_deref(),
|
||||
) {
|
||||
Ok(acc) => Some(acc),
|
||||
Err(err) => {
|
||||
eprintln!(
|
||||
"[WebSocketServer] TLS-Initialisierung fehlgeschlagen, starte ohne SSL: {err}"
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let rt = Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.expect("Tokio Runtime konnte nicht erstellt werden");
|
||||
|
||||
rt.spawn(run_accept_loop(
|
||||
addr,
|
||||
running_flag,
|
||||
tx,
|
||||
self.pool.clone(),
|
||||
registry,
|
||||
tls_acceptor,
|
||||
));
|
||||
|
||||
self.runtime = Some(rt);
|
||||
}
|
||||
|
||||
pub fn stop(&mut self) {
|
||||
if !self.running.swap(false, Ordering::SeqCst) {
|
||||
return;
|
||||
}
|
||||
println!("WebSocket-Server wird gestoppt.");
|
||||
if let Some(rt) = self.runtime.take() {
|
||||
rt.shutdown_background();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct IncomingMessage {
|
||||
#[serde(default)]
|
||||
event: String,
|
||||
#[serde(default)]
|
||||
data: Json,
|
||||
}
|
||||
|
||||
async fn run_accept_loop(
|
||||
addr: String,
|
||||
running: Arc<AtomicBool>,
|
||||
tx: broadcast::Sender<String>,
|
||||
_pool: ConnectionPool,
|
||||
registry: Arc<Mutex<ConnectionRegistry>>,
|
||||
tls_acceptor: Option<TlsAcceptor>,
|
||||
) {
|
||||
let listener = match TcpListener::bind(&addr).await {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
eprintln!("[WebSocketServer] Fehler beim Binden an {}: {}", addr, e);
|
||||
running.store(false, Ordering::SeqCst);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
println!("[WebSocketServer] Lauscht auf {}", addr);
|
||||
|
||||
while running.load(Ordering::SeqCst) {
|
||||
let (stream, peer) = match listener.accept().await {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
eprintln!("[WebSocketServer] accept() fehlgeschlagen: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let peer_addr = peer;
|
||||
let rx = tx.subscribe();
|
||||
let registry_clone = registry.clone();
|
||||
let tls_acceptor_clone = tls_acceptor.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Some(acc) = tls_acceptor_clone {
|
||||
match acc.accept(stream).await {
|
||||
Ok(tls_stream) => {
|
||||
handle_connection(tls_stream, peer_addr, rx, registry_clone).await
|
||||
}
|
||||
Err(err) => {
|
||||
eprintln!(
|
||||
"[WebSocketServer] TLS-Handshake fehlgeschlagen ({peer_addr}): {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
handle_connection(stream, peer_addr, rx, registry_clone).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_connection<S>(
|
||||
stream: S,
|
||||
peer_addr: SocketAddr,
|
||||
mut broker_rx: broadcast::Receiver<String>,
|
||||
registry: Arc<Mutex<ConnectionRegistry>>,
|
||||
) where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let ws_stream = match accept_async(stream).await {
|
||||
Ok(ws) => ws,
|
||||
Err(e) => {
|
||||
eprintln!("[WebSocketServer] WebSocket-Handshake fehlgeschlagen ({peer_addr}): {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
println!("[WebSocketServer] Neue Verbindung von {}", peer_addr);
|
||||
|
||||
let (mut ws_sender, mut ws_receiver) = ws_stream.split();
|
||||
|
||||
// Kanal für Antworten direkt an diesen Client (z.B. getConnections)
|
||||
let (client_tx, mut client_rx) = mpsc::channel::<String>(32);
|
||||
|
||||
// Neue Verbindung in der Registry zählen (zunächst als unauthentifiziert)
|
||||
{
|
||||
let mut reg = registry.lock().await;
|
||||
reg.total += 1;
|
||||
reg.unauthenticated += 1;
|
||||
}
|
||||
|
||||
// user_id der Verbindung (nach setUserId)
|
||||
let user_id = Arc::new(tokio::sync::Mutex::new(Option::<String>::None));
|
||||
let user_id_for_incoming = user_id.clone();
|
||||
let user_id_for_broker = user_id.clone();
|
||||
let registry_for_incoming = registry.clone();
|
||||
let client_tx_incoming = client_tx.clone();
|
||||
|
||||
// Eingehende Nachrichten vom Client
|
||||
let incoming = async move {
|
||||
while let Some(msg) = ws_receiver.next().await {
|
||||
match msg {
|
||||
Ok(Message::Text(txt)) => {
|
||||
if let Ok(parsed) = serde_json::from_str::<IncomingMessage>(&txt) {
|
||||
match parsed.event.as_str() {
|
||||
"setUserId" => {
|
||||
if let Some(uid) =
|
||||
parsed.data.get("userId").and_then(|v| v.as_str())
|
||||
{
|
||||
{
|
||||
// Registry aktualisieren: von unauthentifiziert -> Nutzer
|
||||
let mut reg = registry_for_incoming.lock().await;
|
||||
if reg.unauthenticated > 0 {
|
||||
reg.unauthenticated -= 1;
|
||||
}
|
||||
*reg.by_user.entry(uid.to_string()).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
let mut guard = user_id_for_incoming.lock().await;
|
||||
*guard = Some(uid.to_string());
|
||||
println!(
|
||||
"[WebSocketServer] User-ID gesetzt für {}: {}",
|
||||
peer_addr, uid
|
||||
);
|
||||
}
|
||||
}
|
||||
"getConnections" => {
|
||||
// Einfache Übersicht über aktuelle Verbindungen zurückgeben.
|
||||
let snapshot = {
|
||||
let reg = registry_for_incoming.lock().await;
|
||||
serde_json::json!({
|
||||
"event": "getConnectionsResponse",
|
||||
"total": reg.total,
|
||||
"unauthenticated": reg.unauthenticated,
|
||||
"users": reg.by_user,
|
||||
})
|
||||
.to_string()
|
||||
};
|
||||
let _ = client_tx_incoming.send(snapshot).await;
|
||||
}
|
||||
_ => {
|
||||
// Unbekannte Events ignorieren
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Message::Ping(_)) => {
|
||||
// Ping wird aktuell nur geloggt/ignoriert; optional könnte man hier ein eigenes
|
||||
// Ping/Pong-Handling ergänzen.
|
||||
}
|
||||
Ok(Message::Close(_)) => break,
|
||||
Err(e) => {
|
||||
eprintln!("[WebSocketServer] Fehler bei Nachricht von {peer_addr}: {e}");
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Broker-Nachrichten an den Client
|
||||
let outgoing = async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Nachrichten aus dem MessageBroker
|
||||
broker_msg = broker_rx.recv() => {
|
||||
let msg = match broker_msg {
|
||||
Ok(m) => m,
|
||||
Err(_) => break,
|
||||
};
|
||||
|
||||
// Filter nach user_id, falls gesetzt
|
||||
let target_user = {
|
||||
let guard = user_id_for_broker.lock().await;
|
||||
guard.clone()
|
||||
};
|
||||
|
||||
if let Some(uid) = target_user.clone() {
|
||||
if let Ok(json) = serde_json::from_str::<Json>(&msg) {
|
||||
let matches_user = json
|
||||
.get("user_id")
|
||||
.and_then(|v| {
|
||||
if let Some(s) = v.as_str() {
|
||||
Some(s.to_string())
|
||||
} else if let Some(n) = v.as_i64() {
|
||||
Some(n.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.map(|v| v == uid)
|
||||
.unwrap_or(false);
|
||||
|
||||
if !matches_user {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(e) = ws_sender.send(Message::Text(msg)).await {
|
||||
eprintln!(
|
||||
"[WebSocketServer] Fehler beim Senden an {}: {}",
|
||||
peer_addr, e
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Antworten aus der Verbindung selbst (z.B. getConnections)
|
||||
client_msg = client_rx.recv() => {
|
||||
match client_msg {
|
||||
Some(msg) => {
|
||||
if let Err(e) = ws_sender.send(Message::Text(msg)).await {
|
||||
eprintln!(
|
||||
"[WebSocketServer] Fehler beim Senden an {}: {}",
|
||||
peer_addr, e
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// Kanal wurde geschlossen
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
futures_util::future::select(incoming.boxed(), outgoing.boxed()).await;
|
||||
|
||||
// Verbindung aus der Registry entfernen
|
||||
let final_uid = {
|
||||
let guard = user_id.lock().await;
|
||||
guard.clone()
|
||||
};
|
||||
{
|
||||
let mut reg = registry.lock().await;
|
||||
if reg.total > 0 {
|
||||
reg.total -= 1;
|
||||
}
|
||||
if let Some(uid) = final_uid {
|
||||
if let Some(count) = reg.by_user.get_mut(&uid) {
|
||||
if *count > 0 {
|
||||
*count -= 1;
|
||||
}
|
||||
if *count == 0 {
|
||||
reg.by_user.remove(&uid);
|
||||
}
|
||||
}
|
||||
} else if reg.unauthenticated > 0 {
|
||||
reg.unauthenticated -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
println!("[WebSocketServer] Verbindung geschlossen: {}", peer_addr);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user