simple broadcast of messages

This commit is contained in:
Nikita Vilunov 2023-01-31 13:55:47 +01:00
parent 492f415947
commit f4dda9fb4b
4 changed files with 138 additions and 35 deletions

55
src/chat.rs Normal file
View File

@ -0,0 +1,55 @@
use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
use tokio::sync::mpsc::{Sender, channel, Receiver};
use crate::prelude::*;
pub type UserId = u64;
#[derive(Clone)]
pub struct Chats {
inner: Arc<RwLock<ChatsInner>>,
}
impl Chats {
pub fn new() -> Chats {
let subscriptions = HashMap::new();
let chats_inner = ChatsInner { subscriptions, next_sub: 0 };
let inner = Arc::new(RwLock::new(chats_inner));
Chats { inner }
}
pub fn new_sub(&self) -> (UserId, Receiver<String>) {
let mut inner = self.inner.write().unwrap();
let sub_id = inner.next_sub;
inner.next_sub += 1;
let (rx,tx) = channel(32);
inner.subscriptions.insert(sub_id, rx);
(sub_id, tx)
}
pub async fn broadcast(&self, msg: &str) -> Result<()> {
let subs = {
let inner = self.inner.read().unwrap();
inner.subscriptions.clone()
};
for (sub_id, tx) in subs.iter() {
tx.send(msg.to_string()).await?;
tracing::info!("Sent message to {}", sub_id);
}
Ok(())
}
pub fn remove_sub(&self, sub_id: UserId) {
let mut inner = self.inner.write().unwrap();
inner.subscriptions.remove(&sub_id);
}
}
struct ChatsInner {
pub next_sub: u64,
pub subscriptions: HashMap<UserId, Sender<String>>,
}

View File

