use std::collections::HashMap; use std::fs::File; use std::net::SocketAddr; use std::path::PathBuf; use std::io::BufReader as SyncBufReader; use std::sync::Arc; use futures_util::future::join_all; use prometheus::Registry as MetricsRegistry; use rustls_pemfile::{certs, rsa_private_keys}; use serde::Deserialize; use tokio::io::{AsyncWriteExt, AsyncReadExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc::channel; use tokio_rustls::TlsAcceptor; use tokio_rustls::rustls::{Certificate, PrivateKey}; use crate::core::player::PlayerRegistry; use crate::core::room::RoomRegistry; use crate::prelude::*; use crate::util::Terminator; #[derive(Deserialize, Debug, Clone)] pub struct ServerConfig { pub listen_on: SocketAddr, pub cert: PathBuf, pub key: PathBuf, } struct LoadedConfig { cert: Certificate, key: PrivateKey, } pub async fn launch( config: ServerConfig, players: PlayerRegistry, rooms: RoomRegistry, metrics: MetricsRegistry, ) -> Result { log::info!("Starting XMPP projection"); let certs = certs(&mut SyncBufReader::new(File::open(config.cert)?))?; let certs = certs.into_iter().map(Certificate).collect::>(); let keys = rsa_private_keys(&mut SyncBufReader::new(File::open(config.key)?))?; let keys = keys.into_iter().map(PrivateKey).collect::>(); let loaded_config = Arc::new(LoadedConfig { cert: certs.into_iter().next().expect("no certs in file"), key: keys.into_iter().next().expect("no keys in file"), }); let listener = TcpListener::bind(config.listen_on).await?; let terminator = Terminator::spawn(|mut termination| async move { let (stopped_tx, mut stopped_rx) = channel(32); let mut actors = HashMap::new(); loop { select! { biased; _ = &mut termination => break, stopped = stopped_rx.recv() => match stopped { Some(stopped) => { let _ = actors.remove(&stopped); }, None => unreachable!(), }, new_conn = listener.accept() => { match new_conn { Ok((stream, socket_addr)) => { log::debug!("Incoming connection from {socket_addr}"); if actors.contains_key(&socket_addr) { log::warn!("Already contains connection form {socket_addr}"); // TODO kill the older connection and restart it continue; } let players = players.clone(); let rooms = rooms.clone(); let terminator = Terminator::spawn(|termination| { let stopped_tx = stopped_tx.clone(); let loaded_config = loaded_config.clone(); async move { match handle_socket(loaded_config, stream, &socket_addr, players, rooms, termination).await { Ok(_) => log::info!("Connection terminated"), Err(err) => log::warn!("Connection failed: {err}"), } stopped_tx.send(socket_addr).await?; Ok(()) } }); actors.insert(socket_addr, terminator); }, Err(err) => log::warn!("Failed to accept new connection: {err}"), } }, } } log::info!("Stopping XMPP projection"); join_all(actors.into_iter().map(|(socket_addr, terminator)| async move { log::debug!("Stopping XMPP connection at {socket_addr}"); match terminator.terminate().await { Ok(_) => log::debug!("Stopped XMPP connection at {socket_addr}"), Err(err) => { log::warn!("XMPP connection to {socket_addr} finished with error: {err}") } } })).await; log::info!("Stopped XMPP projection"); Ok(()) }); log::info!("Started XMPP projection"); Ok(terminator) } async fn handle_socket( config: Arc, mut stream: TcpStream, socket_addr: &SocketAddr, players: PlayerRegistry, rooms: RoomRegistry, termination: Deferred<()>, // TODO use it to stop the connection gracefully ) -> Result<()> { log::debug!("Received an XMPP connection from {socket_addr}"); // writer.write_all(b"Hi!\n").await?; let mut buf = [0; 1024]; stream.write_all(br###" "###).await?; { let i = stream.read(&mut buf).await?; match std::str::from_utf8(&buf[0..i]) { Ok(e) => println!("{} END", e), Err(_) => println!("{:?} END", &buf[0..i]), } stream.write_all(br###""###).await?; } let config = tokio_rustls::rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() .with_single_cert(vec![config.cert.clone()], config.key.clone())?; let i = stream.read(&mut buf).await?; match std::str::from_utf8(&buf[0..i]) { Ok(e) => println!("{} END", e), Err(_) => println!("{:?} END", &buf[0..i]), } let acceptor = TlsAcceptor::from(Arc::new(config)); let mut new_stream = acceptor.accept(stream).await?; log::debug!("TLS connection established"); loop { let i = new_stream.read(&mut buf).await?; if i == 0 { break; } match std::str::from_utf8(&buf[0..i]) { Ok(e) => println!("{} END", e), Err(_) => println!("{:?} END", &buf[0..i]), } } new_stream.shutdown().await?; Ok(()) }