diff --git a/.gitea/workflows/test-pr.yml b/.gitea/workflows/test-pr.yml index 080b65d..cb0b163 100644 --- a/.gitea/workflows/test-pr.yml +++ b/.gitea/workflows/test-pr.yml @@ -12,7 +12,7 @@ jobs: uses: https://github.com/actions-rs/cargo@v1 with: command: fmt - args: "--check -p mgmt-api -p lavina-core -p projection-irc" + args: "--check -p mgmt-api -p lavina-core -p projection-irc -p projection-xmpp -p sasl" - name: cargo check uses: https://github.com/actions-rs/cargo@v1 with: diff --git a/Cargo.lock b/Cargo.lock index f838968..8461513 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1287,6 +1287,7 @@ dependencies = [ "proto-xmpp", "quick-xml", "rustls-pemfile", + "sasl", "serde", "tokio", "tokio-rustls", @@ -1327,7 +1328,6 @@ version = "0.0.2-dev" dependencies = [ "anyhow", "assert_matches", - "base64", "derive_more", "lazy_static", "quick-xml", @@ -1559,6 +1559,14 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" +[[package]] +name = "sasl" +version = "0.0.2-dev" +dependencies = [ + "anyhow", + "base64", +] + [[package]] name = "scopeguard" version = "1.2.0" diff --git a/Cargo.toml b/Cargo.toml index 07a6784..3412188 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "crates/projection-irc", "crates/proto-xmpp", "crates/mgmt-api", + "crates/sasl", ] [workspace.package] @@ -29,6 +30,7 @@ prometheus = { version = "0.13.3", default-features = false } base64 = "0.21.3" lavina-core = { path = "crates/lavina-core" } tracing-subscriber = "0.3.16" +sasl = { path = "crates/sasl" } [package] name = "lavina" diff --git a/crates/projection-xmpp/Cargo.toml b/crates/projection-xmpp/Cargo.toml index 6395b29..ed6502f 100644 --- a/crates/projection-xmpp/Cargo.toml +++ b/crates/projection-xmpp/Cargo.toml @@ -13,6 +13,7 @@ prometheus.workspace = true futures-util.workspace = true quick-xml.workspace = true +sasl.workspace = true proto-xmpp = { path = "../proto-xmpp" } uuid = { version = "1.3.0", features = ["v4"] } tokio-rustls = { version = "0.24.1", features = ["dangerous_configuration"] } diff --git a/crates/projection-xmpp/src/lib.rs b/crates/projection-xmpp/src/lib.rs index e4da283..9192092 100644 --- a/crates/projection-xmpp/src/lib.rs +++ b/crates/projection-xmpp/src/lib.rs @@ -28,9 +28,9 @@ use lavina_core::repo::Storage; use lavina_core::room::RoomRegistry; use lavina_core::terminator::Terminator; use proto_xmpp::bind::{Name, Resource}; -use proto_xmpp::sasl::AuthBody; use proto_xmpp::stream::*; use proto_xmpp::xml::{Continuation, FromXml, Parser, ToXml}; +use sasl::AuthBody; use self::proto::ClientPacket; @@ -249,9 +249,7 @@ async fn socket_auth( read_xml_header(xml_reader, reader_buf).await?; let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; - xml_writer - .write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))) - .await?; + xml_writer.write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))).await?; ServerStreamStart { from: "localhost".into(), lang: "en".into(), @@ -317,9 +315,7 @@ async fn socket_final( read_xml_header(xml_reader, reader_buf).await?; let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; - xml_writer - .write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))) - .await?; + xml_writer.write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))).await?; ServerStreamStart { from: "localhost".into(), lang: "en".into(), diff --git a/crates/projection-xmpp/src/message.rs b/crates/projection-xmpp/src/message.rs index 9369076..44aab05 100644 --- a/crates/projection-xmpp/src/message.rs +++ b/crates/projection-xmpp/src/message.rs @@ -19,9 +19,7 @@ impl<'a> XmppConnection<'a> { }) = m.to { if server.0.as_ref() == "rooms.localhost" && m.r#type == MessageType::Groupchat { - self.user_handle - .send_message(RoomId::from(name.0.clone())?, m.body.clone().into()) - .await?; + self.user_handle.send_message(RoomId::from(name.0.clone())?, m.body.clone().into()).await?; Message::<()> { to: Some(Jid { name: Some(self.user.xmpp_name.clone()), diff --git a/crates/proto-xmpp/Cargo.toml b/crates/proto-xmpp/Cargo.toml index df0f03c..01fec80 100644 --- a/crates/proto-xmpp/Cargo.toml +++ b/crates/proto-xmpp/Cargo.toml @@ -10,7 +10,6 @@ regex.workspace = true anyhow.workspace = true tokio.workspace = true derive_more.workspace = true -base64.workspace = true [dev-dependencies] assert_matches.workspace = true diff --git a/crates/proto-xmpp/src/sasl.rs b/crates/proto-xmpp/src/sasl.rs index 71f4945..e147962 100644 --- a/crates/proto-xmpp/src/sasl.rs +++ b/crates/proto-xmpp/src/sasl.rs @@ -1,14 +1,11 @@ use std::borrow::Borrow; -use quick_xml::{ - events::{BytesStart, Event}, - NsReader, Writer, -}; +use anyhow::{anyhow, Result}; +use quick_xml::events::{BytesStart, Event}; +use quick_xml::{NsReader, Writer}; use tokio::io::{AsyncBufRead, AsyncWrite}; -use base64::{Engine as _, engine::general_purpose}; use super::skip_text; -use anyhow::{anyhow, Result}; pub enum Mechanism { Plain, @@ -29,98 +26,13 @@ impl Mechanism { } } -#[derive(PartialEq, Debug)] -pub struct AuthBody { - pub login: String, - pub password: String, -} - -impl AuthBody { - pub fn from_str(input: &[u8]) -> Result { - match general_purpose::STANDARD.decode(input){ - Ok(decoded_body) => { - match String::from_utf8(decoded_body) { - Ok(parsed_to_string) => { - let separated_words: Vec<&str> = parsed_to_string.split("\x00").collect::>().clone(); - if separated_words.len() == 3 { - // first segment ignored (might be needed in the future) - Ok(AuthBody { login: separated_words[1].to_string(), password: separated_words[2].to_string() }) - } else { return Err(anyhow!("Incorrect auth format")) } - }, - Err(e) => return Err(anyhow!(e)) - } - }, - Err(e) => return Err(anyhow!(e)) - } - } -} - -#[cfg(test)] -mod test { - use super::*; - #[test] - fn test_returning_auth_body() { - let orig = b"\x00login\x00pass"; - let encoded = general_purpose::STANDARD.encode(orig); - let expected = AuthBody {login: "login".to_string(), password: "pass".to_string()}; - let result = AuthBody::from_str(encoded.as_bytes()).unwrap(); - - assert_eq!(expected, result); - } - - #[test] - fn test_ignoring_first_segment() { - let orig = b"ignored\x00login\x00pass"; - let encoded = general_purpose::STANDARD.encode(orig); - let expected = AuthBody {login: "login".to_string(), password: "pass".to_string()}; - let result = AuthBody::from_str(encoded.as_bytes()).unwrap(); - - assert_eq!(expected, result); - } - - #[test] - fn test_returning_auth_body_with_empty_strings() { - let orig = b"\x00\x00"; - let encoded = general_purpose::STANDARD.encode(orig); - let expected = AuthBody {login: "".to_string(), password: "".to_string()}; - let result = AuthBody::from_str(encoded.as_bytes()).unwrap(); - - assert_eq!(expected, result); - } - - - #[test] - fn test_fail_if_size_less_then_3() { - let orig = b"login\x00pass"; - let encoded = general_purpose::STANDARD.encode(orig); - let expected = AuthBody {login: "login".to_string(), password: "pass".to_string()}; - let result = AuthBody::from_str(encoded.as_bytes()); - - assert!(result.is_err()); - } - - #[test] - fn test_fail_if_size_greater_then_3() { - let orig = b"first\x00login\x00pass\x00other"; - let encoded = general_purpose::STANDARD.encode(orig); - let expected = AuthBody {login: "login".to_string(), password: "pass".to_string()}; - let result = AuthBody::from_str(encoded.as_bytes()); - - assert!(result.is_err()); - } - -} - pub struct Auth { pub mechanism: Mechanism, pub body: Vec, } impl Auth { - pub async fn parse( - reader: &mut NsReader, - buf: &mut Vec, - ) -> Result { + pub async fn parse(reader: &mut NsReader, buf: &mut Vec) -> Result { let event = skip_text!(reader, buf); let mechanism = if let Event::Start(bytes) = event { let mut mechanism = None; diff --git a/crates/sasl/Cargo.toml b/crates/sasl/Cargo.toml new file mode 100644 index 0000000..59504ef --- /dev/null +++ b/crates/sasl/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "sasl" +edition = "2021" +version.workspace = true + +[dependencies] +anyhow.workspace = true +base64.workspace = true diff --git a/crates/sasl/src/lib.rs b/crates/sasl/src/lib.rs new file mode 100644 index 0000000..e00d67b --- /dev/null +++ b/crates/sasl/src/lib.rs @@ -0,0 +1,103 @@ +use anyhow::{anyhow, Result}; +use base64::engine::general_purpose; +use base64::Engine; + +#[derive(PartialEq, Debug)] +pub struct AuthBody { + pub login: String, + pub password: String, +} + +impl AuthBody { + pub fn from_str(input: &[u8]) -> Result { + match general_purpose::STANDARD.decode(input) { + Ok(decoded_body) => { + match String::from_utf8(decoded_body) { + Ok(parsed_to_string) => { + let separated_words: Vec<&str> = parsed_to_string.split("\x00").collect::>().clone(); + if separated_words.len() == 3 { + // first segment ignored (might be needed in the future) + Ok(AuthBody { + login: separated_words[1].to_string(), + password: separated_words[2].to_string(), + }) + } else { + return Err(anyhow!("Incorrect auth format")); + } + } + Err(e) => return Err(e.into()), + } + } + Err(e) => return Err(e.into()), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_returning_auth_body() { + let orig = b"\x00login\x00pass"; + let encoded = general_purpose::STANDARD.encode(orig); + let expected = AuthBody { + login: "login".to_string(), + password: "pass".to_string(), + }; + let result = AuthBody::from_str(encoded.as_bytes()).unwrap(); + + assert_eq!(expected, result); + } + + #[test] + fn test_ignoring_first_segment() { + let orig = b"ignored\x00login\x00pass"; + let encoded = general_purpose::STANDARD.encode(orig); + let expected = AuthBody { + login: "login".to_string(), + password: "pass".to_string(), + }; + let result = AuthBody::from_str(encoded.as_bytes()).unwrap(); + + assert_eq!(expected, result); + } + + #[test] + fn test_returning_auth_body_with_empty_strings() { + let orig = b"\x00\x00"; + let encoded = general_purpose::STANDARD.encode(orig); + let expected = AuthBody { + login: "".to_string(), + password: "".to_string(), + }; + let result = AuthBody::from_str(encoded.as_bytes()).unwrap(); + + assert_eq!(expected, result); + } + + #[test] + fn test_fail_if_size_less_then_3() { + let orig = b"login\x00pass"; + let encoded = general_purpose::STANDARD.encode(orig); + let expected = AuthBody { + login: "login".to_string(), + password: "pass".to_string(), + }; + let result = AuthBody::from_str(encoded.as_bytes()); + + assert!(result.is_err()); + } + + #[test] + fn test_fail_if_size_greater_then_3() { + let orig = b"first\x00login\x00pass\x00other"; + let encoded = general_purpose::STANDARD.encode(orig); + let expected = AuthBody { + login: "login".to_string(), + password: "pass".to_string(), + }; + let result = AuthBody::from_str(encoded.as_bytes()); + + assert!(result.is_err()); + } +} diff --git a/rust-toolchain b/rust-toolchain index e1143b5..4538c78 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -nightly-2023-11-20 +nightly-2023-12-07