forked from lavina/lavina
1
0
Fork 0

feat(xmpp): extract tls xml defns into a separate module

This commit is contained in:
Nikita Vilunov 2023-03-06 12:49:51 +01:00
parent f1eff730a2
commit dc788a89c4
5 changed files with 53 additions and 38 deletions

View File

@ -16,3 +16,11 @@ pub type ByteVec = Vec<u8>;
pub fn fail(msg: &str) -> anyhow::Error {
anyhow::Error::msg(msg.to_owned())
}
macro_rules! ffail {
($($arg:tt)*) => {
fail(&format!($($arg)*))
};
}
pub(crate) use ffail;

View File

@ -173,8 +173,9 @@ async fn socket_force_tls(
writer: &mut (impl AsyncWrite + Unpin),
reader_buf: &mut Vec<u8>,
) -> Result<()> {
let mut xml_reader = &mut NsReader::from_reader(reader);
let mut xml_writer = &mut Writer::new(writer);
use crate::protos::xmpp::tls::*;
let xml_reader = &mut NsReader::from_reader(reader);
let xml_writer = &mut Writer::new(writer);
read_xml_header(xml_reader, reader_buf).await?;
let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?;
@ -195,7 +196,7 @@ async fn socket_force_tls(
xml_writer.get_mut().flush().await?;
let StartTLS = StartTLS::parse(xml_reader, reader_buf).await?;
Proceed.write_xml(xml_writer).await?;
ProceedTLS.write_xml(xml_writer).await?;
xml_writer.get_mut().flush().await?;
Ok(())
}

View File

@ -1,6 +1,7 @@
pub mod client;
pub mod sasl;
pub mod stream;
pub mod tls;
// Implemented as a macro instead of a fn due to borrowck limitations
macro_rules! skip_text {

View File

@ -10,7 +10,6 @@ use crate::prelude::*;
pub static XMLNS: &'static str = "http://etherx.jabber.org/streams";
pub static PREFIX: &'static str = "stream";
pub static XMLNS_XML: &'static str = "http://www.w3.org/XML/1998/namespace";
pub static XMLNS_TLS: &'static str = "urn:ietf:params:xml:ns:xmpp-tls";
#[derive(Debug, PartialEq, Eq)]
pub struct ClientStreamStart {
@ -161,40 +160,6 @@ impl Features {
}
}
pub struct StartTLS;
impl StartTLS {
pub async fn parse(
reader: &mut NsReader<impl AsyncBufRead + Unpin>,
buf: &mut Vec<u8>,
) -> Result<StartTLS> {
let incoming = skip_text!(reader, buf);
if let Event::Empty(e) = incoming {
if e.name().0 == b"starttls" {
Ok(StartTLS)
} else {
Err(fail("starttls expected"))
}
} else {
log::error!("WAT: {incoming:?}");
Err(panic!())
}
}
}
pub struct Proceed;
impl Proceed {
pub async fn write_xml(&self, writer: &mut Writer<impl AsyncWrite + Unpin>) -> Result<()> {
let mut event = BytesStart::new("proceed");
let attributes = [Attribute {
key: QName(b"xmlns"),
value: XMLNS_TLS.as_bytes().into(),
}];
event.extend_attributes(attributes.into_iter());
writer.write_event_async(Event::Empty(event)).await?;
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;

40
src/protos/xmpp/tls.rs Normal file
View File

@ -0,0 +1,40 @@
use quick_xml::events::attributes::Attribute;
use quick_xml::events::{BytesStart, Event};
use quick_xml::name::QName;
use quick_xml::{NsReader, Writer};
use tokio::io::{AsyncBufRead, AsyncWrite};
use super::skip_text;
use crate::prelude::*;
pub static XMLNS: &'static str = "urn:ietf:params:xml:ns:xmpp-tls";
pub struct StartTLS;
impl StartTLS {
pub async fn parse(
reader: &mut NsReader<impl AsyncBufRead + Unpin>,
buf: &mut Vec<u8>,
) -> Result<StartTLS> {
let incoming = skip_text!(reader, buf);
if let Event::Empty(ref e) = incoming {
if e.name().0 == b"starttls" {
return Ok(StartTLS);
}
}
Err(ffail!("XML tag starttls expected, received: {incoming:?}"))
}
}
pub struct ProceedTLS;
impl ProceedTLS {
pub async fn write_xml(&self, writer: &mut Writer<impl AsyncWrite + Unpin>) -> Result<()> {
let mut event = BytesStart::new("proceed");
let attributes = [Attribute {
key: QName(b"xmlns"),
value: XMLNS.as_bytes().into(),
}];
event.extend_attributes(attributes.into_iter());
writer.write_event_async(Event::Empty(event)).await?;
Ok(())
}
}