diff --git a/src/prelude.rs b/src/prelude.rs index a7b8ddb..815a74a 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -16,3 +16,11 @@ pub type ByteVec = Vec; pub fn fail(msg: &str) -> anyhow::Error { anyhow::Error::msg(msg.to_owned()) } + +macro_rules! ffail { + ($($arg:tt)*) => { + fail(&format!($($arg)*)) + }; +} + +pub(crate) use ffail; diff --git a/src/projections/xmpp/mod.rs b/src/projections/xmpp/mod.rs index 52f156f..630861d 100644 --- a/src/projections/xmpp/mod.rs +++ b/src/projections/xmpp/mod.rs @@ -173,8 +173,9 @@ async fn socket_force_tls( writer: &mut (impl AsyncWrite + Unpin), reader_buf: &mut Vec, ) -> Result<()> { - let mut xml_reader = &mut NsReader::from_reader(reader); - let mut xml_writer = &mut Writer::new(writer); + use crate::protos::xmpp::tls::*; + let xml_reader = &mut NsReader::from_reader(reader); + let xml_writer = &mut Writer::new(writer); read_xml_header(xml_reader, reader_buf).await?; let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; @@ -195,7 +196,7 @@ async fn socket_force_tls( xml_writer.get_mut().flush().await?; let StartTLS = StartTLS::parse(xml_reader, reader_buf).await?; - Proceed.write_xml(xml_writer).await?; + ProceedTLS.write_xml(xml_writer).await?; xml_writer.get_mut().flush().await?; Ok(()) } diff --git a/src/protos/xmpp/mod.rs b/src/protos/xmpp/mod.rs index e1a8a99..0bb3baa 100644 --- a/src/protos/xmpp/mod.rs +++ b/src/protos/xmpp/mod.rs @@ -1,6 +1,7 @@ pub mod client; pub mod sasl; pub mod stream; +pub mod tls; // Implemented as a macro instead of a fn due to borrowck limitations macro_rules! skip_text { diff --git a/src/protos/xmpp/stream.rs b/src/protos/xmpp/stream.rs index 978721b..7da2a3a 100644 --- a/src/protos/xmpp/stream.rs +++ b/src/protos/xmpp/stream.rs @@ -10,7 +10,6 @@ use crate::prelude::*; pub static XMLNS: &'static str = "http://etherx.jabber.org/streams"; pub static PREFIX: &'static str = "stream"; pub static XMLNS_XML: &'static str = "http://www.w3.org/XML/1998/namespace"; -pub static XMLNS_TLS: &'static str = "urn:ietf:params:xml:ns:xmpp-tls"; #[derive(Debug, PartialEq, Eq)] pub struct ClientStreamStart { @@ -161,40 +160,6 @@ impl Features { } } -pub struct StartTLS; -impl StartTLS { - pub async fn parse( - reader: &mut NsReader, - buf: &mut Vec, - ) -> Result { - let incoming = skip_text!(reader, buf); - if let Event::Empty(e) = incoming { - if e.name().0 == b"starttls" { - Ok(StartTLS) - } else { - Err(fail("starttls expected")) - } - } else { - log::error!("WAT: {incoming:?}"); - Err(panic!()) - } - } -} - -pub struct Proceed; -impl Proceed { - pub async fn write_xml(&self, writer: &mut Writer) -> Result<()> { - let mut event = BytesStart::new("proceed"); - let attributes = [Attribute { - key: QName(b"xmlns"), - value: XMLNS_TLS.as_bytes().into(), - }]; - event.extend_attributes(attributes.into_iter()); - writer.write_event_async(Event::Empty(event)).await?; - Ok(()) - } -} - #[cfg(test)] mod test { use super::*; diff --git a/src/protos/xmpp/tls.rs b/src/protos/xmpp/tls.rs new file mode 100644 index 0000000..6fdf4ad --- /dev/null +++ b/src/protos/xmpp/tls.rs @@ -0,0 +1,40 @@ +use quick_xml::events::attributes::Attribute; +use quick_xml::events::{BytesStart, Event}; +use quick_xml::name::QName; +use quick_xml::{NsReader, Writer}; +use tokio::io::{AsyncBufRead, AsyncWrite}; + +use super::skip_text; +use crate::prelude::*; + +pub static XMLNS: &'static str = "urn:ietf:params:xml:ns:xmpp-tls"; + +pub struct StartTLS; +impl StartTLS { + pub async fn parse( + reader: &mut NsReader, + buf: &mut Vec, + ) -> Result { + let incoming = skip_text!(reader, buf); + if let Event::Empty(ref e) = incoming { + if e.name().0 == b"starttls" { + return Ok(StartTLS); + } + } + Err(ffail!("XML tag starttls expected, received: {incoming:?}")) + } +} + +pub struct ProceedTLS; +impl ProceedTLS { + pub async fn write_xml(&self, writer: &mut Writer) -> Result<()> { + let mut event = BytesStart::new("proceed"); + let attributes = [Attribute { + key: QName(b"xmlns"), + value: XMLNS.as_bytes().into(), + }]; + event.extend_attributes(attributes.into_iter()); + writer.write_event_async(Event::Empty(event)).await?; + Ok(()) + } +}