diff --git a/Cargo.toml b/Cargo.toml index 0fdcccf..b7b457d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ tracing-subscriber = "0.3.16" tokio-tungstenite = "0.18.0" futures-util = "0.3.25" prometheus = { version = "0.13.3", default_features = false } +regex = "1.7.1" [dev-dependencies] regex = "1.7.1" diff --git a/src/chat.rs b/src/chat.rs deleted file mode 100644 index 926f651..0000000 --- a/src/chat.rs +++ /dev/null @@ -1,58 +0,0 @@ -use std::{ - collections::HashMap, - sync::{Arc, RwLock}, -}; - -use tokio::sync::mpsc::{channel, Receiver, Sender}; - -use crate::prelude::*; - -pub type UserId = u64; - -#[derive(Clone)] -pub struct Chats { - inner: Arc>, -} - -impl Chats { - pub fn new() -> Chats { - let subscriptions = HashMap::new(); - let chats_inner = ChatsInner { - subscriptions, - next_sub: 0, - }; - let inner = Arc::new(RwLock::new(chats_inner)); - Chats { inner } - } - - pub fn new_sub(&self) -> (UserId, Receiver) { - let mut inner = self.inner.write().unwrap(); - let sub_id = inner.next_sub; - inner.next_sub += 1; - let (rx, tx) = channel(32); - inner.subscriptions.insert(sub_id, rx); - (sub_id, tx) - } - - pub async fn broadcast(&self, msg: &str) -> Result<()> { - let subs = { - let inner = self.inner.read().unwrap(); - inner.subscriptions.clone() - }; - for (sub_id, tx) in subs.iter() { - tx.send(msg.to_string()).await?; - tracing::info!("Sent message to {}", sub_id); - } - Ok(()) - } - - pub fn remove_sub(&self, sub_id: UserId) { - let mut inner = self.inner.write().unwrap(); - inner.subscriptions.remove(&sub_id); - } -} - -struct ChatsInner { - pub next_sub: u64, - pub subscriptions: HashMap>, -} diff --git a/src/http.rs b/src/http.rs index e5be24a..f733072 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,5 +1,6 @@ -use crate::chat::Chats; +use crate::player::PlayerRegistry; use crate::prelude::*; +use crate::room::*; use std::convert::Infallible; @@ -8,7 +9,7 @@ use hyper::server::conn::http1; use hyper::{body::Bytes, service::service_fn, Request, Response}; use hyper::{Method, StatusCode}; -use prometheus::{Encoder, IntGauge, Opts, Registry, TextEncoder}; +use prometheus::{Encoder, IntGauge, Opts, Registry as MetricsRegistry, TextEncoder}; use tokio::net::TcpListener; use tokio::sync::oneshot::Sender; use tokio::task::JoinHandle; @@ -29,7 +30,7 @@ fn not_found() -> std::result::Result>, Infallible> { Ok(response) } -fn metrics(registry: Registry) -> std::result::Result>, Infallible> { +fn metrics(registry: MetricsRegistry) -> std::result::Result>, Infallible> { let mf = registry.gather(); let mut buffer = vec![]; let encoder = TextEncoder::new(); @@ -41,8 +42,8 @@ fn metrics(registry: Registry) -> std::result::Result>, Inf } async fn route( - registry: Registry, - chats: Chats, + registry: MetricsRegistry, + chats: PlayerRegistry, request: Request, ) -> std::result::Result, Infallible> { match (request.method(), request.uri().path()) { @@ -62,11 +63,13 @@ pub struct HttpServerActor { impl HttpServerActor { pub async fn launch( listener: TcpListener, - metrics: Registry, - chats: Chats, + metrics: MetricsRegistry, + rooms: RoomRegistry, + players: PlayerRegistry, ) -> Result { let (terminator, receiver) = tokio::sync::oneshot::channel::<()>(); - let fiber = tokio::task::spawn(Self::main_loop(listener, receiver, metrics, chats)); + let fiber = + tokio::task::spawn(Self::main_loop(listener, receiver, metrics, rooms, players)); Ok(HttpServerActor { terminator, fiber }) } @@ -82,8 +85,9 @@ impl HttpServerActor { async fn main_loop( listener: TcpListener, termination: impl Future, - registry: Registry, - chats: Chats, + registry: MetricsRegistry, + rooms: RoomRegistry, + players: PlayerRegistry, ) -> Result<()> { log::info!("Starting the http server"); pin!(termination); @@ -99,13 +103,13 @@ impl HttpServerActor { result = listener.accept() => { let (stream, _) = result?; let registry = registry.clone(); - let chats = chats.clone(); + let players = players.clone(); let reqs = reqs.clone(); tokio::task::spawn(async move { reqs.inc(); let registry = registry.clone(); if let Err(err) = http1::Builder::new() - .serve_connection(stream, service_fn(move |r| route(registry.clone(), chats.clone(), r))) + .serve_connection(stream, service_fn(move |r| route(registry.clone(), players.clone(), r))) .with_upgrades() .await { diff --git a/src/http/ws.rs b/src/http/ws.rs index 9ee668a..c5f5a04 100644 --- a/src/http/ws.rs +++ b/src/http/ws.rs @@ -7,6 +7,7 @@ use hyper::http::HeaderValue; use hyper::upgrade::Upgraded; use hyper::{body::Bytes, Request, Response}; use hyper::{StatusCode, Version}; +use regex::Regex; use std::convert::Infallible; use tokio_tungstenite::tungstenite::handshake::derive_accept_key; @@ -17,22 +18,55 @@ use tokio_tungstenite::WebSocketStream; use futures_util::sink::SinkExt; use futures_util::stream::StreamExt; -use crate::chat::Chats; +use crate::player::{PlayerRegistry, Updates}; +use crate::room::RoomId; -async fn handle_connection(mut ws_stream: WebSocketStream, chats: Chats) { +enum WsCommand { + CreateRoom, + Send { room_id: RoomId, body: String }, + Join { room_id: RoomId }, +} + +fn parse(str: &str) -> Option { + if str == "/create\n" { + return Some(WsCommand::CreateRoom); + } + let pattern_send = Regex::new(r"^/send (\d+) (.+)\n$").unwrap(); + if let Some(captures) = pattern_send.captures(str) { + if let (Some(id), Some(msg)) = (captures.get(1), captures.get(2)) { + return Some(WsCommand::Send { + room_id: RoomId(id.as_str().parse().unwrap()), + body: msg.as_str().to_owned(), + }); + } + return None; + } + let pattern_join = Regex::new(r"^/join (\d+)\n$").unwrap(); + if let Some(captures) = pattern_join.captures(str) { + if let Some(id) = captures.get(1) { + return Some(WsCommand::Join { + room_id: RoomId(id.as_str().parse().unwrap()), + }); + } + return None; + } + None +} + +async fn handle_connection(mut ws_stream: WebSocketStream, mut players: PlayerRegistry) { tracing::info!("WebSocket connection established"); - let (sub_id, mut sub) = chats.new_sub(); - tracing::info!("New conn id: {sub_id}"); - - let _ = chats.broadcast(format!("{sub_id} joined").as_str()).await; + let (player_id, mut player_handle) = players.create_player().await; + tracing::info!("New conn id: {player_id:?}"); ws_stream .send(Message::Text("Started a connection!".into())) .await .unwrap(); - tracing::info!("Started stream for {sub_id}"); + let mut sub = player_handle.subscribe().await; + + tracing::info!("Started stream for {player_id:?}"); loop { tokio::select! { biased; @@ -40,20 +74,30 @@ async fn handle_connection(mut ws_stream: WebSocketStream, chats: Chat match msg { Some(Ok(msg)) => { let txt = msg.to_text().unwrap().to_string(); - tracing::info!("Received a message: {txt}, sub_id={sub_id}"); - match chats.broadcast(format!("{sub_id}: {txt}").as_str()).await { - Ok(_) => {}, - Err(err) => { - tracing::error!("Failed to broadcast a message from sub_id={sub_id}: {err}"); + tracing::info!("Received a message: {txt}, sub_id={player_id:?}"); + let text = msg.into_text().unwrap(); + let parsed = parse(text.as_str()); + match parsed { + Some(WsCommand::CreateRoom) => { + player_handle.create_room().await + }, + Some(WsCommand::Send { room_id, body }) => { + player_handle.send_message(room_id, body).await + }, + Some(WsCommand::Join { room_id }) => { + player_handle.join_room(room_id).await + }, + None => { + ws_stream.send(Message::Text(format!("Failed to parse: {text}"))).await; }, } }, Some(Err(err)) => { - tracing::warn!("Client {sub_id} failure: {err}"); + tracing::warn!("Client {player_id:?} failure: {err}"); break; } None => { - tracing::info!("Client {sub_id} closed the socket, stopping.."); + tracing::info!("Client {player_id:?} closed the socket, stopping.."); break; }, } @@ -61,12 +105,13 @@ async fn handle_connection(mut ws_stream: WebSocketStream, chats: Chat msg = sub.recv() => { match msg { Some(msg) => { - match ws_stream.send(Message::Text(msg)).await { - Ok(_) => {}, - Err(err) => { - tracing::warn!("Failed to send msg, sub_id={sub_id}: {err}"); - break; + match msg { + Updates::RoomJoined { room_id } => { + ws_stream.send(Message::Text(format!("Joined room {room_id:?}"))).await.unwrap(); }, + Updates::NewMessage { room_id, body } => { + ws_stream.send(Message::Text(format!("{room_id:?}: {body}"))).await.unwrap(); + } } }, None => { @@ -76,14 +121,12 @@ async fn handle_connection(mut ws_stream: WebSocketStream, chats: Chat } } } - tracing::info!("Ended stream for {sub_id}"); - chats.remove_sub(sub_id); - let _ = chats.broadcast(format!("{sub_id} left").as_str()).await; + tracing::info!("Ended stream for {player_id:?}"); } pub async fn handle_request( mut req: Request, - chats: Chats, + players: PlayerRegistry, ) -> std::result::Result>, Infallible> { tracing::info!("Received a new WS request"); let upgrade = HeaderValue::from_static("Upgrade"); @@ -118,13 +161,13 @@ pub async fn handle_request( } let ver = req.version(); - let chats = chats.clone(); + let players = players.clone(); tokio::task::spawn(async move { match hyper::upgrade::on(&mut req).await { Ok(upgraded) => { handle_connection( WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await, - chats, + players, ) .await; } diff --git a/src/main.rs b/src/main.rs index 49daf29..a86fb72 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,13 @@ -mod chat; mod http; +mod player; mod prelude; +mod room; +mod table; mod tcp; -use crate::chat::Chats; +use crate::player::PlayerRegistry; use crate::prelude::*; +use crate::room::*; use prometheus::{IntCounter, Opts, Registry}; use tcp::ClientSocketActor; @@ -48,12 +51,18 @@ async fn main() -> Result<()> { let counter = IntCounter::with_opts(Opts::new("actor_count", "Number of alive actors"))?; registry.register(Box::new(counter.clone()))?; - let chats = Chats::new(); + let rooms = RoomRegistry::empty(); + let players = PlayerRegistry::empty(rooms.clone()); let listener = TcpListener::bind("127.0.0.1:3721").await?; let listener_http = TcpListener::bind("127.0.0.1:8080").await?; - let http_server_actor = - http::HttpServerActor::launch(listener_http, registry.clone(), chats.clone()).await?; + let http_server_actor = http::HttpServerActor::launch( + listener_http, + registry.clone(), + rooms.clone(), + players.clone(), + ) + .await?; tracing::info!("Started"); diff --git a/src/player.rs b/src/player.rs new file mode 100644 index 0000000..250500d --- /dev/null +++ b/src/player.rs @@ -0,0 +1,179 @@ +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, +}; + +use tokio::{ + sync::mpsc::{channel, Receiver, Sender}, + task::JoinHandle, +}; + +use crate::{ + room::{RoomId, RoomRegistry}, + table::AnonTable, +}; + +/// Opaque player id +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct PlayerId(u64); + +#[derive(Clone)] +pub struct PlayerHandle { + tx: Sender, +} +impl PlayerHandle { + pub async fn subscribe(&mut self) -> Receiver { + let (sender, rx) = channel(32); + self.tx.send(PlayerCommand::AddSocket { sender }).await; + rx + } + + pub async fn create_room(&mut self) { + match self.tx.send(PlayerCommand::CreateRoom).await { + Ok(_) => {} + Err(_) => { + panic!("unexpected err"); + } + }; + } + + pub async fn send_message(&mut self, room_id: RoomId, body: String) { + self.tx + .send(PlayerCommand::SendMessage { room_id, body }) + .await; + } + + pub async fn join_room(&mut self, room_id: RoomId) { + self.tx.send(PlayerCommand::JoinRoom { room_id }).await; + } + + pub async fn receive_message(&mut self, room_id: RoomId, author: PlayerId, body: String) { + self.tx + .send(PlayerCommand::IncomingMessage { + room_id, + author, + body, + }) + .await; + } +} + +pub enum Updates { + RoomJoined { room_id: RoomId }, + NewMessage { room_id: RoomId, body: String }, +} +enum PlayerCommand { + AddSocket { + sender: Sender, + }, + CreateRoom, + JoinRoom { + room_id: RoomId, + }, + SendMessage { + room_id: RoomId, + body: String, + }, + IncomingMessage { + room_id: RoomId, + author: PlayerId, + body: String, + }, +} + +#[derive(Clone)] +pub struct PlayerRegistry(Arc>); +impl PlayerRegistry { + pub fn empty(room_registry: RoomRegistry) -> PlayerRegistry { + let inner = PlayerRegistryInner { + next_id: PlayerId(0), + room_registry, + players: HashMap::new(), + }; + PlayerRegistry(Arc::new(RwLock::new(inner))) + } + + pub async fn create_player(&mut self) -> (PlayerId, PlayerHandle) { + let player = Player { + sockets: AnonTable::new(), + }; + let mut inner = self.0.write().unwrap(); + let id = inner.next_id; + inner.next_id.0 += 1; + let (handle, fiber) = player.launch(id, inner.room_registry.clone()); + inner.players.insert(id, (handle.clone(), fiber)); + (id, handle) + } +} + +struct PlayerRegistryInner { + next_id: PlayerId, + room_registry: RoomRegistry, + players: HashMap)>, +} + +struct Player { + sockets: AnonTable>, +} +impl Player { + fn launch( + mut self, + player_id: PlayerId, + mut rooms: RoomRegistry, + ) -> (PlayerHandle, JoinHandle) { + let (tx, mut rx) = channel(32); + let handle = PlayerHandle { tx }; + let handle_clone = handle.clone(); + let fiber = tokio::task::spawn(async move { + while let Some(cmd) = rx.recv().await { + match cmd { + PlayerCommand::AddSocket { sender } => { + self.sockets.insert(sender); + } + PlayerCommand::CreateRoom => { + let (room_id, room_handle) = rooms.create_room(); + } + PlayerCommand::JoinRoom { room_id } => { + let room = rooms.get_room(room_id); + match room { + Some(mut room) => { + room.subscribe(player_id, handle.clone()).await; + } + None => { + tracing::info!("no room found"); + } + } + } + PlayerCommand::SendMessage { room_id, body } => { + let room = rooms.get_room(room_id); + match room { + Some(mut room) => { + room.send_message(player_id, body).await; + } + None => { + tracing::info!("no room found"); + } + } + } + PlayerCommand::IncomingMessage { + room_id, + author, + body, + } => { + tracing::info!("Handling incoming message"); + for socket in &self.sockets { + socket + .send(Updates::NewMessage { + room_id, + body: body.clone(), + }) + .await; + } + } + } + } + self + }); + (handle_clone, fiber) + } +} diff --git a/src/room.rs b/src/room.rs new file mode 100644 index 0000000..da465e7 --- /dev/null +++ b/src/room.rs @@ -0,0 +1,117 @@ +use std::{ + collections::HashMap, + hash::Hash, + sync::{Arc, RwLock}, +}; + +use tokio::sync::mpsc::{channel, Sender}; + +use crate::{ + player::{PlayerHandle, PlayerId}, + prelude::*, +}; + +/// Opaque room id +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct RoomId(pub u64); + +/// Shared datastructure for storing metadata about rooms. +#[derive(Clone)] +pub struct RoomRegistry(Arc>); +impl RoomRegistry { + pub fn empty() -> RoomRegistry { + let inner = RoomRegistryInner { + next_room_id: RoomId(0), + rooms: HashMap::new(), + }; + RoomRegistry(Arc::new(RwLock::new(inner))) + } + + pub fn create_room(&mut self) -> (RoomId, RoomHandle) { + let room = Room { + subscriptions: HashMap::new(), + }; + let mut inner = self.0.write().unwrap(); + let room_id = inner.next_room_id; + inner.next_room_id.0 += 1; + let (room_handle, fiber) = room.launch(room_id); + inner.rooms.insert(room_id, (room_handle.clone(), fiber)); + (room_id, room_handle) + } + + pub fn get_room(&self, room_id: RoomId) -> Option { + let inner = self.0.read().unwrap(); + let res = inner.rooms.get(&room_id); + res.map(|r| r.0.clone()) + } +} + +struct RoomRegistryInner { + next_room_id: RoomId, + rooms: HashMap)>, +} + +#[derive(Clone)] +pub struct RoomHandle { + tx: Sender, +} +impl RoomHandle { + pub async fn subscribe(&mut self, player_id: PlayerId, player: PlayerHandle) { + match self + .tx + .send(RoomCommand::AddSubscriber { player_id, player }) + .await + { + Ok(_) => {} + Err(_) => { + tracing::error!("Room mailbox is closed unexpectedly"); + } + }; + } + + pub async fn send_message(&mut self, player_id: PlayerId, body: String) { + self.tx + .send(RoomCommand::SendMessage { player_id, body }) + .await; + } +} + +enum RoomCommand { + AddSubscriber { + player_id: PlayerId, + player: PlayerHandle, + }, + SendMessage { + player_id: PlayerId, + body: String, + }, +} + +struct Room { + subscriptions: HashMap, +} +impl Room { + fn launch(mut self, room_id: RoomId) -> (RoomHandle, JoinHandle) { + let (tx, mut rx) = channel(32); + let fiber = tokio::task::spawn(async move { + tracing::info!("Starting room fiber"); + while let Some(a) = rx.recv().await { + match a { + RoomCommand::AddSubscriber { player_id, player } => { + tracing::info!("Adding a subscriber to room"); + self.subscriptions.insert(player_id, player); + } + RoomCommand::SendMessage { player_id, body } => { + tracing::info!("Adding a message to room"); + for (_, sub) in &mut self.subscriptions { + sub.receive_message(room_id, player_id, body.clone()).await; + } + } + } + } + tracing::info!("Stopping room fiber"); + self + }); + (RoomHandle { tx }, fiber) + } +} diff --git a/src/table.rs b/src/table.rs new file mode 100644 index 0000000..9e1dbe9 --- /dev/null +++ b/src/table.rs @@ -0,0 +1,77 @@ +use std::collections::HashMap; + +pub struct Key(u32); + +pub struct AnonTable { + next: u32, + inner: HashMap, +} + +impl AnonTable { + pub fn new() -> AnonTable { + AnonTable { + next: 0, + inner: HashMap::new(), + } + } + + pub fn insert(&mut self, value: V) -> Option { + let id = self.next; + self.next += 1; + self.inner.insert(id, value) + } + + pub fn get_mut(&mut self, key: Key) -> Option<&mut V> { + self.inner.get_mut(&key.0) + } + + pub fn get(&self, key: Key) -> Option<&V> { + self.inner.get(&key.0) + } + + pub fn pop(&mut self, key: Key) -> Option { + self.inner.remove(&key.0) + } + + pub fn len(&self) -> usize { + self.inner.len() + } +} + +pub struct AnonTableIterator<'a, V>(<&'a HashMap as IntoIterator>::IntoIter); +impl<'a, V> Iterator for AnonTableIterator<'a, V> { + type Item = &'a V; + + fn next(&mut self) -> Option<&'a V> { + self.0.next().map(|a| a.1) + } +} + +impl<'a, V> IntoIterator for &'a AnonTable { + type Item = &'a V; + + type IntoIter = AnonTableIterator<'a, V>; + + fn into_iter(self) -> Self::IntoIter { + AnonTableIterator(IntoIterator::into_iter(&self.inner)) + } +} + +pub struct AnonTableMutIterator<'a, V>(<&'a mut HashMap as IntoIterator>::IntoIter); +impl<'a, V> Iterator for AnonTableMutIterator<'a, V> { + type Item = &'a mut V; + + fn next(&mut self) -> Option<&'a mut V> { + self.0.next().map(|a| a.1) + } +} + +impl<'a, V> IntoIterator for &'a mut AnonTable { + type Item = &'a mut V; + + type IntoIter = AnonTableMutIterator<'a, V>; + + fn into_iter(self) -> Self::IntoIter { + AnonTableMutIterator(IntoIterator::into_iter(&mut self.inner)) + } +} diff --git a/tests/mod.rs b/tests/mod.rs index 78fb04e..45b5d08 100644 --- a/tests/mod.rs +++ b/tests/mod.rs @@ -16,7 +16,7 @@ async fn hello_endpoint() -> Test { #[tokio::test] async fn websocket_connect() -> Test { let connected = Regex::new(r"^(\d+) joined$").unwrap(); - let msg = Regex::new(r"^(\d+): (.*)").unwrap(); + let msg = Regex::new(r"^(\d+): (.*)$").unwrap(); let (mut socket, response) = connect_async("ws://localhost:8080/socket").await?;