From d1dad72c08eb61619927ad52e14ec759d37bdd73 Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Tue, 7 Mar 2023 14:56:31 +0100 Subject: [PATCH] feat(xmpp): push-based message parser --- src/protos/xmpp/client.rs | 202 +++++++++++++++++++++++++++++++++ src/protos/xmpp/mod.rs | 1 + src/protos/xmpp/stanzaerror.rs | 25 ++++ src/util/mod.rs | 1 + src/util/xml.rs | 23 ++++ 5 files changed, 252 insertions(+) create mode 100644 src/protos/xmpp/stanzaerror.rs create mode 100644 src/util/xml.rs diff --git a/src/protos/xmpp/client.rs b/src/protos/xmpp/client.rs index e8f2fd5..5ce9e34 100644 --- a/src/protos/xmpp/client.rs +++ b/src/protos/xmpp/client.rs @@ -1 +1,203 @@ +use quick_xml::events::Event; + +use crate::prelude::*; +use crate::util::xml::*; + pub static XMLNS: &'static str = "jabber:client"; + +#[derive(PartialEq, Eq, Debug)] +pub struct Message { + from: Option, + id: Option, + to: Option, + // default is Normal + r#type: MessageType, + lang: Option, + + subject: Option, + body: String, +} +impl Message { + pub fn parse() -> impl Parser> { + MessageParser::Init + } +} + +#[derive(Default)] +enum MessageParser { + #[default] + Init, + Outer(MessageParserState), + InSubject(MessageParserState), + InBody(MessageParserState), +} +#[derive(Default)] +struct MessageParserState { + from: Option, + id: Option, + to: Option, + r#type: MessageType, + lang: Option, + subject: Option, + body: Option, +} +impl Parser for MessageParser { + type Output = Result; + + fn consume<'a>(self: Self, event: &Event<'a>) -> Continuation { + // TODO validate tag name and namespace at each stage + match self { + MessageParser::Init => { + if let Event::Start(ref bytes) = event { + let mut state: MessageParserState = Default::default(); + for attr in bytes.attributes() { + let attr = fail_fast!(attr); + if attr.key.0 == b"from" { + let value = fail_fast!(std::str::from_utf8(&*attr.value)); + state.from = Some(value.to_string()) + } else if attr.key.0 == b"id" { + let value = fail_fast!(std::str::from_utf8(&*attr.value)); + state.id = Some(value.to_string()) + } else if attr.key.0 == b"to" { + let value = fail_fast!(std::str::from_utf8(&*attr.value)); + state.to = Some(value.to_string()) + } + } + Continuation::Continue(MessageParser::Outer(state)) + } else { + Continuation::Final(Err(ffail!("Expected start"))) + } + } + MessageParser::Outer(state) => { + match event { + Event::Start(ref bytes) => { + if bytes.name().0 == b"subject" { + Continuation::Continue(MessageParser::InSubject(state)) + } else if bytes.name().0 == b"body" { + Continuation::Continue(MessageParser::InBody(state)) + } else { + Continuation::Final(Err(ffail!("Unexpected XML tag"))) + } + } + Event::End(_) => { + if let Some(body) = state.body { + Continuation::Final(Ok(Message { + from: state.from, + id: state.id, + to: state.to, + r#type: state.r#type, + lang: state.lang, + subject: state.subject, + body, + })) + } else { + Continuation::Final(Err(ffail!("Body not found"))) + } + } + _ => { + Continuation::Final(Err(ffail!("Unexpected XML event: {event:?}"))) + } + } + }, + MessageParser::InSubject(mut state) => { + match event { + Event::Text(ref bytes) =>{ + let subject = fail_fast!(std::str::from_utf8(&*bytes)); + state.subject = Some(subject.to_string()); + Continuation::Continue(MessageParser::InSubject(state)) + } + Event::End(_) => { + Continuation::Continue(MessageParser::Outer(state)) + } + _ => { + Continuation::Final(Err(ffail!("Unexpected XML event: {event:?}"))) + } + } + } + MessageParser::InBody(mut state) => { + match event { + Event::Text(ref bytes) =>{ + match std::str::from_utf8(&*bytes) { + Ok(subject) => { + state.body = Some(subject.to_string()); + Continuation::Continue(MessageParser::InBody(state)) + } + Err(err) => Continuation::Final(Err(err.into())), + } + } + Event::End(_) => { + Continuation::Continue(MessageParser::Outer(state)) + } + _ => { + Continuation::Final(Err(ffail!("Unexpected XML event: {event:?}"))) + } + } + } + } + } +} + +#[derive(PartialEq, Eq, Debug)] +pub enum MessageType { + Chat, + Error, + Groupchat, + Headline, + Normal, +} + +impl Default for MessageType { + fn default() -> Self { + MessageType::Normal + } +} + +impl MessageType { + pub fn from_str(s: &str) -> Result { + use MessageType::*; + match s { + "chat" => Ok(Chat), + "error" => Ok(Error), + "groupchat" => Ok(Groupchat), + "headline" => Ok(Headline), + "normal" => Ok(Normal), + t => Err(ffail!("Unknown message type: {t}")), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use quick_xml::NsReader; + + #[tokio::test] + async fn parse_message() { + let input = r#"daabbb"#; + let mut reader = NsReader::from_reader(input.as_bytes()); + let mut buf = vec![]; + let event = reader.read_event_into_async(&mut buf).await.unwrap(); + let mut parser = Message::parse().consume(&event); + let result = loop { + match parser { + Continuation::Final(res) => break res, + Continuation::Continue(next) => { + parser = next.consume(&reader.read_event_into_async(&mut buf).await.unwrap()) + } + } + } + .unwrap(); + assert_eq!( + result, + Message { + from: None, + id: Some("aacea".to_string()), + to: Some("nikita@vlnv.dev".to_string()), + r#type: MessageType::Normal, + lang: None, + subject: Some("daa".to_string()), + body: "bbb".to_string(), + } + ) + } +} diff --git a/src/protos/xmpp/mod.rs b/src/protos/xmpp/mod.rs index 0bb3baa..2707748 100644 --- a/src/protos/xmpp/mod.rs +++ b/src/protos/xmpp/mod.rs @@ -2,6 +2,7 @@ pub mod client; pub mod sasl; pub mod stream; pub mod tls; +pub mod stanzaerror; // Implemented as a macro instead of a fn due to borrowck limitations macro_rules! skip_text { diff --git a/src/protos/xmpp/stanzaerror.rs b/src/protos/xmpp/stanzaerror.rs new file mode 100644 index 0000000..2c1aa15 --- /dev/null +++ b/src/protos/xmpp/stanzaerror.rs @@ -0,0 +1,25 @@ +pub enum StanzaError { + BadRequest, + Conflict, + FeatureNotImplemented, + Forbidden, + Gone(String), + InternalServerError, + ItemNotFound, + JidMalformed, + NotAcceptable, + NotAllowed, + NotAuthorized, + PaymentRequired, + PolicyViolation, + RecipientUnavailable, + Redirect(String), + RegistrationRequired, + RemoteServerNotFound, + RemoteServerTimeout, + ResourceConstraint, + ServiceUnavailable, + SubscriptionRequired, + UndefinedCondition, + UnexpectedRequest, +} diff --git a/src/util/mod.rs b/src/util/mod.rs index fc6553f..5bbf59c 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -3,6 +3,7 @@ use crate::prelude::*; pub mod http; pub mod table; pub mod telemetry; +pub mod xml; #[cfg(test)] pub mod testkit; diff --git a/src/util/xml.rs b/src/util/xml.rs new file mode 100644 index 0000000..fa14873 --- /dev/null +++ b/src/util/xml.rs @@ -0,0 +1,23 @@ +use quick_xml::events::Event; + +pub trait Parser: Sized { + type Output; + + fn consume<'a>(self: Self, event: &Event<'a>) -> Continuation; +} + +pub enum Continuation { + Final(Res), + Continue(Parser), +} + +macro_rules! fail_fast { + ($errorable: expr) => { + match $errorable { + Ok(i) => i, + Err(e) => return Continuation::Final(Err(e.into())) + } + }; +} + +pub(crate) use fail_fast; \ No newline at end of file