lavina/crates/proto-xmpp/src/sasl.rs

165 lines
5.0 KiB
Rust
Raw Normal View History

use std::borrow::Borrow;
use quick_xml::{
events::{BytesStart, Event},
NsReader, Writer,
};
use tokio::io::{AsyncBufRead, AsyncWrite};
use base64::{Engine as _, engine::general_purpose};
use super::skip_text;
use anyhow::{anyhow, Result};
pub enum Mechanism {
Plain,
}
impl Mechanism {
pub fn to_str(&self) -> &'static str {
match self {
Mechanism::Plain => "PLAIN",
}
}
pub fn from_str(input: &[u8]) -> Result<Mechanism> {
match input {
b"PLAIN" => Ok(Mechanism::Plain),
_ => Err(anyhow!("unknown auth mechanism: {input:?}")),
}
}
}
#[derive(PartialEq, Debug)]
pub struct AuthBody {
pub login: String,
pub password: String,
}
impl AuthBody {
pub fn from_str(input: &[u8]) -> Result<AuthBody> {
match general_purpose::STANDARD.decode(input){
Ok(decoded_body) => {
match String::from_utf8(decoded_body) {
Ok(parsed_to_string) => {
let separated_words: Vec<&str> = parsed_to_string.split("\x00").collect::<Vec<_>>().clone();
if separated_words.len() == 3 {
// first segment ignored (might be needed in the future)
Ok(AuthBody { login: separated_words[1].to_string(), password: separated_words[2].to_string() })
} else { return Err(anyhow!("Incorrect auth format")) }
},
Err(e) => return Err(anyhow!(e))
}
},
Err(e) => return Err(anyhow!(e))
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_returning_auth_body() {
let orig = b"\x00login\x00pass";
let encoded = general_purpose::STANDARD.encode(orig);
let expected = AuthBody {login: "login".to_string(), password: "pass".to_string()};
let result = AuthBody::from_str(encoded.as_bytes()).unwrap();
assert_eq!(expected, result);
}
#[test]
fn test_ignoring_first_segment() {
let orig = b"ignored\x00login\x00pass";
let encoded = general_purpose::STANDARD.encode(orig);
let expected = AuthBody {login: "login".to_string(), password: "pass".to_string()};
let result = AuthBody::from_str(encoded.as_bytes()).unwrap();
assert_eq!(expected, result);
}
#[test]
fn test_returning_auth_body_with_empty_strings() {
let orig = b"\x00\x00";
let encoded = general_purpose::STANDARD.encode(orig);
let expected = AuthBody {login: "".to_string(), password: "".to_string()};
let result = AuthBody::from_str(encoded.as_bytes()).unwrap();
assert_eq!(expected, result);
}
#[test]
fn test_fail_if_size_less_then_3() {
let orig = b"login\x00pass";
let encoded = general_purpose::STANDARD.encode(orig);
let expected = AuthBody {login: "login".to_string(), password: "pass".to_string()};
let result = AuthBody::from_str(encoded.as_bytes());
assert!(result.is_err());
}
#[test]
fn test_fail_if_size_greater_then_3() {
let orig = b"first\x00login\x00pass\x00other";
let encoded = general_purpose::STANDARD.encode(orig);
let expected = AuthBody {login: "login".to_string(), password: "pass".to_string()};
let result = AuthBody::from_str(encoded.as_bytes());
assert!(result.is_err());
}
}
pub struct Auth {
pub mechanism: Mechanism,
pub body: Vec<u8>,
}
impl Auth {
pub async fn parse(
reader: &mut NsReader<impl AsyncBufRead + Unpin>,
buf: &mut Vec<u8>,
) -> Result<Auth> {
let event = skip_text!(reader, buf);
let mechanism = if let Event::Start(bytes) = event {
let mut mechanism = None;
for attr in bytes.attributes() {
let attr = attr?;
if attr.key.0 == b"mechanism" {
mechanism = Some(attr.value)
}
}
if let Some(mechanism) = mechanism {
Mechanism::from_str(mechanism.borrow())?
} else {
return Err(anyhow!("expected mechanism attribute in <auth>"));
}
} else {
return Err(anyhow!("expected start of <auth>"));
};
let body = if let Event::Text(text) = reader.read_event_into_async(buf).await? {
text.into_inner().into_owned()
} else {
return Err(anyhow!("expected text body in <auth>"));
};
if let Event::End(_) = reader.read_event_into_async(buf).await? {
//TODO
} else {
return Err(anyhow!("expected end of <auth>"));
};
Ok(Auth { mechanism, body })
}
}
pub struct Success;
impl Success {
pub async fn write_xml(&self, writer: &mut Writer<impl AsyncWrite + Unpin>) -> Result<()> {
let event = BytesStart::new(r#"success xmlns="urn:ietf:params:xml:ns:xmpp-sasl""#);
writer.write_event_async(Event::Empty(event)).await?;
Ok(())
}
}