use futures_util::{SinkExt, StreamExt}; use std::collections::HashSet; use std::env; use std::fs::File; use std::io::BufReader as StdBufReader; use std::path::Path; use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; use tokio::net::{TcpListener, UnixListener}; use tokio::sync::{mpsc, watch, RwLock}; use tokio::time::{Duration, interval}; use tokio_rustls::TlsAcceptor; use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; use tokio_rustls::rustls::ServerConfig as RustlsServerConfig; use tokio_tungstenite::{accept_async, tungstenite::Message}; mod commands; mod db; mod state; mod types; use types::{ChatState, ClientConn, ServerConfig}; #[tokio::main] async fn main() -> Result<(), Box> { let ws_addr = env::var("CHAT_WS_ADDR").unwrap_or_else(|_| "0.0.0.0:1235".to_string()); let ws_tls = env_bool("CHAT_WS_TLS"); let tcp_addr = env::var("CHAT_TCP_ADDR").unwrap_or_else(|_| "127.0.0.1:1236".to_string()); let unix_socket = env::var("CHAT_UNIX_SOCKET").ok().filter(|s| !s.trim().is_empty()); let tls_acceptor = if ws_tls { Some(Arc::new(load_tls_acceptor_from_env()?)) } else { None }; let state = Arc::new(RwLock::new(ChatState::default())); let db_client = db::connect_db_from_env().await?; let config = Arc::new(ServerConfig { allowed_users: db::parse_allowed_users(), db_client, }); if handle_cli_commands().as_deref() == Some("--list-rooms") { print_rooms_for_cli(&config).await?; return Ok(()); } let rooms = db::load_room_configs(&config).await.unwrap_or_else(|_| { vec![types::RoomMeta { name: "lobby".to_string(), is_public: true, ..types::RoomMeta::default() }] }); { let mut guard = state.write().await; for room in rooms { guard.rooms.entry(room.name.clone()).or_default(); guard.room_meta.insert(room.name.clone(), room); } } let next_client_id = Arc::new(AtomicU64::new(1)); let (shutdown_tx, shutdown_rx) = watch::channel(false); let cleanup_state = Arc::clone(&state); let mut cleanup_shutdown_rx = shutdown_rx.clone(); let cleanup_task = tokio::spawn(async move { let mut ticker = interval(Duration::from_secs(60)); loop { tokio::select! { changed = cleanup_shutdown_rx.changed() => { if changed.is_ok() && *cleanup_shutdown_rx.borrow() { break; } } _ = ticker.tick() => { let removed = state::cleanup_stale_temporary_rooms(Arc::clone(&cleanup_state), 15 * 60).await; if removed > 0 { println!("[yourchat2] removed {removed} stale temporary room(s)"); } } } } }); let ws_listener = TcpListener::bind(&ws_addr).await?; if ws_tls { println!("[yourchat2] listening on wss://{}", ws_addr); } else { println!("[yourchat2] listening on ws://{}", ws_addr); } let ws_state = Arc::clone(&state); let ws_config = Arc::clone(&config); let ws_next = Arc::clone(&next_client_id); let ws_tls_acceptor = tls_acceptor.clone(); let mut ws_shutdown_rx = shutdown_rx.clone(); let ws_task = tokio::spawn(async move { loop { tokio::select! { changed = ws_shutdown_rx.changed() => { if changed.is_ok() && *ws_shutdown_rx.borrow() { break; } } accepted = ws_listener.accept() => { match accepted { Ok((socket, addr)) => { if ws_tls_acceptor.is_some() { println!("[yourchat2] wss client connected: {}", addr); } else { println!("[yourchat2] ws client connected: {}", addr); } let state = Arc::clone(&ws_state); let config = Arc::clone(&ws_config); let next = Arc::clone(&ws_next); let tls_acceptor = ws_tls_acceptor.clone(); let shutdown = ws_shutdown_rx.clone(); tokio::spawn(async move { if let Some(acceptor) = tls_acceptor { match acceptor.accept(socket).await { Ok(tls_stream) => { if let Err(err) = handle_ws_stream(tls_stream, state, config, next, shutdown).await { eprintln!("[yourchat2] wss client error: {err}"); } } Err(err) => { eprintln!("[yourchat2] tls handshake error: {err}"); } } } else if let Err(err) = handle_ws_stream(socket, state, config, next, shutdown).await { eprintln!("[yourchat2] ws client error: {err}"); } }); } Err(err) => eprintln!("[yourchat2] ws accept error: {err}"), } } } } }); let tcp_listener = TcpListener::bind(&tcp_addr).await?; println!("[yourchat2] listening on tcp://{}", tcp_addr); let tcp_state = Arc::clone(&state); let tcp_config = Arc::clone(&config); let tcp_next = Arc::clone(&next_client_id); let mut tcp_shutdown_rx = shutdown_rx.clone(); let tcp_task = tokio::spawn(async move { loop { tokio::select! { changed = tcp_shutdown_rx.changed() => { if changed.is_ok() && *tcp_shutdown_rx.borrow() { break; } } accepted = tcp_listener.accept() => { match accepted { Ok((socket, addr)) => { println!("[yourchat2] tcp client connected: {}", addr); let state = Arc::clone(&tcp_state); let config = Arc::clone(&tcp_config); let next = Arc::clone(&tcp_next); let shutdown = tcp_shutdown_rx.clone(); tokio::spawn(async move { if let Err(err) = handle_client(socket, state, config, next, shutdown).await { eprintln!("[yourchat2] client error: {err}"); } }); } Err(err) => eprintln!("[yourchat2] accept error: {err}"), } } } } }); let unix_task = if let Some(socket_path) = unix_socket.clone() { let path = Path::new(&socket_path); if let Some(parent) = path.parent() { tokio::fs::create_dir_all(parent).await?; } if path.exists() { tokio::fs::remove_file(path).await?; } let listener = UnixListener::bind(path)?; println!("[yourchat2] listening on unix://{}", socket_path); let unix_state = Arc::clone(&state); let unix_config = Arc::clone(&config); let unix_next = Arc::clone(&next_client_id); let mut unix_shutdown_rx = shutdown_rx.clone(); Some(tokio::spawn(async move { loop { tokio::select! { changed = unix_shutdown_rx.changed() => { if changed.is_ok() && *unix_shutdown_rx.borrow() { break; } } accepted = listener.accept() => { match accepted { Ok((socket, _addr)) => { let state = Arc::clone(&unix_state); let config = Arc::clone(&unix_config); let next = Arc::clone(&unix_next); let shutdown = unix_shutdown_rx.clone(); tokio::spawn(async move { if let Err(err) = handle_client(socket, state, config, next, shutdown).await { eprintln!("[yourchat2] unix client error: {err}"); } }); } Err(err) => eprintln!("[yourchat2] unix accept error: {err}"), } } } } })) } else { None }; tokio::signal::ctrl_c().await?; println!("[yourchat2] shutdown requested"); let _ = shutdown_tx.send(true); let _ = ws_task.await; let _ = tcp_task.await; let _ = cleanup_task.await; if let Some(task) = unix_task { let _ = task.await; if let Some(path) = unix_socket { let _ = tokio::fs::remove_file(path).await; } } println!("[yourchat2] stopped"); Ok(()) } fn handle_cli_commands() -> Option { env::args().nth(1) } async fn print_rooms_for_cli( config: &ServerConfig, ) -> Result<(), Box> { let rooms = db::load_room_configs(config) .await .map_err(std::io::Error::other)?; println!( "yourchat2 rooms source: {}", if config.db_client.is_some() { "database" } else { "fallback" } ); println!( "{:<24} {:<8} {:<8} {:<8} {:<8} {:<10} {:<8}", "name", "public", "gender", "min_age", "max_age", "password", "right_id" ); println!("{}", "-".repeat(92)); for room in rooms { println!( "{:<24} {:<8} {:<8} {:<8} {:<8} {:<10} {:<8}", room.name, if room.is_public { "yes" } else { "no" }, room.gender_restriction_id .filter(|v| *v > 0) .map(|v| v.to_string()) .unwrap_or_else(|| "-".to_string()), room.min_age .map(|v| v.to_string()) .unwrap_or_else(|| "-".to_string()), room.max_age .map(|v| v.to_string()) .unwrap_or_else(|| "-".to_string()), if room.password.as_deref().unwrap_or("").is_empty() { "none" } else { "set" }, room.required_user_right_id .filter(|v| *v > 0) .map(|v| v.to_string()) .unwrap_or_else(|| "-".to_string()), ); } Ok(()) } async fn handle_client( stream: S, state: Arc>, config: Arc, next_client_id: Arc, mut shutdown_rx: watch::Receiver, ) -> Result<(), Box> where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let client_id = next_client_id.fetch_add(1, Ordering::Relaxed); let default_name = format!("Guest-{client_id}"); let (read_half, mut write_half) = tokio::io::split(stream); let (tx, mut rx) = mpsc::unbounded_channel::(); { let mut guard = state.write().await; guard.clients.insert( client_id, ClientConn { user_name: default_name.clone(), room: String::new(), color: None, token: None, falukant_user_id: None, chat_user_id: None, gender_id: None, age: None, rights: HashSet::new(), right_type_ids: HashSet::new(), logged_in: false, tx: tx.clone(), }, ); } let writer_task = tokio::spawn(async move { while let Some(msg) = rx.recv().await { if write_half.write_all(msg.as_bytes()).await.is_err() { break; } if write_half.write_all(b"\n").await.is_err() { break; } } }); let mut lines = BufReader::new(read_half).lines(); loop { tokio::select! { changed = shutdown_rx.changed() => { if changed.is_ok() && *shutdown_rx.borrow() { break; } } line = lines.next_line() => { match line? { Some(raw) => { commands::process_text_command(client_id, &raw, Arc::clone(&state), Arc::clone(&config)).await; } None => break, } } } } state::disconnect_client(client_id, state).await; writer_task.abort(); Ok(()) } async fn handle_ws_stream( socket: S, state: Arc>, config: Arc, next_client_id: Arc, mut shutdown_rx: watch::Receiver, ) -> Result<(), Box> where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let ws_stream = accept_async(socket).await?; let (mut ws_write, mut ws_read) = ws_stream.split(); let client_id = next_client_id.fetch_add(1, Ordering::Relaxed); let default_name = format!("Guest-{client_id}"); let (tx, mut rx) = mpsc::unbounded_channel::(); { let mut guard = state.write().await; guard.clients.insert( client_id, ClientConn { user_name: default_name, room: String::new(), color: None, token: None, falukant_user_id: None, chat_user_id: None, gender_id: None, age: None, rights: HashSet::new(), right_type_ids: HashSet::new(), logged_in: false, tx: tx.clone(), }, ); } let writer_task = tokio::spawn(async move { while let Some(msg) = rx.recv().await { if ws_write.send(Message::Text(msg.into())).await.is_err() { break; } } }); loop { tokio::select! { changed = shutdown_rx.changed() => { if changed.is_ok() && *shutdown_rx.borrow() { break; } } incoming = ws_read.next() => { match incoming { Some(Ok(Message::Text(text))) => { commands::process_text_command(client_id, &text, Arc::clone(&state), Arc::clone(&config)).await; } Some(Ok(Message::Binary(bin))) => { if let Ok(text) = std::str::from_utf8(&bin) { commands::process_text_command(client_id, text, Arc::clone(&state), Arc::clone(&config)).await; } } Some(Ok(Message::Ping(_))) => {} Some(Ok(Message::Close(_))) => break, Some(Ok(_)) => {} Some(Err(_)) | None => break, } } } } state::disconnect_client(client_id, state).await; writer_task.abort(); Ok(()) } fn env_bool(name: &str) -> bool { matches!( env::var(name).ok().as_deref(), Some("1") | Some("true") | Some("TRUE") | Some("yes") | Some("YES") | Some("on") | Some("ON") ) } fn load_tls_acceptor_from_env() -> Result> { let cert_path = env::var("CHAT_TLS_CERT_PATH") .map_err(|_| "CHAT_WS_TLS=true requires CHAT_TLS_CERT_PATH")?; let key_path = env::var("CHAT_TLS_KEY_PATH") .map_err(|_| "CHAT_WS_TLS=true requires CHAT_TLS_KEY_PATH")?; let mut cert_reader = StdBufReader::new(File::open(&cert_path)?); let certs: Vec> = rustls_pemfile::certs(&mut cert_reader).collect::, _>>()?; if certs.is_empty() { return Err("No certificates found in CHAT_TLS_CERT_PATH".into()); } let mut key_reader = StdBufReader::new(File::open(&key_path)?); let key: PrivateKeyDer<'static> = rustls_pemfile::private_key(&mut key_reader)? .ok_or("No private key found in CHAT_TLS_KEY_PATH")?; let config = RustlsServerConfig::builder() .with_no_client_auth() .with_single_cert(certs, key)?; Ok(TlsAcceptor::from(Arc::new(config))) }