diff --git a/src/projections/xmpp/mod.rs b/src/projections/xmpp/mod.rs index e4ee8dc..fe2d624 100644 --- a/src/projections/xmpp/mod.rs +++ b/src/projections/xmpp/mod.rs @@ -1,19 +1,19 @@ use std::collections::HashMap; use std::fs::File; +use std::io::BufReader as SyncBufReader; use std::net::SocketAddr; use std::path::PathBuf; -use std::io::BufReader as SyncBufReader; use std::sync::Arc; use futures_util::future::join_all; use prometheus::Registry as MetricsRegistry; use rustls_pemfile::{certs, rsa_private_keys}; use serde::Deserialize; -use tokio::io::{AsyncWriteExt, AsyncReadExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc::channel; -use tokio_rustls::TlsAcceptor; use tokio_rustls::rustls::{Certificate, PrivateKey}; +use tokio_rustls::TlsAcceptor; use crate::core::player::PlayerRegistry; use crate::core::room::RoomRegistry; @@ -84,7 +84,7 @@ pub async fn launch( } stopped_tx.send(socket_addr).await?; Ok(()) - } + } }); actors.insert(socket_addr, terminator); @@ -95,15 +95,22 @@ pub async fn launch( } } log::info!("Stopping XMPP projection"); - join_all(actors.into_iter().map(|(socket_addr, terminator)| async move { - log::debug!("Stopping XMPP connection at {socket_addr}"); - match terminator.terminate().await { - Ok(_) => log::debug!("Stopped XMPP connection at {socket_addr}"), - Err(err) => { - log::warn!("XMPP connection to {socket_addr} finished with error: {err}") - } - } - })).await; + join_all( + actors + .into_iter() + .map(|(socket_addr, terminator)| async move { + log::debug!("Stopping XMPP connection at {socket_addr}"); + match terminator.terminate().await { + Ok(_) => log::debug!("Stopped XMPP connection at {socket_addr}"), + Err(err) => { + log::warn!( + "XMPP connection to {socket_addr} finished with error: {err}" + ) + } + } + }), + ) + .await; log::info!("Stopped XMPP projection"); Ok(()) }); @@ -136,16 +143,18 @@ async fn handle_socket( Ok(e) => println!("{} END", e), Err(_) => println!("{:?} END", &buf[0..i]), } - stream.write_all(br###""###).await?; + stream + .write_all(br###""###) + .await?; } let config = tokio_rustls::rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() .with_single_cert(vec![config.cert.clone()], config.key.clone())?; - + let i = stream.read(&mut buf).await?; - + match std::str::from_utf8(&buf[0..i]) { Ok(e) => println!("{} END", e), Err(_) => println!("{:?} END", &buf[0..i]), @@ -155,11 +164,11 @@ async fn handle_socket( let mut new_stream = acceptor.accept(stream).await?; log::debug!("TLS connection established"); - - loop { let i = new_stream.read(&mut buf).await?; - if i == 0 { break; } + if i == 0 { + break; + } match std::str::from_utf8(&buf[0..i]) { Ok(e) => println!("{} END", e), Err(_) => println!("{:?} END", &buf[0..i]), diff --git a/src/protos/xmpp/client.rs b/src/protos/xmpp/client.rs new file mode 100644 index 0000000..e8f2fd5 --- /dev/null +++ b/src/protos/xmpp/client.rs @@ -0,0 +1 @@ +pub static XMLNS: &'static str = "jabber:client"; diff --git a/src/protos/xmpp/mod.rs b/src/protos/xmpp/mod.rs index baf29e0..93d19c0 100644 --- a/src/protos/xmpp/mod.rs +++ b/src/protos/xmpp/mod.rs @@ -1 +1,2 @@ +pub mod client; pub mod stream; diff --git a/src/protos/xmpp/stream.rs b/src/protos/xmpp/stream.rs index 98a12bf..5e7f244 100644 --- a/src/protos/xmpp/stream.rs +++ b/src/protos/xmpp/stream.rs @@ -1,9 +1,15 @@ -use quick_xml::name::{ResolveResult, Namespace, LocalName, QName}; -use quick_xml::{Result, NsReader}; -use quick_xml::events::Event; +use std::io::Write; + +use quick_xml::events::attributes::Attribute; +use quick_xml::events::{BytesStart, Event}; +use quick_xml::name::{Namespace, QName, ResolveResult}; +use quick_xml::writer::Writer; +use quick_xml::{NsReader, Result}; use tokio::io::AsyncBufRead; pub static XMLNS: &'static str = "http://etherx.jabber.org/streams"; +pub static PREFIX: &'static str = "stream"; +pub static XMLNS_XML: &'static str = "http://www.w3.org/XML/1998/namespace"; #[derive(Debug, PartialEq, Eq)] pub struct ClientStreamStart { @@ -12,10 +18,13 @@ pub struct ClientStreamStart { pub version: String, } impl ClientStreamStart { - pub async fn parse(reader: &mut NsReader, buf: &mut Vec) -> Result { + pub async fn parse( + reader: &mut NsReader, + buf: &mut Vec, + ) -> Result { if let Event::Start(e) = reader.read_event_into_async(buf).await? { let (ns, local) = reader.resolve_element(e.name()); - if ns != ResolveResult::Bound(Namespace(XMLNS.as_bytes())) { + if ns != ResolveResult::Bound(Namespace(XMLNS.as_bytes())) { return Err(panic!()); } if local.into_inner() != b"stream" { @@ -31,35 +40,101 @@ impl ClientStreamStart { (ResolveResult::Unbound, b"to") => { let value = attr.unescape_value()?; to = Some(value.to_string()); - }, - (ResolveResult::Bound(Namespace(b"http://www.w3.org/XML/1998/namespace")), b"lang") => { + } + ( + ResolveResult::Bound(Namespace(b"http://www.w3.org/XML/1998/namespace")), + b"lang", + ) => { let value = attr.unescape_value()?; lang = Some(value.to_string()); - }, + } (ResolveResult::Unbound, b"version") => { let value = attr.unescape_value()?; version = Some(value.to_string()); - }, - _ => {}, + } + _ => {} } } - Ok(ClientStreamStart { to: to.unwrap(), lang: lang.unwrap(), version: version.unwrap() }) + Ok(ClientStreamStart { + to: to.unwrap(), + lang: lang.unwrap(), + version: version.unwrap(), + }) } else { Err(panic!()) } } } +pub struct ServerStreamStart { + pub from: String, + pub lang: String, + pub version: String, +} +impl ServerStreamStart { + pub fn write(&self, writer: &mut Writer) -> Result<()> { + let mut event = BytesStart::new("stream:stream"); + let attributes = [ + Attribute { + key: QName(b"from"), + value: self.from.as_bytes().into(), + }, + Attribute { + key: QName(b"version"), + value: self.version.as_bytes().into(), + }, + Attribute { + key: QName(b"xmlns"), + value: super::client::XMLNS.as_bytes().into(), + }, + Attribute { + key: QName(b"xmlns:stream"), + value: XMLNS.as_bytes().into(), + }, + Attribute { + key: QName(b"xml:lang"), + value: self.lang.as_bytes().into(), + }, + ]; + event.extend_attributes(attributes.into_iter()); + writer.write_event(Event::Start(event))?; + Ok(()) + } +} + #[cfg(test)] mod test { use super::*; #[tokio::test] - async fn stream_start_correct() { + async fn client_stream_start_correct_parse() { let input = r###""###; let mut reader = NsReader::from_reader(input.as_bytes()); let mut buf = vec![]; - let res = ClientStreamStart::parse(&mut reader, &mut buf).await.unwrap(); - assert_eq!(res, ClientStreamStart { to: "vlnv.dev".to_owned(), lang: "en".to_owned(), version: "1.0".to_owned()}) + let res = ClientStreamStart::parse(&mut reader, &mut buf) + .await + .unwrap(); + assert_eq!( + res, + ClientStreamStart { + to: "vlnv.dev".to_owned(), + lang: "en".to_owned(), + version: "1.0".to_owned() + } + ) } -} \ No newline at end of file + + #[test] + fn server_stream_start_write() { + let input = ServerStreamStart { + from: "vlnv.dev".to_owned(), + lang: "en".to_owned(), + version: "1.0".to_owned(), + }; + let expected = r###""###; + let mut output: Vec = vec![]; + let mut writer = Writer::new(&mut output); + input.write(&mut writer).unwrap(); + assert_eq!(std::str::from_utf8(&output).unwrap(), expected); + } +} diff --git a/src/util/telemetry.rs b/src/util/telemetry.rs index c218e02..26387dd 100644 --- a/src/util/telemetry.rs +++ b/src/util/telemetry.rs @@ -10,7 +10,6 @@ use hyper::{Method, Request, Response}; use prometheus::{Encoder, Registry as MetricsRegistry, TextEncoder}; use serde::Deserialize; use tokio::net::TcpListener; -use tokio::sync::oneshot::channel; use crate::core::room::RoomRegistry; use crate::prelude::*;