Implement TLS support for WebSocket connections in yourchat2. Updated main.rs to handle secure WebSocket connections based on environment variables. Enhanced install-systemd.sh to include a template for environment configuration. Updated README to document new TLS-related environment variables and installation instructions.

This commit is contained in:
Torsten Schulz (local)
2026-03-04 17:42:47 +01:00
parent 0037ac5c28
commit aca290f1d0
6 changed files with 331 additions and 18 deletions

View File

@@ -1,12 +1,17 @@
use futures_util::{SinkExt, StreamExt};
use std::collections::HashSet;
use std::env;
use std::fs::File;
use std::io::BufReader as StdBufReader;
use std::path::Path;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
use tokio::net::{TcpListener, UnixListener};
use tokio::sync::{mpsc, watch, RwLock};
use tokio_rustls::TlsAcceptor;
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
use tokio_rustls::rustls::ServerConfig as RustlsServerConfig;
use tokio_tungstenite::{accept_async, tungstenite::Message};
mod commands;
@@ -18,8 +23,14 @@ use types::{ChatState, ClientConn, ServerConfig};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let ws_addr = env::var("CHAT_WS_ADDR").unwrap_or_else(|_| "0.0.0.0:1235".to_string());
let ws_tls = env_bool("CHAT_WS_TLS");
let tcp_addr = env::var("CHAT_TCP_ADDR").unwrap_or_else(|_| "127.0.0.1:1236".to_string());
let unix_socket = env::var("CHAT_UNIX_SOCKET").ok().filter(|s| !s.trim().is_empty());
let tls_acceptor = if ws_tls {
Some(Arc::new(load_tls_acceptor_from_env()?))
} else {
None
};
let state = Arc::new(RwLock::new(ChatState::default()));
let db_client = db::connect_db_from_env().await?;
@@ -45,11 +56,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let (shutdown_tx, shutdown_rx) = watch::channel(false);
let ws_listener = TcpListener::bind(&ws_addr).await?;
println!("[yourchat2] listening on ws://{}", ws_addr);
if ws_tls {
println!("[yourchat2] listening on wss://{}", ws_addr);
} else {
println!("[yourchat2] listening on ws://{}", ws_addr);
}
let ws_state = Arc::clone(&state);
let ws_config = Arc::clone(&config);
let ws_next = Arc::clone(&next_client_id);
let ws_tls_acceptor = tls_acceptor.clone();
let mut ws_shutdown_rx = shutdown_rx.clone();
let ws_task = tokio::spawn(async move {
loop {
@@ -62,13 +78,29 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
accepted = ws_listener.accept() => {
match accepted {
Ok((socket, addr)) => {
println!("[yourchat2] ws client connected: {}", addr);
if ws_tls_acceptor.is_some() {
println!("[yourchat2] wss client connected: {}", addr);
} else {
println!("[yourchat2] ws client connected: {}", addr);
}
let state = Arc::clone(&ws_state);
let config = Arc::clone(&ws_config);
let next = Arc::clone(&ws_next);
let tls_acceptor = ws_tls_acceptor.clone();
let shutdown = ws_shutdown_rx.clone();
tokio::spawn(async move {
if let Err(err) = handle_ws_client(socket, state, config, next, shutdown).await {
if let Some(acceptor) = tls_acceptor {
match acceptor.accept(socket).await {
Ok(tls_stream) => {
if let Err(err) = handle_ws_stream(tls_stream, state, config, next, shutdown).await {
eprintln!("[yourchat2] wss client error: {err}");
}
}
Err(err) => {
eprintln!("[yourchat2] tls handshake error: {err}");
}
}
} else if let Err(err) = handle_ws_stream(socket, state, config, next, shutdown).await {
eprintln!("[yourchat2] ws client error: {err}");
}
});
@@ -249,13 +281,16 @@ where
Ok(())
}
async fn handle_ws_client(
socket: tokio::net::TcpStream,
async fn handle_ws_stream<S>(
socket: S,
state: Arc<RwLock<ChatState>>,
config: Arc<ServerConfig>,
next_client_id: Arc<AtomicU64>,
mut shutdown_rx: watch::Receiver<bool>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let ws_stream = accept_async(socket).await?;
let (mut ws_write, mut ws_read) = ws_stream.split();
let client_id = next_client_id.fetch_add(1, Ordering::Relaxed);
@@ -319,3 +354,33 @@ async fn handle_ws_client(
writer_task.abort();
Ok(())
}
fn env_bool(name: &str) -> bool {
matches!(
env::var(name).ok().as_deref(),
Some("1") | Some("true") | Some("TRUE") | Some("yes") | Some("YES") | Some("on") | Some("ON")
)
}
fn load_tls_acceptor_from_env() -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
let cert_path = env::var("CHAT_TLS_CERT_PATH")
.map_err(|_| "CHAT_WS_TLS=true requires CHAT_TLS_CERT_PATH")?;
let key_path = env::var("CHAT_TLS_KEY_PATH")
.map_err(|_| "CHAT_WS_TLS=true requires CHAT_TLS_KEY_PATH")?;
let mut cert_reader = StdBufReader::new(File::open(&cert_path)?);
let certs: Vec<CertificateDer<'static>> =
rustls_pemfile::certs(&mut cert_reader).collect::<Result<Vec<_>, _>>()?;
if certs.is_empty() {
return Err("No certificates found in CHAT_TLS_CERT_PATH".into());
}
let mut key_reader = StdBufReader::new(File::open(&key_path)?);
let key: PrivateKeyDer<'static> = rustls_pemfile::private_key(&mut key_reader)?
.ok_or("No private key found in CHAT_TLS_KEY_PATH")?;
let config = RustlsServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)?;
Ok(TlsAcceptor::from(Arc::new(config)))
}