@ -1,7 +1,7 @@
use crate::chat::Chats;
use crate::prelude::*;
use std::convert::Infallible;
use std::sync::Arc;
use http_body_util::{BodyExt, Full};
use hyper::server::conn::http1;
@ -29,7 +29,7 @@ fn not_found() -> std::result::Result<Response<Full<Bytes>>, Infallible> {
Ok(response)
}
fn metrics(registry: Arc<Registry>) -> std::result::Result<Response<Full<Bytes>>, Infallible> {
fn metrics(registry: Registry) -> std::result::Result<Response<Full<Bytes>>, Infallible> {
let mf = registry.gather();
let mut buffer = vec![];
let encoder = TextEncoder::new();
@ -41,12 +41,13 @@ fn metrics(registry: Arc<Registry>) -> std::result::Result<Response<Full<Bytes>>
}
async fn route(
registry: Arc<Registry>,
registry: Registry,
chats: Chats,
request: Request<hyper::body::Incoming>,
) -> std::result::Result<Response<BoxBody>, Infallible> {
match (request.method(), request.uri().path()) {
(&Method::GET, "/hello") => Ok(hello(request).await?.map(BodyExt::boxed)),
(&Method::GET, "/socket") => Ok(ws::handle_request(request).await?.map(BodyExt::boxed)),
(&Method::GET, "/socket") => Ok(ws::handle_request(request, chats).await?.map(BodyExt::boxed)),
(&Method::GET, "/metrics") => Ok(metrics(registry)?.map(BodyExt::boxed)),
_ => Ok(not_found()?.map(BodyExt::boxed)),
}
@ -57,9 +58,9 @@ pub struct HttpServerActor {
fiber: JoinHandle<Result<()>>,
}
impl HttpServerActor {
pub async fn launch(listener: TcpListener, metrics: Arc<Registry>) -> Result<HttpServerActor> {
pub async fn launch(listener: TcpListener, metrics: Registry, chats: Chats) -> Result<HttpServerActor> {
let (terminator, receiver) = tokio::sync::oneshot::channel::<()>();
let fiber = tokio::task::spawn(Self::main_loop(listener, receiver, metrics));
let fiber = tokio::task::spawn(Self::main_loop(listener, receiver, metrics, chats));
Ok(HttpServerActor { terminator, fiber })
}
@ -75,7 +76,8 @@ impl HttpServerActor {
async fn main_loop(
listener: TcpListener,
termination: impl Future,
registry: Arc<Registry>,
registry: Registry,
chats: Chats,
) -> Result<()> {
log::info!("Starting the http server");
pin!(termination);
@ -90,13 +92,14 @@ impl HttpServerActor {
},
result = listener.accept() => {
let (stream, _) = result?;
let c = registry.clone();
let registry = registry.clone();
let chats = chats.clone();
let reqs = reqs.clone();
tokio::task::spawn(async move {
reqs.inc();
let c = c.clone();
let registry = registry.clone();
if let Err(err) = http1::Builder::new()
.serve_connection(stream, service_fn(move |r| route(c.clone(), r)))
.serve_connection(stream, service_fn(move |r| route(registry.clone(), chats.clone(), r)))
.with_upgrades()
.await
{

View File

@ -1,12 +1,13 @@
use futures_util::TryStreamExt;
use http_body_util::Empty;
use hyper::body::Incoming;
use hyper::header::{
CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE,
};
use hyper::http::HeaderValue;
use hyper::upgrade::Upgraded;
use hyper::{body::Bytes, Request, Response};
use hyper::{Method, StatusCode, Version};
use hyper::{StatusCode, Version};
use std::convert::Infallible;
use tokio_tungstenite::tungstenite::handshake::derive_accept_key;
@ -17,42 +18,82 @@ use tokio_tungstenite::WebSocketStream;
use futures_util::sink::SinkExt;
use futures_util::stream::StreamExt;
async fn handle_connection(ws_stream: WebSocketStream<Upgraded>) {
use crate::chat::Chats;
async fn handle_connection(ws_stream: WebSocketStream<Upgraded>, chats: Chats) {
tracing::info!("WebSocket connection established");
let (mut outgoing, incoming) = ws_stream.split();
let broadcast_incoming = incoming.try_for_each(|msg| {
tracing::info!("Received a message: {}", msg.to_text().unwrap());
let (sub_id, mut sub) = chats.new_sub();
tracing::info!("New conn id: {sub_id}");
async { Ok(()) }
});
let _ = chats.broadcast(format!("{sub_id} joined").as_str()).await;
outgoing.send(Message::Text("adsads".into())).await.unwrap();
let broadcast_incoming = async {
tracing::info!("Started incoming stream for {sub_id}");
let res = incoming.try_for_each(|msg| {
let txt = msg.to_text().unwrap().to_string();
let chats = chats.clone();
async move {
tracing::info!("Received a message: {}, sub_id={}", txt, sub_id);
match chats.broadcast(format!("{sub_id}: {txt}").as_str()).await {
Ok(_) => {},
Err(err) => {
tracing::error!("Failed to broadcast a message from sub_id={sub_id}: {err}");
},
}
Ok(())
}
}).await;
tracing::info!("Stopped incoming stream for {}", sub_id);
res
};
match broadcast_incoming.await {
outgoing
.send(Message::Text("Started a connection!".into()))
.await
.unwrap();
let outgoing = async {
tracing::info!("Started outgoing stream for {sub_id}");
while let Some(msg) = sub.recv().await {
match outgoing.send(Message::Text(msg)).await {
Ok(_) => {},
Err(err) => {
tracing::warn!("Failed to send msg, sub_id={sub_id}: {err}");
break;
},
}
}
tracing::info!("Stopped outgoing stream for {}", sub_id);
};
let (broadcast_incoming, _) = tokio::join!(broadcast_incoming, outgoing);
let _ = chats.broadcast(format!("{sub_id} left").as_str()).await;
match broadcast_incoming {
Ok(_) => tracing::info!("Disconnected"),
Err(e) => tracing::warn!("Socket failed: {}", e),
Err(err) => tracing::warn!("Socket failed: {err}"),
}
tracing::info!("Terminating WS connection {sub_id}");
chats.remove_sub(sub_id);
}
pub async fn handle_request(
mut req: Request<hyper::body::Incoming>,
mut req: Request<Incoming>,
chats: Chats,
) -> std::result::Result<Response<Empty<Bytes>>, Infallible> {
dbg!(&req);
println!("Received a new, potentially ws handshake");
println!("The request's path is: {}", req.uri().path());
println!("The request's headers are:");
for (ref header, _value) in req.headers() {
println!("* {}", header);
}
tracing::info!("Received a new WS request");
let upgrade = HeaderValue::from_static("Upgrade");
let websocket = HeaderValue::from_static("websocket");
let headers = req.headers();
let key = headers.get(SEC_WEBSOCKET_KEY);
let derived = key.map(|k| derive_accept_key(k.as_bytes()));
if req.method() != Method::GET
|| req.version() < Version::HTTP_11
if req.version() < Version::HTTP_11
|| !headers
.get(CONNECTION)
.and_then(|h| h.to_str().ok())
@ -71,24 +112,25 @@ pub async fn handle_request(
.map(|h| h == "13")
.unwrap_or(false)
|| key.is_none()
|| req.uri() != "/socket"
{
dbg!();
tracing::info!("Malformed request");
let mut resp = Response::new(Empty::new());
*resp.status_mut() = StatusCode::BAD_REQUEST;
return Ok(resp);
}
let ver = req.version();
let chats = chats.clone();
tokio::task::spawn(async move {
match hyper::upgrade::on(&mut req).await {
Ok(upgraded) => {
handle_connection(
WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await,
chats,
)
.await;
}
Err(e) => println!("upgrade error: {}", e),
Err(err) => tracing::error!("upgrade error: {err}"),
}
});
let mut res = Response::new(Empty::new());

View File

@ -1,7 +1,9 @@
mod chat;
mod http;
mod prelude;
mod tcp;
use crate::chat::Chats;
use crate::prelude::*;
use prometheus::{IntCounter, Opts, Registry};
use tcp::ClientSocketActor;
@ -9,7 +11,6 @@ use tcp::ClientSocketActor;
use std::collections::HashMap;
use std::future::Future;
use std::net::SocketAddr;
use std::sync::Arc;
use figment::providers::Format;
use tokio::net::{TcpListener, TcpStream};
@ -43,13 +44,15 @@ async fn main() -> Result<()> {
let config = load_config()?;
dbg!(config);
tracing::info!("Booting up");
let registry = Arc::new(Registry::new());
let registry = Registry::new();
let counter = IntCounter::with_opts(Opts::new("actor_count", "Number of alive actors"))?;
registry.register(Box::new(counter.clone()))?;
let chats = Chats::new();
let listener = TcpListener::bind("127.0.0.1:3721").await?;
let listener_http = TcpListener::bind("127.0.0.1:8080").await?;
let http_server_actor = http::HttpServerActor::launch(listener_http, registry.clone()).await?;
let http_server_actor = http::HttpServerActor::launch(listener_http, registry.clone(), chats.clone()).await?;
tracing::info!("Started");