diff --git a/src/protos/xmpp/client.rs b/src/protos/xmpp/client.rs index 5ce9e34..34349a4 100644 --- a/src/protos/xmpp/client.rs +++ b/src/protos/xmpp/client.rs @@ -7,15 +7,15 @@ pub static XMLNS: &'static str = "jabber:client"; #[derive(PartialEq, Eq, Debug)] pub struct Message { - from: Option, - id: Option, - to: Option, + pub from: Option, + pub id: Option, + pub to: Option, // default is Normal - r#type: MessageType, - lang: Option, + pub r#type: MessageType, + pub lang: Option, - subject: Option, - body: String, + pub subject: Option, + pub body: String, } impl Message { pub fn parse() -> impl Parser> { @@ -61,6 +61,9 @@ impl Parser for MessageParser { } else if attr.key.0 == b"to" { let value = fail_fast!(std::str::from_utf8(&*attr.value)); state.to = Some(value.to_string()) + } else if attr.key.0 == b"type" { + let value = fail_fast!(MessageType::from_str(&*attr.value)); + state.r#type = value; } } Continuation::Continue(MessageParser::Outer(state)) @@ -68,71 +71,53 @@ impl Parser for MessageParser { Continuation::Final(Err(ffail!("Expected start"))) } } - MessageParser::Outer(state) => { - match event { - Event::Start(ref bytes) => { - if bytes.name().0 == b"subject" { - Continuation::Continue(MessageParser::InSubject(state)) - } else if bytes.name().0 == b"body" { - Continuation::Continue(MessageParser::InBody(state)) - } else { - Continuation::Final(Err(ffail!("Unexpected XML tag"))) - } - } - Event::End(_) => { - if let Some(body) = state.body { - Continuation::Final(Ok(Message { - from: state.from, - id: state.id, - to: state.to, - r#type: state.r#type, - lang: state.lang, - subject: state.subject, - body, - })) - } else { - Continuation::Final(Err(ffail!("Body not found"))) - } - } - _ => { - Continuation::Final(Err(ffail!("Unexpected XML event: {event:?}"))) - } - } - }, - MessageParser::InSubject(mut state) => { - match event { - Event::Text(ref bytes) =>{ - let subject = fail_fast!(std::str::from_utf8(&*bytes)); - state.subject = Some(subject.to_string()); + MessageParser::Outer(state) => match event { + Event::Start(ref bytes) => { + if bytes.name().0 == b"subject" { Continuation::Continue(MessageParser::InSubject(state)) - } - Event::End(_) => { - Continuation::Continue(MessageParser::Outer(state)) - } - _ => { - Continuation::Final(Err(ffail!("Unexpected XML event: {event:?}"))) + } else if bytes.name().0 == b"body" { + Continuation::Continue(MessageParser::InBody(state)) + } else { + Continuation::Final(Err(ffail!("Unexpected XML tag"))) } } - } - MessageParser::InBody(mut state) => { - match event { - Event::Text(ref bytes) =>{ - match std::str::from_utf8(&*bytes) { - Ok(subject) => { - state.body = Some(subject.to_string()); - Continuation::Continue(MessageParser::InBody(state)) - } - Err(err) => Continuation::Final(Err(err.into())), - } - } - Event::End(_) => { - Continuation::Continue(MessageParser::Outer(state)) - } - _ => { - Continuation::Final(Err(ffail!("Unexpected XML event: {event:?}"))) + Event::End(_) => { + if let Some(body) = state.body { + Continuation::Final(Ok(Message { + from: state.from, + id: state.id, + to: state.to, + r#type: state.r#type, + lang: state.lang, + subject: state.subject, + body, + })) + } else { + Continuation::Final(Err(ffail!("Body not found"))) } } - } + _ => Continuation::Final(Err(ffail!("Unexpected XML event: {event:?}"))), + }, + MessageParser::InSubject(mut state) => match event { + Event::Text(ref bytes) => { + let subject = fail_fast!(std::str::from_utf8(&*bytes)); + state.subject = Some(subject.to_string()); + Continuation::Continue(MessageParser::InSubject(state)) + } + Event::End(_) => Continuation::Continue(MessageParser::Outer(state)), + _ => Continuation::Final(Err(ffail!("Unexpected XML event: {event:?}"))), + }, + MessageParser::InBody(mut state) => match event { + Event::Text(ref bytes) => match std::str::from_utf8(&*bytes) { + Ok(subject) => { + state.body = Some(subject.to_string()); + Continuation::Continue(MessageParser::InBody(state)) + } + Err(err) => Continuation::Final(Err(err.into())), + }, + Event::End(_) => Continuation::Continue(MessageParser::Outer(state)), + _ => Continuation::Final(Err(ffail!("Unexpected XML event: {event:?}"))), + }, } } } @@ -153,8 +138,9 @@ impl Default for MessageType { } impl MessageType { - pub fn from_str(s: &str) -> Result { + pub fn from_str(s: &[u8]) -> Result { use MessageType::*; + let s = std::str::from_utf8(s)?; match s { "chat" => Ok(Chat), "error" => Ok(Error), @@ -193,7 +179,7 @@ mod tests { from: None, id: Some("aacea".to_string()), to: Some("nikita@vlnv.dev".to_string()), - r#type: MessageType::Normal, + r#type: MessageType::Chat, lang: None, subject: Some("daa".to_string()), body: "bbb".to_string(), diff --git a/src/protos/xmpp/mod.rs b/src/protos/xmpp/mod.rs index 2707748..37088f6 100644 --- a/src/protos/xmpp/mod.rs +++ b/src/protos/xmpp/mod.rs @@ -1,8 +1,8 @@ pub mod client; pub mod sasl; +pub mod stanzaerror; pub mod stream; pub mod tls; -pub mod stanzaerror; // Implemented as a macro instead of a fn due to borrowck limitations macro_rules! skip_text { diff --git a/src/protos/xmpp/stream.rs b/src/protos/xmpp/stream.rs index 7da2a3a..da6465a 100644 --- a/src/protos/xmpp/stream.rs +++ b/src/protos/xmpp/stream.rs @@ -4,8 +4,10 @@ use quick_xml::name::{Namespace, QName, ResolveResult}; use quick_xml::{NsReader, Writer}; use tokio::io::{AsyncBufRead, AsyncWrite}; +use super::client::Message; use super::skip_text; use crate::prelude::*; +use crate::util::xml::{Continuation, Parser}; pub static XMLNS: &'static str = "http://etherx.jabber.org/streams"; pub static PREFIX: &'static str = "stream"; @@ -160,8 +162,43 @@ impl Features { } } +#[derive(PartialEq, Eq, Debug)] +pub enum FromClient { + Message(Message), +} +impl FromClient { + pub async fn parse( + reader: &mut NsReader, + buf: &mut Vec, + ) -> Result { + let incoming = skip_text!(reader, buf); + let start = if let Event::Start(ref bytes) = incoming { + bytes + } else { + return Err(ffail!("Unexpected XML event: {incoming:?}")); + }; + let (ns, name) = reader.resolve_element(start.name()); + if name.as_ref() == b"message" { + let mut parser = Message::parse().consume(&incoming); + let result = loop { + match parser { + Continuation::Final(res) => break res, + Continuation::Continue(next) => { + parser = next.consume(&reader.read_event_into_async(buf).await.unwrap()) + } + } + }?; + Ok(FromClient::Message(result)) + } else { + Err(ffail!("Unknown XML tag: {name:?}")) + } + } +} + #[cfg(test)] mod test { + use crate::protos::xmpp::client::MessageType; + use super::*; #[tokio::test] @@ -195,4 +232,24 @@ mod test { input.write_xml(&mut writer).await.unwrap(); assert_eq!(std::str::from_utf8(&output).unwrap(), expected); } + + #[tokio::test] + async fn client_message() { + let input = r#"daabbb"#; + let mut reader = NsReader::from_reader(input.as_bytes()); + let mut buf = vec![]; + let res = FromClient::parse(&mut reader, &mut buf).await.unwrap(); + assert_eq!( + res, + FromClient::Message(Message { + from: None, + id: Some("aacea".to_string()), + r#type: MessageType::Chat, + to: Some("nikita@vlnv.dev".to_string()), + lang: None, + subject: Some("daa".to_string()), + body: "bbb".to_string(), + }) + ) + } } diff --git a/src/util/mod.rs b/src/util/mod.rs index 5bbf59c..0593dbd 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -3,9 +3,9 @@ use crate::prelude::*; pub mod http; pub mod table; pub mod telemetry; -pub mod xml; #[cfg(test)] pub mod testkit; +pub mod xml; pub struct Terminator { signal: Promise<()>, diff --git a/src/util/xml.rs b/src/util/xml.rs index fa14873..2a9dadd 100644 --- a/src/util/xml.rs +++ b/src/util/xml.rs @@ -15,9 +15,9 @@ macro_rules! fail_fast { ($errorable: expr) => { match $errorable { Ok(i) => i, - Err(e) => return Continuation::Final(Err(e.into())) + Err(e) => return Continuation::Final(Err(e.into())), } }; } -pub(crate) use fail_fast; \ No newline at end of file +pub(crate) use fail_fast;