forked from lavina/lavina
1
0
Fork 0

implement multiple rooms

This commit is contained in:
Nikita Vilunov 2023-02-03 23:43:59 +01:00
parent 03b0ababa7
commit b7995584f0
9 changed files with 473 additions and 101 deletions

View File

@ -20,6 +20,7 @@ tracing-subscriber = "0.3.16"
tokio-tungstenite = "0.18.0" tokio-tungstenite = "0.18.0"
futures-util = "0.3.25" futures-util = "0.3.25"
prometheus = { version = "0.13.3", default_features = false } prometheus = { version = "0.13.3", default_features = false }
regex = "1.7.1"
[dev-dependencies] [dev-dependencies]
regex = "1.7.1" regex = "1.7.1"

View File

@ -1,58 +0,0 @@
use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
use tokio::sync::mpsc::{channel, Receiver, Sender};
use crate::prelude::*;
pub type UserId = u64;
#[derive(Clone)]
pub struct Chats {
inner: Arc<RwLock<ChatsInner>>,
}
impl Chats {
pub fn new() -> Chats {
let subscriptions = HashMap::new();
let chats_inner = ChatsInner {
subscriptions,
next_sub: 0,
};
let inner = Arc::new(RwLock::new(chats_inner));
Chats { inner }
}
pub fn new_sub(&self) -> (UserId, Receiver<String>) {
let mut inner = self.inner.write().unwrap();
let sub_id = inner.next_sub;
inner.next_sub += 1;
let (rx, tx) = channel(32);
inner.subscriptions.insert(sub_id, rx);
(sub_id, tx)
}
pub async fn broadcast(&self, msg: &str) -> Result<()> {
let subs = {
let inner = self.inner.read().unwrap();
inner.subscriptions.clone()
};
for (sub_id, tx) in subs.iter() {
tx.send(msg.to_string()).await?;
tracing::info!("Sent message to {}", sub_id);
}
Ok(())
}
pub fn remove_sub(&self, sub_id: UserId) {
let mut inner = self.inner.write().unwrap();
inner.subscriptions.remove(&sub_id);
}
}
struct ChatsInner {
pub next_sub: u64,
pub subscriptions: HashMap<UserId, Sender<String>>,
}

View File

