use std::collections::HashMap; use std::fs::File; use std::io::BufReader as SyncBufReader; use std::net::SocketAddr; use std::path::PathBuf; use std::sync::Arc; use futures_util::future::join_all; use prometheus::Registry as MetricsRegistry; use quick_xml::events::{BytesDecl, Event}; use quick_xml::{NsReader, Writer}; use rustls_pemfile::{certs, rsa_private_keys}; use serde::Deserialize; use tokio::io::{AsyncWriteExt, BufReader, BufWriter}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc::channel; use tokio_rustls::rustls::{Certificate, PrivateKey}; use tokio_rustls::TlsAcceptor; use crate::core::player::PlayerRegistry; use crate::core::room::RoomRegistry; use crate::prelude::*; use crate::protos::xmpp; use crate::util::Terminator; #[derive(Deserialize, Debug, Clone)] pub struct ServerConfig { pub listen_on: SocketAddr, pub cert: PathBuf, pub key: PathBuf, } struct LoadedConfig { cert: Certificate, key: PrivateKey, } pub async fn launch( config: ServerConfig, players: PlayerRegistry, rooms: RoomRegistry, metrics: MetricsRegistry, ) -> Result { log::info!("Starting XMPP projection"); let certs = certs(&mut SyncBufReader::new(File::open(config.cert)?))?; let certs = certs.into_iter().map(Certificate).collect::>(); let keys = rsa_private_keys(&mut SyncBufReader::new(File::open(config.key)?))?; let keys = keys.into_iter().map(PrivateKey).collect::>(); let loaded_config = Arc::new(LoadedConfig { cert: certs.into_iter().next().expect("no certs in file"), key: keys.into_iter().next().expect("no keys in file"), }); let listener = TcpListener::bind(config.listen_on).await?; let terminator = Terminator::spawn(|mut termination| async move { let (stopped_tx, mut stopped_rx) = channel(32); let mut actors = HashMap::new(); loop { select! { biased; _ = &mut termination => break, stopped = stopped_rx.recv() => match stopped { Some(stopped) => { let _ = actors.remove(&stopped); }, None => unreachable!(), }, new_conn = listener.accept() => { match new_conn { Ok((stream, socket_addr)) => { log::debug!("Incoming connection from {socket_addr}"); if actors.contains_key(&socket_addr) { log::warn!("Already contains connection form {socket_addr}"); // TODO kill the older connection and restart it continue; } let players = players.clone(); let rooms = rooms.clone(); let terminator = Terminator::spawn(|termination| { let stopped_tx = stopped_tx.clone(); let loaded_config = loaded_config.clone(); async move { match handle_socket(loaded_config, stream, &socket_addr, players, rooms, termination).await { Ok(_) => log::info!("Connection terminated"), Err(err) => log::warn!("Connection failed: {err}"), } stopped_tx.send(socket_addr).await?; Ok(()) } }); actors.insert(socket_addr, terminator); }, Err(err) => log::warn!("Failed to accept new connection: {err}"), } }, } } log::info!("Stopping XMPP projection"); join_all( actors .into_iter() .map(|(socket_addr, terminator)| async move { log::debug!("Stopping XMPP connection at {socket_addr}"); match terminator.terminate().await { Ok(_) => log::debug!("Stopped XMPP connection at {socket_addr}"), Err(err) => { log::warn!( "XMPP connection to {socket_addr} finished with error: {err}" ) } } }), ) .await; log::info!("Stopped XMPP projection"); Ok(()) }); log::info!("Started XMPP projection"); Ok(terminator) } async fn handle_socket( config: Arc, mut stream: TcpStream, socket_addr: &SocketAddr, players: PlayerRegistry, rooms: RoomRegistry, termination: Deferred<()>, // TODO use it to stop the connection gracefully ) -> Result<()> { use xmpp::stream::*; log::debug!("Received an XMPP connection from {socket_addr}"); let mut reader_buf = vec![]; let (reader, writer) = stream.split(); let mut buf_reader = BufReader::new(reader); let mut buf_writer = BufWriter::new(writer); { let mut xml_reader = NsReader::from_reader(&mut buf_reader); let mut xml_writer = Writer::new(&mut buf_writer); let aaa = xml_reader.read_event_into_async(&mut reader_buf).await?; if let Event::Decl(_) = aaa { // this is header } else { return Err(fail("expected XML header")); } let _ = ClientStreamStart::parse(&mut xml_reader, &mut reader_buf).await?; xml_writer .write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))) .await?; xmpp::stream::ServerStreamStart { from: "localhost".into(), lang: "en".into(), version: "1.0".into(), } .write_xml(&mut xml_writer) .await?; xmpp::stream::Features { start_tls: true, mechanisms: false, bind: false, } .write_xml(&mut xml_writer) .await?; xml_writer.get_mut().flush().await?; let StartTLS = StartTLS::parse(&mut xml_reader, &mut reader_buf).await?; // TODO read xmpp::stream::Proceed.write_xml(&mut xml_writer).await?; xml_writer.get_mut().flush().await?; } let config = tokio_rustls::rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() .with_single_cert(vec![config.cert.clone()], config.key.clone())?; let acceptor = TlsAcceptor::from(Arc::new(config)); let new_stream = acceptor.accept(stream).await?; log::debug!("TLS connection established"); let (a, b) = tokio::io::split(new_stream); let buf_reader = BufReader::new(a); let mut xml_reader = NsReader::from_reader(buf_reader); let mut xml_writer = Writer::new(b); { if let Event::Decl(_) = xml_reader.read_event_into_async(&mut reader_buf).await? { // this is header } else { return Err(fail("expected XML header")); } let _ = ClientStreamStart::parse(&mut xml_reader, &mut reader_buf).await?; xml_writer .write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))) .await?; xmpp::stream::ServerStreamStart { from: "localhost".into(), lang: "en".into(), version: "1.0".into(), } .write_xml(&mut xml_writer) .await?; xmpp::stream::Features { start_tls: false, mechanisms: true, bind: false, } .write_xml(&mut xml_writer) .await?; xml_writer.get_mut().flush().await?; let _ = xmpp::sasl::Auth::parse(&mut xml_reader, &mut reader_buf).await?; xmpp::sasl::Success.write_xml(&mut xml_writer).await?; } { if let Event::Decl(_) = xml_reader.read_event_into_async(&mut reader_buf).await? { // this is header } else { return Err(fail("expected XML header")); } let _ = ClientStreamStart::parse(&mut xml_reader, &mut reader_buf).await?; xml_writer .write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))) .await?; xmpp::stream::ServerStreamStart { from: "localhost".into(), lang: "en".into(), version: "1.0".into(), } .write_xml(&mut xml_writer) .await?; xmpp::stream::Features { start_tls: false, mechanisms: false, bind: true, } .write_xml(&mut xml_writer) .await?; xml_writer.get_mut().flush().await?; } loop { let event = xml_reader.read_event_into_async(&mut reader_buf).await?; println!("EVENT: {event:?}"); if event == Event::Eof { break; } } let a = xml_reader.into_inner().into_inner(); let b = xml_writer.into_inner(); a.unsplit(b).shutdown().await?; Ok(()) }