forked from lavina/lavina
1
0
Fork 0

termination usage for stopping the socket connection gracefully (#34)

Reviewed-on: lavina/lavina#34
Co-authored-by: JustTestingV <JustTestingV@gmail.com>
Co-committed-by: JustTestingV <JustTestingV@gmail.com>
This commit is contained in:
JustTestingV 2024-02-18 16:46:29 +00:00 committed by Nikita Vilunov
parent 7613055dde
commit c6fb74a848
4 changed files with 134 additions and 26 deletions

View File

@ -62,19 +62,24 @@ async fn handle_socket(
let mut reader: BufReader<ReadHalf> = BufReader::new(reader); let mut reader: BufReader<ReadHalf> = BufReader::new(reader);
let mut writer = BufWriter::new(writer); let mut writer = BufWriter::new(writer);
let registered_user: Result<RegisteredUser> = pin!(termination);
handle_registration(&mut reader, &mut writer, &mut storage, &config).await; select! {
biased;
match registered_user { _ = &mut termination =>{
Ok(user) => { log::info!("Socket handling was terminated");
log::debug!("User registered"); return Ok(())
handle_registered_socket(config, players, rooms, &mut reader, &mut writer, user).await?; },
} registered_user = handle_registration(&mut reader, &mut writer, &mut storage, &config) =>
Err(err) => { match registered_user {
log::debug!("Registration failed: {err}"); Ok(user) => {
} log::debug!("User registered");
handle_registered_socket(config, players, rooms, &mut reader, &mut writer, user).await?;
}
Err(err) => {
log::debug!("Registration failed: {err}");
}
}
} }
stream.shutdown().await?; stream.shutdown().await?;
Ok(()) Ok(())
} }

View File

@ -1,3 +1,5 @@
use std::io::ErrorKind;
use std::net::SocketAddr;
use std::time::Duration; use std::time::Duration;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
@ -266,3 +268,29 @@ async fn scenario_cap_sasl_fail() -> Result<()> {
server.server.terminate().await?; server.server.terminate().await?;
Ok(()) Ok(())
} }
#[tokio::test]
async fn terminate_socket_scenario() -> Result<()> {
let mut server = TestServer::start().await?;
let address: SocketAddr = ("127.0.0.1:0".parse().unwrap());
// test scenario
server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
s.send("NICK tester").await?;
s.send("CAP REQ :sasl").await?;
s.send("USER UserName 0 * :Real Name").await?;
s.expect(":testserver CAP tester ACK :sasl").await?;
s.send("AUTHENTICATE PLAIN").await?;
s.expect(":testserver AUTHENTICATE +").await?;
server.server.terminate().await?;
assert_eq!(stream.read_u8().await.unwrap_err().kind(), ErrorKind::UnexpectedEof);
Ok(())
}

View File

@ -187,18 +187,33 @@ async fn handle_socket(
let mut xml_reader = NsReader::from_reader(BufReader::new(a)); let mut xml_reader = NsReader::from_reader(BufReader::new(a));
let mut xml_writer = Writer::new(b); let mut xml_writer = Writer::new(b);
let authenticated = socket_auth(&mut xml_reader, &mut xml_writer, &mut reader_buf, &mut storage).await?; pin!(termination);
log::debug!("User authenticated"); select! {
let mut connection = players.connect_to_player(authenticated.player_id.clone()).await; biased;
socket_final( _ = &mut termination =>{
&mut xml_reader, log::info!("Socket handling was terminated");
&mut xml_writer, return Ok(())
&mut reader_buf, },
&authenticated, authenticated = socket_auth(&mut xml_reader, &mut xml_writer, &mut reader_buf, &mut storage) => {
&mut connection, match authenticated {
&rooms, Ok(authenticated) => {
) let mut connection = players.connect_to_player(authenticated.player_id.clone()).await;
.await?; socket_final(
&mut xml_reader,
&mut xml_writer,
&mut reader_buf,
&authenticated,
&mut connection,
&rooms,
)
.await?;
},
Err(err) => {
log::error!("Authentication error: {:?}", err);
}
}
},
}
let a = xml_reader.into_inner().into_inner(); let a = xml_reader.into_inner().into_inner();
let b = xml_writer.into_inner(); let b = xml_writer.into_inner();

View File

@ -1,3 +1,5 @@
use std::io::ErrorKind;
use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
@ -6,7 +8,7 @@ use assert_matches::*;
use prometheus::Registry as MetricsRegistry; use prometheus::Registry as MetricsRegistry;
use quick_xml::events::Event; use quick_xml::events::Event;
use quick_xml::NsReader; use quick_xml::NsReader;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::io::{ReadHalf as GenericReadHalf, WriteHalf as GenericWriteHalf}; use tokio::io::{ReadHalf as GenericReadHalf, WriteHalf as GenericWriteHalf};
use tokio::net::tcp::{ReadHalf, WriteHalf}; use tokio::net::tcp::{ReadHalf, WriteHalf};
use tokio::net::TcpStream; use tokio::net::TcpStream;
@ -122,7 +124,7 @@ impl ServerCertVerifier for IgnoreCertVerification {
#[tokio::test] #[tokio::test]
async fn scenario_basic() -> Result<()> { async fn scenario_basic() -> Result<()> {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::try_init();
let config = ServerConfig { let config = ServerConfig {
listen_on: "127.0.0.1:0".parse().unwrap(), listen_on: "127.0.0.1:0".parse().unwrap(),
cert: "tests/certs/xmpp.pem".parse().unwrap(), cert: "tests/certs/xmpp.pem".parse().unwrap(),
@ -184,3 +186,61 @@ async fn scenario_basic() -> Result<()> {
server.terminate().await?; server.terminate().await?;
Ok(()) Ok(())
} }
#[tokio::test]
async fn terminate_socket() -> Result<()> {
tracing_subscriber::fmt::try_init();
let config = ServerConfig {
listen_on: "127.0.0.1:0".parse().unwrap(),
cert: "tests/certs/xmpp.pem".parse().unwrap(),
key: "tests/certs/xmpp.key".parse().unwrap(),
};
let mut metrics = MetricsRegistry::new();
let mut storage = Storage::open(StorageConfig {
db_path: ":memory:".into(),
})
.await?;
let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap();
let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap();
let server = launch(config, players, rooms, metrics, storage.clone()).await.unwrap();
let address: SocketAddr = ("127.0.0.1:0".parse().unwrap());
// test scenario
storage.create_user("tester").await?;
storage.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.addr).await?;
let mut s = TestScope::new(&mut stream);
tracing::info!("TCP connection established");
s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"starttls"));
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"required"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"starttls"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features"));
s.send(r#"<starttls/>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed"));
let buffer = s.buffer;
let connector = TlsConnector::from(Arc::new(
ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(Arc::new(IgnoreCertVerification))
.with_no_client_auth(),
));
tracing::info!("Initiating TLS connection...");
let mut stream = connector.connect(ServerName::IpAddress(server.addr.ip()), stream).await?;
tracing::info!("TLS connection established");
server.terminate().await?;
assert_eq!(stream.read_u8().await.unwrap_err().kind(), ErrorKind::UnexpectedEof);
Ok(())
}