@ -1,5 +1,6 @@
use crate::chat::Chats; use crate::player::PlayerRegistry;
use crate::prelude::*; use crate::prelude::*;
use crate::room::*;
use std::convert::Infallible; use std::convert::Infallible;
@ -8,7 +9,7 @@ use hyper::server::conn::http1;
use hyper::{body::Bytes, service::service_fn, Request, Response}; use hyper::{body::Bytes, service::service_fn, Request, Response};
use hyper::{Method, StatusCode}; use hyper::{Method, StatusCode};
use prometheus::{Encoder, IntGauge, Opts, Registry, TextEncoder}; use prometheus::{Encoder, IntGauge, Opts, Registry as MetricsRegistry, TextEncoder};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::sync::oneshot::Sender; use tokio::sync::oneshot::Sender;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
@ -29,7 +30,7 @@ fn not_found() -> std::result::Result<Response<Full<Bytes>>, Infallible> {
Ok(response) Ok(response)
} }
fn metrics(registry: Registry) -> std::result::Result<Response<Full<Bytes>>, Infallible> { fn metrics(registry: MetricsRegistry) -> std::result::Result<Response<Full<Bytes>>, Infallible> {
let mf = registry.gather(); let mf = registry.gather();
let mut buffer = vec![]; let mut buffer = vec![];
let encoder = TextEncoder::new(); let encoder = TextEncoder::new();
@ -41,8 +42,8 @@ fn metrics(registry: Registry) -> std::result::Result<Response<Full<Bytes>>, Inf
} }
async fn route( async fn route(
registry: Registry, registry: MetricsRegistry,
chats: Chats, chats: PlayerRegistry,
request: Request<hyper::body::Incoming>, request: Request<hyper::body::Incoming>,
) -> std::result::Result<Response<BoxBody>, Infallible> { ) -> std::result::Result<Response<BoxBody>, Infallible> {
match (request.method(), request.uri().path()) { match (request.method(), request.uri().path()) {
@ -62,11 +63,13 @@ pub struct HttpServerActor {
impl HttpServerActor { impl HttpServerActor {
pub async fn launch( pub async fn launch(
listener: TcpListener, listener: TcpListener,
metrics: Registry, metrics: MetricsRegistry,
chats: Chats, rooms: RoomRegistry,
players: PlayerRegistry,
) -> Result<HttpServerActor> { ) -> Result<HttpServerActor> {
let (terminator, receiver) = tokio::sync::oneshot::channel::<()>(); let (terminator, receiver) = tokio::sync::oneshot::channel::<()>();
let fiber = tokio::task::spawn(Self::main_loop(listener, receiver, metrics, chats)); let fiber =
tokio::task::spawn(Self::main_loop(listener, receiver, metrics, rooms, players));
Ok(HttpServerActor { terminator, fiber }) Ok(HttpServerActor { terminator, fiber })
} }
@ -82,8 +85,9 @@ impl HttpServerActor {
async fn main_loop( async fn main_loop(
listener: TcpListener, listener: TcpListener,
termination: impl Future, termination: impl Future,
registry: Registry, registry: MetricsRegistry,
chats: Chats, rooms: RoomRegistry,
players: PlayerRegistry,
) -> Result<()> { ) -> Result<()> {
log::info!("Starting the http server"); log::info!("Starting the http server");
pin!(termination); pin!(termination);
@ -99,13 +103,13 @@ impl HttpServerActor {
result = listener.accept() => { result = listener.accept() => {
let (stream, _) = result?; let (stream, _) = result?;
let registry = registry.clone(); let registry = registry.clone();
let chats = chats.clone(); let players = players.clone();
let reqs = reqs.clone(); let reqs = reqs.clone();
tokio::task::spawn(async move { tokio::task::spawn(async move {
reqs.inc(); reqs.inc();
let registry = registry.clone(); let registry = registry.clone();
if let Err(err) = http1::Builder::new() if let Err(err) = http1::Builder::new()
.serve_connection(stream, service_fn(move |r| route(registry.clone(), chats.clone(), r))) .serve_connection(stream, service_fn(move |r| route(registry.clone(), players.clone(), r)))
.with_upgrades() .with_upgrades()
.await .await
{ {

View File

@ -7,6 +7,7 @@ use hyper::http::HeaderValue;
use hyper::upgrade::Upgraded; use hyper::upgrade::Upgraded;
use hyper::{body::Bytes, Request, Response}; use hyper::{body::Bytes, Request, Response};
use hyper::{StatusCode, Version}; use hyper::{StatusCode, Version};
use regex::Regex;
use std::convert::Infallible; use std::convert::Infallible;
use tokio_tungstenite::tungstenite::handshake::derive_accept_key; use tokio_tungstenite::tungstenite::handshake::derive_accept_key;
@ -17,22 +18,55 @@ use tokio_tungstenite::WebSocketStream;
use futures_util::sink::SinkExt; use futures_util::sink::SinkExt;
use futures_util::stream::StreamExt; use futures_util::stream::StreamExt;
use crate::chat::Chats; use crate::player::{PlayerRegistry, Updates};
use crate::room::RoomId;
async fn handle_connection(mut ws_stream: WebSocketStream<Upgraded>, chats: Chats) { enum WsCommand {
CreateRoom,
Send { room_id: RoomId, body: String },
Join { room_id: RoomId },
}
fn parse(str: &str) -> Option<WsCommand> {
if str == "/create\n" {
return Some(WsCommand::CreateRoom);
}
let pattern_send = Regex::new(r"^/send (\d+) (.+)\n$").unwrap();
if let Some(captures) = pattern_send.captures(str) {
if let (Some(id), Some(msg)) = (captures.get(1), captures.get(2)) {
return Some(WsCommand::Send {
room_id: RoomId(id.as_str().parse().unwrap()),
body: msg.as_str().to_owned(),
});
}
return None;
}
let pattern_join = Regex::new(r"^/join (\d+)\n$").unwrap();
if let Some(captures) = pattern_join.captures(str) {
if let Some(id) = captures.get(1) {
return Some(WsCommand::Join {
room_id: RoomId(id.as_str().parse().unwrap()),
});
}
return None;
}
None
}
async fn handle_connection(mut ws_stream: WebSocketStream<Upgraded>, mut players: PlayerRegistry) {
tracing::info!("WebSocket connection established"); tracing::info!("WebSocket connection established");
let (sub_id, mut sub) = chats.new_sub(); let (player_id, mut player_handle) = players.create_player().await;
tracing::info!("New conn id: {sub_id}"); tracing::info!("New conn id: {player_id:?}");
let _ = chats.broadcast(format!("{sub_id} joined").as_str()).await;
ws_stream ws_stream
.send(Message::Text("Started a connection!".into())) .send(Message::Text("Started a connection!".into()))
.await .await
.unwrap(); .unwrap();
tracing::info!("Started stream for {sub_id}"); let mut sub = player_handle.subscribe().await;
tracing::info!("Started stream for {player_id:?}");
loop { loop {
tokio::select! { tokio::select! {
biased; biased;
@ -40,20 +74,30 @@ async fn handle_connection(mut ws_stream: WebSocketStream<Upgraded>, chats: Chat
match msg { match msg {
Some(Ok(msg)) => { Some(Ok(msg)) => {
let txt = msg.to_text().unwrap().to_string(); let txt = msg.to_text().unwrap().to_string();
tracing::info!("Received a message: {txt}, sub_id={sub_id}"); tracing::info!("Received a message: {txt}, sub_id={player_id:?}");
match chats.broadcast(format!("{sub_id}: {txt}").as_str()).await { let text = msg.into_text().unwrap();
Ok(_) => {}, let parsed = parse(text.as_str());
Err(err) => { match parsed {
tracing::error!("Failed to broadcast a message from sub_id={sub_id}: {err}"); Some(WsCommand::CreateRoom) => {
player_handle.create_room().await
},
Some(WsCommand::Send { room_id, body }) => {
player_handle.send_message(room_id, body).await
},
Some(WsCommand::Join { room_id }) => {
player_handle.join_room(room_id).await
},
None => {
ws_stream.send(Message::Text(format!("Failed to parse: {text}"))).await;
}, },
} }
}, },
Some(Err(err)) => { Some(Err(err)) => {
tracing::warn!("Client {sub_id} failure: {err}"); tracing::warn!("Client {player_id:?} failure: {err}");
break; break;
} }
None => { None => {
tracing::info!("Client {sub_id} closed the socket, stopping.."); tracing::info!("Client {player_id:?} closed the socket, stopping..");
break; break;
}, },
} }
@ -61,12 +105,13 @@ async fn handle_connection(mut ws_stream: WebSocketStream<Upgraded>, chats: Chat
msg = sub.recv() => { msg = sub.recv() => {
match msg { match msg {
Some(msg) => { Some(msg) => {
match ws_stream.send(Message::Text(msg)).await { match msg {
Ok(_) => {}, Updates::RoomJoined { room_id } => {
Err(err) => { ws_stream.send(Message::Text(format!("Joined room {room_id:?}"))).await.unwrap();
tracing::warn!("Failed to send msg, sub_id={sub_id}: {err}");
break;
}, },
Updates::NewMessage { room_id, body } => {
ws_stream.send(Message::Text(format!("{room_id:?}: {body}"))).await.unwrap();
}
} }
}, },
None => { None => {
@ -76,14 +121,12 @@ async fn handle_connection(mut ws_stream: WebSocketStream<Upgraded>, chats: Chat
} }
} }
} }
tracing::info!("Ended stream for {sub_id}"); tracing::info!("Ended stream for {player_id:?}");
chats.remove_sub(sub_id);
let _ = chats.broadcast(format!("{sub_id} left").as_str()).await;
} }
pub async fn handle_request( pub async fn handle_request(
mut req: Request<Incoming>, mut req: Request<Incoming>,
chats: Chats, players: PlayerRegistry,
) -> std::result::Result<Response<Empty<Bytes>>, Infallible> { ) -> std::result::Result<Response<Empty<Bytes>>, Infallible> {
tracing::info!("Received a new WS request"); tracing::info!("Received a new WS request");
let upgrade = HeaderValue::from_static("Upgrade"); let upgrade = HeaderValue::from_static("Upgrade");
@ -118,13 +161,13 @@ pub async fn handle_request(
} }
let ver = req.version(); let ver = req.version();
let chats = chats.clone(); let players = players.clone();
tokio::task::spawn(async move { tokio::task::spawn(async move {
match hyper::upgrade::on(&mut req).await { match hyper::upgrade::on(&mut req).await {
Ok(upgraded) => { Ok(upgraded) => {
handle_connection( handle_connection(
WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await, WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await,
chats, players,
) )
.await; .await;
} }

View File

@ -1,10 +1,13 @@
mod chat;
mod http; mod http;
mod player;
mod prelude; mod prelude;
mod room;
mod table;
mod tcp; mod tcp;
use crate::chat::Chats; use crate::player::PlayerRegistry;
use crate::prelude::*; use crate::prelude::*;
use crate::room::*;
use prometheus::{IntCounter, Opts, Registry}; use prometheus::{IntCounter, Opts, Registry};
use tcp::ClientSocketActor; use tcp::ClientSocketActor;
@ -48,12 +51,18 @@ async fn main() -> Result<()> {
let counter = IntCounter::with_opts(Opts::new("actor_count", "Number of alive actors"))?; let counter = IntCounter::with_opts(Opts::new("actor_count", "Number of alive actors"))?;
registry.register(Box::new(counter.clone()))?; registry.register(Box::new(counter.clone()))?;
let chats = Chats::new(); let rooms = RoomRegistry::empty();
let players = PlayerRegistry::empty(rooms.clone());
let listener = TcpListener::bind("127.0.0.1:3721").await?; let listener = TcpListener::bind("127.0.0.1:3721").await?;
let listener_http = TcpListener::bind("127.0.0.1:8080").await?; let listener_http = TcpListener::bind("127.0.0.1:8080").await?;
let http_server_actor = let http_server_actor = http::HttpServerActor::launch(
http::HttpServerActor::launch(listener_http, registry.clone(), chats.clone()).await?; listener_http,
registry.clone(),
rooms.clone(),
players.clone(),
)
.await?;
tracing::info!("Started"); tracing::info!("Started");

179
src/player.rs Normal file
View File

@ -0,0 +1,179 @@
use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
use tokio::{
sync::mpsc::{channel, Receiver, Sender},
task::JoinHandle,
};
use crate::{
room::{RoomId, RoomRegistry},
table::AnonTable,
};
/// Opaque player id
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct PlayerId(u64);
#[derive(Clone)]
pub struct PlayerHandle {
tx: Sender<PlayerCommand>,
}
impl PlayerHandle {
pub async fn subscribe(&mut self) -> Receiver<Updates> {
let (sender, rx) = channel(32);
self.tx.send(PlayerCommand::AddSocket { sender }).await;
rx
}
pub async fn create_room(&mut self) {
match self.tx.send(PlayerCommand::CreateRoom).await {
Ok(_) => {}
Err(_) => {
panic!("unexpected err");
}
};
}
pub async fn send_message(&mut self, room_id: RoomId, body: String) {
self.tx
.send(PlayerCommand::SendMessage { room_id, body })
.await;
}
pub async fn join_room(&mut self, room_id: RoomId) {
self.tx.send(PlayerCommand::JoinRoom { room_id }).await;
}
pub async fn receive_message(&mut self, room_id: RoomId, author: PlayerId, body: String) {
self.tx
.send(PlayerCommand::IncomingMessage {
room_id,
author,
body,
})
.await;
}
}
pub enum Updates {
RoomJoined { room_id: RoomId },
NewMessage { room_id: RoomId, body: String },
}
enum PlayerCommand {
AddSocket {
sender: Sender<Updates>,
},
CreateRoom,
JoinRoom {
room_id: RoomId,
},
SendMessage {
room_id: RoomId,
body: String,
},
IncomingMessage {
room_id: RoomId,
author: PlayerId,
body: String,
},
}
#[derive(Clone)]
pub struct PlayerRegistry(Arc<RwLock<PlayerRegistryInner>>);
impl PlayerRegistry {
pub fn empty(room_registry: RoomRegistry) -> PlayerRegistry {
let inner = PlayerRegistryInner {
next_id: PlayerId(0),
room_registry,
players: HashMap::new(),
};
PlayerRegistry(Arc::new(RwLock::new(inner)))
}
pub async fn create_player(&mut self) -> (PlayerId, PlayerHandle) {
let player = Player {
sockets: AnonTable::new(),
};
let mut inner = self.0.write().unwrap();
let id = inner.next_id;
inner.next_id.0 += 1;
let (handle, fiber) = player.launch(id, inner.room_registry.clone());
inner.players.insert(id, (handle.clone(), fiber));
(id, handle)
}
}
struct PlayerRegistryInner {
next_id: PlayerId,
room_registry: RoomRegistry,
players: HashMap<PlayerId, (PlayerHandle, JoinHandle<Player>)>,
}
struct Player {
sockets: AnonTable<Sender<Updates>>,
}
impl Player {
fn launch(
mut self,
player_id: PlayerId,
mut rooms: RoomRegistry,
) -> (PlayerHandle, JoinHandle<Player>) {
let (tx, mut rx) = channel(32);
let handle = PlayerHandle { tx };
let handle_clone = handle.clone();
let fiber = tokio::task::spawn(async move {
while let Some(cmd) = rx.recv().await {
match cmd {
PlayerCommand::AddSocket { sender } => {
self.sockets.insert(sender);
}
PlayerCommand::CreateRoom => {
let (room_id, room_handle) = rooms.create_room();
}
PlayerCommand::JoinRoom { room_id } => {
let room = rooms.get_room(room_id);
match room {
Some(mut room) => {
room.subscribe(player_id, handle.clone()).await;
}
None => {
tracing::info!("no room found");
}
}
}
PlayerCommand::SendMessage { room_id, body } => {
let room = rooms.get_room(room_id);
match room {
Some(mut room) => {
room.send_message(player_id, body).await;
}
None => {
tracing::info!("no room found");
}
}
}
PlayerCommand::IncomingMessage {
room_id,
author,
body,
} => {
tracing::info!("Handling incoming message");
for socket in &self.sockets {
socket
.send(Updates::NewMessage {
room_id,
body: body.clone(),
})
.await;
}
}
}
}
self
});
(handle_clone, fiber)
}
}

117
src/room.rs Normal file
View File

@ -0,0 +1,117 @@
use std::{
collections::HashMap,
hash::Hash,
sync::{Arc, RwLock},
};
use tokio::sync::mpsc::{channel, Sender};
use crate::{
player::{PlayerHandle, PlayerId},
prelude::*,
};
/// Opaque room id
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct RoomId(pub u64);
/// Shared datastructure for storing metadata about rooms.
#[derive(Clone)]
pub struct RoomRegistry(Arc<RwLock<RoomRegistryInner>>);
impl RoomRegistry {
pub fn empty() -> RoomRegistry {
let inner = RoomRegistryInner {
next_room_id: RoomId(0),
rooms: HashMap::new(),
};
RoomRegistry(Arc::new(RwLock::new(inner)))
}
pub fn create_room(&mut self) -> (RoomId, RoomHandle) {
let room = Room {
subscriptions: HashMap::new(),
};
let mut inner = self.0.write().unwrap();
let room_id = inner.next_room_id;
inner.next_room_id.0 += 1;
let (room_handle, fiber) = room.launch(room_id);
inner.rooms.insert(room_id, (room_handle.clone(), fiber));
(room_id, room_handle)
}
pub fn get_room(&self, room_id: RoomId) -> Option<RoomHandle> {
let inner = self.0.read().unwrap();
let res = inner.rooms.get(&room_id);
res.map(|r| r.0.clone())
}
}
struct RoomRegistryInner {
next_room_id: RoomId,
rooms: HashMap<RoomId, (RoomHandle, JoinHandle<Room>)>,
}
#[derive(Clone)]
pub struct RoomHandle {
tx: Sender<RoomCommand>,
}
impl RoomHandle {
pub async fn subscribe(&mut self, player_id: PlayerId, player: PlayerHandle) {
match self
.tx
.send(RoomCommand::AddSubscriber { player_id, player })
.await
{
Ok(_) => {}
Err(_) => {
tracing::error!("Room mailbox is closed unexpectedly");
}
};
}
pub async fn send_message(&mut self, player_id: PlayerId, body: String) {
self.tx
.send(RoomCommand::SendMessage { player_id, body })
.await;
}
}
enum RoomCommand {
AddSubscriber {
player_id: PlayerId,
player: PlayerHandle,
},
SendMessage {
player_id: PlayerId,
body: String,
},
}
struct Room {
subscriptions: HashMap<PlayerId, PlayerHandle>,
}
impl Room {
fn launch(mut self, room_id: RoomId) -> (RoomHandle, JoinHandle<Room>) {
let (tx, mut rx) = channel(32);
let fiber = tokio::task::spawn(async move {
tracing::info!("Starting room fiber");
while let Some(a) = rx.recv().await {
match a {
RoomCommand::AddSubscriber { player_id, player } => {
tracing::info!("Adding a subscriber to room");
self.subscriptions.insert(player_id, player);
}
RoomCommand::SendMessage { player_id, body } => {
tracing::info!("Adding a message to room");
for (_, sub) in &mut self.subscriptions {
sub.receive_message(room_id, player_id, body.clone()).await;
}
}
}
}
tracing::info!("Stopping room fiber");
self
});
(RoomHandle { tx }, fiber)
}
}

77
src/table.rs Normal file
View File

@ -0,0 +1,77 @@
use std::collections::HashMap;
pub struct Key(u32);
pub struct AnonTable<V> {
next: u32,
inner: HashMap<u32, V>,
}
impl<V> AnonTable<V> {
pub fn new() -> AnonTable<V> {
AnonTable {
next: 0,
inner: HashMap::new(),
}
}
pub fn insert(&mut self, value: V) -> Option<V> {
let id = self.next;
self.next += 1;
self.inner.insert(id, value)
}
pub fn get_mut(&mut self, key: Key) -> Option<&mut V> {
self.inner.get_mut(&key.0)
}
pub fn get(&self, key: Key) -> Option<&V> {
self.inner.get(&key.0)
}
pub fn pop(&mut self, key: Key) -> Option<V> {
self.inner.remove(&key.0)
}
pub fn len(&self) -> usize {
self.inner.len()
}
}
pub struct AnonTableIterator<'a, V>(<&'a HashMap<u32, V> as IntoIterator>::IntoIter);
impl<'a, V> Iterator for AnonTableIterator<'a, V> {
type Item = &'a V;
fn next(&mut self) -> Option<&'a V> {
self.0.next().map(|a| a.1)
}
}
impl<'a, V> IntoIterator for &'a AnonTable<V> {
type Item = &'a V;
type IntoIter = AnonTableIterator<'a, V>;
fn into_iter(self) -> Self::IntoIter {
AnonTableIterator(IntoIterator::into_iter(&self.inner))
}
}
pub struct AnonTableMutIterator<'a, V>(<&'a mut HashMap<u32, V> as IntoIterator>::IntoIter);
impl<'a, V> Iterator for AnonTableMutIterator<'a, V> {
type Item = &'a mut V;
fn next(&mut self) -> Option<&'a mut V> {
self.0.next().map(|a| a.1)
}
}
impl<'a, V> IntoIterator for &'a mut AnonTable<V> {
type Item = &'a mut V;
type IntoIter = AnonTableMutIterator<'a, V>;
fn into_iter(self) -> Self::IntoIter {
AnonTableMutIterator(IntoIterator::into_iter(&mut self.inner))
}
}

View File

@ -16,7 +16,7 @@ async fn hello_endpoint() -> Test {
#[tokio::test] #[tokio::test]
async fn websocket_connect() -> Test { async fn websocket_connect() -> Test {
let connected = Regex::new(r"^(\d+) joined$").unwrap(); let connected = Regex::new(r"^(\d+) joined$").unwrap();
let msg = Regex::new(r"^(\d+): (.*)").unwrap(); let msg = Regex::new(r"^(\d+): (.*)$").unwrap();
let (mut socket, response) = connect_async("ws://localhost:8080/socket").await?; let (mut socket, response) = connect_async("ws://localhost:8080/socket").await?;