mod proto; use std::collections::HashMap; use std::fs::File; use std::io::BufReader as SyncBufReader; use std::net::SocketAddr; use std::path::PathBuf; use std::sync::Arc; use futures_util::future::join_all; use prometheus::Registry as MetricsRegistry; use quick_xml::events::{BytesDecl, Event}; use quick_xml::{NsReader, Writer}; use rustls_pemfile::{certs, rsa_private_keys}; use serde::Deserialize; use tokio::io::{AsyncBufRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc::channel; use tokio_rustls::rustls::{Certificate, PrivateKey}; use tokio_rustls::TlsAcceptor; use crate::core::player::PlayerRegistry; use crate::core::room::RoomRegistry; use crate::prelude::*; use crate::protos::xmpp; use crate::protos::xmpp::bind::{BindResponse, Jid, Name, Resource, Server}; use crate::protos::xmpp::client::{Iq, Presence}; use crate::protos::xmpp::roster::RosterQuery; use crate::protos::xmpp::session::Session; use crate::protos::xmpp::stream::*; use crate::util::xml::{Continuation, FromXml, Parser, ToXml}; use crate::util::Terminator; #[derive(Deserialize, Debug, Clone)] pub struct ServerConfig { pub listen_on: SocketAddr, pub cert: PathBuf, pub key: PathBuf, } struct LoadedConfig { cert: Certificate, key: PrivateKey, } pub async fn launch( config: ServerConfig, players: PlayerRegistry, rooms: RoomRegistry, metrics: MetricsRegistry, ) -> Result { log::info!("Starting XMPP projection"); let certs = certs(&mut SyncBufReader::new(File::open(config.cert)?))?; let certs = certs.into_iter().map(Certificate).collect::>(); let keys = rsa_private_keys(&mut SyncBufReader::new(File::open(config.key)?))?; let keys = keys.into_iter().map(PrivateKey).collect::>(); let loaded_config = Arc::new(LoadedConfig { cert: certs.into_iter().next().expect("no certs in file"), key: keys.into_iter().next().expect("no keys in file"), }); let listener = TcpListener::bind(config.listen_on).await?; let terminator = Terminator::spawn(|mut termination| async move { let (stopped_tx, mut stopped_rx) = channel(32); let mut actors = HashMap::new(); loop { select! { biased; _ = &mut termination => break, stopped = stopped_rx.recv() => match stopped { Some(stopped) => { let _ = actors.remove(&stopped); }, None => unreachable!(), }, new_conn = listener.accept() => { match new_conn { Ok((stream, socket_addr)) => { log::debug!("Incoming connection from {socket_addr}"); if actors.contains_key(&socket_addr) { log::warn!("Already contains connection form {socket_addr}"); // TODO kill the older connection and restart it continue; } let players = players.clone(); let rooms = rooms.clone(); let terminator = Terminator::spawn(|termination| { let stopped_tx = stopped_tx.clone(); let loaded_config = loaded_config.clone(); async move { match handle_socket(loaded_config, stream, &socket_addr, players, rooms, termination).await { Ok(_) => log::info!("Connection terminated"), Err(err) => log::warn!("Connection failed: {err}"), } stopped_tx.send(socket_addr).await?; Ok(()) } }); actors.insert(socket_addr, terminator); }, Err(err) => log::warn!("Failed to accept new connection: {err}"), } }, } } log::info!("Stopping XMPP projection"); join_all( actors .into_iter() .map(|(socket_addr, terminator)| async move { log::debug!("Stopping XMPP connection at {socket_addr}"); match terminator.terminate().await { Ok(_) => log::debug!("Stopped XMPP connection at {socket_addr}"), Err(err) => { log::warn!( "XMPP connection to {socket_addr} finished with error: {err}" ) } } }), ) .await; log::info!("Stopped XMPP projection"); Ok(()) }); log::info!("Started XMPP projection"); Ok(terminator) } async fn handle_socket( config: Arc, mut stream: TcpStream, socket_addr: &SocketAddr, players: PlayerRegistry, rooms: RoomRegistry, termination: Deferred<()>, // TODO use it to stop the connection gracefully ) -> Result<()> { log::debug!("Received an XMPP connection from {socket_addr}"); let mut reader_buf = vec![]; let (reader, writer) = stream.split(); let mut buf_reader = BufReader::new(reader); let mut buf_writer = BufWriter::new(writer); socket_force_tls(&mut buf_reader, &mut buf_writer, &mut reader_buf).await?; let config = tokio_rustls::rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() .with_single_cert(vec![config.cert.clone()], config.key.clone())?; let acceptor = TlsAcceptor::from(Arc::new(config)); let new_stream = acceptor.accept(stream).await?; log::debug!("TLS connection established"); let (a, b) = tokio::io::split(new_stream); let mut xml_reader = NsReader::from_reader(BufReader::new(a)); let mut xml_writer = Writer::new(b); socket_auth(&mut xml_reader, &mut xml_writer, &mut reader_buf).await?; socket_final(&mut xml_reader, &mut xml_writer, &mut reader_buf).await?; loop { let event = xml_reader.read_event_into_async(&mut reader_buf).await?; println!("EVENT: {event:?}"); if event == Event::Eof { break; } } let a = xml_reader.into_inner().into_inner(); let b = xml_writer.into_inner(); a.unsplit(b).shutdown().await?; Ok(()) } async fn socket_force_tls( reader: &mut (impl AsyncBufRead + Unpin), writer: &mut (impl AsyncWrite + Unpin), reader_buf: &mut Vec, ) -> Result<()> { use crate::protos::xmpp::tls::*; let xml_reader = &mut NsReader::from_reader(reader); let xml_writer = &mut Writer::new(writer); read_xml_header(xml_reader, reader_buf).await?; let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; let event = Event::Decl(BytesDecl::new("1.0", None, None)); xml_writer.write_event_async(event).await?; let msg = ServerStreamStart { from: "localhost".into(), lang: "en".into(), version: "1.0".into(), }; msg.write_xml(xml_writer).await?; let msg = Features { start_tls: true, mechanisms: false, bind: false, }; msg.write_xml(xml_writer).await?; xml_writer.get_mut().flush().await?; let StartTLS = StartTLS::parse(xml_reader, reader_buf).await?; ProceedTLS.write_xml(xml_writer).await?; xml_writer.get_mut().flush().await?; Ok(()) } async fn socket_auth( xml_reader: &mut NsReader<(impl AsyncBufRead + Unpin)>, xml_writer: &mut Writer<(impl AsyncWrite + Unpin)>, reader_buf: &mut Vec, ) -> Result<()> { read_xml_header(xml_reader, reader_buf).await?; let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; xml_writer .write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))) .await?; ServerStreamStart { from: "localhost".into(), lang: "en".into(), version: "1.0".into(), } .write_xml(xml_writer) .await?; Features { start_tls: false, mechanisms: true, bind: false, } .write_xml(xml_writer) .await?; xml_writer.get_mut().flush().await?; let _ = xmpp::sasl::Auth::parse(xml_reader, reader_buf).await?; xmpp::sasl::Success.write_xml(xml_writer).await?; Ok(()) } async fn socket_final( xml_reader: &mut NsReader<(impl AsyncBufRead + Unpin)>, xml_writer: &mut Writer<(impl AsyncWrite + Unpin)>, reader_buf: &mut Vec, ) -> Result<()> { read_xml_header(xml_reader, reader_buf).await?; let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; xml_writer .write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))) .await?; ServerStreamStart { from: "localhost".into(), lang: "en".into(), version: "1.0".into(), } .write_xml(xml_writer) .await?; Features { start_tls: false, mechanisms: false, bind: true, } .write_xml(xml_writer) .await?; xml_writer.get_mut().flush().await?; let mut parser = proto::ClientPacket::parse(); loop { reader_buf.clear(); let (ns, event) = xml_reader .read_resolved_event_into_async(reader_buf) .await?; if let Event::Text(ref e) = event { if e.iter().all(|x| *x == 0xA) { continue; } } match parser.consume(ns, &event) { Continuation::Final(res) => { let res = res?; dbg!(&res); match res { proto::ClientPacket::Iq(iq) => match iq.body { proto::IqClientBody::Bind(b) => { let mut events = vec![]; let req = Iq { from: None, id: iq.id, to: None, r#type: xmpp::client::IqType::Result, body: BindResponse(Jid { name: Name("darova".to_string()), server: Server("localhost".to_string()), resource: Resource("kek".to_string()), }), }; req.serialize(&mut events); for i in events { xml_writer.write_event_async(i).await?; } xml_writer.get_mut().flush().await?; } proto::IqClientBody::Session(_) => { let mut events = vec![]; let req = Iq { from: None, id: iq.id, to: None, r#type: xmpp::client::IqType::Result, body: Session, }; req.serialize(&mut events); for i in events { xml_writer.write_event_async(i).await?; } xml_writer.get_mut().flush().await?; } proto::IqClientBody::Roster(_) => { let mut events = vec![]; let req = Iq { from: None, id: iq.id, to: None, r#type: xmpp::client::IqType::Result, body: RosterQuery, }; req.serialize(&mut events); for i in events { xml_writer.write_event_async(i).await?; } xml_writer.get_mut().flush().await?; } proto::IqClientBody::Unknown(_) => { let mut events = vec![]; let req = Iq { from: None, id: iq.id, to: None, r#type: xmpp::client::IqType::Error, body: (), }; req.serialize(&mut events); for i in events { xml_writer.write_event_async(i).await?; } xml_writer.get_mut().flush().await?; }, }, proto::ClientPacket::Message(_) => todo!(), proto::ClientPacket::Presence(p) => { let mut events = vec![]; let response = Presence::<()> { to: Some("darova@localhost/kek".to_string()), from: Some("darova@localhost/kek".to_string()), ..Default::default() }; response.serialize(&mut events); for i in events { xml_writer.write_event_async(i).await?; } xml_writer.get_mut().flush().await?; } } parser = proto::ClientPacket::parse(); } Continuation::Continue(p) => parser = p, } } Ok(()) } async fn read_xml_header( xml_reader: &mut NsReader<(impl AsyncBufRead + Unpin)>, reader_buf: &mut Vec, ) -> Result<()> { if let Event::Decl(bytes) = xml_reader.read_event_into_async(reader_buf).await? { // this is header if let Some(encoding) = bytes.encoding() { let encoding = encoding?; if &*encoding == b"UTF-8" { Ok(()) } else { Err(fail(format!("Unsupported encoding: {encoding:?}").as_str())) } } else { // Err(fail("No XML encoding provided")) Ok(()) } } else { Err(fail("Expected XML header")) } }