forked from lavina/lavina
1
0
Fork 0

finish xmpp test scenario up to auth

This commit is contained in:
Nikita Vilunov 2023-10-13 00:18:36 +02:00
parent 20db4f915b
commit ea1735bd92
2 changed files with 99 additions and 51 deletions

View File

@ -177,6 +177,7 @@ async fn handle_socket(
.with_single_cert(vec![config.cert.clone()], config.key.clone())?; .with_single_cert(vec![config.cert.clone()], config.key.clone())?;
config.key_log = Arc::new(tokio_rustls::rustls::KeyLogFile::new()); config.key_log = Arc::new(tokio_rustls::rustls::KeyLogFile::new());
log::debug!("Accepting TLS connection...");
let acceptor = TlsAcceptor::from(Arc::new(config)); let acceptor = TlsAcceptor::from(Arc::new(config));
let new_stream = acceptor.accept(stream).await?; let new_stream = acceptor.accept(stream).await?;
log::debug!("TLS connection established"); log::debug!("TLS connection established");

View File

@ -4,19 +4,22 @@ use std::time::Duration;
use anyhow::Result; use anyhow::Result;
use assert_matches::*; use assert_matches::*;
use prometheus::Registry as MetricsRegistry; use prometheus::Registry as MetricsRegistry;
use proto_xmpp::xml::{Continuation, FromXml, Parser};
use quick_xml::events::Event; use quick_xml::events::Event;
use quick_xml::NsReader; use quick_xml::NsReader;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::io::{ReadHalf as GenericReadHalf, WriteHalf as GenericWriteHalf};
use tokio::net::tcp::{ReadHalf, WriteHalf}; use tokio::net::tcp::{ReadHalf, WriteHalf};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
use tokio_rustls::rustls::client::ServerCertVerifier;
use tokio_rustls::rustls::{ClientConfig, ServerName};
use tokio_rustls::TlsConnector;
use lavina_core::player::PlayerRegistry; use lavina_core::player::PlayerRegistry;
use lavina_core::repo::{Storage, StorageConfig}; use lavina_core::repo::{Storage, StorageConfig};
use lavina_core::room::RoomRegistry; use lavina_core::room::RoomRegistry;
use projection_xmpp::{launch, ServerConfig}; use projection_xmpp::{launch, ServerConfig};
use tokio_rustls::{TlsConnector, Connect}; use proto_xmpp::xml::{Continuation, FromXml, Parser};
use tokio_rustls::rustls::ClientConfig;
pub async fn read_irc_message(reader: &mut BufReader<ReadHalf<'_>>, buf: &mut Vec<u8>) -> Result<usize> { pub async fn read_irc_message(reader: &mut BufReader<ReadHalf<'_>>, buf: &mut Vec<u8>) -> Result<usize> {
let mut size = 0; let mut size = 0;
@ -26,16 +29,9 @@ pub async fn read_irc_message(reader: &mut BufReader<ReadHalf<'_>>, buf: &mut Ve
} }
struct TestScope<'a> { struct TestScope<'a> {
socket: TestSocket<'a>,
buffer: Vec<u8>,
pub timeout: Duration,
}
enum TestSocket<'a> {
Unencrypted {
reader: NsReader<BufReader<ReadHalf<'a>>>, reader: NsReader<BufReader<ReadHalf<'a>>>,
writer: WriteHalf<'a>, writer: WriteHalf<'a>,
}, buffer: Vec<u8>,
} }
impl<'a> TestScope<'a> { impl<'a> TestScope<'a> {
@ -43,59 +39,86 @@ impl<'a> TestScope<'a> {
let (reader, writer) = stream.split(); let (reader, writer) = stream.split();
let reader = NsReader::from_reader(BufReader::new(reader)); let reader = NsReader::from_reader(BufReader::new(reader));
let buffer = vec![]; let buffer = vec![];
TestScope { reader, writer, buffer }
}
async fn send(&mut self, str: &str) -> Result<()> {
self.writer.write_all(str.as_bytes()).await?;
self.writer.write_all(b"\n").await?;
self.writer.flush().await?;
Ok(())
}
async fn next_xml_event(&mut self) -> Result<Event<'_>> {
self.buffer.clear();
let event = self.reader.read_event_into_async(&mut self.buffer).await?;
Ok(event)
}
async fn read<T: FromXml>(&mut self) -> Result<T> {
self.buffer.clear();
let (ns, event) = self.reader.read_resolved_event_into_async(&mut self.buffer).await?;
let mut parser: Continuation<_, std::result::Result<T, anyhow::Error>> = T::parse().consume(ns, &event);
loop {
match parser {
Continuation::Final(res) => return Ok(res?),
Continuation::Continue(next) => {
let (ns, event) = self.reader.read_resolved_event_into_async(&mut self.buffer).await?;
parser = next.consume(ns, &event);
}
}
}
}
}
struct TestScopeTls<'a> {
reader: NsReader<BufReader<GenericReadHalf<&'a mut TlsStream<TcpStream>>>>,
writer: GenericWriteHalf<&'a mut TlsStream<TcpStream>>,
buffer: Vec<u8>,
pub timeout: Duration,
}
impl<'a> TestScopeTls<'a> {
fn new(stream: &'a mut TlsStream<TcpStream>, buffer: Vec<u8>) -> TestScopeTls<'a> {
let (reader, writer) = tokio::io::split(stream);
let reader = NsReader::from_reader(BufReader::new(reader));
let timeout = Duration::from_millis(100); let timeout = Duration::from_millis(100);
let socket = TestSocket::Unencrypted { reader, writer };
TestScope { TestScopeTls {
socket, reader,
writer,
buffer, buffer,
timeout, timeout,
} }
} }
async fn send(&mut self, str: &str) -> Result<()> { async fn send(&mut self, str: &str) -> Result<()> {
match &mut self.socket { self.writer.write_all(str.as_bytes()).await?;
TestSocket::Unencrypted { reader: _, writer } => { self.writer.write_all(b"\n").await?;
writer.write_all(str.as_bytes()).await?; self.writer.flush().await?;
writer.write_all(b"\n").await?;
writer.flush().await?;
}
}
Ok(()) Ok(())
} }
async fn next_xml_event(&mut self) -> Result<Event<'_>> { async fn next_xml_event(&mut self) -> Result<Event<'_>> {
self.buffer.clear(); self.buffer.clear();
let event = match &mut self.socket { let event = self.reader.read_event_into_async(&mut self.buffer);
TestSocket::Unencrypted { reader, writer: _ } => { let event = tokio::time::timeout(self.timeout, event).await??;
reader.read_event_into_async(&mut self.buffer).await?
}
};
Ok(event) Ok(event)
} }
async fn read<T: FromXml>(&mut self) -> Result<T> {
self.buffer.clear();
let reader = match &mut self.socket {
TestSocket::Unencrypted { reader, writer } => reader,
};
let (ns, event) = reader.read_resolved_event_into_async(&mut self.buffer).await?;
let mut parser: Continuation<_, std::result::Result<T, anyhow::Error>> = T::parse().consume(ns, &event);
loop {
match parser {
Continuation::Final(res) => return Ok(res?),
Continuation::Continue(next) => {
let (ns, event) = reader.read_resolved_event_into_async(&mut self.buffer).await?;
parser = next.consume(ns, &event);
}
}
}
} }
async fn init_tls(&mut self) -> Result<()> { struct IgnoreCertVerification;
let mut root_store = tokio_rustls::rustls::RootCertStore::empty(); impl ServerCertVerifier for IgnoreCertVerification {
let connector = TlsConnector::from(Arc::new(ClientConfig::builder().with_safe_defaults().with_root_certificates(root_store).with_no_client_auth())); fn verify_server_cert(
connector.connect(domain, stream) &self,
todo!() _end_entity: &tokio_rustls::rustls::Certificate,
_intermediates: &[tokio_rustls::rustls::Certificate],
_server_name: &ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8],
_now: std::time::SystemTime,
) -> std::result::Result<tokio_rustls::rustls::client::ServerCertVerified, tokio_rustls::rustls::Error> {
Ok(tokio_rustls::rustls::client::ServerCertVerified::assertion())
} }
} }
@ -125,7 +148,31 @@ async fn scenario_basic() -> Result<()> {
let mut s = TestScope::new(&mut stream); let mut s = TestScope::new(&mut stream);
s.send(r#"<?xml version="1.0"?>"#).await?; s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="vilunov.me" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?; s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"starttls"));
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"required"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"starttls"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features"));
s.send(r#"<starttls/>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed"));
let buffer = s.buffer;
let connector = TlsConnector::from(Arc::new(
ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(Arc::new(IgnoreCertVerification))
.with_no_client_auth(),
));
tracing::debug!("Initiating TLS connection...");
let mut stream = connector.connect(ServerName::IpAddress(server.addr.ip()), stream).await?;
let mut s = TestScopeTls::new(&mut stream, buffer);
s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream"));