use std::io::ErrorKind; use std::sync::Arc; use std::time::Duration; use anyhow::Result; use assert_matches::*; use prometheus::Registry as MetricsRegistry; use quick_xml::events::Event; use quick_xml::NsReader; use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}; use tokio::io::{ReadHalf as GenericReadHalf, WriteHalf as GenericWriteHalf}; use tokio::net::tcp::{ReadHalf, WriteHalf}; use tokio::net::TcpStream; use tokio_rustls::client::TlsStream; use tokio_rustls::rustls::client::ServerCertVerifier; use tokio_rustls::rustls::{ClientConfig, ServerName}; use tokio_rustls::TlsConnector; use lavina_core::auth::Authenticator; use lavina_core::repo::{Storage, StorageConfig}; use lavina_core::LavinaCore; use projection_xmpp::{launch, RunningServer, ServerConfig}; use proto_xmpp::xml::{Continuation, FromXml, Parser}; pub async fn read_irc_message(reader: &mut BufReader>, buf: &mut Vec) -> Result { let mut size = 0; let res = reader.read_until(b'\n', buf).await?; size += res; return Ok(size); } struct TestScope<'a> { reader: NsReader>>, writer: WriteHalf<'a>, buffer: Vec, } fn element_name<'a>(event: &quick_xml::events::BytesStart<'a>) -> &'a str { std::str::from_utf8(event.local_name().into_inner()).unwrap() } impl<'a> TestScope<'a> { fn new(stream: &mut TcpStream) -> TestScope<'_> { let (reader, writer) = stream.split(); let reader = NsReader::from_reader(BufReader::new(reader)); let buffer = vec![]; TestScope { reader, writer, buffer } } async fn send(&mut self, str: &str) -> Result<()> { self.writer.write_all(str.as_bytes()).await?; self.writer.flush().await?; Ok(()) } async fn next_xml_event(&mut self) -> Result> { self.buffer.clear(); let event = self.reader.read_event_into_async(&mut self.buffer).await?; Ok(event) } async fn expect_starttls_required(&mut self) -> Result<()> { assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b), "features")); assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"starttls")); assert_matches!(self.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"required")); assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"starttls")); assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features")); Ok(()) } } struct TestScopeTls<'a> { reader: NsReader>>>, writer: GenericWriteHalf<&'a mut TlsStream>, buffer: Vec, pub timeout: Duration, } impl<'a> TestScopeTls<'a> { fn new(stream: &'a mut TlsStream, buffer: Vec) -> TestScopeTls<'a> { let (reader, writer) = tokio::io::split(stream); let reader = NsReader::from_reader(BufReader::new(reader)); let timeout = Duration::from_millis(500); TestScopeTls { reader, writer, buffer, timeout, } } async fn send(&mut self, str: &str) -> Result<()> { self.writer.write_all(str.as_bytes()).await?; self.writer.flush().await?; Ok(()) } async fn expect_auth_mechanisms(&mut self) -> Result<()> { assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features")); assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"mechanisms")); assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"mechanism")); assert_matches!(self.next_xml_event().await?, Event::Text(b) => assert_eq!(&*b, b"PLAIN")); assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"mechanism")); assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"mechanisms")); assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features")); Ok(()) } async fn expect_bind_feature(&mut self) -> Result<()> { assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features")); assert_matches!(self.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"bind")); assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features")); Ok(()) } async fn next_xml_event(&mut self) -> Result> { self.buffer.clear(); let event = self.reader.read_event_into_async(&mut self.buffer); let event = tokio::time::timeout(self.timeout, event).await??; Ok(event) } } struct IgnoreCertVerification; impl ServerCertVerifier for IgnoreCertVerification { fn verify_server_cert( &self, _end_entity: &tokio_rustls::rustls::Certificate, _intermediates: &[tokio_rustls::rustls::Certificate], _server_name: &ServerName, _scts: &mut dyn Iterator, _ocsp_response: &[u8], _now: std::time::SystemTime, ) -> std::result::Result { Ok(tokio_rustls::rustls::client::ServerCertVerified::assertion()) } } struct TestServer { metrics: MetricsRegistry, storage: Storage, core: LavinaCore, server: RunningServer, } impl TestServer { async fn start() -> Result { let _ = tracing_subscriber::fmt::try_init(); let config = ServerConfig { listen_on: "127.0.0.1:0".parse().unwrap(), cert: "tests/certs/xmpp.pem".parse().unwrap(), key: "tests/certs/xmpp.key".parse().unwrap(), hostname: "localhost".into(), }; let metrics = MetricsRegistry::new(); let storage = Storage::open(StorageConfig { db_path: ":memory:".into(), }) .await?; let core = LavinaCore::new(metrics.clone(), storage.clone()).await?; let server = launch(config, core.clone(), metrics.clone(), storage.clone()).await.unwrap(); Ok(TestServer { metrics, storage, core, server, }) } async fn shutdown(self) -> Result<()> { self.server.terminate().await?; self.core.shutdown().await?; self.storage.close().await?; Ok(()) } } #[tokio::test] async fn scenario_basic() -> Result<()> { let mut server = TestServer::start().await?; // test scenario server.storage.create_user("tester").await?; Authenticator::new(&server.storage).set_password("tester", "password").await?; let mut stream = TcpStream::connect(server.server.addr).await?; let mut s = TestScope::new(&mut stream); tracing::info!("TCP connection established"); s.send(r#""#).await?; s.send(r#""#).await?; assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); s.expect_starttls_required().await?; s.send(r#""#).await?; assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed")); let buffer = s.buffer; tracing::info!("TLS feature negotiation complete"); let connector = TlsConnector::from(Arc::new( ClientConfig::builder() .with_safe_defaults() .with_custom_certificate_verifier(Arc::new(IgnoreCertVerification)) .with_no_client_auth(), )); tracing::info!("Initiating TLS connection..."); let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?; tracing::info!("TLS connection established"); let mut s = TestScopeTls::new(&mut stream, buffer); s.send(r#""#).await?; s.send(r#""#).await?; assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); s.expect_auth_mechanisms().await?; // base64-encoded b"\x00tester\x00password" s.send(r#"AHRlc3RlcgBwYXNzd29yZA=="#) .await?; assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"success")); s.send(r#""#).await?; s.send(r#""#).await?; assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); s.expect_bind_feature().await?; s.send(r#"kek"#).await?; assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"iq")); assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"bind")); assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"jid")); assert_matches!(s.next_xml_event().await?, Event::Text(b) => assert_eq!(&*b, b"tester@localhost/tester")); assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"jid")); assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"bind")); assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"iq")); s.send(r#"Logged out"#).await?; stream.shutdown().await?; // wrap up server.shutdown().await?; Ok(()) } #[tokio::test] async fn scenario_wrong_password() -> Result<()> { let mut server = TestServer::start().await?; // test scenario server.storage.create_user("tester").await?; Authenticator::new(&server.storage).set_password("tester", "password").await?; let mut stream = TcpStream::connect(server.server.addr).await?; let mut s = TestScope::new(&mut stream); tracing::info!("TCP connection established"); s.send(r#""#).await?; s.send(r#""#).await?; assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); s.expect_starttls_required().await?; s.send(r#""#).await?; assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed")); let buffer = s.buffer; tracing::info!("TLS feature negotiation complete"); let connector = TlsConnector::from(Arc::new( ClientConfig::builder() .with_safe_defaults() .with_custom_certificate_verifier(Arc::new(IgnoreCertVerification)) .with_no_client_auth(), )); tracing::info!("Initiating TLS connection..."); let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?; tracing::info!("TLS connection established"); let mut s = TestScopeTls::new(&mut stream, buffer); s.send(r#""#).await?; s.send(r#""#).await?; assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); s.expect_auth_mechanisms().await?; // base64-encoded b"\x00tester\x00password2" s.send(r#"AHRlc3RlcgBwYXNzd29yZDI="#) .await?; assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"failure")); assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"not-authorized")); assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"failure")); let _ = stream.shutdown().await; // wrap up server.shutdown().await?; Ok(()) } #[tokio::test] async fn scenario_basic_without_headers() -> Result<()> { let mut server = TestServer::start().await?; // test scenario server.storage.create_user("tester").await?; Authenticator::new(&server.storage).set_password("tester", "password").await?; let mut stream = TcpStream::connect(server.server.addr).await?; let mut s = TestScope::new(&mut stream); tracing::info!("TCP connection established"); s.send(r#""#).await?; assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); s.expect_starttls_required().await?; s.send(r#""#).await?; assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed")); let buffer = s.buffer; tracing::info!("TLS feature negotiation complete"); let connector = TlsConnector::from(Arc::new( ClientConfig::builder() .with_safe_defaults() .with_custom_certificate_verifier(Arc::new(IgnoreCertVerification)) .with_no_client_auth(), )); tracing::info!("Initiating TLS connection..."); let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?; tracing::info!("TLS connection established"); let mut s = TestScopeTls::new(&mut stream, buffer); s.send(r#""#).await?; assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); stream.shutdown().await?; // wrap up server.shutdown().await?; Ok(()) } #[tokio::test] async fn terminate_socket() -> Result<()> { let mut server = TestServer::start().await?; // test scenario server.storage.create_user("tester").await?; Authenticator::new(&server.storage).set_password("tester", "password").await?; let mut stream = TcpStream::connect(server.server.addr).await?; let mut s = TestScope::new(&mut stream); tracing::info!("TCP connection established"); s.send(r#""#).await?; s.send(r#""#).await?; assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); s.expect_starttls_required().await?; s.send(r#""#).await?; assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed")); let connector = TlsConnector::from(Arc::new( ClientConfig::builder() .with_safe_defaults() .with_custom_certificate_verifier(Arc::new(IgnoreCertVerification)) .with_no_client_auth(), )); tracing::info!("Initiating TLS connection..."); let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?; tracing::info!("TLS connection established"); server.shutdown().await?; assert_eq!(stream.read_u8().await.unwrap_err().kind(), ErrorKind::UnexpectedEof); Ok(()) }