forked from lavina/lavina
1
0
Fork 0

Compare commits

...

7 Commits

10 changed files with 243 additions and 217 deletions

29
Cargo.lock generated
View File

@ -56,6 +56,18 @@ version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b34d609dfbaf33d6889b2b7106d3ca345eacad44200913df5ba02bfd31d2ba9" checksum = "9b34d609dfbaf33d6889b2b7106d3ca345eacad44200913df5ba02bfd31d2ba9"
[[package]]
name = "async-scoped"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7a6a57c8aeb40da1ec037f5d455836852f7a57e69e1b1ad3d8f38ac1d6cadf"
dependencies = [
"futures",
"pin-project",
"slab",
"tokio",
]
[[package]] [[package]]
name = "atoi" name = "atoi"
version = "2.0.0" version = "2.0.0"
@ -377,6 +389,21 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "futures"
version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]] [[package]]
name = "futures-channel" name = "futures-channel"
version = "0.3.28" version = "0.3.28"
@ -450,6 +477,7 @@ version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533"
dependencies = [ dependencies = [
"futures-channel",
"futures-core", "futures-core",
"futures-io", "futures-io",
"futures-macro", "futures-macro",
@ -754,6 +782,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"assert_matches", "assert_matches",
"async-scoped",
"derive_more", "derive_more",
"figment", "figment",
"futures-util", "futures-util",

View File

@ -25,6 +25,7 @@ quick-xml = { version = "0.30.0", features = ["async-tokio"] }
derive_more = "0.99.17" derive_more = "0.99.17"
uuid = { version = "1.3.0", features = ["v4"] } uuid = { version = "1.3.0", features = ["v4"] }
sqlx = { version = "0.7.0-alpha.2", features = ["sqlite", "runtime-tokio-rustls", "migrate"] } sqlx = { version = "0.7.0-alpha.2", features = ["sqlite", "runtime-tokio-rustls", "migrate"] }
async-scoped = { version = "0.7.1", features = ["use-tokio"] }
[dev-dependencies] [dev-dependencies]
assert_matches = "1.5.0" assert_matches = "1.5.0"

View File

@ -7,23 +7,17 @@
//! //!
//! A player actor is a serial handler of commands from a single player. It is preferable to run all per-player validations in the player actor, //! A player actor is a serial handler of commands from a single player. It is preferable to run all per-player validations in the player actor,
//! so that they don't overload the room actor. //! so that they don't overload the room actor.
use std::{ use std::collections::{HashMap, HashSet};
collections::{HashMap, HashSet}, use std::sync::RwLock;
sync::{Arc, RwLock},
};
use futures_util::FutureExt;
use prometheus::{IntGauge, Registry as MetricsRegistry}; use prometheus::{IntGauge, Registry as MetricsRegistry};
use serde::Serialize; use serde::Serialize;
use tokio::{ use tokio::sync::mpsc::{channel, Receiver, Sender};
sync::mpsc::{channel, Receiver, Sender},
task::JoinHandle,
};
use crate::{ use crate::util::table::{AnonTable, Key as AnonKey};
core::room::{RoomHandle, RoomId, RoomInfo, RoomRegistry}, use crate::prelude::*;
prelude::*, use crate::core::room::{RoomHandle, RoomId, RoomInfo, RoomRegistry};
util::table::{AnonTable, Key as AnonKey},
};
/// Opaque player identifier. Cannot contain spaces, must be shorter than 32. /// Opaque player identifier. Cannot contain spaces, must be shorter than 32.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize)]
@ -108,7 +102,7 @@ impl PlayerConnection {
} }
/// Handle to a player actor. /// Handle to a player actor.
#[derive(Clone)] #[derive(Clone, Debug)]
pub struct PlayerHandle { pub struct PlayerHandle {
tx: Sender<PlayerCommand>, tx: Sender<PlayerCommand>,
} }
@ -154,7 +148,9 @@ impl PlayerHandle {
} }
async fn send(&self, command: PlayerCommand) { async fn send(&self, command: PlayerCommand) {
let _ = self.tx.send(command).await; if let Err(e) = self.tx.send(command).await {
log::warn!("Failed to send command to a player: {e:?}");
}
} }
pub async fn update(&self, update: Updates) { pub async fn update(&self, update: Updates) {
@ -162,6 +158,7 @@ impl PlayerHandle {
} }
} }
#[derive(Debug)]
enum PlayerCommand { enum PlayerCommand {
/** Commands from connections */ /** Commands from connections */
AddConnection { AddConnection {
@ -174,8 +171,10 @@ enum PlayerCommand {
GetRooms(Promise<Vec<RoomInfo>>), GetRooms(Promise<Vec<RoomInfo>>),
/** Events from rooms */ /** Events from rooms */
Update(Updates), Update(Updates),
Stop,
} }
#[derive(Debug)]
pub enum Cmd { pub enum Cmd {
JoinRoom { JoinRoom {
room_id: RoomId, room_id: RoomId,
@ -197,6 +196,7 @@ pub enum Cmd {
}, },
} }
#[derive(Debug)]
pub enum JoinResult { pub enum JoinResult {
Success(RoomInfo), Success(RoomInfo),
Banned, Banned,
@ -227,74 +227,75 @@ pub enum Updates {
} }
/// Handle to a player registry — a shared data structure containing information about players. /// Handle to a player registry — a shared data structure containing information about players.
#[derive(Clone)] pub struct PlayerRegistry<'a>(RwLock<PlayerRegistryInner<'a>>);
pub struct PlayerRegistry(Arc<RwLock<PlayerRegistryInner>>); impl<'a> PlayerRegistry<'a> {
impl PlayerRegistry {
pub fn empty( pub fn empty(
room_registry: RoomRegistry, room_registry: &'a RoomRegistry<'a>,
metrics: &mut MetricsRegistry, metrics: &MetricsRegistry,
) -> Result<PlayerRegistry> { ) -> Result<PlayerRegistry<'a>> {
let metric_active_players = let metric_active_players =
IntGauge::new("chat_players_active", "Number of alive player actors")?; IntGauge::new("chat_players_active", "Number of alive player actors")?;
metrics.register(Box::new(metric_active_players.clone()))?; metrics.register(Box::new(metric_active_players.clone()))?;
let scope = unsafe { Scope::create() };
let inner = PlayerRegistryInner { let inner = PlayerRegistryInner {
room_registry, room_registry,
players: HashMap::new(), players: HashMap::new(),
metric_active_players, metric_active_players,
scope,
}; };
Ok(PlayerRegistry(Arc::new(RwLock::new(inner)))) Ok(PlayerRegistry(RwLock::new(inner)))
} }
pub async fn get_or_create_player(&mut self, id: PlayerId) -> PlayerHandle { pub async fn get_or_create_player(&self, id: PlayerId) -> PlayerHandle {
let mut inner = self.0.write().unwrap(); let mut inner = self.0.write().unwrap();
if let Some((handle, _)) = inner.players.get(&id) { if let Some(handle) = inner.players.get(&id) {
handle.clone() handle.clone()
} else { } else {
let (handle, fiber) = Player::launch(id.clone(), inner.room_registry.clone()); let handle = Player::launch(id.clone(), inner.room_registry, &mut inner.scope);
inner.players.insert(id, (handle.clone(), fiber)); inner.players.insert(id, handle.clone());
inner.metric_active_players.inc(); inner.metric_active_players.inc();
handle handle
} }
} }
pub async fn connect_to_player(&mut self, id: PlayerId) -> PlayerConnection { pub async fn connect_to_player(&self, id: PlayerId) -> PlayerConnection {
let player_handle = self.get_or_create_player(id).await; let player_handle = self.get_or_create_player(id).await;
player_handle.subscribe().await player_handle.subscribe().await
} }
pub async fn shutdown_all(&mut self) -> Result<()> { pub async fn shutdown_all(self) -> Result<()> {
let mut inner = self.0.write().unwrap(); let mut inner = self.0.write().unwrap();
let mut players = HashMap::new(); for (i, k) in inner.players.drain() {
std::mem::swap(&mut players, &mut inner.players); k.send(PlayerCommand::Stop).await;
for (i, (k, j)) in inner.players.drain() {
drop(k); drop(k);
j.await?; log::debug!("Stopping player #{i:?}")
log::debug!("Player stopped #{i:?}")
} }
let _ = inner.scope.collect().await;
log::debug!("All players stopped"); log::debug!("All players stopped");
Ok(()) Ok(())
} }
} }
/// The player registry state representation. /// The player registry state representation.
struct PlayerRegistryInner { struct PlayerRegistryInner<'a> {
room_registry: RoomRegistry, room_registry: &'a RoomRegistry<'a>,
players: HashMap<PlayerId, (PlayerHandle, JoinHandle<Player>)>, players: HashMap<PlayerId, PlayerHandle>,
metric_active_players: IntGauge, metric_active_players: IntGauge,
scope: Scope<'a>,
} }
/// Player actor inner state representation. /// Player actor inner state representation.
struct Player { struct Player<'a> {
player_id: PlayerId, player_id: PlayerId,
connections: AnonTable<Sender<Updates>>, connections: AnonTable<Sender<Updates>>,
my_rooms: HashMap<RoomId, RoomHandle>, my_rooms: HashMap<RoomId, RoomHandle<'a>>,
banned_from: HashSet<RoomId>, banned_from: HashSet<RoomId>,
rx: Receiver<PlayerCommand>, rx: Receiver<PlayerCommand>,
handle: PlayerHandle, handle: PlayerHandle,
rooms: RoomRegistry, rooms: &'a RoomRegistry<'a>,
} }
impl Player { impl<'a> Player<'a> {
fn launch(player_id: PlayerId, rooms: RoomRegistry) -> (PlayerHandle, JoinHandle<Player>) { fn launch(player_id: PlayerId, rooms: &'a RoomRegistry<'a>, scope: &mut Scope<'a>) -> PlayerHandle {
let (tx, rx) = channel(32); let (tx, rx) = channel(32);
let handle = PlayerHandle { tx }; let handle = PlayerHandle { tx };
let handle_clone = handle.clone(); let handle_clone = handle.clone();
@ -307,11 +308,11 @@ impl Player {
handle, handle,
rooms, rooms,
}; };
let fiber = tokio::task::spawn(player.main_loop()); scope.spawn(player.main_loop().map(|_| ()));
(handle_clone, fiber) handle_clone
} }
async fn main_loop(mut self) -> Self { async fn main_loop(mut self) -> Player<'a> {
while let Some(cmd) = self.rx.recv().await { while let Some(cmd) = self.rx.recv().await {
match cmd { match cmd {
PlayerCommand::AddConnection { sender, promise } => { PlayerCommand::AddConnection { sender, promise } => {
@ -348,6 +349,7 @@ impl Player {
} }
} }
PlayerCommand::Cmd(cmd, connection_id) => self.handle_cmd(cmd, connection_id).await, PlayerCommand::Cmd(cmd, connection_id) => self.handle_cmd(cmd, connection_id).await,
PlayerCommand::Stop => { break; }
} }
} }
log::debug!("Shutting down player actor #{:?}", self.player_id); log::debug!("Shutting down player actor #{:?}", self.player_id);

View File

@ -1,7 +1,6 @@
//! Storage and persistence logic. //! Storage and persistence logic.
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc;
use serde::Deserialize; use serde::Deserialize;
use sqlx::sqlite::SqliteConnectOptions; use sqlx::sqlite::SqliteConnectOptions;
@ -15,9 +14,8 @@ pub struct StorageConfig {
pub db_path: String, pub db_path: String,
} }
#[derive(Clone)]
pub struct Storage { pub struct Storage {
conn: Arc<Mutex<SqliteConnection>>, conn: Mutex<SqliteConnection>,
} }
impl Storage { impl Storage {
pub async fn open(config: StorageConfig) -> Result<Storage> { pub async fn open(config: StorageConfig) -> Result<Storage> {
@ -29,11 +27,11 @@ impl Storage {
migrator.run(&mut conn).await?; migrator.run(&mut conn).await?;
log::info!("Migrations passed"); log::info!("Migrations passed");
let conn = Arc::new(Mutex::new(conn)); let conn = Mutex::new(conn);
Ok(Storage { conn }) Ok(Storage { conn })
} }
pub async fn retrieve_user_by_name(&mut self, name: &str) -> Result<Option<StoredUser>> { pub async fn retrieve_user_by_name(&self, name: &str) -> Result<Option<StoredUser>> {
let mut executor = self.conn.lock().await; let mut executor = self.conn.lock().await;
let res = sqlx::query_as( let res = sqlx::query_as(
"select u.id, u.name, c.password "select u.id, u.name, c.password
@ -47,7 +45,7 @@ impl Storage {
Ok(res) Ok(res)
} }
pub async fn retrieve_room_by_name(&mut self, name: &str) -> Result<Option<StoredRoom>> { pub async fn retrieve_room_by_name(&self, name: &str) -> Result<Option<StoredRoom>> {
let mut executor = self.conn.lock().await; let mut executor = self.conn.lock().await;
let res = sqlx::query_as( let res = sqlx::query_as(
"select id, name, topic, message_count "select id, name, topic, message_count
@ -61,7 +59,7 @@ impl Storage {
Ok(res) Ok(res)
} }
pub async fn create_new_room(&mut self, name: &str, topic: &str) -> Result<u32> { pub async fn create_new_room(&self, name: &str, topic: &str) -> Result<u32> {
let mut executor = self.conn.lock().await; let mut executor = self.conn.lock().await;
let (id,): (u32,) = sqlx::query_as( let (id,): (u32,) = sqlx::query_as(
"insert into rooms(name, topic) "insert into rooms(name, topic)
@ -76,9 +74,9 @@ impl Storage {
Ok(id) Ok(id)
} }
pub async fn insert_message(&mut self, room_id: u32, id: u32, content: &str) -> Result<()> { pub async fn insert_message(&self, room_id: u32, id: u32, content: &str) -> Result<()> {
let mut executor = self.conn.lock().await; let mut executor = self.conn.lock().await;
let (_,): (u32,) = sqlx::query_as( sqlx::query(
"insert into messages(room_id, id, content) "insert into messages(room_id, id, content)
values (?, ?, ?); values (?, ?, ?);
update rooms set message_count = message_count + 1 where id = ?;", update rooms set message_count = message_count + 1 where id = ?;",
@ -87,18 +85,14 @@ impl Storage {
.bind(id) .bind(id)
.bind(content) .bind(content)
.bind(room_id) .bind(room_id)
.fetch_one(&mut *executor) .execute(&mut *executor)
.await?; .await?;
Ok(()) Ok(())
} }
pub async fn close(mut self) -> Result<()> { pub async fn close(self) -> Result<()> {
let res = match Arc::try_unwrap(self.conn) { let res = self.conn.into_inner();
Ok(e) => e,
Err(e) => return Err(fail("failed to acquire DB ownership on shutdown")),
};
let res = res.into_inner();
res.close().await?; res.close().await?;
Ok(()) Ok(())
} }

View File

@ -38,10 +38,9 @@ impl RoomId {
} }
/// Shared datastructure for storing metadata about rooms. /// Shared datastructure for storing metadata about rooms.
#[derive(Clone)] pub struct RoomRegistry<'a>(AsyncRwLock<RoomRegistryInner<'a>>);
pub struct RoomRegistry(Arc<AsyncRwLock<RoomRegistryInner>>); impl<'a> RoomRegistry<'a> {
impl RoomRegistry { pub fn new(metrics: &mut MetricRegistry, storage: &'a Storage) -> Result<RoomRegistry<'a>> {
pub fn new(metrics: &mut MetricRegistry, storage: Storage) -> Result<RoomRegistry> {
let metric_active_rooms = let metric_active_rooms =
IntGauge::new("chat_rooms_active", "Number of alive room actors")?; IntGauge::new("chat_rooms_active", "Number of alive room actors")?;
metrics.register(Box::new(metric_active_rooms.clone()))?; metrics.register(Box::new(metric_active_rooms.clone()))?;
@ -50,10 +49,10 @@ impl RoomRegistry {
metric_active_rooms, metric_active_rooms,
storage, storage,
}; };
Ok(RoomRegistry(Arc::new(AsyncRwLock::new(inner)))) Ok(RoomRegistry(AsyncRwLock::new(inner)))
} }
pub async fn get_or_create_room(&mut self, room_id: RoomId) -> Result<RoomHandle> { pub async fn get_or_create_room(&self, room_id: RoomId) -> Result<RoomHandle<'a>> {
let mut inner = self.0.write().await; let mut inner = self.0.write().await;
if let Some(room_handle) = inner.rooms.get(&room_id) { if let Some(room_handle) = inner.rooms.get(&room_id) {
// room was already loaded into memory // room was already loaded into memory
@ -68,7 +67,7 @@ impl RoomRegistry {
subscriptions: HashMap::new(), // TODO figure out how to populate subscriptions subscriptions: HashMap::new(), // TODO figure out how to populate subscriptions
topic: stored_room.topic.into(), topic: stored_room.topic.into(),
message_count: stored_room.message_count, message_count: stored_room.message_count,
storage: inner.storage.clone(), storage: inner.storage,
}; };
let room_handle = RoomHandle(Arc::new(AsyncRwLock::new(room))); let room_handle = RoomHandle(Arc::new(AsyncRwLock::new(room)));
inner.rooms.insert(room_id, room_handle.clone()); inner.rooms.insert(room_id, room_handle.clone());
@ -85,7 +84,7 @@ impl RoomRegistry {
subscriptions: HashMap::new(), subscriptions: HashMap::new(),
topic: topic.into(), topic: topic.into(),
message_count: 0, message_count: 0,
storage: inner.storage.clone(), storage: inner.storage,
}; };
let room_handle = RoomHandle(Arc::new(AsyncRwLock::new(room))); let room_handle = RoomHandle(Arc::new(AsyncRwLock::new(room)));
inner.rooms.insert(room_id, room_handle.clone()); inner.rooms.insert(room_id, room_handle.clone());
@ -94,7 +93,7 @@ impl RoomRegistry {
} }
} }
pub async fn get_room(&self, room_id: &RoomId) -> Option<RoomHandle> { pub async fn get_room(&self, room_id: &RoomId) -> Option<RoomHandle<'a>> {
let inner = self.0.read().await; let inner = self.0.read().await;
let res = inner.rooms.get(room_id); let res = inner.rooms.get(room_id);
res.map(|r| r.clone()) res.map(|r| r.clone())
@ -114,15 +113,15 @@ impl RoomRegistry {
} }
} }
struct RoomRegistryInner { struct RoomRegistryInner<'a> {
rooms: HashMap<RoomId, RoomHandle>, rooms: HashMap<RoomId, RoomHandle<'a>>,
metric_active_rooms: IntGauge, metric_active_rooms: IntGauge,
storage: Storage, storage: &'a Storage,
} }
#[derive(Clone)] #[derive(Clone)]
pub struct RoomHandle(Arc<AsyncRwLock<Room>>); pub struct RoomHandle<'a>(Arc<AsyncRwLock<Room<'a>>>);
impl RoomHandle { impl<'a> RoomHandle<'a> {
pub async fn subscribe(&self, player_id: PlayerId, player_handle: PlayerHandle) { pub async fn subscribe(&self, player_id: PlayerId, player_handle: PlayerHandle) {
let mut lock = self.0.write().await; let mut lock = self.0.write().await;
lock.add_subscriber(player_id, player_handle).await; lock.add_subscriber(player_id, player_handle).await;
@ -140,7 +139,10 @@ impl RoomHandle {
pub async fn send_message(&self, player_id: PlayerId, body: Str) { pub async fn send_message(&self, player_id: PlayerId, body: Str) {
let mut lock = self.0.write().await; let mut lock = self.0.write().await;
lock.send_message(player_id, body).await; let res = lock.send_message(player_id, body).await;
if let Err(err) = res {
log::warn!("Failed to send message: {err:?}");
}
} }
pub async fn get_room_info(&self) -> RoomInfo { pub async fn get_room_info(&self) -> RoomInfo {
@ -167,15 +169,15 @@ impl RoomHandle {
} }
} }
struct Room { struct Room<'a> {
storage_id: u32, storage_id: u32,
room_id: RoomId, room_id: RoomId,
subscriptions: HashMap<PlayerId, PlayerHandle>, subscriptions: HashMap<PlayerId, PlayerHandle>,
message_count: u32, message_count: u32,
topic: Str, topic: Str,
storage: Storage, storage: &'a Storage,
} }
impl Room { impl<'a> Room<'a> {
async fn add_subscriber(&mut self, player_id: PlayerId, player_handle: PlayerHandle) { async fn add_subscriber(&mut self, player_id: PlayerId, player_handle: PlayerHandle) {
tracing::info!("Adding a subscriber to room"); tracing::info!("Adding a subscriber to room");
self.subscriptions.insert(player_id.clone(), player_handle); self.subscriptions.insert(player_id.clone(), player_handle);
@ -200,6 +202,7 @@ impl Room {
} }
async fn broadcast_update(&self, update: Updates, except: &PlayerId) { async fn broadcast_update(&self, update: Updates, except: &PlayerId) {
tracing::debug!("Broadcasting an update to {} subs", self.subscriptions.len());
for (player_id, sub) in &self.subscriptions { for (player_id, sub) in &self.subscriptions {
if player_id == except { if player_id == except {
continue; continue;
@ -210,7 +213,7 @@ impl Room {
} }
} }
#[derive(Serialize)] #[derive(Serialize, Debug)]
pub struct RoomInfo { pub struct RoomInfo {
pub id: RoomId, pub id: RoomId,
pub members: Vec<PlayerId>, pub members: Vec<PlayerId>,

View File

@ -53,23 +53,26 @@ async fn main() -> Result<()> {
} = config; } = config;
let mut metrics = MetricsRegistry::new(); let mut metrics = MetricsRegistry::new();
let storage = Storage::open(storage_config).await?; let storage = Storage::open(storage_config).await?;
let rooms = RoomRegistry::new(&mut metrics, storage.clone())?; let rooms = RoomRegistry::new(&mut metrics, &storage)?;
let mut players = PlayerRegistry::empty(rooms.clone(), &mut metrics)?; let players = PlayerRegistry::empty(&rooms, &metrics)?;
let telemetry_terminator =
util::telemetry::launch(telemetry_config, metrics.clone(), rooms.clone()).await?; // unsafe: outer future is never dropped, scope is joined on `scope.collect`
let irc = projections::irc::launch(irc_config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await?; let mut scope = unsafe { Scope::create() };
let xmpp = projections::xmpp::launch(xmpp_config, players.clone(), rooms.clone(), metrics.clone()).await?; let telemetry_terminator = util::telemetry::launch(telemetry_config, &metrics, &rooms, &mut scope).await?;
let irc = projections::irc::launch(&irc_config, &players, &rooms, &metrics, &storage, &mut scope).await?;
let xmpp = projections::xmpp::launch(xmpp_config, &players, &rooms, &metrics, &mut scope).await?;
tracing::info!("Started"); tracing::info!("Started");
sleep.await; sleep.await;
tracing::info!("Begin shutdown"); tracing::info!("Begin shutdown");
xmpp.terminate().await?; let _ = xmpp.send(());
irc.terminate().await?; let _ = irc.send(());
telemetry_terminator.terminate().await?; let _ = telemetry_terminator.send(());
let _ = scope.collect().await;
drop(scope);
players.shutdown_all().await?; players.shutdown_all().await?;
drop(players);
drop(rooms);
storage.close().await?; storage.close().await?;
tracing::info!("Shutdown complete"); tracing::info!("Shutdown complete");
Ok(()) Ok(())

View File

@ -23,3 +23,5 @@ macro_rules! ffail {
} }
pub(crate) use ffail; pub(crate) use ffail;
pub type Scope<'a> = async_scoped::Scope<'a, (), async_scoped::Tokio>;

View File

@ -1,7 +1,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use futures_util::future::join_all; use futures_util::FutureExt;
use prometheus::{IntCounter, IntGauge, Registry as MetricsRegistry}; use prometheus::{IntCounter, IntGauge, Registry as MetricsRegistry};
use serde::Deserialize; use serde::Deserialize;
use tokio::io::{AsyncBufReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}; use tokio::io::{AsyncBufReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
@ -16,7 +16,6 @@ use crate::prelude::*;
use crate::protos::irc::client::{client_message, ClientMessage}; use crate::protos::irc::client::{client_message, ClientMessage};
use crate::protos::irc::server::{AwayStatus, ServerMessage, ServerMessageBody}; use crate::protos::irc::server::{AwayStatus, ServerMessage, ServerMessageBody};
use crate::protos::irc::{Chan, Recipient}; use crate::protos::irc::{Chan, Recipient};
use crate::util::Terminator;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
@ -43,10 +42,10 @@ async fn handle_socket(
config: ServerConfig, config: ServerConfig,
mut stream: TcpStream, mut stream: TcpStream,
socket_addr: &SocketAddr, socket_addr: &SocketAddr,
players: PlayerRegistry, players: &PlayerRegistry<'_>,
rooms: RoomRegistry, rooms: &RoomRegistry<'_>,
termination: Deferred<()>, // TODO use it to stop the connection gracefully termination: Deferred<()>, // TODO use it to stop the connection gracefully
mut storage: Storage, storage: &Storage,
) -> Result<()> { ) -> Result<()> {
let (reader, writer) = stream.split(); let (reader, writer) = stream.split();
let mut reader: BufReader<ReadHalf> = BufReader::new(reader); let mut reader: BufReader<ReadHalf> = BufReader::new(reader);
@ -66,7 +65,7 @@ async fn handle_socket(
writer.flush().await?; writer.flush().await?;
let registered_user: Result<RegisteredUser> = let registered_user: Result<RegisteredUser> =
handle_registration(&mut reader, &mut writer, &mut storage).await; handle_registration(&mut reader, &mut writer, &storage).await;
match registered_user { match registered_user {
Ok(user) => { Ok(user) => {
@ -83,7 +82,7 @@ async fn handle_socket(
async fn handle_registration<'a>( async fn handle_registration<'a>(
reader: &mut BufReader<ReadHalf<'a>>, reader: &mut BufReader<ReadHalf<'a>>,
writer: &mut BufWriter<WriteHalf<'a>>, writer: &mut BufWriter<WriteHalf<'a>>,
storage: &mut Storage, storage: &Storage,
) -> Result<RegisteredUser> { ) -> Result<RegisteredUser> {
let mut buffer = vec![]; let mut buffer = vec![];
@ -176,8 +175,8 @@ async fn handle_registration<'a>(
async fn handle_registered_socket<'a>( async fn handle_registered_socket<'a>(
config: ServerConfig, config: ServerConfig,
mut players: PlayerRegistry, players: &PlayerRegistry<'_>,
rooms: RoomRegistry, rooms: &RoomRegistry<'_>,
reader: &mut BufReader<ReadHalf<'a>>, reader: &mut BufReader<ReadHalf<'a>>,
writer: &mut BufWriter<WriteHalf<'a>>, writer: &mut BufWriter<WriteHalf<'a>>,
user: RegisteredUser, user: RegisteredUser,
@ -301,7 +300,7 @@ async fn handle_update(
user: &RegisteredUser, user: &RegisteredUser,
player_id: &PlayerId, player_id: &PlayerId,
writer: &mut (impl AsyncWrite + Unpin), writer: &mut (impl AsyncWrite + Unpin),
rooms: &RoomRegistry, rooms: &RoomRegistry<'_>,
update: Updates, update: Updates,
) -> Result<()> { ) -> Result<()> {
log::debug!("Sending irc message to player {player_id:?} on update {update:?}"); log::debug!("Sending irc message to player {player_id:?} on update {update:?}");
@ -398,7 +397,7 @@ async fn handle_incoming_message(
buffer: &str, buffer: &str,
config: &ServerConfig, config: &ServerConfig,
user: &RegisteredUser, user: &RegisteredUser,
rooms: &RoomRegistry, rooms: &RoomRegistry<'_>,
user_handle: &mut PlayerConnection, user_handle: &mut PlayerConnection,
writer: &mut (impl AsyncWrite + Unpin), writer: &mut (impl AsyncWrite + Unpin),
) -> Result<HandleResult> { ) -> Result<HandleResult> {
@ -689,13 +688,14 @@ async fn produce_on_join_cmd_messages(
Ok(()) Ok(())
} }
pub async fn launch( pub async fn launch<'a>(
config: ServerConfig, config: &'a ServerConfig,
players: PlayerRegistry, players: &'a PlayerRegistry<'_>,
rooms: RoomRegistry, rooms: &'a RoomRegistry<'_>,
metrics: MetricsRegistry, metrics: &'a MetricsRegistry,
storage: Storage, storage: &'a Storage,
) -> Result<Terminator> { scope: &mut Scope<'a>,
) -> Result<Promise<()>> {
log::info!("Starting IRC projection"); log::info!("Starting IRC projection");
let (stopped_tx, mut stopped_rx) = channel(32); let (stopped_tx, mut stopped_rx) = channel(32);
let current_connections = let current_connections =
@ -709,10 +709,12 @@ pub async fn launch(
let listener = TcpListener::bind(config.listen_on).await?; let listener = TcpListener::bind(config.listen_on).await?;
let terminator = Terminator::spawn(|mut rx| async move { let (signal, mut rx) = oneshot();
let future = async move {
// TODO probably should separate logic for accepting new connection and storing them // TODO probably should separate logic for accepting new connection and storing them
// into two tasks so that they don't block each other // into two tasks so that they don't block each other
let mut actors = HashMap::new(); let mut actors = HashMap::new();
let mut scope = unsafe { Scope::create() };
loop { loop {
select! { select! {
biased; biased;
@ -735,23 +737,22 @@ pub async fn launch(
continue; continue;
} }
let terminator = Terminator::spawn(|termination| {
let players = players.clone();
let rooms = rooms.clone();
let current_connections_clone = current_connections.clone(); let current_connections_clone = current_connections.clone();
let stopped_tx = stopped_tx.clone(); let stopped_tx = stopped_tx.clone();
let storage = storage.clone();
async move { let (a, rx) = oneshot();
match handle_socket(config, stream, &socket_addr, players, rooms, termination, storage).await { let future = async move {
match handle_socket(config, stream, &socket_addr, players, rooms, rx, storage).await {
Ok(_) => log::info!("Connection terminated"), Ok(_) => log::info!("Connection terminated"),
Err(err) => log::warn!("Connection failed: {err}"), Err(err) => log::warn!("Connection failed: {err}"),
} }
current_connections_clone.dec(); current_connections_clone.dec();
stopped_tx.send(socket_addr).await?; stopped_tx.send(socket_addr).await?;
Ok(()) Ok(())
} };
}); scope.spawn(future.map(|_: Result<()>| ()));
actors.insert(socket_addr, terminator); actors.insert(socket_addr, a);
}, },
Err(err) => log::warn!("Failed to accept new connection: {err}"), Err(err) => log::warn!("Failed to accept new connection: {err}"),
} }
@ -760,19 +761,20 @@ pub async fn launch(
} }
log::info!("Stopping IRC projection"); log::info!("Stopping IRC projection");
join_all(actors.into_iter().map(|(socket_addr, terminator)| async move { for (socket_addr, terminator) in actors {
log::debug!("Stopping IRC connection at {socket_addr}"); match terminator.send(()) {
match terminator.terminate().await { Ok(_) => log::debug!("Stopping IRC connection at {socket_addr}"),
Ok(_) => log::debug!("Stopped IRC connection at {socket_addr}"), Err(_) => log::debug!("IRC connection at {socket_addr} already stopped")
Err(err) => {
log::warn!("IRC connection to {socket_addr} finished with error: {err}")
} }
} }
})).await; let _ = scope.collect().await;
drop(scope);
log::info!("Stopped IRC projection"); log::info!("Stopped IRC projection");
Ok(()) Ok(())
}); };
scope.spawn(future.map(|_: Result<()>| ()));
log::info!("Started IRC projection"); log::info!("Started IRC projection");
Ok(terminator) Ok(signal)
} }

View File

@ -7,7 +7,7 @@ use std::net::SocketAddr;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use futures_util::future::join_all; use futures_util::FutureExt;
use prometheus::Registry as MetricsRegistry; use prometheus::Registry as MetricsRegistry;
use quick_xml::events::{BytesDecl, Event}; use quick_xml::events::{BytesDecl, Event};
use quick_xml::{NsReader, Writer}; use quick_xml::{NsReader, Writer};
@ -30,7 +30,6 @@ use crate::protos::xmpp::roster::RosterQuery;
use crate::protos::xmpp::session::Session; use crate::protos::xmpp::session::Session;
use crate::protos::xmpp::stream::*; use crate::protos::xmpp::stream::*;
use crate::util::xml::{Continuation, FromXml, Parser, ToXml}; use crate::util::xml::{Continuation, FromXml, Parser, ToXml};
use crate::util::Terminator;
use self::proto::{ClientPacket, IqClientBody}; use self::proto::{ClientPacket, IqClientBody};
@ -53,12 +52,13 @@ struct Authenticated {
xmpp_muc_name: Resource, xmpp_muc_name: Resource,
} }
pub async fn launch( pub async fn launch<'a>(
config: ServerConfig, config: ServerConfig,
players: PlayerRegistry, players: &'a PlayerRegistry<'_>,
rooms: RoomRegistry, rooms: &'a RoomRegistry<'_>,
metrics: MetricsRegistry, metrics: &'a MetricsRegistry,
) -> Result<Terminator> { scope: &mut Scope<'a>,
) -> Result<Promise<()>> {
log::info!("Starting XMPP projection"); log::info!("Starting XMPP projection");
let certs = certs(&mut SyncBufReader::new(File::open(config.cert)?))?; let certs = certs(&mut SyncBufReader::new(File::open(config.cert)?))?;
@ -75,13 +75,16 @@ pub async fn launch(
}); });
let listener = TcpListener::bind(config.listen_on).await?; let listener = TcpListener::bind(config.listen_on).await?;
let terminator = Terminator::spawn(|mut termination| async move {
let (signal, mut rx) = oneshot();
let future = async move {
let (stopped_tx, mut stopped_rx) = channel(32); let (stopped_tx, mut stopped_rx) = channel(32);
let mut actors = HashMap::new(); let mut actors = HashMap::new();
let mut scope = unsafe { Scope::create() };
loop { loop {
select! { select! {
biased; biased;
_ = &mut termination => break, _ = &mut rx => break,
stopped = stopped_rx.recv() => match stopped { stopped = stopped_rx.recv() => match stopped {
Some(stopped) => { let _ = actors.remove(&stopped); }, Some(stopped) => { let _ = actors.remove(&stopped); },
None => unreachable!(), None => unreachable!(),
@ -95,22 +98,20 @@ pub async fn launch(
// TODO kill the older connection and restart it // TODO kill the older connection and restart it
continue; continue;
} }
let players = players.clone();
let rooms = rooms.clone();
let terminator = Terminator::spawn(|termination| {
let stopped_tx = stopped_tx.clone(); let stopped_tx = stopped_tx.clone();
let loaded_config = loaded_config.clone(); let loaded_config = loaded_config.clone();
async move { let (a, rx) = oneshot();
match handle_socket(loaded_config, stream, &socket_addr, players, rooms, termination).await { let future = async move {
match handle_socket(loaded_config, stream, &socket_addr, players, rooms, rx).await {
Ok(_) => log::info!("Connection terminated"), Ok(_) => log::info!("Connection terminated"),
Err(err) => log::warn!("Connection failed: {err}"), Err(err) => log::warn!("Connection failed: {err}"),
} }
stopped_tx.send(socket_addr).await?; stopped_tx.send(socket_addr).await?;
Ok(()) Ok(())
} };
}); scope.spawn(future.map(|_: Result<()>| ()));
actors.insert(socket_addr, terminator); actors.insert(socket_addr, a);
}, },
Err(err) => log::warn!("Failed to accept new connection: {err}"), Err(err) => log::warn!("Failed to accept new connection: {err}"),
} }
@ -118,35 +119,28 @@ pub async fn launch(
} }
} }
log::info!("Stopping XMPP projection"); log::info!("Stopping XMPP projection");
join_all( for (socket_addr, terminator) in actors {
actors match terminator.send(()) {
.into_iter() Ok(_) => log::debug!("Stopping XMPP connection at {socket_addr}"),
.map(|(socket_addr, terminator)| async move { Err(_) => log::debug!("XMPP connection at {socket_addr} already stopped")
log::debug!("Stopping XMPP connection at {socket_addr}");
match terminator.terminate().await {
Ok(_) => log::debug!("Stopped XMPP connection at {socket_addr}"),
Err(err) => {
log::warn!(
"XMPP connection to {socket_addr} finished with error: {err}"
)
} }
} }
}), let _ = scope.collect().await;
) drop(scope);
.await;
log::info!("Stopped XMPP projection"); log::info!("Stopped XMPP projection");
Ok(()) Ok(())
}); };
scope.spawn(future.map(|_: Result<()>| ()));
log::info!("Started XMPP projection"); log::info!("Started XMPP projection");
Ok(terminator) Ok(signal)
} }
async fn handle_socket( async fn handle_socket(
config: Arc<LoadedConfig>, config: Arc<LoadedConfig>,
mut stream: TcpStream, mut stream: TcpStream,
socket_addr: &SocketAddr, socket_addr: &SocketAddr,
mut players: PlayerRegistry, players: &PlayerRegistry<'_>,
rooms: RoomRegistry, rooms: &RoomRegistry<'_>,
termination: Deferred<()>, // TODO use it to stop the connection gracefully termination: Deferred<()>, // TODO use it to stop the connection gracefully
) -> Result<()> { ) -> Result<()> {
log::debug!("Received an XMPP connection from {socket_addr}"); log::debug!("Received an XMPP connection from {socket_addr}");
@ -271,7 +265,7 @@ async fn socket_final(
reader_buf: &mut Vec<u8>, reader_buf: &mut Vec<u8>,
authenticated: &Authenticated, authenticated: &Authenticated,
user_handle: &mut PlayerConnection, user_handle: &mut PlayerConnection,
rooms: &RoomRegistry, rooms: &RoomRegistry<'_>,
) -> Result<()> { ) -> Result<()> {
read_xml_header(xml_reader, reader_buf).await?; read_xml_header(xml_reader, reader_buf).await?;
let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?;
@ -381,7 +375,7 @@ async fn handle_packet(
packet: ClientPacket, packet: ClientPacket,
user: &Authenticated, user: &Authenticated,
user_handle: &mut PlayerConnection, user_handle: &mut PlayerConnection,
rooms: &RoomRegistry, rooms: &RoomRegistry<'_>,
) -> Result<bool> { ) -> Result<bool> {
Ok(match packet { Ok(match packet {
proto::ClientPacket::Iq(iq) => { proto::ClientPacket::Iq(iq) => {
@ -478,7 +472,7 @@ async fn handle_packet(
}) })
} }
async fn handle_iq(output: &mut Vec<Event<'static>>, iq: Iq<IqClientBody>, rooms: &RoomRegistry) { async fn handle_iq(output: &mut Vec<Event<'static>>, iq: Iq<IqClientBody>, rooms: &RoomRegistry<'_>) {
match iq.body { match iq.body {
proto::IqClientBody::Bind(b) => { proto::IqClientBody::Bind(b) => {
let req = Iq { let req = Iq {
@ -590,7 +584,7 @@ fn disco_info(to: Option<&str>, req: &InfoQuery) -> InfoQuery {
} }
} }
async fn disco_items(to: Option<&str>, req: &ItemQuery, rooms: &RoomRegistry) -> ItemQuery { async fn disco_items(to: Option<&str>, req: &ItemQuery, rooms: &RoomRegistry<'_>) -> ItemQuery {
let item = match to { let item = match to {
Some("localhost") => { Some("localhost") => {
vec![Item { vec![Item {

View File

@ -15,7 +15,6 @@ use crate::core::room::RoomRegistry;
use crate::prelude::*; use crate::prelude::*;
use crate::util::http::*; use crate::util::http::*;
use crate::util::Terminator;
type BoxBody = http_body_util::combinators::BoxBody<Bytes, Infallible>; type BoxBody = http_body_util::combinators::BoxBody<Bytes, Infallible>;
type HttpResult<T> = std::result::Result<T, Infallible>; type HttpResult<T> = std::result::Result<T, Infallible>;
@ -25,37 +24,28 @@ pub struct ServerConfig {
pub listen_on: SocketAddr, pub listen_on: SocketAddr,
} }
pub async fn launch( pub async fn launch<'a>(
config: ServerConfig, config: ServerConfig,
metrics: MetricsRegistry, metrics: &'a MetricsRegistry,
rooms: RoomRegistry, rooms: &'a RoomRegistry<'_>,
) -> Result<Terminator> { scope: &mut Scope<'a>,
) -> Result<Promise<()>> {
log::info!("Starting the telemetry service"); log::info!("Starting the telemetry service");
let listener = TcpListener::bind(config.listen_on).await?; let listener = TcpListener::bind(config.listen_on).await?;
log::debug!("Listener started"); log::debug!("Listener started");
let terminator = Terminator::spawn(|rx| main_loop(listener, metrics, rooms, rx.map(|_| ())));
Ok(terminator)
}
async fn main_loop( let (signal, mut rx) = oneshot();
listener: TcpListener,
metrics: MetricsRegistry, let future = async move {
rooms: RoomRegistry, let mut scope = unsafe { Scope::create() };
termination: impl Future<Output = ()>,
) -> Result<()> {
pin!(termination);
loop { loop {
select! { select! {
biased; biased;
_ = &mut termination => break, _ = &mut rx => break,
result = listener.accept() => { result = listener.accept() => {
let (stream, _) = result?; let (stream, _) = result?;
let metrics = metrics.clone(); scope.spawn(async move {
let rooms = rooms.clone(); let server = http1::Builder::new().serve_connection(stream, service_fn(move |r| route(metrics, rooms, r)));
tokio::task::spawn(async move {
let registry = metrics.clone();
let rooms = rooms.clone();
let server = http1::Builder::new().serve_connection(stream, service_fn(move |r| route(registry.clone(), rooms.clone(), r)));
if let Err(err) = server.await { if let Err(err) = server.await {
tracing::error!("Error serving connection: {:?}", err); tracing::error!("Error serving connection: {:?}", err);
} }
@ -63,13 +53,19 @@ async fn main_loop(
}, },
} }
} }
let _ = scope.collect().await;
drop(scope);
log::info!("Terminating the telemetry service"); log::info!("Terminating the telemetry service");
Ok(()) Ok(())
};
scope.spawn(future.map(|_: Result<()>| ()));
Ok(signal)
} }
async fn route( async fn route(
registry: MetricsRegistry, registry: &MetricsRegistry,
rooms: RoomRegistry, rooms: &RoomRegistry<'_>,
request: Request<hyper::body::Incoming>, request: Request<hyper::body::Incoming>,
) -> std::result::Result<Response<BoxBody>, Infallible> { ) -> std::result::Result<Response<BoxBody>, Infallible> {
match (request.method(), request.uri().path()) { match (request.method(), request.uri().path()) {
@ -79,7 +75,7 @@ async fn route(
} }
} }
fn endpoint_metrics(registry: MetricsRegistry) -> HttpResult<Response<Full<Bytes>>> { fn endpoint_metrics(registry: &MetricsRegistry) -> HttpResult<Response<Full<Bytes>>> {
let mf = registry.gather(); let mf = registry.gather();
let mut buffer = vec![]; let mut buffer = vec![];
TextEncoder TextEncoder
@ -88,7 +84,7 @@ fn endpoint_metrics(registry: MetricsRegistry) -> HttpResult<Response<Full<Bytes
Ok(Response::new(Full::new(Bytes::from(buffer)))) Ok(Response::new(Full::new(Bytes::from(buffer))))
} }
async fn endpoint_rooms(rooms: RoomRegistry) -> HttpResult<Response<Full<Bytes>>> { async fn endpoint_rooms(rooms: &RoomRegistry<'_>) -> HttpResult<Response<Full<Bytes>>> {
let room_list = rooms.get_all_rooms().await; let room_list = rooms.get_all_rooms().await;
let mut buffer = vec![]; let mut buffer = vec![];
serde_json::to_writer(&mut buffer, &room_list).expect("unexpected fail when writing to vec"); serde_json::to_writer(&mut buffer, &room_list).expect("unexpected fail when writing to vec");