Implement TLS support for WebSocket connections in yourchat2. Updated main.rs to handle secure WebSocket connections based on environment variables. Enhanced install-systemd.sh to include a template for environment configuration. Updated README to document new TLS-related environment variables and installation instructions.
This commit is contained in:
77
src/main.rs
77
src/main.rs
@@ -1,12 +1,17 @@
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use std::collections::HashSet;
|
||||
use std::env;
|
||||
use std::fs::File;
|
||||
use std::io::BufReader as StdBufReader;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
|
||||
use tokio::net::{TcpListener, UnixListener};
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
|
||||
use tokio_rustls::rustls::ServerConfig as RustlsServerConfig;
|
||||
use tokio_tungstenite::{accept_async, tungstenite::Message};
|
||||
|
||||
mod commands;
|
||||
@@ -18,8 +23,14 @@ use types::{ChatState, ClientConn, ServerConfig};
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let ws_addr = env::var("CHAT_WS_ADDR").unwrap_or_else(|_| "0.0.0.0:1235".to_string());
|
||||
let ws_tls = env_bool("CHAT_WS_TLS");
|
||||
let tcp_addr = env::var("CHAT_TCP_ADDR").unwrap_or_else(|_| "127.0.0.1:1236".to_string());
|
||||
let unix_socket = env::var("CHAT_UNIX_SOCKET").ok().filter(|s| !s.trim().is_empty());
|
||||
let tls_acceptor = if ws_tls {
|
||||
Some(Arc::new(load_tls_acceptor_from_env()?))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let state = Arc::new(RwLock::new(ChatState::default()));
|
||||
let db_client = db::connect_db_from_env().await?;
|
||||
@@ -45,11 +56,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
|
||||
let ws_listener = TcpListener::bind(&ws_addr).await?;
|
||||
println!("[yourchat2] listening on ws://{}", ws_addr);
|
||||
if ws_tls {
|
||||
println!("[yourchat2] listening on wss://{}", ws_addr);
|
||||
} else {
|
||||
println!("[yourchat2] listening on ws://{}", ws_addr);
|
||||
}
|
||||
|
||||
let ws_state = Arc::clone(&state);
|
||||
let ws_config = Arc::clone(&config);
|
||||
let ws_next = Arc::clone(&next_client_id);
|
||||
let ws_tls_acceptor = tls_acceptor.clone();
|
||||
let mut ws_shutdown_rx = shutdown_rx.clone();
|
||||
let ws_task = tokio::spawn(async move {
|
||||
loop {
|
||||
@@ -62,13 +78,29 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
accepted = ws_listener.accept() => {
|
||||
match accepted {
|
||||
Ok((socket, addr)) => {
|
||||
println!("[yourchat2] ws client connected: {}", addr);
|
||||
if ws_tls_acceptor.is_some() {
|
||||
println!("[yourchat2] wss client connected: {}", addr);
|
||||
} else {
|
||||
println!("[yourchat2] ws client connected: {}", addr);
|
||||
}
|
||||
let state = Arc::clone(&ws_state);
|
||||
let config = Arc::clone(&ws_config);
|
||||
let next = Arc::clone(&ws_next);
|
||||
let tls_acceptor = ws_tls_acceptor.clone();
|
||||
let shutdown = ws_shutdown_rx.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = handle_ws_client(socket, state, config, next, shutdown).await {
|
||||
if let Some(acceptor) = tls_acceptor {
|
||||
match acceptor.accept(socket).await {
|
||||
Ok(tls_stream) => {
|
||||
if let Err(err) = handle_ws_stream(tls_stream, state, config, next, shutdown).await {
|
||||
eprintln!("[yourchat2] wss client error: {err}");
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
eprintln!("[yourchat2] tls handshake error: {err}");
|
||||
}
|
||||
}
|
||||
} else if let Err(err) = handle_ws_stream(socket, state, config, next, shutdown).await {
|
||||
eprintln!("[yourchat2] ws client error: {err}");
|
||||
}
|
||||
});
|
||||
@@ -249,13 +281,16 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_ws_client(
|
||||
socket: tokio::net::TcpStream,
|
||||
async fn handle_ws_stream<S>(
|
||||
socket: S,
|
||||
state: Arc<RwLock<ChatState>>,
|
||||
config: Arc<ServerConfig>,
|
||||
next_client_id: Arc<AtomicU64>,
|
||||
mut shutdown_rx: watch::Receiver<bool>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let ws_stream = accept_async(socket).await?;
|
||||
let (mut ws_write, mut ws_read) = ws_stream.split();
|
||||
let client_id = next_client_id.fetch_add(1, Ordering::Relaxed);
|
||||
@@ -319,3 +354,33 @@ async fn handle_ws_client(
|
||||
writer_task.abort();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn env_bool(name: &str) -> bool {
|
||||
matches!(
|
||||
env::var(name).ok().as_deref(),
|
||||
Some("1") | Some("true") | Some("TRUE") | Some("yes") | Some("YES") | Some("on") | Some("ON")
|
||||
)
|
||||
}
|
||||
|
||||
fn load_tls_acceptor_from_env() -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let cert_path = env::var("CHAT_TLS_CERT_PATH")
|
||||
.map_err(|_| "CHAT_WS_TLS=true requires CHAT_TLS_CERT_PATH")?;
|
||||
let key_path = env::var("CHAT_TLS_KEY_PATH")
|
||||
.map_err(|_| "CHAT_WS_TLS=true requires CHAT_TLS_KEY_PATH")?;
|
||||
|
||||
let mut cert_reader = StdBufReader::new(File::open(&cert_path)?);
|
||||
let certs: Vec<CertificateDer<'static>> =
|
||||
rustls_pemfile::certs(&mut cert_reader).collect::<Result<Vec<_>, _>>()?;
|
||||
if certs.is_empty() {
|
||||
return Err("No certificates found in CHAT_TLS_CERT_PATH".into());
|
||||
}
|
||||
|
||||
let mut key_reader = StdBufReader::new(File::open(&key_path)?);
|
||||
let key: PrivateKeyDer<'static> = rustls_pemfile::private_key(&mut key_reader)?
|
||||
.ok_or("No private key found in CHAT_TLS_KEY_PATH")?;
|
||||
|
||||
let config = RustlsServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key)?;
|
||||
Ok(TlsAcceptor::from(Arc::new(config)))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user