From ce3c49cff6d111435fdafcdbf445ac2a3300a70a Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Fri, 13 Oct 2023 16:54:08 +0200 Subject: [PATCH] move sasl into separate crate --- .gitea/workflows/test-pr.yml | 2 +- Cargo.lock | 10 ++- Cargo.toml | 2 + crates/projection-xmpp/Cargo.toml | 1 + crates/projection-xmpp/src/lib.rs | 8 +-- crates/proto-xmpp/Cargo.toml | 1 - crates/proto-xmpp/src/sasl.rs | 96 ++-------------------------- crates/sasl/Cargo.toml | 8 +++ crates/sasl/src/lib.rs | 103 ++++++++++++++++++++++++++++++ 9 files changed, 132 insertions(+), 99 deletions(-) create mode 100644 crates/sasl/Cargo.toml create mode 100644 crates/sasl/src/lib.rs diff --git a/.gitea/workflows/test-pr.yml b/.gitea/workflows/test-pr.yml index 080b65d..cb0b163 100644 --- a/.gitea/workflows/test-pr.yml +++ b/.gitea/workflows/test-pr.yml @@ -12,7 +12,7 @@ jobs: uses: https://github.com/actions-rs/cargo@v1 with: command: fmt - args: "--check -p mgmt-api -p lavina-core -p projection-irc" + args: "--check -p mgmt-api -p lavina-core -p projection-irc -p projection-xmpp -p sasl" - name: cargo check uses: https://github.com/actions-rs/cargo@v1 with: diff --git a/Cargo.lock b/Cargo.lock index 09a8b25..497a58b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1286,6 +1286,7 @@ dependencies = [ "proto-xmpp", "quick-xml", "rustls-pemfile", + "sasl", "serde", "tokio", "tokio-rustls", @@ -1325,7 +1326,6 @@ version = "0.0.2-dev" dependencies = [ "anyhow", "assert_matches", - "base64", "derive_more", "lazy_static", "quick-xml", @@ -1557,6 +1557,14 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" +[[package]] +name = "sasl" +version = "0.0.2-dev" +dependencies = [ + "anyhow", + "base64", +] + [[package]] name = "scopeguard" version = "1.2.0" diff --git a/Cargo.toml b/Cargo.toml index 0953ce7..e6ee89e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "crates/projection-irc", "crates/proto-xmpp", "crates/mgmt-api", + "crates/sasl", ] [workspace.package] @@ -28,6 +29,7 @@ tracing = "0.1.37" # logging & tracing api prometheus = { version = "0.13.3", default-features = false } base64 = "0.21.3" lavina-core = { path = "crates/lavina-core" } +sasl = { path = "crates/sasl" } [package] name = "lavina" diff --git a/crates/projection-xmpp/Cargo.toml b/crates/projection-xmpp/Cargo.toml index c69cb12..aac2b71 100644 --- a/crates/projection-xmpp/Cargo.toml +++ b/crates/projection-xmpp/Cargo.toml @@ -13,6 +13,7 @@ prometheus.workspace = true futures-util.workspace = true quick-xml.workspace = true +sasl.workspace = true proto-xmpp = { path = "../proto-xmpp" } uuid = { version = "1.3.0", features = ["v4"] } tokio-rustls = "0.24.1" diff --git a/crates/projection-xmpp/src/lib.rs b/crates/projection-xmpp/src/lib.rs index 2a7e0ee..8855144 100644 --- a/crates/projection-xmpp/src/lib.rs +++ b/crates/projection-xmpp/src/lib.rs @@ -24,17 +24,17 @@ use tokio_rustls::TlsAcceptor; use lavina_core::player::{PlayerConnection, PlayerId, PlayerRegistry}; use lavina_core::prelude::*; +use lavina_core::repo::Storage; use lavina_core::room::{RoomId, RoomRegistry}; use lavina_core::terminator::Terminator; -use lavina_core::repo::Storage; use proto_xmpp::bind::{BindResponse, Jid, Name, Resource, Server}; use proto_xmpp::client::{Iq, Message, MessageType, Presence}; use proto_xmpp::disco::*; use proto_xmpp::roster::RosterQuery; -use proto_xmpp::sasl::AuthBody; use proto_xmpp::session::Session; use proto_xmpp::stream::*; use proto_xmpp::xml::{Continuation, FromXml, Parser, ToXml}; +use sasl::AuthBody; use self::proto::{ClientPacket, IqClientBody}; @@ -286,8 +286,8 @@ async fn socket_auth( xmpp_resource: Resource(name.to_string().into()), xmpp_muc_name: Resource(name.to_string().into()), }) - }, - Err(e) => return Err(e) + } + Err(e) => return Err(e), } } diff --git a/crates/proto-xmpp/Cargo.toml b/crates/proto-xmpp/Cargo.toml index df0f03c..01fec80 100644 --- a/crates/proto-xmpp/Cargo.toml +++ b/crates/proto-xmpp/Cargo.toml @@ -10,7 +10,6 @@ regex.workspace = true anyhow.workspace = true tokio.workspace = true derive_more.workspace = true -base64.workspace = true [dev-dependencies] assert_matches.workspace = true diff --git a/crates/proto-xmpp/src/sasl.rs b/crates/proto-xmpp/src/sasl.rs index 71f4945..e147962 100644 --- a/crates/proto-xmpp/src/sasl.rs +++ b/crates/proto-xmpp/src/sasl.rs @@ -1,14 +1,11 @@ use std::borrow::Borrow; -use quick_xml::{ - events::{BytesStart, Event}, - NsReader, Writer, -}; +use anyhow::{anyhow, Result}; +use quick_xml::events::{BytesStart, Event}; +use quick_xml::{NsReader, Writer}; use tokio::io::{AsyncBufRead, AsyncWrite}; -use base64::{Engine as _, engine::general_purpose}; use super::skip_text; -use anyhow::{anyhow, Result}; pub enum Mechanism { Plain, @@ -29,98 +26,13 @@ impl Mechanism { } } -#[derive(PartialEq, Debug)] -pub struct AuthBody { - pub login: String, - pub password: String, -} - -impl AuthBody { - pub fn from_str(input: &[u8]) -> Result { - match general_purpose::STANDARD.decode(input){ - Ok(decoded_body) => { - match String::from_utf8(decoded_body) { - Ok(parsed_to_string) => { - let separated_words: Vec<&str> = parsed_to_string.split("\x00").collect::>().clone(); - if separated_words.len() == 3 { - // first segment ignored (might be needed in the future) - Ok(AuthBody { login: separated_words[1].to_string(), password: separated_words[2].to_string() }) - } else { return Err(anyhow!("Incorrect auth format")) } - }, - Err(e) => return Err(anyhow!(e)) - } - }, - Err(e) => return Err(anyhow!(e)) - } - } -} - -#[cfg(test)] -mod test { - use super::*; - #[test] - fn test_returning_auth_body() { - let orig = b"\x00login\x00pass"; - let encoded = general_purpose::STANDARD.encode(orig); - let expected = AuthBody {login: "login".to_string(), password: "pass".to_string()}; - let result = AuthBody::from_str(encoded.as_bytes()).unwrap(); - - assert_eq!(expected, result); - } - - #[test] - fn test_ignoring_first_segment() { - let orig = b"ignored\x00login\x00pass"; - let encoded = general_purpose::STANDARD.encode(orig); - let expected = AuthBody {login: "login".to_string(), password: "pass".to_string()}; - let result = AuthBody::from_str(encoded.as_bytes()).unwrap(); - - assert_eq!(expected, result); - } - - #[test] - fn test_returning_auth_body_with_empty_strings() { - let orig = b"\x00\x00"; - let encoded = general_purpose::STANDARD.encode(orig); - let expected = AuthBody {login: "".to_string(), password: "".to_string()}; - let result = AuthBody::from_str(encoded.as_bytes()).unwrap(); - - assert_eq!(expected, result); - } - - - #[test] - fn test_fail_if_size_less_then_3() { - let orig = b"login\x00pass"; - let encoded = general_purpose::STANDARD.encode(orig); - let expected = AuthBody {login: "login".to_string(), password: "pass".to_string()}; - let result = AuthBody::from_str(encoded.as_bytes()); - - assert!(result.is_err()); - } - - #[test] - fn test_fail_if_size_greater_then_3() { - let orig = b"first\x00login\x00pass\x00other"; - let encoded = general_purpose::STANDARD.encode(orig); - let expected = AuthBody {login: "login".to_string(), password: "pass".to_string()}; - let result = AuthBody::from_str(encoded.as_bytes()); - - assert!(result.is_err()); - } - -} - pub struct Auth { pub mechanism: Mechanism, pub body: Vec, } impl Auth { - pub async fn parse( - reader: &mut NsReader, - buf: &mut Vec, - ) -> Result { + pub async fn parse(reader: &mut NsReader, buf: &mut Vec) -> Result { let event = skip_text!(reader, buf); let mechanism = if let Event::Start(bytes) = event { let mut mechanism = None; diff --git a/crates/sasl/Cargo.toml b/crates/sasl/Cargo.toml new file mode 100644 index 0000000..59504ef --- /dev/null +++ b/crates/sasl/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "sasl" +edition = "2021" +version.workspace = true + +[dependencies] +anyhow.workspace = true +base64.workspace = true diff --git a/crates/sasl/src/lib.rs b/crates/sasl/src/lib.rs new file mode 100644 index 0000000..e00d67b --- /dev/null +++ b/crates/sasl/src/lib.rs @@ -0,0 +1,103 @@ +use anyhow::{anyhow, Result}; +use base64::engine::general_purpose; +use base64::Engine; + +#[derive(PartialEq, Debug)] +pub struct AuthBody { + pub login: String, + pub password: String, +} + +impl AuthBody { + pub fn from_str(input: &[u8]) -> Result { + match general_purpose::STANDARD.decode(input) { + Ok(decoded_body) => { + match String::from_utf8(decoded_body) { + Ok(parsed_to_string) => { + let separated_words: Vec<&str> = parsed_to_string.split("\x00").collect::>().clone(); + if separated_words.len() == 3 { + // first segment ignored (might be needed in the future) + Ok(AuthBody { + login: separated_words[1].to_string(), + password: separated_words[2].to_string(), + }) + } else { + return Err(anyhow!("Incorrect auth format")); + } + } + Err(e) => return Err(e.into()), + } + } + Err(e) => return Err(e.into()), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_returning_auth_body() { + let orig = b"\x00login\x00pass"; + let encoded = general_purpose::STANDARD.encode(orig); + let expected = AuthBody { + login: "login".to_string(), + password: "pass".to_string(), + }; + let result = AuthBody::from_str(encoded.as_bytes()).unwrap(); + + assert_eq!(expected, result); + } + + #[test] + fn test_ignoring_first_segment() { + let orig = b"ignored\x00login\x00pass"; + let encoded = general_purpose::STANDARD.encode(orig); + let expected = AuthBody { + login: "login".to_string(), + password: "pass".to_string(), + }; + let result = AuthBody::from_str(encoded.as_bytes()).unwrap(); + + assert_eq!(expected, result); + } + + #[test] + fn test_returning_auth_body_with_empty_strings() { + let orig = b"\x00\x00"; + let encoded = general_purpose::STANDARD.encode(orig); + let expected = AuthBody { + login: "".to_string(), + password: "".to_string(), + }; + let result = AuthBody::from_str(encoded.as_bytes()).unwrap(); + + assert_eq!(expected, result); + } + + #[test] + fn test_fail_if_size_less_then_3() { + let orig = b"login\x00pass"; + let encoded = general_purpose::STANDARD.encode(orig); + let expected = AuthBody { + login: "login".to_string(), + password: "pass".to_string(), + }; + let result = AuthBody::from_str(encoded.as_bytes()); + + assert!(result.is_err()); + } + + #[test] + fn test_fail_if_size_greater_then_3() { + let orig = b"first\x00login\x00pass\x00other"; + let encoded = general_purpose::STANDARD.encode(orig); + let expected = AuthBody { + login: "login".to_string(), + password: "pass".to_string(), + }; + let result = AuthBody::from_str(encoded.as_bytes()); + + assert!(result.is_err()); + } +}