forked from lavina/lavina
202 lines
7.3 KiB
Rust
202 lines
7.3 KiB
Rust
use quick_xml::events::attributes::Attribute;
|
|
use quick_xml::events::{BytesEnd, BytesStart, BytesText, Event};
|
|
use quick_xml::name::{Namespace, QName, ResolveResult};
|
|
use quick_xml::{NsReader, Writer};
|
|
use tokio::io::{AsyncBufRead, AsyncWrite};
|
|
|
|
use super::skip_text;
|
|
|
|
use crate::xml::ToXml;
|
|
use anyhow::{anyhow, Result};
|
|
|
|
pub static XMLNS: &'static str = "http://etherx.jabber.org/streams";
|
|
pub static PREFIX: &'static str = "stream";
|
|
pub static XMLNS_XML: &'static str = "http://www.w3.org/XML/1998/namespace";
|
|
|
|
#[derive(Debug, PartialEq, Eq)]
|
|
pub struct ClientStreamStart {
|
|
pub to: String,
|
|
pub lang: Option<String>,
|
|
pub version: String,
|
|
}
|
|
impl ClientStreamStart {
|
|
pub async fn parse(
|
|
reader: &mut NsReader<impl AsyncBufRead + Unpin>,
|
|
buf: &mut Vec<u8>,
|
|
) -> Result<ClientStreamStart> {
|
|
let mut incoming = skip_text!(reader, buf);
|
|
if let Event::Decl(bytes) = incoming {
|
|
// this is <?xml ...> header
|
|
if let Some(encoding) = bytes.encoding() {
|
|
let encoding = encoding?;
|
|
if &*encoding != b"UTF-8" {
|
|
return Err(anyhow!("Unsupported encoding: {encoding:?}"));
|
|
}
|
|
}
|
|
incoming = skip_text!(reader, buf);
|
|
}
|
|
if let Event::Start(e) = incoming {
|
|
let (ns, local) = reader.resolve_element(e.name());
|
|
if ns != ResolveResult::Bound(Namespace(XMLNS.as_bytes())) {
|
|
return Err(anyhow!("Invalid namespace for stream element"));
|
|
}
|
|
if local.into_inner() != b"stream" {
|
|
return Err(anyhow!("Invalid local name for stream element"));
|
|
}
|
|
let mut to = None;
|
|
let mut lang = None;
|
|
let mut version = None;
|
|
for attr in e.attributes() {
|
|
let attr = attr?;
|
|
let (ns, name) = reader.resolve_attribute(attr.key);
|
|
match (ns, name.into_inner()) {
|
|
(ResolveResult::Unbound, b"to") => {
|
|
let value = attr.unescape_value()?;
|
|
to = Some(value.to_string());
|
|
}
|
|
(ResolveResult::Bound(Namespace(b"http://www.w3.org/XML/1998/namespace")), b"lang") => {
|
|
let value = attr.unescape_value()?;
|
|
lang = Some(value.to_string());
|
|
}
|
|
(ResolveResult::Unbound, b"version") => {
|
|
let value = attr.unescape_value()?;
|
|
version = Some(value.to_string());
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
Ok(ClientStreamStart {
|
|
to: to.unwrap(),
|
|
lang: lang,
|
|
version: version.unwrap(),
|
|
})
|
|
} else {
|
|
Err(anyhow!("Incoming message does not belong XML Start Event"))
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct ServerStreamStart {
|
|
pub from: String,
|
|
pub lang: String,
|
|
pub id: String,
|
|
pub version: String,
|
|
}
|
|
impl ServerStreamStart {
|
|
pub async fn write_xml(&self, writer: &mut Writer<impl AsyncWrite + Unpin>) -> Result<()> {
|
|
let mut event = BytesStart::new("stream:stream");
|
|
let attributes = [
|
|
Attribute {
|
|
key: QName(b"from"),
|
|
value: self.from.as_bytes().into(),
|
|
},
|
|
Attribute {
|
|
key: QName(b"version"),
|
|
value: self.version.as_bytes().into(),
|
|
},
|
|
Attribute {
|
|
key: QName(b"xmlns"),
|
|
value: super::client::XMLNS.as_bytes().into(),
|
|
},
|
|
Attribute {
|
|
key: QName(b"xmlns:stream"),
|
|
value: XMLNS.as_bytes().into(),
|
|
},
|
|
Attribute {
|
|
key: QName(b"xml:lang"),
|
|
value: self.lang.as_bytes().into(),
|
|
},
|
|
Attribute {
|
|
key: QName(b"id"),
|
|
value: self.id.as_bytes().into(),
|
|
},
|
|
];
|
|
event.extend_attributes(attributes.into_iter());
|
|
writer.write_event_async(Event::Start(event)).await?;
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
pub struct ServerStreamEnd;
|
|
impl ToXml for ServerStreamEnd {
|
|
fn serialize(&self, events: &mut Vec<Event<'static>>) {
|
|
events.push(Event::End(BytesEnd::new("stream:stream")));
|
|
}
|
|
}
|
|
|
|
pub struct Features {
|
|
pub start_tls: bool,
|
|
pub mechanisms: bool,
|
|
pub bind: bool,
|
|
}
|
|
impl Features {
|
|
pub async fn write_xml(&self, writer: &mut Writer<impl AsyncWrite + Unpin>) -> Result<()> {
|
|
writer.write_event_async(Event::Start(BytesStart::new("stream:features"))).await?;
|
|
if self.start_tls {
|
|
writer
|
|
.write_event_async(Event::Start(BytesStart::new(
|
|
r#"starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls""#,
|
|
)))
|
|
.await?;
|
|
writer.write_event_async(Event::Empty(BytesStart::new("required"))).await?;
|
|
writer.write_event_async(Event::End(BytesEnd::new("starttls"))).await?;
|
|
}
|
|
if self.mechanisms {
|
|
writer
|
|
.write_event_async(Event::Start(BytesStart::new(
|
|
r#"mechanisms xmlns="urn:ietf:params:xml:ns:xmpp-sasl""#,
|
|
)))
|
|
.await?;
|
|
writer.write_event_async(Event::Start(BytesStart::new(r#"mechanism"#))).await?;
|
|
writer.write_event_async(Event::Text(BytesText::new("PLAIN"))).await?;
|
|
writer.write_event_async(Event::End(BytesEnd::new("mechanism"))).await?;
|
|
writer.write_event_async(Event::End(BytesEnd::new("mechanisms"))).await?;
|
|
}
|
|
if self.bind {
|
|
writer
|
|
.write_event_async(Event::Empty(BytesStart::new(
|
|
r#"bind xmlns="urn:ietf:params:xml:ns:xmpp-bind""#,
|
|
)))
|
|
.await?;
|
|
}
|
|
writer.write_event_async(Event::End(BytesEnd::new("stream:features"))).await?;
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
use super::*;
|
|
|
|
#[tokio::test]
|
|
async fn client_stream_start_correct_parse() {
|
|
let input = r###"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="vlnv.dev" version="1.0" xmlns="jabber:client" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace">"###;
|
|
let mut reader = NsReader::from_reader(input.as_bytes());
|
|
let mut buf = vec![];
|
|
let res = ClientStreamStart::parse(&mut reader, &mut buf).await.unwrap();
|
|
assert_eq!(
|
|
res,
|
|
ClientStreamStart {
|
|
to: "vlnv.dev".to_owned(),
|
|
lang: Some("en".to_owned()),
|
|
version: "1.0".to_owned()
|
|
}
|
|
)
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn server_stream_start_write() {
|
|
let input = ServerStreamStart {
|
|
from: "vlnv.dev".to_owned(),
|
|
lang: "en".to_owned(),
|
|
id: "stream_id".to_owned(),
|
|
version: "1.0".to_owned(),
|
|
};
|
|
let expected = r###"<stream:stream from="vlnv.dev" version="1.0" xmlns="jabber:client" xmlns:stream="http://etherx.jabber.org/streams" xml:lang="en" id="stream_id">"###;
|
|
let mut output: Vec<u8> = vec![];
|
|
let mut writer = Writer::new(&mut output);
|
|
input.write_xml(&mut writer).await.unwrap();
|
|
assert_eq!(std::str::from_utf8(&output).unwrap(), expected);
|
|
}
|
|
}
|