diff --git a/crates/lavina-core/src/dialog.rs b/crates/lavina-core/src/dialog.rs index 66fe8b5..f87294c 100644 --- a/crates/lavina-core/src/dialog.rs +++ b/crates/lavina-core/src/dialog.rs @@ -130,6 +130,21 @@ impl DialogRegistry { let mut guard = self.0.write().await; guard.players = Some(players); } + + pub async fn unset_players(&self) { + let mut guard = self.0.write().await; + guard.players = None; + } + + pub fn shutdown(self) -> Result<()> { + let res = match Arc::try_unwrap(self.0) { + Ok(e) => e, + Err(_) => return Err(fail("failed to acquire dialogs ownership on shutdown")), + }; + let res = res.into_inner(); + drop(res); + Ok(()) + } } #[cfg(test)] diff --git a/crates/lavina-core/src/lib.rs b/crates/lavina-core/src/lib.rs index 1128c61..b251ed9 100644 --- a/crates/lavina-core/src/lib.rs +++ b/crates/lavina-core/src/lib.rs @@ -40,8 +40,10 @@ impl LavinaCore { pub async fn shutdown(mut self) -> Result<()> { self.players.shutdown_all().await?; - drop(self.players); - drop(self.rooms); + self.dialogs.unset_players().await; + self.players.shutdown()?; + self.dialogs.shutdown()?; + self.rooms.shutdown()?; Ok(()) } } diff --git a/crates/lavina-core/src/player.rs b/crates/lavina-core/src/player.rs index 4d6f6cb..2f5edb4 100644 --- a/crates/lavina-core/src/player.rs +++ b/crates/lavina-core/src/player.rs @@ -258,6 +258,16 @@ impl PlayerRegistry { Ok(PlayerRegistry(Arc::new(RwLock::new(inner)))) } + pub fn shutdown(self) -> Result<()> { + let res = match Arc::try_unwrap(self.0) { + Ok(e) => e, + Err(_) => return Err(fail("failed to acquire players ownership on shutdown")), + }; + let res = res.into_inner(); + drop(res); + Ok(()) + } + pub async fn get_player(&self, id: &PlayerId) -> Option { let inner = self.0.read().await; inner.players.get(id).map(|(handle, _)| handle.clone()) diff --git a/crates/lavina-core/src/room.rs b/crates/lavina-core/src/room.rs index 52ac7c4..d50e169 100644 --- a/crates/lavina-core/src/room.rs +++ b/crates/lavina-core/src/room.rs @@ -48,6 +48,17 @@ impl RoomRegistry { Ok(RoomRegistry(Arc::new(AsyncRwLock::new(inner)))) } + pub fn shutdown(self) -> Result<()> { + let res = match Arc::try_unwrap(self.0) { + Ok(e) => e, + Err(_) => return Err(fail("failed to acquire rooms ownership on shutdown")), + }; + let res = res.into_inner(); + // TODO drop all rooms + drop(res); + Ok(()) + } + pub async fn get_or_create_room(&mut self, room_id: RoomId) -> Result { let mut inner = self.0.write().await; if let Some(room_handle) = inner.get_or_load_room(&room_id).await? { diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index 2de4f9e..2b069ef 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -151,6 +151,13 @@ impl TestServer { server, }) } + + async fn shutdown(self) -> Result<()> { + self.server.terminate().await?; + self.core.shutdown().await?; + self.storage.close().await?; + Ok(()) + } } #[tokio::test] @@ -178,7 +185,7 @@ async fn scenario_basic() -> Result<()> { // wrap up - server.server.terminate().await?; + server.shutdown().await?; Ok(()) } @@ -248,7 +255,7 @@ async fn scenario_join_and_reboot() -> Result<()> { // wrap up - server.server.terminate().await?; + server.shutdown().await?; Ok(()) } @@ -314,7 +321,7 @@ async fn scenario_force_join_msg() -> Result<()> { // wrap up - server.server.terminate().await?; + server.shutdown().await?; Ok(()) } @@ -375,6 +382,11 @@ async fn scenario_two_users() -> Result<()> { s1.expect(":tester1 PART #test").await?; // The second user should receive the PART message s2.expect(":tester1 PART #test").await?; + + stream1.shutdown().await?; + stream2.shutdown().await?; + + server.shutdown().await?; Ok(()) } @@ -418,7 +430,7 @@ async fn scenario_cap_full_negotiation() -> Result<()> { // wrap up - server.server.terminate().await?; + server.shutdown().await?; Ok(()) } @@ -457,7 +469,7 @@ async fn scenario_cap_full_negotiation_nick_last() -> Result<()> { // wrap up - server.server.terminate().await?; + server.shutdown().await?; Ok(()) } @@ -495,7 +507,7 @@ async fn scenario_cap_short_negotiation() -> Result<()> { // wrap up - server.server.terminate().await?; + server.shutdown().await?; Ok(()) } @@ -539,7 +551,7 @@ async fn scenario_cap_sasl_fail() -> Result<()> { // wrap up - server.server.terminate().await?; + server.shutdown().await?; Ok(()) } @@ -562,7 +574,7 @@ async fn terminate_socket_scenario() -> Result<()> { s.send("AUTHENTICATE PLAIN").await?; s.expect(":testserver AUTHENTICATE +").await?; - server.server.terminate().await?; + server.shutdown().await?; assert_eq!(stream.read_u8().await.unwrap_err().kind(), ErrorKind::UnexpectedEof); Ok(()) @@ -633,7 +645,7 @@ async fn server_time_capability() -> Result<()> { // wrap up - server.server.terminate().await?; + server.shutdown().await?; Ok(()) } @@ -694,5 +706,10 @@ async fn scenario_two_players_dialog() -> Result<()> { s1.expect(":tester2 PRIVMSG tester1 :good").await?; s1.expect_nothing().await?; + stream1.shutdown().await?; + stream2.shutdown().await?; + + server.shutdown().await?; + Ok(()) } diff --git a/crates/projection-xmpp/tests/lib.rs b/crates/projection-xmpp/tests/lib.rs index bece5d9..88af83d 100644 --- a/crates/projection-xmpp/tests/lib.rs +++ b/crates/projection-xmpp/tests/lib.rs @@ -150,6 +150,13 @@ impl TestServer { server, }) } + + async fn shutdown(self) -> Result<()> { + self.server.terminate().await?; + self.core.shutdown().await?; + self.storage.close().await?; + Ok(()) + } } #[tokio::test] @@ -200,7 +207,7 @@ async fn scenario_basic() -> Result<()> { // wrap up - server.server.terminate().await?; + server.shutdown().await?; Ok(()) } @@ -250,7 +257,7 @@ async fn scenario_basic_without_headers() -> Result<()> { // wrap up - server.server.terminate().await?; + server.shutdown().await?; Ok(()) } @@ -291,7 +298,7 @@ async fn terminate_socket() -> Result<()> { let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?; tracing::info!("TLS connection established"); - server.server.terminate().await?; + server.shutdown().await?; assert_eq!(stream.read_u8().await.unwrap_err().kind(), ErrorKind::UnexpectedEof);