diff --git a/crates/projection-xmpp/src/lib.rs b/crates/projection-xmpp/src/lib.rs index 0d5cd30..fb587ea 100644 --- a/crates/projection-xmpp/src/lib.rs +++ b/crates/projection-xmpp/src/lib.rs @@ -177,6 +177,7 @@ async fn handle_socket( .with_single_cert(vec![config.cert.clone()], config.key.clone())?; config.key_log = Arc::new(tokio_rustls::rustls::KeyLogFile::new()); + log::debug!("Accepting TLS connection..."); let acceptor = TlsAcceptor::from(Arc::new(config)); let new_stream = acceptor.accept(stream).await?; log::debug!("TLS connection established"); diff --git a/crates/projection-xmpp/tests/lib.rs b/crates/projection-xmpp/tests/lib.rs index 1ce2379..369749b 100644 --- a/crates/projection-xmpp/tests/lib.rs +++ b/crates/projection-xmpp/tests/lib.rs @@ -4,19 +4,22 @@ use std::time::Duration; use anyhow::Result; use assert_matches::*; use prometheus::Registry as MetricsRegistry; -use proto_xmpp::xml::{Continuation, FromXml, Parser}; use quick_xml::events::Event; use quick_xml::NsReader; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::io::{ReadHalf as GenericReadHalf, WriteHalf as GenericWriteHalf}; use tokio::net::tcp::{ReadHalf, WriteHalf}; use tokio::net::TcpStream; +use tokio_rustls::client::TlsStream; +use tokio_rustls::rustls::client::ServerCertVerifier; +use tokio_rustls::rustls::{ClientConfig, ServerName}; +use tokio_rustls::TlsConnector; use lavina_core::player::PlayerRegistry; use lavina_core::repo::{Storage, StorageConfig}; use lavina_core::room::RoomRegistry; use projection_xmpp::{launch, ServerConfig}; -use tokio_rustls::{TlsConnector, Connect}; -use tokio_rustls::rustls::ClientConfig; +use proto_xmpp::xml::{Continuation, FromXml, Parser}; pub async fn read_irc_message(reader: &mut BufReader>, buf: &mut Vec) -> Result { let mut size = 0; @@ -26,16 +29,9 @@ pub async fn read_irc_message(reader: &mut BufReader>, buf: &mut Ve } struct TestScope<'a> { - socket: TestSocket<'a>, + reader: NsReader>>, + writer: WriteHalf<'a>, buffer: Vec, - pub timeout: Duration, -} - -enum TestSocket<'a> { - Unencrypted { - reader: NsReader>>, - writer: WriteHalf<'a>, - }, } impl<'a> TestScope<'a> { @@ -43,59 +39,86 @@ impl<'a> TestScope<'a> { let (reader, writer) = stream.split(); let reader = NsReader::from_reader(BufReader::new(reader)); let buffer = vec![]; + TestScope { reader, writer, buffer } + } + + async fn send(&mut self, str: &str) -> Result<()> { + self.writer.write_all(str.as_bytes()).await?; + self.writer.write_all(b"\n").await?; + self.writer.flush().await?; + Ok(()) + } + + async fn next_xml_event(&mut self) -> Result> { + self.buffer.clear(); + let event = self.reader.read_event_into_async(&mut self.buffer).await?; + Ok(event) + } + + async fn read(&mut self) -> Result { + self.buffer.clear(); + let (ns, event) = self.reader.read_resolved_event_into_async(&mut self.buffer).await?; + let mut parser: Continuation<_, std::result::Result> = T::parse().consume(ns, &event); + loop { + match parser { + Continuation::Final(res) => return Ok(res?), + Continuation::Continue(next) => { + let (ns, event) = self.reader.read_resolved_event_into_async(&mut self.buffer).await?; + parser = next.consume(ns, &event); + } + } + } + } +} + +struct TestScopeTls<'a> { + reader: NsReader>>>, + writer: GenericWriteHalf<&'a mut TlsStream>, + buffer: Vec, + pub timeout: Duration, +} + +impl<'a> TestScopeTls<'a> { + fn new(stream: &'a mut TlsStream, buffer: Vec) -> TestScopeTls<'a> { + let (reader, writer) = tokio::io::split(stream); + let reader = NsReader::from_reader(BufReader::new(reader)); let timeout = Duration::from_millis(100); - let socket = TestSocket::Unencrypted { reader, writer }; - TestScope { - socket, + + TestScopeTls { + reader, + writer, buffer, timeout, } } async fn send(&mut self, str: &str) -> Result<()> { - match &mut self.socket { - TestSocket::Unencrypted { reader: _, writer } => { - writer.write_all(str.as_bytes()).await?; - writer.write_all(b"\n").await?; - writer.flush().await?; - } - } + self.writer.write_all(str.as_bytes()).await?; + self.writer.write_all(b"\n").await?; + self.writer.flush().await?; Ok(()) } async fn next_xml_event(&mut self) -> Result> { self.buffer.clear(); - let event = match &mut self.socket { - TestSocket::Unencrypted { reader, writer: _ } => { - reader.read_event_into_async(&mut self.buffer).await? - } - }; + let event = self.reader.read_event_into_async(&mut self.buffer); + let event = tokio::time::timeout(self.timeout, event).await??; Ok(event) } +} - async fn read(&mut self) -> Result { - self.buffer.clear(); - let reader = match &mut self.socket { - TestSocket::Unencrypted { reader, writer } => reader, - }; - let (ns, event) = reader.read_resolved_event_into_async(&mut self.buffer).await?; - let mut parser: Continuation<_, std::result::Result> = T::parse().consume(ns, &event); - loop { - match parser { - Continuation::Final(res) => return Ok(res?), - Continuation::Continue(next) => { - let (ns, event) = reader.read_resolved_event_into_async(&mut self.buffer).await?; - parser = next.consume(ns, &event); - } - } - } - } - - async fn init_tls(&mut self) -> Result<()> { - let mut root_store = tokio_rustls::rustls::RootCertStore::empty(); - let connector = TlsConnector::from(Arc::new(ClientConfig::builder().with_safe_defaults().with_root_certificates(root_store).with_no_client_auth())); - connector.connect(domain, stream) - todo!() +struct IgnoreCertVerification; +impl ServerCertVerifier for IgnoreCertVerification { + fn verify_server_cert( + &self, + _end_entity: &tokio_rustls::rustls::Certificate, + _intermediates: &[tokio_rustls::rustls::Certificate], + _server_name: &ServerName, + _scts: &mut dyn Iterator, + _ocsp_response: &[u8], + _now: std::time::SystemTime, + ) -> std::result::Result { + Ok(tokio_rustls::rustls::client::ServerCertVerified::assertion()) } } @@ -125,7 +148,31 @@ async fn scenario_basic() -> Result<()> { let mut s = TestScope::new(&mut stream); s.send(r#""#).await?; - s.send(r#""#).await?; + s.send(r#""#).await?; + assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"starttls")); + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"required")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"starttls")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features")); + s.send(r#""#).await?; + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed")); + let buffer = s.buffer; + + let connector = TlsConnector::from(Arc::new( + ClientConfig::builder() + .with_safe_defaults() + .with_custom_certificate_verifier(Arc::new(IgnoreCertVerification)) + .with_no_client_auth(), + )); + tracing::debug!("Initiating TLS connection..."); + let mut stream = connector.connect(ServerName::IpAddress(server.addr.ip()), stream).await?; + + let mut s = TestScopeTls::new(&mut stream, buffer); + + s.send(r#""#).await?; + s.send(r#""#).await?; assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream"));