lavina/src/protos/xmpp/stream.rs

199 lines
6.7 KiB
Rust
Raw Normal View History

use quick_xml::events::attributes::Attribute;
use quick_xml::events::{BytesEnd, BytesStart, BytesText, Event};
use quick_xml::name::{Namespace, QName, ResolveResult};
use quick_xml::{NsReader, Writer};
use tokio::io::{AsyncBufRead, AsyncWrite};
use super::skip_text;
use crate::prelude::*;
pub static XMLNS: &'static str = "http://etherx.jabber.org/streams";
pub static PREFIX: &'static str = "stream";
pub static XMLNS_XML: &'static str = "http://www.w3.org/XML/1998/namespace";
#[derive(Debug, PartialEq, Eq)]
pub struct ClientStreamStart {
pub to: String,
pub lang: String,
pub version: String,
}
impl ClientStreamStart {
pub async fn parse(
reader: &mut NsReader<impl AsyncBufRead + Unpin>,
buf: &mut Vec<u8>,
) -> Result<ClientStreamStart> {
let incoming = skip_text!(reader, buf);
if let Event::Start(e) = incoming {
let (ns, local) = reader.resolve_element(e.name());
if ns != ResolveResult::Bound(Namespace(XMLNS.as_bytes())) {
return Err(panic!());
}
if local.into_inner() != b"stream" {
return Err(panic!());
}
let mut to = None;
let mut lang = None;
let mut version = None;
for attr in e.attributes() {
let attr = attr?;
let (ns, name) = reader.resolve_attribute(attr.key);
match (ns, name.into_inner()) {
(ResolveResult::Unbound, b"to") => {
let value = attr.unescape_value()?;
to = Some(value.to_string());
}
(
ResolveResult::Bound(Namespace(b"http://www.w3.org/XML/1998/namespace")),
b"lang",
) => {
let value = attr.unescape_value()?;
lang = Some(value.to_string());
}
(ResolveResult::Unbound, b"version") => {
let value = attr.unescape_value()?;
version = Some(value.to_string());
}
_ => {}
}
}
Ok(ClientStreamStart {
to: to.unwrap(),
lang: lang.unwrap(),
version: version.unwrap(),
})
} else {
log::error!("WAT: {incoming:?}");
Err(panic!())
}
}
}
pub struct ServerStreamStart {
pub from: String,
pub lang: String,
pub version: String,
}
impl ServerStreamStart {
pub async fn write_xml(&self, writer: &mut Writer<impl AsyncWrite + Unpin>) -> Result<()> {
let mut event = BytesStart::new("stream:stream");
let attributes = [
Attribute {
key: QName(b"from"),
value: self.from.as_bytes().into(),
},
Attribute {
key: QName(b"version"),
value: self.version.as_bytes().into(),
},
Attribute {
key: QName(b"xmlns"),
value: super::client::XMLNS.as_bytes().into(),
},
Attribute {
key: QName(b"xmlns:stream"),
value: XMLNS.as_bytes().into(),
},
Attribute {
key: QName(b"xml:lang"),
value: self.lang.as_bytes().into(),
},
];
event.extend_attributes(attributes.into_iter());
writer.write_event_async(Event::Start(event)).await?;
Ok(())
}
}
pub struct Features {
pub start_tls: bool,
pub mechanisms: bool,
pub bind: bool,
}
impl Features {
pub async fn write_xml(&self, writer: &mut Writer<impl AsyncWrite + Unpin>) -> Result<()> {
writer
.write_event_async(Event::Start(BytesStart::new("stream:features")))
.await?;
if self.start_tls {
writer
.write_event_async(Event::Start(BytesStart::new(
r#"starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls""#,
)))
.await?;
writer
.write_event_async(Event::Empty(BytesStart::new("required")))
.await?;
writer
.write_event_async(Event::End(BytesEnd::new("starttls")))
.await?;
}
if self.mechanisms {
writer
.write_event_async(Event::Start(BytesStart::new(
r#"mechanisms xmlns="urn:ietf:params:xml:ns:xmpp-sasl""#,
)))
.await?;
writer
.write_event_async(Event::Start(BytesStart::new(r#"mechanism"#)))
.await?;
writer
.write_event_async(Event::Text(BytesText::new("PLAIN")))
.await?;
writer
.write_event_async(Event::End(BytesEnd::new("mechanism")))
.await?;
writer
.write_event_async(Event::End(BytesEnd::new("mechanisms")))
.await?;
}
if self.bind {
writer
.write_event_async(Event::Empty(BytesStart::new(
r#"bind xmlns="urn:ietf:params:xml:ns:xmpp-bind""#,
)))
.await?;
}
writer
.write_event_async(Event::End(BytesEnd::new("stream:features")))
.await?;
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;
#[tokio::test]
async fn client_stream_start_correct_parse() {
let input = r###"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="vlnv.dev" version="1.0" xmlns="jabber:client" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace">"###;
let mut reader = NsReader::from_reader(input.as_bytes());
let mut buf = vec![];
let res = ClientStreamStart::parse(&mut reader, &mut buf)
.await
.unwrap();
assert_eq!(
res,
ClientStreamStart {
to: "vlnv.dev".to_owned(),
lang: "en".to_owned(),
version: "1.0".to_owned()
}
)
}
#[tokio::test]
async fn server_stream_start_write() {
let input = ServerStreamStart {
from: "vlnv.dev".to_owned(),
lang: "en".to_owned(),
version: "1.0".to_owned(),
};
let expected = r###"<stream:stream from="vlnv.dev" version="1.0" xmlns="jabber:client" xmlns:stream="http://etherx.jabber.org/streams" xml:lang="en">"###;
let mut output: Vec<u8> = vec![];
let mut writer = Writer::new(&mut output);
input.write_xml(&mut writer).await.unwrap();
assert_eq!(std::str::from_utf8(&output).unwrap(), expected);
}
}