diff --git a/crates/projection-xmpp/src/lib.rs b/crates/projection-xmpp/src/lib.rs index 5468c19..fe56481 100644 --- a/crates/projection-xmpp/src/lib.rs +++ b/crates/projection-xmpp/src/lib.rs @@ -296,20 +296,19 @@ async fn socket_auth( xml_writer.get_mut().flush().await?; let auth: proto_xmpp::sasl::Auth = proto_xmpp::sasl::Auth::parse(xml_reader, reader_buf).await?; - proto_xmpp::sasl::Success.write_xml(xml_writer).await?; - xml_writer.get_mut().flush().await?; match AuthBody::from_str(&auth.body) { Ok(logopass) => { let name = &logopass.login; let verdict = Authenticator::new(storage).authenticate(name, &logopass.password).await?; - // TODO return proper XML errors to the client match verdict { - Verdict::Authenticated => {} - Verdict::UserNotFound => { - return Err(anyhow!("no user found")); + Verdict::Authenticated => { + proto_xmpp::sasl::Success.write_xml(xml_writer).await?; + xml_writer.get_mut().flush().await?; } - Verdict::InvalidPassword => { + Verdict::UserNotFound | Verdict::InvalidPassword => { + proto_xmpp::sasl::Failure.write_xml(xml_writer).await?; + xml_writer.get_mut().flush().await?; return Err(anyhow!("incorrect credentials")); } } diff --git a/crates/projection-xmpp/tests/lib.rs b/crates/projection-xmpp/tests/lib.rs index ef8c046..8c05128 100644 --- a/crates/projection-xmpp/tests/lib.rs +++ b/crates/projection-xmpp/tests/lib.rs @@ -240,6 +240,73 @@ async fn scenario_basic() -> Result<()> { Ok(()) } +#[tokio::test] +async fn scenario_wrong_password() -> Result<()> { + let mut server = TestServer::start().await?; + + // test scenario + + server.storage.create_user("tester").await?; + Authenticator::new(&server.storage).set_password("tester", "password").await?; + + let mut stream = TcpStream::connect(server.server.addr).await?; + let mut s = TestScope::new(&mut stream); + tracing::info!("TCP connection established"); + + s.send(r#""#).await?; + s.send(r#""#).await?; + assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"starttls")); + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"required")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"starttls")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features")); + s.send(r#""#).await?; + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed")); + let buffer = s.buffer; + tracing::info!("TLS feature negotiation complete"); + + let connector = TlsConnector::from(Arc::new( + ClientConfig::builder() + .with_safe_defaults() + .with_custom_certificate_verifier(Arc::new(IgnoreCertVerification)) + .with_no_client_auth(), + )); + tracing::info!("Initiating TLS connection..."); + let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?; + tracing::info!("TLS connection established"); + + let mut s = TestScopeTls::new(&mut stream, buffer); + + s.send(r#""#).await?; + s.send(r#""#).await?; + assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); + + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"mechanisms")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"mechanism")); + assert_matches!(s.next_xml_event().await?, Event::Text(b) => assert_eq!(&*b, b"PLAIN")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"mechanism")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"mechanisms")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features")); + + // base64-encoded b"\x00tester\x00password2" + s.send(r#"AHRlc3RlcgBwYXNzd29yZDI="#) + .await?; + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"failure")); + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"not-authorized")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"failure")); + + stream.shutdown().await?; + + // wrap up + + server.shutdown().await?; + Ok(()) +} + #[tokio::test] async fn scenario_basic_without_headers() -> Result<()> { let mut server = TestServer::start().await?; diff --git a/crates/proto-xmpp/src/sasl.rs b/crates/proto-xmpp/src/sasl.rs index e147962..b042f09 100644 --- a/crates/proto-xmpp/src/sasl.rs +++ b/crates/proto-xmpp/src/sasl.rs @@ -1,7 +1,7 @@ use std::borrow::Borrow; use anyhow::{anyhow, Result}; -use quick_xml::events::{BytesStart, Event}; +use quick_xml::events::{BytesEnd, BytesStart, Event}; use quick_xml::{NsReader, Writer}; use tokio::io::{AsyncBufRead, AsyncWrite}; @@ -74,3 +74,16 @@ impl Success { Ok(()) } } + +pub struct Failure; +impl Failure { + pub async fn write_xml(&self, writer: &mut Writer) -> Result<()> { + let event = BytesStart::new(r#"failure xmlns="urn:ietf:params:xml:ns:xmpp-sasl""#); + writer.write_event_async(Event::Start(event)).await?; + let event = BytesStart::new(r#"not-authorized"#); + writer.write_event_async(Event::Empty(event)).await?; + let event = BytesEnd::new(r#"failure"#); + writer.write_event_async(Event::End(event)).await?; + Ok(()) + } +}