diff --git a/crates/lavina-core/src/clustering.rs b/crates/lavina-core/src/clustering.rs index 9fffcaf..a5468cb 100644 --- a/crates/lavina-core/src/clustering.rs +++ b/crates/lavina-core/src/clustering.rs @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize}; use std::net::SocketAddr; use std::sync::Arc; -mod broadcast; +pub mod broadcast; type Addresses = Vec; @@ -41,6 +41,14 @@ pub struct SendMessageReq<'a> { pub created_at: &'a str, } +#[derive(Serialize, Deserialize, Debug)] +pub struct BroadcastMessageReq<'a> { + pub room_id: &'a str, + pub author_id: &'a str, + pub message: &'a str, + pub created_at: &'a str, +} + #[derive(Serialize, Deserialize, Debug)] pub struct SetRoomTopicReq<'a> { pub room_id: &'a str, @@ -48,6 +56,12 @@ pub struct SetRoomTopicReq<'a> { pub topic: &'a str, } +pub mod paths { + pub const ADD_MESSAGE: &'static str = "/cluster/rooms/add_message"; + pub const BROADCAST_MESSAGE: &'static str = "/cluster/rooms/broadcast_message"; + pub const SET_TOPIC: &'static str = "/cluster/rooms/set_topic"; +} + impl LavinaClient { pub fn new(addresses: Addresses) -> Self { let client = ClientBuilder::new(Client::new()).with(TracingMiddleware::::new()).build(); @@ -64,7 +78,7 @@ impl LavinaClient { tracing::error!("Failed"); return Err(anyhow!("Unknown node")); }; - match self.client.post(format!("http://{}/cluster/rooms/add_message", address)).json(&req).send().await { + match self.client.post(format!("http://{}{}", address, paths::BROADCAST_MESSAGE)).json(&req).send().await { Ok(_) => { tracing::info!("Message sent"); Ok(()) @@ -76,13 +90,32 @@ impl LavinaClient { } } + #[tracing::instrument(skip(self, req), name = "LavinaClient::broadcast_room_message")] + pub async fn broadcast_room_message(&self, node_id: u32, req: BroadcastMessageReq<'_>) -> Result<()> { + tracing::info!("Broadcasting a message to a room on a remote node"); + let Some(address) = self.addresses.get(node_id as usize) else { + tracing::error!("Failed"); + return Err(anyhow!("Unknown node")); + }; + match self.client.post(format!("http://{}{}", address, paths::BROADCAST_MESSAGE)).json(&req).send().await { + Ok(_) => { + tracing::info!("Message broadcasted"); + Ok(()) + } + Err(e) => { + tracing::error!("Failed to broadcast message: {e:?}"); + Err(e.into()) + } + } + } + pub async fn set_room_topic(&self, node_id: u32, req: SetRoomTopicReq<'_>) -> Result<()> { tracing::info!("Setting the topic of a room on a remote node"); let Some(address) = self.addresses.get(node_id as usize) else { tracing::error!("Failed"); return Err(anyhow!("Unknown node")); }; - match self.client.post(format!("http://{}/cluster/rooms/set_topic", address)).json(&req).send().await { + match self.client.post(format!("http://{}{}", address, paths::SET_TOPIC)).json(&req).send().await { Ok(_) => { tracing::info!("Room topic set"); Ok(()) diff --git a/crates/lavina-core/src/clustering/broadcast.rs b/crates/lavina-core/src/clustering/broadcast.rs index 470db85..456f820 100644 --- a/crates/lavina-core/src/clustering/broadcast.rs +++ b/crates/lavina-core/src/clustering/broadcast.rs @@ -1,29 +1,58 @@ use std::collections::{HashMap, HashSet}; -use crate::player::PlayerId; +use std::sync::Arc; + +use chrono::{DateTime, Utc}; +use tokio::sync::Mutex; + +use crate::player::{PlayerId, PlayerRegistry, Updates}; use crate::prelude::Str; use crate::room::RoomId; /// Receives updates from other nodes and broadcasts them to local player actors. -struct Broadcasting { +struct BroadcastingInner { subscriptions: HashMap>, } +impl Broadcasting {} + +#[derive(Clone)] +pub struct Broadcasting(Arc>); impl Broadcasting { - /// Creates a new broadcasting instance. pub fn new() -> Self { - Self { + let inner = BroadcastingInner { subscriptions: HashMap::new(), + }; + Self(Arc::new(Mutex::new(inner))) + } + + /// Broadcasts the given update to subscribed player actors on local node. + pub async fn broadcast( + &self, + players: &PlayerRegistry, + room_id: RoomId, + author_id: PlayerId, + message: Str, + created_at: DateTime, + ) { + let inner = self.0.lock().await; + let Some(subscribers) = inner.subscriptions.get(&room_id) else { + return; + }; + let update = Updates::NewMessage { + room_id: room_id.clone(), + author_id: author_id.clone(), + body: message.clone(), + created_at: created_at.clone(), + }; + for i in subscribers { + let Some(player) = players.get_player(i).await else { + continue; + }; + player.update(update.clone()).await; } } - /// Broadcasts the given update to player actors. - pub fn broadcast(&self, room_id: RoomId, author_id: PlayerId, message: Str) { - self.subscriptions.get(&room_id).map(|players| { - players.iter().for_each(|player_id| { - // Send the message to the player actor. - }); - }); + pub async fn subscribe(&self, subscriber: PlayerId, room_id: RoomId) { + self.0.lock().await.subscriptions.entry(room_id).or_insert_with(HashSet::new).insert(subscriber); } - - -} \ No newline at end of file +} diff --git a/crates/lavina-core/src/lib.rs b/crates/lavina-core/src/lib.rs index 6e42ede..dfc62db 100644 --- a/crates/lavina-core/src/lib.rs +++ b/crates/lavina-core/src/lib.rs @@ -3,8 +3,10 @@ use crate::clustering::{ClusterConfig, LavinaClient}; use anyhow::Result; use prometheus::Registry as MetricsRegistry; use std::sync::Arc; +use tokio::sync::Mutex; use crate::auth::Authenticator; +use crate::clustering::broadcast::Broadcasting; use crate::dialog::DialogRegistry; use crate::player::PlayerRegistry; use crate::repo::Storage; @@ -26,6 +28,7 @@ pub struct LavinaCore { pub players: PlayerRegistry, pub rooms: RoomRegistry, pub dialogs: DialogRegistry, + pub broadcasting: Broadcasting, pub authenticator: Authenticator, } @@ -36,6 +39,7 @@ impl LavinaCore { storage: Storage, ) -> Result { // TODO shutdown all services in reverse order on error + let broadcasting = Broadcasting::new(); let client = LavinaClient::new(cluster_config.addresses.clone()); let rooms = RoomRegistry::new(&mut metrics, storage.clone())?; let dialogs = DialogRegistry::new(storage.clone()); @@ -46,6 +50,7 @@ impl LavinaCore { &mut metrics, Arc::new(cluster_config.metadata), client, + broadcasting.clone(), )?; dialogs.set_players(players.clone()).await; let authenticator = Authenticator::new(storage.clone()); @@ -53,6 +58,7 @@ impl LavinaCore { players, rooms, dialogs, + broadcasting, authenticator, }) } diff --git a/crates/lavina-core/src/player.rs b/crates/lavina-core/src/player.rs index a35a263..07fdfd5 100644 --- a/crates/lavina-core/src/player.rs +++ b/crates/lavina-core/src/player.rs @@ -18,6 +18,7 @@ use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::RwLock; use tracing::{Instrument, Span}; +use crate::clustering::broadcast::Broadcasting; use crate::clustering::{ClusterMetadata, LavinaClient, SendMessageReq, SetRoomTopicReq}; use crate::dialog::DialogRegistry; use crate::prelude::*; @@ -275,6 +276,7 @@ impl PlayerRegistry { metrics: &mut MetricsRegistry, cluster_metadata: Arc, cluster_client: LavinaClient, + broadcasting: Broadcasting, ) -> Result { let metric_active_players = IntGauge::new("chat_players_active", "Number of alive player actors")?; metrics.register(Box::new(metric_active_players.clone()))?; @@ -284,6 +286,7 @@ impl PlayerRegistry { storage, cluster_metadata, cluster_client, + broadcasting, players: HashMap::new(), metric_active_players, }; @@ -337,6 +340,7 @@ impl PlayerRegistry { inner.dialogs.clone(), inner.cluster_metadata.clone(), inner.cluster_client.clone(), + inner.broadcasting.clone(), inner.storage.clone(), ) .await; @@ -373,6 +377,7 @@ struct PlayerRegistryInner { storage: Storage, cluster_metadata: Arc, cluster_client: LavinaClient, + broadcasting: Broadcasting, /// Active player actors. players: HashMap)>, metric_active_players: IntGauge, @@ -397,6 +402,7 @@ struct Player { storage: Storage, cluster_metadata: Arc, cluster_client: LavinaClient, + broadcasting: Broadcasting, } impl Player { async fn launch( @@ -405,6 +411,7 @@ impl Player { dialogs: DialogRegistry, cluster_metadata: Arc, cluster_client: LavinaClient, + broadcasting: Broadcasting, storage: Storage, ) -> (PlayerHandle, JoinHandle) { let (tx, rx) = channel(32); @@ -427,6 +434,7 @@ impl Player { storage, cluster_metadata, cluster_client, + broadcasting, }; let fiber = tokio::task::spawn(player.main_loop()); (handle_clone, fiber) @@ -449,7 +457,8 @@ impl Player { let rooms = self.storage.get_rooms_of_a_user(self.storage_id).await.unwrap(); for room_id in rooms { if let Some(remote_node) = self.room_location(&room_id) { - self.my_rooms.insert(room_id, RoomRef::Remote { node_id: remote_node }); + self.my_rooms.insert(room_id.clone(), RoomRef::Remote { node_id: remote_node }); + self.broadcasting.subscribe(self.player_id.clone(), room_id).await; } else { let room = self.rooms.get_room(&room_id).await; if let Some(room) = room { diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index 583b18f..4ac98c0 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -9,6 +9,7 @@ use tokio::net::tcp::{ReadHalf, WriteHalf}; use tokio::net::TcpStream; use lavina_core::auth::Authenticator; +use lavina_core::clustering::{ClusterConfig, ClusterMetadata}; use lavina_core::player::{JoinResult, PlayerId, SendMessageResult}; use lavina_core::repo::{Storage, StorageConfig}; use lavina_core::room::RoomId; @@ -118,7 +119,16 @@ impl TestServer { db_path: ":memory:".into(), }) .await?; - let core = LavinaCore::new(metrics.clone(), storage.clone()).await?; + let cluster_config = ClusterConfig { + addresses: vec![], + metadata: ClusterMetadata { + node_id: 0, + main_owner: 0, + test_owner: 0, + test2_owner: 0, + }, + }; + let core = LavinaCore::new(metrics.clone(), cluster_config, storage.clone()).await?; let server = launch(config, core.clone(), metrics.clone()).await.unwrap(); Ok(TestServer { metrics, @@ -133,6 +143,15 @@ impl TestServer { listen_on: "127.0.0.1:0".parse().unwrap(), server_name: "testserver".into(), }; + let cluster_config = ClusterConfig { + addresses: vec![], + metadata: ClusterMetadata { + node_id: 0, + main_owner: 0, + test_owner: 0, + test2_owner: 0, + }, + }; let TestServer { metrics: _, storage, @@ -142,7 +161,7 @@ impl TestServer { server.terminate().await?; core.shutdown().await?; let metrics = MetricsRegistry::new(); - let core = LavinaCore::new(metrics.clone(), storage.clone()).await?; + let core = LavinaCore::new(metrics.clone(), cluster_config, storage.clone()).await?; let server = launch(config, core.clone(), metrics.clone()).await.unwrap(); Ok(TestServer { metrics, diff --git a/src/http.rs b/src/http.rs index f367e51..633ae14 100644 --- a/src/http.rs +++ b/src/http.rs @@ -13,13 +13,13 @@ use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; use lavina_core::auth::UpdatePasswordResult; -use lavina_core::clustering::SendMessageReq; +use lavina_core::clustering::{BroadcastMessageReq, SendMessageReq}; use lavina_core::player::{PlayerId, PlayerRegistry, SendMessageResult}; use lavina_core::prelude::*; use lavina_core::repo::Storage; use lavina_core::room::{RoomId, RoomRegistry}; use lavina_core::terminator::Terminator; -use lavina_core::LavinaCore; +use lavina_core::{clustering, LavinaCore}; use mgmt_api::*; @@ -91,7 +91,10 @@ async fn route( (&Method::POST, paths::SET_PASSWORD) => endpoint_set_password(request, core).await.or5xx(), (&Method::POST, rooms::paths::SEND_MESSAGE) => endpoint_send_room_message(request, core).await.or5xx(), (&Method::POST, rooms::paths::SET_TOPIC) => endpoint_set_room_topic(request, core).await.or5xx(), - (&Method::POST, "/cluster/rooms/add_message") => endpoint_cluster_add_message(request, core).await.or5xx(), + (&Method::POST, clustering::paths::ADD_MESSAGE) => endpoint_cluster_add_message(request, core).await.or5xx(), + (&Method::POST, clustering::paths::BROADCAST_MESSAGE) => { + endpoint_cluster_broadcast_message(request, core).await.or5xx() + } _ => endpoint_not_found(), }; Ok(res) @@ -238,6 +241,38 @@ async fn endpoint_cluster_add_message( Ok(empty_204_request()) } +#[tracing::instrument(skip_all, name = "endpoint_cluster_broadcast_message")] +async fn endpoint_cluster_broadcast_message( + request: Request, + core: &LavinaCore, +) -> Result>> { + let str = request.collect().await?.to_bytes(); + let Ok(req) = serde_json::from_slice::(&str[..]) else { + return Ok(malformed_request()); + }; + let Ok(created_at) = chrono::DateTime::parse_from_rfc3339(req.created_at) else { + return Ok(malformed_request()); + }; + let Ok(room_id) = RoomId::from(req.room_id) else { + return Ok(room_not_found()); + }; + let Ok(author_id) = PlayerId::from(req.author_id) else { + return Ok(player_not_found()); + }; + let broadcasting = core.broadcasting.0.lock().await; + broadcasting + .broadcast( + &core.players, + room_id, + author_id, + req.message.into(), + created_at.to_utc(), + ) + .await; + drop(broadcasting); + Ok(empty_204_request()) +} + fn endpoint_not_found() -> Response> { let payload = ErrorResponse { code: errors::INVALID_PATH,