forked from lavina/lavina
1
0
Fork 0
This commit is contained in:
Nikita Vilunov 2024-05-10 15:12:33 +02:00
parent 23a59bc303
commit cb889193c7
10 changed files with 225 additions and 156 deletions

2
.gitignore vendored
View File

@ -1,4 +1,4 @@
/target
/db.sqlite
*.sqlite
.idea/
.DS_Store

View File

@ -7,6 +7,7 @@ use std::net::SocketAddr;
use std::sync::Arc;
pub mod broadcast;
pub mod room;
type Addresses = Vec<SocketAddr>;
@ -33,35 +34,6 @@ pub struct LavinaClient {
client: ClientWithMiddleware,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct SendMessageReq<'a> {
pub room_id: &'a str,
pub player_id: &'a str,
pub message: &'a str,
pub created_at: &'a str,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct BroadcastMessageReq<'a> {
pub room_id: &'a str,
pub author_id: &'a str,
pub message: &'a str,
pub created_at: &'a str,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct SetRoomTopicReq<'a> {
pub room_id: &'a str,
pub player_id: &'a str,
pub topic: &'a str,
}
pub mod paths {
pub const ADD_MESSAGE: &'static str = "/cluster/rooms/add_message";
pub const BROADCAST_MESSAGE: &'static str = "/cluster/rooms/broadcast_message";
pub const SET_TOPIC: &'static str = "/cluster/rooms/set_topic";
}
impl LavinaClient {
pub fn new(addresses: Addresses) -> Self {
let client = ClientBuilder::new(Client::new()).with(TracingMiddleware::<DefaultSpanBackend>::new()).build();
@ -71,59 +43,13 @@ impl LavinaClient {
}
}
#[tracing::instrument(skip(self, req), name = "LavinaClient::send_room_message")]
pub async fn send_room_message(&self, node_id: u32, req: SendMessageReq<'_>) -> Result<()> {
tracing::info!("Sending a message to a room on a remote node");
async fn send_request(&self, node_id: u32, path: &str, req: impl Serialize) -> Result<()> {
let Some(address) = self.addresses.get(node_id as usize) else {
tracing::error!("Failed");
return Err(anyhow!("Unknown node"));
};
match self.client.post(format!("http://{}{}", address, paths::BROADCAST_MESSAGE)).json(&req).send().await {
Ok(_) => {
tracing::info!("Message sent");
Ok(())
}
Err(e) => {
tracing::error!("Failed to send message: {e:?}");
Err(e.into())
}
}
}
#[tracing::instrument(skip(self, req), name = "LavinaClient::broadcast_room_message")]
pub async fn broadcast_room_message(&self, node_id: u32, req: BroadcastMessageReq<'_>) -> Result<()> {
tracing::info!("Broadcasting a message to a room on a remote node");
let Some(address) = self.addresses.get(node_id as usize) else {
tracing::error!("Failed");
return Err(anyhow!("Unknown node"));
};
match self.client.post(format!("http://{}{}", address, paths::BROADCAST_MESSAGE)).json(&req).send().await {
Ok(_) => {
tracing::info!("Message broadcasted");
Ok(())
}
Err(e) => {
tracing::error!("Failed to broadcast message: {e:?}");
Err(e.into())
}
}
}
pub async fn set_room_topic(&self, node_id: u32, req: SetRoomTopicReq<'_>) -> Result<()> {
tracing::info!("Setting the topic of a room on a remote node");
let Some(address) = self.addresses.get(node_id as usize) else {
tracing::error!("Failed");
return Err(anyhow!("Unknown node"));
};
match self.client.post(format!("http://{}{}", address, paths::SET_TOPIC)).json(&req).send().await {
Ok(_) => {
tracing::info!("Room topic set");
Ok(())
}
Err(e) => {
tracing::error!("Failed to set room topic: {e:?}");
Err(e.into())
}
match self.client.post(format!("http://{}{}", address, path)).json(&req).send().await {
Ok(_) => Ok(()),
Err(e) => Err(e.into()),
}
}
}

View File

@ -45,6 +45,9 @@ impl Broadcasting {
created_at: created_at.clone(),
};
for i in subscribers {
if i == &author_id {
continue;
}
let Some(player) = players.get_player(i).await else {
continue;
};

View File

@ -0,0 +1,59 @@
use serde::{Deserialize, Serialize};
use crate::clustering::LavinaClient;
pub mod paths {
pub const JOIN: &'static str = "/cluster/rooms/join";
pub const LEAVE: &'static str = "/cluster/rooms/leave";
pub const ADD_MESSAGE: &'static str = "/cluster/rooms/add_message";
pub const SET_TOPIC: &'static str = "/cluster/rooms/set_topic";
}
#[derive(Serialize, Deserialize, Debug)]
pub struct JoinRoomReq<'a> {
pub room_id: &'a str,
pub player_id: &'a str,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct LeaveRoomReq<'a> {
pub room_id: &'a str,
pub player_id: &'a str,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct SendMessageReq<'a> {
pub room_id: &'a str,
pub player_id: &'a str,
pub message: &'a str,
pub created_at: &'a str,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct SetRoomTopicReq<'a> {
pub room_id: &'a str,
pub player_id: &'a str,
pub topic: &'a str,
}
impl LavinaClient {
#[tracing::instrument(skip(self, req), name = "LavinaClient::join_room")]
pub async fn join_room(&self, node_id: u32, req: JoinRoomReq<'_>) -> anyhow::Result<()> {
self.send_request(node_id, paths::JOIN, req).await
}
#[tracing::instrument(skip(self, req), name = "LavinaClient::leave_room")]
pub async fn leave_room(&self, node_id: u32, req: LeaveRoomReq<'_>) -> anyhow::Result<()> {
self.send_request(node_id, paths::LEAVE, req).await
}
#[tracing::instrument(skip(self, req), name = "LavinaClient::send_room_message")]
pub async fn send_room_message(&self, node_id: u32, req: SendMessageReq<'_>) -> anyhow::Result<()> {
self.send_request(node_id, paths::ADD_MESSAGE, req).await
}
#[tracing::instrument(skip(self, req), name = "LavinaClient::set_room_topic")]
pub async fn set_room_topic(&self, node_id: u32, req: SetRoomTopicReq<'_>) -> anyhow::Result<()> {
self.send_request(node_id, paths::SET_TOPIC, req).await
}
}

View File

@ -19,7 +19,8 @@ use tokio::sync::RwLock;
use tracing::{Instrument, Span};
use crate::clustering::broadcast::Broadcasting;
use crate::clustering::{ClusterMetadata, LavinaClient, SendMessageReq, SetRoomTopicReq};
use crate::clustering::room::*;
use crate::clustering::{ClusterMetadata, LavinaClient};
use crate::dialog::DialogRegistry;
use crate::prelude::*;
use crate::repo::Storage;
@ -336,6 +337,7 @@ impl PlayerRegistry {
} else {
let (handle, fiber) = Player::launch(
id.clone(),
self.clone(),
inner.room_registry.clone(),
inner.dialogs.clone(),
inner.cluster_metadata.clone(),
@ -397,6 +399,7 @@ struct Player {
banned_from: HashSet<RoomId>,
rx: Receiver<(ActorCommand, Span)>,
handle: PlayerHandle,
players: PlayerRegistry,
rooms: RoomRegistry,
dialogs: DialogRegistry,
storage: Storage,
@ -407,6 +410,7 @@ struct Player {
impl Player {
async fn launch(
player_id: PlayerId,
players: PlayerRegistry,
rooms: RoomRegistry,
dialogs: DialogRegistry,
cluster_metadata: Arc<ClusterMetadata>,
@ -429,6 +433,7 @@ impl Player {
banned_from: HashSet::new(),
rx,
handle,
players,
rooms,
dialogs,
storage,
@ -582,7 +587,19 @@ impl Player {
}
if let Some(remote_node) = self.room_location(&room_id) {
todo!()
let req = JoinRoomReq {
room_id: room_id.as_inner(),
player_id: self.player_id.as_inner(),
};
self.cluster_client.join_room(remote_node, req).await.unwrap();
let room_storage_id = self.storage.create_or_retrieve_room_id_by_name(room_id.as_inner()).await.unwrap();
self.storage.add_room_member(room_storage_id, self.storage_id).await.unwrap();
self.my_rooms.insert(room_id.clone(), RoomRef::Remote { node_id: remote_node });
JoinResult::Success(RoomInfo {
id: room_id,
topic: "unknown".into(),
members: vec![],
})
} else {
let room = match self.rooms.get_or_create_room(room_id.clone()).await {
Ok(room) => room,
@ -608,9 +625,22 @@ impl Player {
async fn leave_room(&mut self, connection_id: ConnectionId, room_id: RoomId) {
let room = self.my_rooms.remove(&room_id);
if let Some(room) = room {
panic!();
// room.unsubscribe(&self.player_id).await;
// room.remove_member(&self.player_id, self.storage_id).await;
match room {
RoomRef::Local(room) => {
room.unsubscribe(&self.player_id).await;
room.remove_member(&self.player_id, self.storage_id).await;
let room_storage_id =
self.storage.create_or_retrieve_room_id_by_name(room_id.as_inner()).await.unwrap();
self.storage.remove_room_member(room_storage_id, self.storage_id).await.unwrap();
}
RoomRef::Remote { node_id } => {
let req = LeaveRoomReq {
room_id: room_id.as_inner(),
player_id: self.player_id.as_inner(),
};
self.cluster_client.leave_room(node_id, req).await.unwrap();
}
}
}
let update = Updates::RoomLeft {
room_id,
@ -643,6 +673,15 @@ impl Player {
created_at: &*created_at.to_rfc3339(),
};
self.cluster_client.send_room_message(*node_id, req).await.unwrap();
self.broadcasting
.broadcast(
&self.players,
room_id.clone(),
self.player_id.clone(),
body.clone(),
created_at.clone(),
)
.await;
}
}
let update = Updates::NewMessage {

View File

@ -48,4 +48,19 @@ impl Storage {
Ok(())
}
pub async fn create_or_retrieve_room_id_by_name(&self, name: &str) -> Result<u32> {
let mut executor = self.conn.lock().await;
let res: (u32,) = sqlx::query_as(
"insert into rooms(name, topic)
values (?, '')
on conflict(name) do nothing
returning id;",
)
.bind(name)
.fetch_one(&mut *executor)
.await?;
Ok(res.0)
}
}

View File

@ -14,6 +14,21 @@ impl Storage {
Ok(res.map(|(id,)| id))
}
pub async fn create_or_retrieve_user_id_by_name(&self, name: &str) -> Result<u32> {
let mut executor = self.conn.lock().await;
let res: (u32,) = sqlx::query_as(
"insert into users(name)
values (?)
on conflict(name) do update set name = excluded.name
returning id;",
)
.bind(name)
.fetch_one(&mut *executor)
.await?;
Ok(res.0)
}
pub async fn get_rooms_of_a_user(&self, user_id: u32) -> Result<Vec<RoomId>> {
let mut executor = self.conn.lock().await;
let res: Vec<(String,)> = sqlx::query_as(

View File

@ -60,7 +60,7 @@ impl RoomRegistry {
}
#[tracing::instrument(skip(self), name = "RoomRegistry::get_or_create_room")]
pub async fn get_or_create_room(&mut self, room_id: RoomId) -> Result<RoomHandle> {
pub async fn get_or_create_room(&self, room_id: RoomId) -> Result<RoomHandle> {
let mut inner = self.0.write().await;
if let Some(room_handle) = inner.get_or_load_room(&room_id).await? {
Ok(room_handle.clone())

View File

@ -13,16 +13,16 @@ use serde::{Deserialize, Serialize};
use tokio::net::TcpListener;
use lavina_core::auth::UpdatePasswordResult;
use lavina_core::clustering::{BroadcastMessageReq, SendMessageReq};
use lavina_core::player::{PlayerId, PlayerRegistry, SendMessageResult};
use lavina_core::prelude::*;
use lavina_core::repo::Storage;
use lavina_core::room::{RoomId, RoomRegistry};
use lavina_core::terminator::Terminator;
use lavina_core::{clustering, LavinaCore};
use lavina_core::LavinaCore;
use mgmt_api::*;
mod clustering;
type HttpResult<T> = std::result::Result<T, Infallible>;
#[derive(Deserialize, Debug)]
@ -91,11 +91,7 @@ async fn route(
(&Method::POST, paths::SET_PASSWORD) => endpoint_set_password(request, core).await.or5xx(),
(&Method::POST, rooms::paths::SEND_MESSAGE) => endpoint_send_room_message(request, core).await.or5xx(),
(&Method::POST, rooms::paths::SET_TOPIC) => endpoint_set_room_topic(request, core).await.or5xx(),
(&Method::POST, clustering::paths::ADD_MESSAGE) => endpoint_cluster_add_message(request, core).await.or5xx(),
(&Method::POST, clustering::paths::BROADCAST_MESSAGE) => {
endpoint_cluster_broadcast_message(request, core).await.or5xx()
}
_ => endpoint_not_found(),
_ => clustering::route(core, storage, request).await.unwrap_or_else(endpoint_not_found),
};
Ok(res)
}
@ -211,68 +207,6 @@ async fn endpoint_set_room_topic(
Ok(empty_204_request())
}
#[tracing::instrument(skip_all, name = "endpoint_cluster_add_message")]
async fn endpoint_cluster_add_message(
request: Request<hyper::body::Incoming>,
core: &LavinaCore,
) -> Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes();
let Ok(req) = serde_json::from_slice::<SendMessageReq>(&str[..]) else {
return Ok(malformed_request());
};
tracing::info!("Incoming request: {:?}", &req);
let Ok(created_at) = chrono::DateTime::parse_from_rfc3339(req.created_at) else {
dbg!(&req.created_at);
return Ok(malformed_request());
};
let Ok(room_id) = RoomId::from(req.room_id) else {
dbg!(&req.room_id);
return Ok(room_not_found());
};
let Ok(player_id) = PlayerId::from(req.player_id) else {
dbg!(&req.player_id);
return Ok(player_not_found());
};
let Some(room_handle) = core.rooms.get_room(&room_id).await else {
dbg!(&room_id);
return Ok(room_not_found());
};
room_handle.send_message(&player_id, req.message.into(), created_at.to_utc()).await;
Ok(empty_204_request())
}
#[tracing::instrument(skip_all, name = "endpoint_cluster_broadcast_message")]
async fn endpoint_cluster_broadcast_message(
request: Request<hyper::body::Incoming>,
core: &LavinaCore,
) -> Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes();
let Ok(req) = serde_json::from_slice::<BroadcastMessageReq>(&str[..]) else {
return Ok(malformed_request());
};
let Ok(created_at) = chrono::DateTime::parse_from_rfc3339(req.created_at) else {
return Ok(malformed_request());
};
let Ok(room_id) = RoomId::from(req.room_id) else {
return Ok(room_not_found());
};
let Ok(author_id) = PlayerId::from(req.author_id) else {
return Ok(player_not_found());
};
let broadcasting = core.broadcasting.0.lock().await;
broadcasting
.broadcast(
&core.players,
room_id,
author_id,
req.message.into(),
created_at.to_utc(),
)
.await;
drop(broadcasting);
Ok(empty_204_request())
}
fn endpoint_not_found() -> Response<Full<Bytes>> {
let payload = ErrorResponse {
code: errors::INVALID_PATH,

78
src/http/clustering.rs Normal file
View File

@ -0,0 +1,78 @@
use http_body_util::{BodyExt, Full};
use hyper::body::Bytes;
use hyper::{Method, Request, Response};
use super::Or5xx;
use crate::http::{empty_204_request, malformed_request, player_not_found, room_not_found};
use lavina_core::clustering::room::{paths, JoinRoomReq, SendMessageReq};
use lavina_core::player::PlayerId;
use lavina_core::repo::Storage;
use lavina_core::room::RoomId;
use lavina_core::LavinaCore;
pub async fn route(
core: &LavinaCore,
storage: &Storage,
request: Request<hyper::body::Incoming>,
) -> Option<Response<Full<Bytes>>> {
match (request.method(), request.uri().path()) {
(&Method::POST, paths::JOIN) => Some(endpoint_cluster_join_room(request, core, storage).await.or5xx()),
(&Method::POST, paths::ADD_MESSAGE) => Some(endpoint_cluster_add_message(request, core).await.or5xx()),
_ => None,
}
}
#[tracing::instrument(skip_all, name = "endpoint_cluster_join_room")]
async fn endpoint_cluster_join_room(
request: Request<hyper::body::Incoming>,
core: &LavinaCore,
storage: &Storage,
) -> lavina_core::prelude::Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes();
let Ok(req) = serde_json::from_slice::<JoinRoomReq>(&str[..]) else {
return Ok(malformed_request());
};
tracing::info!("Incoming request: {:?}", &req);
let Ok(room_id) = RoomId::from(req.room_id) else {
dbg!(&req.room_id);
return Ok(room_not_found());
};
let Ok(player_id) = PlayerId::from(req.player_id) else {
dbg!(&req.player_id);
return Ok(player_not_found());
};
let room_handle = core.rooms.get_or_create_room(room_id).await.unwrap();
let storage_id = storage.create_or_retrieve_user_id_by_name(req.player_id).await?;
room_handle.add_member(&player_id, storage_id).await;
Ok(empty_204_request())
}
#[tracing::instrument(skip_all, name = "endpoint_cluster_add_message")]
async fn endpoint_cluster_add_message(
request: Request<hyper::body::Incoming>,
core: &LavinaCore,
) -> lavina_core::prelude::Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes();
let Ok(req) = serde_json::from_slice::<SendMessageReq>(&str[..]) else {
return Ok(malformed_request());
};
tracing::info!("Incoming request: {:?}", &req);
let Ok(created_at) = chrono::DateTime::parse_from_rfc3339(req.created_at) else {
dbg!(&req.created_at);
return Ok(malformed_request());
};
let Ok(room_id) = RoomId::from(req.room_id) else {
dbg!(&req.room_id);
return Ok(room_not_found());
};
let Ok(player_id) = PlayerId::from(req.player_id) else {
dbg!(&req.player_id);
return Ok(player_not_found());
};
let Some(room_handle) = core.rooms.get_room(&room_id).await else {
dbg!(&room_id);
return Ok(room_not_found());
};
room_handle.send_message(&player_id, req.message.into(), created_at.to_utc()).await;
Ok(empty_204_request())
}