use quick_xml::events::attributes::Attribute; use quick_xml::events::{BytesEnd, BytesStart, BytesText, Event}; use quick_xml::name::{Namespace, QName, ResolveResult}; use quick_xml::{NsReader, Writer}; use tokio::io::{AsyncBufRead, AsyncWrite}; use super::skip_text; use crate::prelude::*; pub static XMLNS: &'static str = "http://etherx.jabber.org/streams"; pub static PREFIX: &'static str = "stream"; pub static XMLNS_XML: &'static str = "http://www.w3.org/XML/1998/namespace"; #[derive(Debug, PartialEq, Eq)] pub struct ClientStreamStart { pub to: String, pub lang: String, pub version: String, } impl ClientStreamStart { pub async fn parse( reader: &mut NsReader, buf: &mut Vec, ) -> Result { let incoming = skip_text!(reader, buf); if let Event::Start(e) = incoming { let (ns, local) = reader.resolve_element(e.name()); if ns != ResolveResult::Bound(Namespace(XMLNS.as_bytes())) { return Err(panic!()); } if local.into_inner() != b"stream" { return Err(panic!()); } let mut to = None; let mut lang = None; let mut version = None; for attr in e.attributes() { let attr = attr?; let (ns, name) = reader.resolve_attribute(attr.key); match (ns, name.into_inner()) { (ResolveResult::Unbound, b"to") => { let value = attr.unescape_value()?; to = Some(value.to_string()); } ( ResolveResult::Bound(Namespace(b"http://www.w3.org/XML/1998/namespace")), b"lang", ) => { let value = attr.unescape_value()?; lang = Some(value.to_string()); } (ResolveResult::Unbound, b"version") => { let value = attr.unescape_value()?; version = Some(value.to_string()); } _ => {} } } Ok(ClientStreamStart { to: to.unwrap(), lang: lang.unwrap(), version: version.unwrap(), }) } else { log::error!("WAT: {incoming:?}"); Err(panic!()) } } } pub struct ServerStreamStart { pub from: String, pub lang: String, pub version: String, } impl ServerStreamStart { pub async fn write_xml(&self, writer: &mut Writer) -> Result<()> { let mut event = BytesStart::new("stream:stream"); let attributes = [ Attribute { key: QName(b"from"), value: self.from.as_bytes().into(), }, Attribute { key: QName(b"version"), value: self.version.as_bytes().into(), }, Attribute { key: QName(b"xmlns"), value: super::client::XMLNS.as_bytes().into(), }, Attribute { key: QName(b"xmlns:stream"), value: XMLNS.as_bytes().into(), }, Attribute { key: QName(b"xml:lang"), value: self.lang.as_bytes().into(), }, ]; event.extend_attributes(attributes.into_iter()); writer.write_event_async(Event::Start(event)).await?; Ok(()) } } pub struct Features { pub start_tls: bool, pub mechanisms: bool, pub bind: bool, } impl Features { pub async fn write_xml(&self, writer: &mut Writer) -> Result<()> { writer .write_event_async(Event::Start(BytesStart::new("stream:features"))) .await?; if self.start_tls { writer .write_event_async(Event::Start(BytesStart::new( r#"starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls""#, ))) .await?; writer .write_event_async(Event::Empty(BytesStart::new("required"))) .await?; writer .write_event_async(Event::End(BytesEnd::new("starttls"))) .await?; } if self.mechanisms { writer .write_event_async(Event::Start(BytesStart::new( r#"mechanisms xmlns="urn:ietf:params:xml:ns:xmpp-sasl""#, ))) .await?; writer .write_event_async(Event::Start(BytesStart::new(r#"mechanism"#))) .await?; writer .write_event_async(Event::Text(BytesText::new("PLAIN"))) .await?; writer .write_event_async(Event::End(BytesEnd::new("mechanism"))) .await?; writer .write_event_async(Event::End(BytesEnd::new("mechanisms"))) .await?; } if self.bind { writer .write_event_async(Event::Empty(BytesStart::new( r#"bind xmlns="urn:ietf:params:xml:ns:xmpp-bind""#, ))) .await?; } writer .write_event_async(Event::End(BytesEnd::new("stream:features"))) .await?; Ok(()) } } #[cfg(test)] mod test { use super::*; #[tokio::test] async fn client_stream_start_correct_parse() { let input = r###""###; let mut reader = NsReader::from_reader(input.as_bytes()); let mut buf = vec![]; let res = ClientStreamStart::parse(&mut reader, &mut buf) .await .unwrap(); assert_eq!( res, ClientStreamStart { to: "vlnv.dev".to_owned(), lang: "en".to_owned(), version: "1.0".to_owned() } ) } #[tokio::test] async fn server_stream_start_write() { let input = ServerStreamStart { from: "vlnv.dev".to_owned(), lang: "en".to_owned(), version: "1.0".to_owned(), }; let expected = r###""###; let mut output: Vec = vec![]; let mut writer = Writer::new(&mut output); input.write_xml(&mut writer).await.unwrap(); assert_eq!(std::str::from_utf8(&output).unwrap(), expected); } }