forked from lavina/lavina
1
0
Fork 0

implement persistent memberships

This commit is contained in:
Nikita Vilunov 2024-04-13 02:27:35 +02:00
parent fd694cd75c
commit 041c1b5fe2
8 changed files with 252 additions and 96 deletions

View File

@ -7,16 +7,16 @@
//!
//! A player actor is a serial handler of commands from a single player. It is preferable to run all per-player validations in the player actor,
//! so that they don't overload the room actor.
use std::{
collections::{HashMap, HashSet},
sync::{Arc, RwLock},
};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use prometheus::{IntGauge, Registry as MetricsRegistry};
use serde::Serialize;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::RwLock;
use crate::prelude::*;
use crate::repo::Storage;
use crate::room::{RoomHandle, RoomId, RoomInfo, RoomRegistry};
use crate::table::{AnonTable, Key as AnonKey};
@ -208,23 +208,28 @@ pub enum Updates {
#[derive(Clone)]
pub struct PlayerRegistry(Arc<RwLock<PlayerRegistryInner>>);
impl PlayerRegistry {
pub fn empty(room_registry: RoomRegistry, metrics: &mut MetricsRegistry) -> Result<PlayerRegistry> {
pub fn empty(
room_registry: RoomRegistry,
storage: Storage,
metrics: &mut MetricsRegistry,
) -> Result<PlayerRegistry> {
let metric_active_players = IntGauge::new("chat_players_active", "Number of alive player actors")?;
metrics.register(Box::new(metric_active_players.clone()))?;
let inner = PlayerRegistryInner {
room_registry,
storage,
players: HashMap::new(),
metric_active_players,
};
Ok(PlayerRegistry(Arc::new(RwLock::new(inner))))
}
pub async fn get_or_create_player(&mut self, id: PlayerId) -> PlayerHandle {
let mut inner = self.0.write().unwrap();
pub async fn get_or_launch_player(&mut self, id: PlayerId) -> PlayerHandle {
let mut inner = self.0.write().await;
if let Some((handle, _)) = inner.players.get(&id) {
handle.clone()
} else {
let (handle, fiber) = Player::launch(id.clone(), inner.room_registry.clone());
let (handle, fiber) = Player::launch(id.clone(), inner.room_registry.clone(), inner.storage.clone()).await;
inner.players.insert(id, (handle.clone(), fiber));
inner.metric_active_players.inc();
handle
@ -232,12 +237,12 @@ impl PlayerRegistry {
}
pub async fn connect_to_player(&mut self, id: PlayerId) -> PlayerConnection {
let player_handle = self.get_or_create_player(id).await;
let player_handle = self.get_or_launch_player(id).await;
player_handle.subscribe().await
}
pub async fn shutdown_all(&mut self) -> Result<()> {
let mut inner = self.0.write().unwrap();
let mut inner = self.0.write().await;
for (i, (k, j)) in inner.players.drain() {
k.send(ActorCommand::Stop).await;
drop(k);
@ -252,6 +257,8 @@ impl PlayerRegistry {
/// The player registry state representation.
struct PlayerRegistryInner {
room_registry: RoomRegistry,
storage: Storage,
/// Active player actors.
players: HashMap<PlayerId, (PlayerHandle, JoinHandle<Player>)>,
metric_active_players: IntGauge,
}
@ -259,26 +266,31 @@ struct PlayerRegistryInner {
/// Player actor inner state representation.
struct Player {
player_id: PlayerId,
storage_id: u32,
connections: AnonTable<Sender<Updates>>,
my_rooms: HashMap<RoomId, RoomHandle>,
banned_from: HashSet<RoomId>,
rx: Receiver<ActorCommand>,
handle: PlayerHandle,
rooms: RoomRegistry,
storage: Storage,
}
impl Player {
fn launch(player_id: PlayerId, rooms: RoomRegistry) -> (PlayerHandle, JoinHandle<Player>) {
async fn launch(player_id: PlayerId, rooms: RoomRegistry, storage: Storage) -> (PlayerHandle, JoinHandle<Player>) {
let (tx, rx) = channel(32);
let handle = PlayerHandle { tx };
let handle_clone = handle.clone();
let storage_id = storage.retrieve_user_id_by_name(player_id.as_inner()).await.unwrap().unwrap();
let player = Player {
player_id,
storage_id,
connections: AnonTable::new(),
my_rooms: HashMap::new(),
banned_from: HashSet::from([RoomId::from("Empty").unwrap()]),
banned_from: HashSet::new(),
rx,
handle,
rooms,
storage,
};
let fiber = tokio::task::spawn(player.main_loop());
(handle_clone, fiber)
@ -372,6 +384,7 @@ impl Player {
todo!();
}
};
room.add_member(self.player_id.clone(), self.storage_id).await;
room.subscribe(self.player_id.clone(), self.handle.clone()).await;
self.my_rooms.insert(room_id.clone(), room.clone());
let room_info = room.get_room_info().await;
@ -387,6 +400,7 @@ impl Player {
let room = self.my_rooms.remove(&room_id);
if let Some(room) = room {
room.unsubscribe(&self.player_id).await;
room.remove_member(&self.player_id).await;
}
let update = Updates::RoomLeft {
room_id,
@ -396,12 +410,11 @@ impl Player {
}
async fn send_message(&mut self, connection_id: ConnectionId, room_id: RoomId, body: Str) {
let room = self.rooms.get_room(&room_id).await;
if let Some(room) = room {
room.send_message(self.player_id.clone(), body.clone()).await;
} else {
let Some(room) = self.my_rooms.get(&room_id) else {
tracing::info!("no room found");
}
return;
};
room.send_message(self.player_id.clone(), body.clone()).await;
let update = Updates::NewMessage {
room_id,
author_id: self.player_id.clone(),
@ -411,12 +424,11 @@ impl Player {
}
async fn change_topic(&mut self, connection_id: ConnectionId, room_id: RoomId, new_topic: Str) {
let room = self.rooms.get_room(&room_id).await;
if let Some(mut room) = room {
room.set_topic(self.player_id.clone(), new_topic.clone()).await;
} else {
let Some(room) = self.my_rooms.get(&room_id) else {
tracing::info!("no room found");
}
return;
};
room.set_topic(self.player_id.clone(), new_topic.clone()).await;
let update = Updates::RoomTopicChanged { room_id, new_topic };
self.broadcast_update(update, connection_id).await;
}

View File

@ -11,6 +11,9 @@ use tokio::sync::Mutex;
use crate::prelude::*;
mod user;
mod room;
#[derive(Deserialize, Debug, Clone)]
pub struct StorageConfig {
pub db_path: String,

View File

@ -0,0 +1,19 @@
use anyhow::Result;
use crate::repo::Storage;
impl Storage {
pub async fn add_room_member(&mut self, room_id: &u32, player_id: &u32) -> Result<()> {
let mut executor = self.conn.lock().await;
sqlx::query(
"insert into memberships(user_id, room_id, status)
values (?, ?, 1);",
)
.bind(player_id)
.bind(room_id)
.execute(&mut *executor)
.await?;
Ok(())
}
}

View File

@ -0,0 +1,15 @@
use anyhow::Result;
use crate::repo::Storage;
impl Storage {
pub async fn retrieve_user_id_by_name(&self, name: &str) -> Result<Option<u32>> {
let mut executor = self.conn.lock().await;
let res: Option<(u32,)> = sqlx::query_as("select u.id from users u where u.name = ?;")
.bind(name)
.fetch_optional(&mut *executor)
.await?;
Ok(res.map(|(id,)| id))
}
}

View File

@ -1,4 +1,5 @@
//! Domain of rooms — chats with multiple participants.
use std::collections::HashSet;
use std::{collections::HashMap, hash::Hash, sync::Arc};
use prometheus::{IntGauge, Registry as MetricRegistry};
@ -59,6 +60,7 @@ impl RoomRegistry {
storage_id: stored_room.id,
room_id: room_id.clone(),
subscriptions: HashMap::new(), // TODO figure out how to populate subscriptions
members: HashSet::new(), // TODO load members from storage
topic: stored_room.topic.into(),
message_count: stored_room.message_count,
storage: inner.storage.clone(),
@ -76,6 +78,7 @@ impl RoomRegistry {
storage_id: id,
room_id: room_id.clone(),
subscriptions: HashMap::new(),
members: HashSet::new(),
topic: topic.into(),
message_count: 0,
storage: inner.storage.clone(),
@ -118,12 +121,32 @@ pub struct RoomHandle(Arc<AsyncRwLock<Room>>);
impl RoomHandle {
pub async fn subscribe(&self, player_id: PlayerId, player_handle: PlayerHandle) {
let mut lock = self.0.write().await;
lock.add_subscriber(player_id, player_handle).await;
tracing::info!("Adding a subscriber to a room");
lock.subscriptions.insert(player_id.clone(), player_handle);
}
pub async fn add_member(&self, player_id: PlayerId, player_storage_id: u32) {
let mut lock = self.0.write().await;
tracing::info!("Adding a new member to a room");
let storage_id = lock.storage_id;
lock.members.insert(player_id.clone());
lock.storage.add_room_member(&storage_id, &player_storage_id).await.unwrap();
let update = Updates::RoomJoined {
room_id: lock.room_id.clone(),
new_member_id: player_id.clone(),
};
lock.broadcast_update(update, &player_id).await;
}
pub async fn unsubscribe(&self, player_id: &PlayerId) {
let mut lock = self.0.write().await;
lock.subscriptions.remove(player_id);
}
pub async fn remove_member(&self, player_id: &PlayerId) {
let mut lock = self.0.write().await;
tracing::info!("Removing a member from a room");
lock.members.remove(player_id);
let update = Updates::RoomLeft {
room_id: lock.room_id.clone(),
former_member_id: player_id.clone(),
@ -148,7 +171,7 @@ impl RoomHandle {
}
}
pub async fn set_topic(&mut self, changer_id: PlayerId, new_topic: Str) {
pub async fn set_topic(&self, changer_id: PlayerId, new_topic: Str) {
let mut lock = self.0.write().await;
lock.topic = new_topic.clone();
let update = Updates::RoomTopicChanged {
@ -166,22 +189,14 @@ struct Room {
room_id: RoomId,
/// Player actors on the local node which are subscribed to this room's updates.
subscriptions: HashMap<PlayerId, PlayerHandle>,
/// Members of the room.
members: HashSet<PlayerId>,
/// The total number of messages. Used to calculate the id of the new message.
message_count: u32,
topic: Str,
storage: Storage,
}
impl Room {
async fn add_subscriber(&mut self, player_id: PlayerId, player_handle: PlayerHandle) {
tracing::info!("Adding a subscriber to room");
self.subscriptions.insert(player_id.clone(), player_handle);
let update = Updates::RoomJoined {
room_id: self.room_id.clone(),
new_member_id: player_id.clone(),
};
self.broadcast_update(update, &player_id).await;
}
async fn send_message(&mut self, author_id: PlayerId, body: Str) -> Result<()> {
tracing::info!("Adding a message to room");
self.storage.insert_message(self.storage_id, self.message_count, &body, &*author_id.as_inner()).await?;

View File

@ -111,7 +111,36 @@ impl TestServer {
})
.await?;
let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap();
let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap();
let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics).unwrap();
let server = launch(config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await.unwrap();
Ok(TestServer {
metrics,
storage,
rooms,
players,
server,
})
}
async fn reboot(mut self) -> Result<TestServer> {
let config = ServerConfig {
listen_on: "127.0.0.1:0".parse().unwrap(),
server_name: "testserver".into(),
};
let TestServer {
mut metrics,
mut storage,
rooms,
mut players,
server,
} = self;
server.terminate().await?;
players.shutdown_all().await.unwrap();
drop(players);
drop(rooms);
let mut metrics = MetricsRegistry::new();
let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap();
let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics).unwrap();
let server = launch(config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await.unwrap();
Ok(TestServer {
metrics,
@ -152,6 +181,76 @@ async fn scenario_basic() -> Result<()> {
Ok(())
}
#[tokio::test]
async fn scenario_join_and_reboot() -> Result<()> {
let mut server = TestServer::start().await?;
// test scenario
server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
// Open a connection and join a channel
s.send("PASS password").await?;
s.send("NICK tester").await?;
s.send("USER UserName 0 * :Real Name").await?;
s.expect_server_introduction("tester").await?;
s.expect_nothing().await?;
s.send("JOIN #test").await?;
s.expect(":tester JOIN #test").await?;
s.expect(":testserver 332 tester #test :New room").await?;
s.expect(":testserver 353 tester = #test :tester").await?;
s.expect(":testserver 366 tester #test :End of /NAMES list").await?;
s.send("PRIVMSG #test :Hello").await?;
s.send("QUIT :Leaving").await?;
s.expect(":testserver ERROR :Leaving the server").await?;
s.expect_eof().await?;
stream.shutdown().await?;
// Open a new connection and expect to be force-joined to the channel
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
async fn test(s: &mut TestScope<'_>) -> Result<()> {
s.send("PASS password").await?;
s.send("NICK tester").await?;
s.send("USER UserName 0 * :Real Name").await?;
s.expect_server_introduction("tester").await?;
s.expect(":tester JOIN #test").await?;
s.expect(":testserver 332 tester #test :New room").await?;
s.expect(":testserver 353 tester = #test :tester").await?;
s.expect(":testserver 366 tester #test :End of /NAMES list").await?;
s.send("QUIT :Leaving").await?;
s.expect(":testserver ERROR :Leaving the server").await?;
s.expect_eof().await?;
Ok(())
}
test(&mut s).await?;
stream.shutdown().await?;
// Reboot the server
let server = server.reboot().await?;
// Open a new connection and expect to be force-joined to the channel
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
test(&mut s).await?;
stream.shutdown().await?;
// wrap up
server.server.terminate().await?;
Ok(())
}
#[tokio::test]
async fn scenario_force_join_msg() -> Result<()> {
let mut server = TestServer::start().await?;
@ -407,7 +506,6 @@ async fn scenario_cap_sasl_fail() -> Result<()> {
#[tokio::test]
async fn terminate_socket_scenario() -> Result<()> {
let mut server = TestServer::start().await?;
let address: SocketAddr = ("127.0.0.1:0".parse().unwrap());
// test scenario

View File

@ -20,7 +20,7 @@ use tokio_rustls::TlsConnector;
use lavina_core::player::PlayerRegistry;
use lavina_core::repo::{Storage, StorageConfig};
use lavina_core::room::RoomRegistry;
use projection_xmpp::{launch, ServerConfig};
use projection_xmpp::{launch, RunningServer, ServerConfig};
use proto_xmpp::xml::{Continuation, FromXml, Parser};
pub async fn read_irc_message(reader: &mut BufReader<ReadHalf<'_>>, buf: &mut Vec<u8>) -> Result<usize> {
@ -122,9 +122,16 @@ impl ServerCertVerifier for IgnoreCertVerification {
}
}
#[tokio::test]
async fn scenario_basic() -> Result<()> {
tracing_subscriber::fmt::try_init();
struct TestServer {
metrics: MetricsRegistry,
storage: Storage,
rooms: RoomRegistry,
players: PlayerRegistry,
server: RunningServer,
}
impl TestServer {
async fn start() -> Result<TestServer> {
let _ = tracing_subscriber::fmt::try_init();
let config = ServerConfig {
listen_on: "127.0.0.1:0".parse().unwrap(),
cert: "tests/certs/xmpp.pem".parse().unwrap(),
@ -136,15 +143,28 @@ async fn scenario_basic() -> Result<()> {
})
.await?;
let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap();
let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap();
let server = launch(config, players, rooms, metrics, storage.clone()).await.unwrap();
let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics).unwrap();
let server = launch(config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await.unwrap();
Ok(TestServer {
metrics,
storage,
rooms,
players,
server,
})
}
}
#[tokio::test]
async fn scenario_basic() -> Result<()> {
let mut server = TestServer::start().await?;
// test scenario
storage.create_user("tester").await?;
storage.set_password("tester", "password").await?;
server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.addr).await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
tracing::info!("TCP connection established");
@ -169,7 +189,7 @@ async fn scenario_basic() -> Result<()> {
.with_no_client_auth(),
));
tracing::info!("Initiating TLS connection...");
let mut stream = connector.connect(ServerName::IpAddress(server.addr.ip()), stream).await?;
let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?;
tracing::info!("TLS connection established");
let mut s = TestScopeTls::new(&mut stream, buffer);
@ -183,33 +203,20 @@ async fn scenario_basic() -> Result<()> {
// wrap up
server.terminate().await?;
server.server.terminate().await?;
Ok(())
}
#[tokio::test]
async fn scenario_basic_without_headers() -> Result<()> {
tracing_subscriber::fmt::try_init();
let config = ServerConfig {
listen_on: "127.0.0.1:0".parse().unwrap(),
cert: "tests/certs/xmpp.pem".parse().unwrap(),
key: "tests/certs/xmpp.key".parse().unwrap(),
};
let mut metrics = MetricsRegistry::new();
let mut storage = Storage::open(StorageConfig {
db_path: ":memory:".into(),
})
.await?;
let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap();
let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap();
let server = launch(config, players, rooms, metrics, storage.clone()).await.unwrap();
let mut server = TestServer::start().await?;
// test scenario
storage.create_user("tester").await?;
storage.set_password("tester", "password").await?;
server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.addr).await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
tracing::info!("TCP connection established");
@ -233,7 +240,7 @@ async fn scenario_basic_without_headers() -> Result<()> {
.with_no_client_auth(),
));
tracing::info!("Initiating TLS connection...");
let mut stream = connector.connect(ServerName::IpAddress(server.addr.ip()), stream).await?;
let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?;
tracing::info!("TLS connection established");
let mut s = TestScopeTls::new(&mut stream, buffer);
@ -246,33 +253,20 @@ async fn scenario_basic_without_headers() -> Result<()> {
// wrap up
server.terminate().await?;
server.server.terminate().await?;
Ok(())
}
#[tokio::test]
async fn terminate_socket() -> Result<()> {
tracing_subscriber::fmt::try_init();
let config = ServerConfig {
listen_on: "127.0.0.1:0".parse().unwrap(),
cert: "tests/certs/xmpp.pem".parse().unwrap(),
key: "tests/certs/xmpp.key".parse().unwrap(),
};
let mut metrics = MetricsRegistry::new();
let mut storage = Storage::open(StorageConfig {
db_path: ":memory:".into(),
})
.await?;
let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap();
let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap();
let server = launch(config, players, rooms, metrics, storage.clone()).await.unwrap();
let address: SocketAddr = ("127.0.0.1:0".parse().unwrap());
let mut server = TestServer::start().await?;
// test scenario
storage.create_user("tester").await?;
storage.set_password("tester", "password").await?;
server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.addr).await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
tracing::info!("TCP connection established");
@ -298,10 +292,10 @@ async fn terminate_socket() -> Result<()> {
));
tracing::info!("Initiating TLS connection...");
let mut stream = connector.connect(ServerName::IpAddress(server.addr.ip()), stream).await?;
let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?;
tracing::info!("TLS connection established");
server.terminate().await?;
server.server.terminate().await?;
assert_eq!(stream.read_u8().await.unwrap_err().kind(), ErrorKind::UnexpectedEof);

View File

@ -52,7 +52,7 @@ async fn main() -> Result<()> {
let mut metrics = MetricsRegistry::new();
let storage = Storage::open(storage_config).await?;
let rooms = RoomRegistry::new(&mut metrics, storage.clone())?;
let mut players = PlayerRegistry::empty(rooms.clone(), &mut metrics)?;
let mut players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics)?;
let telemetry_terminator = http::launch(telemetry_config, metrics.clone(), rooms.clone(), storage.clone()).await?;
let irc = projection_irc::launch(
irc_config,