diff --git a/src/chat.rs b/src/chat.rs new file mode 100644 index 0000000..09570a7 --- /dev/null +++ b/src/chat.rs @@ -0,0 +1,55 @@ +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, +}; + +use tokio::sync::mpsc::{Sender, channel, Receiver}; + +use crate::prelude::*; + +pub type UserId = u64; + +#[derive(Clone)] +pub struct Chats { + inner: Arc>, +} + +impl Chats { + pub fn new() -> Chats { + let subscriptions = HashMap::new(); + let chats_inner = ChatsInner { subscriptions, next_sub: 0 }; + let inner = Arc::new(RwLock::new(chats_inner)); + Chats { inner } + } + + pub fn new_sub(&self) -> (UserId, Receiver) { + let mut inner = self.inner.write().unwrap(); + let sub_id = inner.next_sub; + inner.next_sub += 1; + let (rx,tx) = channel(32); + inner.subscriptions.insert(sub_id, rx); + (sub_id, tx) + } + + pub async fn broadcast(&self, msg: &str) -> Result<()> { + let subs = { + let inner = self.inner.read().unwrap(); + inner.subscriptions.clone() + }; + for (sub_id, tx) in subs.iter() { + tx.send(msg.to_string()).await?; + tracing::info!("Sent message to {}", sub_id); + } + Ok(()) + } + + pub fn remove_sub(&self, sub_id: UserId) { + let mut inner = self.inner.write().unwrap(); + inner.subscriptions.remove(&sub_id); + } +} + +struct ChatsInner { + pub next_sub: u64, + pub subscriptions: HashMap>, +} diff --git a/src/http.rs b/src/http.rs index 45e2a52..dc6808c 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,7 +1,7 @@ +use crate::chat::Chats; use crate::prelude::*; use std::convert::Infallible; -use std::sync::Arc; use http_body_util::{BodyExt, Full}; use hyper::server::conn::http1; @@ -29,7 +29,7 @@ fn not_found() -> std::result::Result>, Infallible> { Ok(response) } -fn metrics(registry: Arc) -> std::result::Result>, Infallible> { +fn metrics(registry: Registry) -> std::result::Result>, Infallible> { let mf = registry.gather(); let mut buffer = vec![]; let encoder = TextEncoder::new(); @@ -41,12 +41,13 @@ fn metrics(registry: Arc) -> std::result::Result> } async fn route( - registry: Arc, + registry: Registry, + chats: Chats, request: Request, ) -> std::result::Result, Infallible> { match (request.method(), request.uri().path()) { (&Method::GET, "/hello") => Ok(hello(request).await?.map(BodyExt::boxed)), - (&Method::GET, "/socket") => Ok(ws::handle_request(request).await?.map(BodyExt::boxed)), + (&Method::GET, "/socket") => Ok(ws::handle_request(request, chats).await?.map(BodyExt::boxed)), (&Method::GET, "/metrics") => Ok(metrics(registry)?.map(BodyExt::boxed)), _ => Ok(not_found()?.map(BodyExt::boxed)), } @@ -57,9 +58,9 @@ pub struct HttpServerActor { fiber: JoinHandle>, } impl HttpServerActor { - pub async fn launch(listener: TcpListener, metrics: Arc) -> Result { + pub async fn launch(listener: TcpListener, metrics: Registry, chats: Chats) -> Result { let (terminator, receiver) = tokio::sync::oneshot::channel::<()>(); - let fiber = tokio::task::spawn(Self::main_loop(listener, receiver, metrics)); + let fiber = tokio::task::spawn(Self::main_loop(listener, receiver, metrics, chats)); Ok(HttpServerActor { terminator, fiber }) } @@ -75,7 +76,8 @@ impl HttpServerActor { async fn main_loop( listener: TcpListener, termination: impl Future, - registry: Arc, + registry: Registry, + chats: Chats, ) -> Result<()> { log::info!("Starting the http server"); pin!(termination); @@ -90,13 +92,14 @@ impl HttpServerActor { }, result = listener.accept() => { let (stream, _) = result?; - let c = registry.clone(); + let registry = registry.clone(); + let chats = chats.clone(); let reqs = reqs.clone(); tokio::task::spawn(async move { reqs.inc(); - let c = c.clone(); + let registry = registry.clone(); if let Err(err) = http1::Builder::new() - .serve_connection(stream, service_fn(move |r| route(c.clone(), r))) + .serve_connection(stream, service_fn(move |r| route(registry.clone(), chats.clone(), r))) .with_upgrades() .await { diff --git a/src/http/ws.rs b/src/http/ws.rs index ef3fb64..7b64d94 100644 --- a/src/http/ws.rs +++ b/src/http/ws.rs @@ -1,12 +1,13 @@ use futures_util::TryStreamExt; use http_body_util::Empty; +use hyper::body::Incoming; use hyper::header::{ CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE, }; use hyper::http::HeaderValue; use hyper::upgrade::Upgraded; use hyper::{body::Bytes, Request, Response}; -use hyper::{Method, StatusCode, Version}; +use hyper::{StatusCode, Version}; use std::convert::Infallible; use tokio_tungstenite::tungstenite::handshake::derive_accept_key; @@ -17,42 +18,82 @@ use tokio_tungstenite::WebSocketStream; use futures_util::sink::SinkExt; use futures_util::stream::StreamExt; -async fn handle_connection(ws_stream: WebSocketStream) { +use crate::chat::Chats; + +async fn handle_connection(ws_stream: WebSocketStream, chats: Chats) { tracing::info!("WebSocket connection established"); let (mut outgoing, incoming) = ws_stream.split(); - let broadcast_incoming = incoming.try_for_each(|msg| { - tracing::info!("Received a message: {}", msg.to_text().unwrap()); + let (sub_id, mut sub) = chats.new_sub(); + tracing::info!("New conn id: {sub_id}"); - async { Ok(()) } - }); + let _ = chats.broadcast(format!("{sub_id} joined").as_str()).await; - outgoing.send(Message::Text("adsads".into())).await.unwrap(); + let broadcast_incoming = async { + tracing::info!("Started incoming stream for {sub_id}"); + let res = incoming.try_for_each(|msg| { + let txt = msg.to_text().unwrap().to_string(); + let chats = chats.clone(); + async move { + tracing::info!("Received a message: {}, sub_id={}", txt, sub_id); + match chats.broadcast(format!("{sub_id}: {txt}").as_str()).await { + Ok(_) => {}, + Err(err) => { + tracing::error!("Failed to broadcast a message from sub_id={sub_id}: {err}"); + }, + } + Ok(()) + } + }).await; + tracing::info!("Stopped incoming stream for {}", sub_id); + res + }; - match broadcast_incoming.await { + + outgoing + .send(Message::Text("Started a connection!".into())) + .await + .unwrap(); + + let outgoing = async { + tracing::info!("Started outgoing stream for {sub_id}"); + while let Some(msg) = sub.recv().await { + match outgoing.send(Message::Text(msg)).await { + Ok(_) => {}, + Err(err) => { + tracing::warn!("Failed to send msg, sub_id={sub_id}: {err}"); + break; + }, + } + } + tracing::info!("Stopped outgoing stream for {}", sub_id); + }; + + let (broadcast_incoming, _) = tokio::join!(broadcast_incoming, outgoing); + + let _ = chats.broadcast(format!("{sub_id} left").as_str()).await; + + + match broadcast_incoming { Ok(_) => tracing::info!("Disconnected"), - Err(e) => tracing::warn!("Socket failed: {}", e), + Err(err) => tracing::warn!("Socket failed: {err}"), } + tracing::info!("Terminating WS connection {sub_id}"); + chats.remove_sub(sub_id); } pub async fn handle_request( - mut req: Request, + mut req: Request, + chats: Chats, ) -> std::result::Result>, Infallible> { - dbg!(&req); - println!("Received a new, potentially ws handshake"); - println!("The request's path is: {}", req.uri().path()); - println!("The request's headers are:"); - for (ref header, _value) in req.headers() { - println!("* {}", header); - } + tracing::info!("Received a new WS request"); let upgrade = HeaderValue::from_static("Upgrade"); let websocket = HeaderValue::from_static("websocket"); let headers = req.headers(); let key = headers.get(SEC_WEBSOCKET_KEY); let derived = key.map(|k| derive_accept_key(k.as_bytes())); - if req.method() != Method::GET - || req.version() < Version::HTTP_11 + if req.version() < Version::HTTP_11 || !headers .get(CONNECTION) .and_then(|h| h.to_str().ok()) @@ -71,24 +112,25 @@ pub async fn handle_request( .map(|h| h == "13") .unwrap_or(false) || key.is_none() - || req.uri() != "/socket" { - dbg!(); + tracing::info!("Malformed request"); let mut resp = Response::new(Empty::new()); *resp.status_mut() = StatusCode::BAD_REQUEST; return Ok(resp); } let ver = req.version(); + let chats = chats.clone(); tokio::task::spawn(async move { match hyper::upgrade::on(&mut req).await { Ok(upgraded) => { handle_connection( WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await, + chats, ) .await; } - Err(e) => println!("upgrade error: {}", e), + Err(err) => tracing::error!("upgrade error: {err}"), } }); let mut res = Response::new(Empty::new()); diff --git a/src/main.rs b/src/main.rs index bbe83e5..bfcbb56 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,9 @@ +mod chat; mod http; mod prelude; mod tcp; +use crate::chat::Chats; use crate::prelude::*; use prometheus::{IntCounter, Opts, Registry}; use tcp::ClientSocketActor; @@ -9,7 +11,6 @@ use tcp::ClientSocketActor; use std::collections::HashMap; use std::future::Future; use std::net::SocketAddr; -use std::sync::Arc; use figment::providers::Format; use tokio::net::{TcpListener, TcpStream}; @@ -43,13 +44,15 @@ async fn main() -> Result<()> { let config = load_config()?; dbg!(config); tracing::info!("Booting up"); - let registry = Arc::new(Registry::new()); + let registry = Registry::new(); let counter = IntCounter::with_opts(Opts::new("actor_count", "Number of alive actors"))?; registry.register(Box::new(counter.clone()))?; + let chats = Chats::new(); + let listener = TcpListener::bind("127.0.0.1:3721").await?; let listener_http = TcpListener::bind("127.0.0.1:8080").await?; - let http_server_actor = http::HttpServerActor::launch(listener_http, registry.clone()).await?; + let http_server_actor = http::HttpServerActor::launch(listener_http, registry.clone(), chats.clone()).await?; tracing::info!("Started");