use std::borrow::Borrow; use quick_xml::{ events::{BytesStart, Event}, NsReader, Writer, }; use tokio::io::{AsyncBufRead, AsyncWrite}; use base64::{Engine as _, engine::general_purpose}; use super::skip_text; use anyhow::{anyhow, Result}; pub enum Mechanism { Plain, } impl Mechanism { pub fn to_str(&self) -> &'static str { match self { Mechanism::Plain => "PLAIN", } } pub fn from_str(input: &[u8]) -> Result { match input { b"PLAIN" => Ok(Mechanism::Plain), _ => Err(anyhow!("unknown auth mechanism: {input:?}")), } } } #[derive(PartialEq, Debug)] pub struct AuthBody { pub login: String, pub password: String, } impl AuthBody { pub fn from_str(input: &[u8]) -> Result { match general_purpose::STANDARD.decode(input){ Ok(decoded_body) => { match String::from_utf8(decoded_body) { Ok(parsed_to_string) => { let separated_words: Vec<&str> = parsed_to_string.split("\x00").collect::>().clone(); if separated_words.len() == 3 { // first segment ignored (might be needed in the future) Ok(AuthBody { login: separated_words[1].to_string(), password: separated_words[2].to_string() }) } else { return Err(anyhow!("Incorrect auth format")) } }, Err(e) => return Err(anyhow!(e)) } }, Err(e) => return Err(anyhow!(e)) } } } #[cfg(test)] mod test { use super::*; #[test] fn test_returning_auth_body() { let orig = b"\x00login\x00pass"; let encoded = general_purpose::STANDARD.encode(orig); let expected = AuthBody {login: "login".to_string(), password: "pass".to_string()}; let result = AuthBody::from_str(encoded.as_bytes()).unwrap(); assert_eq!(expected, result); } #[test] fn test_ignoring_first_segment() { let orig = b"ignored\x00login\x00pass"; let encoded = general_purpose::STANDARD.encode(orig); let expected = AuthBody {login: "login".to_string(), password: "pass".to_string()}; let result = AuthBody::from_str(encoded.as_bytes()).unwrap(); assert_eq!(expected, result); } #[test] fn test_returning_auth_body_with_empty_strings() { let orig = b"\x00\x00"; let encoded = general_purpose::STANDARD.encode(orig); let expected = AuthBody {login: "".to_string(), password: "".to_string()}; let result = AuthBody::from_str(encoded.as_bytes()).unwrap(); assert_eq!(expected, result); } #[test] fn test_fail_if_size_less_then_3() { let orig = b"login\x00pass"; let encoded = general_purpose::STANDARD.encode(orig); let expected = AuthBody {login: "login".to_string(), password: "pass".to_string()}; let result = AuthBody::from_str(encoded.as_bytes()); assert!(result.is_err()); } #[test] fn test_fail_if_size_greater_then_3() { let orig = b"first\x00login\x00pass\x00other"; let encoded = general_purpose::STANDARD.encode(orig); let expected = AuthBody {login: "login".to_string(), password: "pass".to_string()}; let result = AuthBody::from_str(encoded.as_bytes()); assert!(result.is_err()); } } pub struct Auth { pub mechanism: Mechanism, pub body: Vec, } impl Auth { pub async fn parse( reader: &mut NsReader, buf: &mut Vec, ) -> Result { let event = skip_text!(reader, buf); let mechanism = if let Event::Start(bytes) = event { let mut mechanism = None; for attr in bytes.attributes() { let attr = attr?; if attr.key.0 == b"mechanism" { mechanism = Some(attr.value) } } if let Some(mechanism) = mechanism { Mechanism::from_str(mechanism.borrow())? } else { return Err(anyhow!("expected mechanism attribute in ")); } } else { return Err(anyhow!("expected start of ")); }; let body = if let Event::Text(text) = reader.read_event_into_async(buf).await? { text.into_inner().into_owned() } else { return Err(anyhow!("expected text body in ")); }; if let Event::End(_) = reader.read_event_into_async(buf).await? { //TODO } else { return Err(anyhow!("expected end of ")); }; Ok(Auth { mechanism, body }) } } pub struct Success; impl Success { pub async fn write_xml(&self, writer: &mut Writer) -> Result<()> { let event = BytesStart::new(r#"success xmlns="urn:ietf:params:xml:ns:xmpp-sasl""#); writer.write_event_async(Event::Empty(event)).await?; Ok(()) } }