feat(xmpp): serialization of stream start

This commit is contained in:
Nikita Vilunov 2023-02-28 12:12:03 +01:00
parent 494ddc7ee1
commit 435da6663a
5 changed files with 120 additions and 35 deletions

View File

@ -1,19 +1,19 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::fs::File; use std::fs::File;
use std::io::BufReader as SyncBufReader;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::PathBuf; use std::path::PathBuf;
use std::io::BufReader as SyncBufReader;
use std::sync::Arc; use std::sync::Arc;
use futures_util::future::join_all; use futures_util::future::join_all;
use prometheus::Registry as MetricsRegistry; use prometheus::Registry as MetricsRegistry;
use rustls_pemfile::{certs, rsa_private_keys}; use rustls_pemfile::{certs, rsa_private_keys};
use serde::Deserialize; use serde::Deserialize;
use tokio::io::{AsyncWriteExt, AsyncReadExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc::channel; use tokio::sync::mpsc::channel;
use tokio_rustls::TlsAcceptor;
use tokio_rustls::rustls::{Certificate, PrivateKey}; use tokio_rustls::rustls::{Certificate, PrivateKey};
use tokio_rustls::TlsAcceptor;
use crate::core::player::PlayerRegistry; use crate::core::player::PlayerRegistry;
use crate::core::room::RoomRegistry; use crate::core::room::RoomRegistry;
@ -95,15 +95,22 @@ pub async fn launch(
} }
} }
log::info!("Stopping XMPP projection"); log::info!("Stopping XMPP projection");
join_all(actors.into_iter().map(|(socket_addr, terminator)| async move { join_all(
log::debug!("Stopping XMPP connection at {socket_addr}"); actors
match terminator.terminate().await { .into_iter()
Ok(_) => log::debug!("Stopped XMPP connection at {socket_addr}"), .map(|(socket_addr, terminator)| async move {
Err(err) => { log::debug!("Stopping XMPP connection at {socket_addr}");
log::warn!("XMPP connection to {socket_addr} finished with error: {err}") match terminator.terminate().await {
} Ok(_) => log::debug!("Stopped XMPP connection at {socket_addr}"),
} Err(err) => {
})).await; log::warn!(
"XMPP connection to {socket_addr} finished with error: {err}"
)
}
}
}),
)
.await;
log::info!("Stopped XMPP projection"); log::info!("Stopped XMPP projection");
Ok(()) Ok(())
}); });
@ -136,7 +143,9 @@ async fn handle_socket(
Ok(e) => println!("{} END", e), Ok(e) => println!("{} END", e),
Err(_) => println!("{:?} END", &buf[0..i]), Err(_) => println!("{:?} END", &buf[0..i]),
} }
stream.write_all(br###"<proceed xmlns="urn:ietf:params:xml:ns:xmpp-tls"/>"###).await?; stream
.write_all(br###"<proceed xmlns="urn:ietf:params:xml:ns:xmpp-tls"/>"###)
.await?;
} }
let config = tokio_rustls::rustls::ServerConfig::builder() let config = tokio_rustls::rustls::ServerConfig::builder()
@ -155,11 +164,11 @@ async fn handle_socket(
let mut new_stream = acceptor.accept(stream).await?; let mut new_stream = acceptor.accept(stream).await?;
log::debug!("TLS connection established"); log::debug!("TLS connection established");
loop { loop {
let i = new_stream.read(&mut buf).await?; let i = new_stream.read(&mut buf).await?;
if i == 0 { break; } if i == 0 {
break;
}
match std::str::from_utf8(&buf[0..i]) { match std::str::from_utf8(&buf[0..i]) {
Ok(e) => println!("{} END", e), Ok(e) => println!("{} END", e),
Err(_) => println!("{:?} END", &buf[0..i]), Err(_) => println!("{:?} END", &buf[0..i]),

View File

@ -0,0 +1 @@
pub static XMLNS: &'static str = "jabber:client";

View File

@ -1 +1,2 @@
pub mod client;
pub mod stream; pub mod stream;

View File

@ -1,9 +1,15 @@
use quick_xml::name::{ResolveResult, Namespace, LocalName, QName}; use std::io::Write;
use quick_xml::{Result, NsReader};
use quick_xml::events::Event; use quick_xml::events::attributes::Attribute;
use quick_xml::events::{BytesStart, Event};
use quick_xml::name::{Namespace, QName, ResolveResult};
use quick_xml::writer::Writer;
use quick_xml::{NsReader, Result};
use tokio::io::AsyncBufRead; use tokio::io::AsyncBufRead;
pub static XMLNS: &'static str = "http://etherx.jabber.org/streams"; pub static XMLNS: &'static str = "http://etherx.jabber.org/streams";
pub static PREFIX: &'static str = "stream";
pub static XMLNS_XML: &'static str = "http://www.w3.org/XML/1998/namespace";
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub struct ClientStreamStart { pub struct ClientStreamStart {
@ -12,10 +18,13 @@ pub struct ClientStreamStart {
pub version: String, pub version: String,
} }
impl ClientStreamStart { impl ClientStreamStart {
pub async fn parse(reader: &mut NsReader<impl AsyncBufRead + Unpin>, buf: &mut Vec<u8>) -> Result<ClientStreamStart> { pub async fn parse(
reader: &mut NsReader<impl AsyncBufRead + Unpin>,
buf: &mut Vec<u8>,
) -> Result<ClientStreamStart> {
if let Event::Start(e) = reader.read_event_into_async(buf).await? { if let Event::Start(e) = reader.read_event_into_async(buf).await? {
let (ns, local) = reader.resolve_element(e.name()); let (ns, local) = reader.resolve_element(e.name());
if ns != ResolveResult::Bound(Namespace(XMLNS.as_bytes())) { if ns != ResolveResult::Bound(Namespace(XMLNS.as_bytes())) {
return Err(panic!()); return Err(panic!());
} }
if local.into_inner() != b"stream" { if local.into_inner() != b"stream" {
@ -31,35 +40,101 @@ impl ClientStreamStart {
(ResolveResult::Unbound, b"to") => { (ResolveResult::Unbound, b"to") => {
let value = attr.unescape_value()?; let value = attr.unescape_value()?;
to = Some(value.to_string()); to = Some(value.to_string());
}, }
(ResolveResult::Bound(Namespace(b"http://www.w3.org/XML/1998/namespace")), b"lang") => { (
ResolveResult::Bound(Namespace(b"http://www.w3.org/XML/1998/namespace")),
b"lang",
) => {
let value = attr.unescape_value()?; let value = attr.unescape_value()?;
lang = Some(value.to_string()); lang = Some(value.to_string());
}, }
(ResolveResult::Unbound, b"version") => { (ResolveResult::Unbound, b"version") => {
let value = attr.unescape_value()?; let value = attr.unescape_value()?;
version = Some(value.to_string()); version = Some(value.to_string());
}, }
_ => {}, _ => {}
} }
} }
Ok(ClientStreamStart { to: to.unwrap(), lang: lang.unwrap(), version: version.unwrap() }) Ok(ClientStreamStart {
to: to.unwrap(),
lang: lang.unwrap(),
version: version.unwrap(),
})
} else { } else {
Err(panic!()) Err(panic!())
} }
} }
} }
pub struct ServerStreamStart {
pub from: String,
pub lang: String,
pub version: String,
}
impl ServerStreamStart {
pub fn write(&self, writer: &mut Writer<impl Write>) -> Result<()> {
let mut event = BytesStart::new("stream:stream");
let attributes = [
Attribute {
key: QName(b"from"),
value: self.from.as_bytes().into(),
},
Attribute {
key: QName(b"version"),
value: self.version.as_bytes().into(),
},
Attribute {
key: QName(b"xmlns"),
value: super::client::XMLNS.as_bytes().into(),
},
Attribute {
key: QName(b"xmlns:stream"),
value: XMLNS.as_bytes().into(),
},
Attribute {
key: QName(b"xml:lang"),
value: self.lang.as_bytes().into(),
},
];
event.extend_attributes(attributes.into_iter());
writer.write_event(Event::Start(event))?;
Ok(())
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::*; use super::*;
#[tokio::test] #[tokio::test]
async fn stream_start_correct() { async fn client_stream_start_correct_parse() {
let input = r###"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="vlnv.dev" version="1.0" xmlns="jabber:client" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace">"###; let input = r###"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="vlnv.dev" version="1.0" xmlns="jabber:client" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace">"###;
let mut reader = NsReader::from_reader(input.as_bytes()); let mut reader = NsReader::from_reader(input.as_bytes());
let mut buf = vec![]; let mut buf = vec![];
let res = ClientStreamStart::parse(&mut reader, &mut buf).await.unwrap(); let res = ClientStreamStart::parse(&mut reader, &mut buf)
assert_eq!(res, ClientStreamStart { to: "vlnv.dev".to_owned(), lang: "en".to_owned(), version: "1.0".to_owned()}) .await
.unwrap();
assert_eq!(
res,
ClientStreamStart {
to: "vlnv.dev".to_owned(),
lang: "en".to_owned(),
version: "1.0".to_owned()
}
)
}
#[test]
fn server_stream_start_write() {
let input = ServerStreamStart {
from: "vlnv.dev".to_owned(),
lang: "en".to_owned(),
version: "1.0".to_owned(),
};
let expected = r###"<stream:stream from="vlnv.dev" version="1.0" xmlns="jabber:client" xmlns:stream="http://etherx.jabber.org/streams" xml:lang="en">"###;
let mut output: Vec<u8> = vec![];
let mut writer = Writer::new(&mut output);
input.write(&mut writer).unwrap();
assert_eq!(std::str::from_utf8(&output).unwrap(), expected);
} }
} }

View File

@ -10,7 +10,6 @@ use hyper::{Method, Request, Response};
use prometheus::{Encoder, Registry as MetricsRegistry, TextEncoder}; use prometheus::{Encoder, Registry as MetricsRegistry, TextEncoder};
use serde::Deserialize; use serde::Deserialize;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::sync::oneshot::channel;
use crate::core::room::RoomRegistry; use crate::core::room::RoomRegistry;
use crate::prelude::*; use crate::prelude::*;