forked from lavina/lavina
simple broadcast of messages
This commit is contained in:
parent
492f415947
commit
f4dda9fb4b
|
@ -0,0 +1,55 @@
|
|||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{Arc, RwLock},
|
||||
};
|
||||
|
||||
use tokio::sync::mpsc::{Sender, channel, Receiver};
|
||||
|
||||
use crate::prelude::*;
|
||||
|
||||
pub type UserId = u64;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Chats {
|
||||
inner: Arc<RwLock<ChatsInner>>,
|
||||
}
|
||||
|
||||
impl Chats {
|
||||
pub fn new() -> Chats {
|
||||
let subscriptions = HashMap::new();
|
||||
let chats_inner = ChatsInner { subscriptions, next_sub: 0 };
|
||||
let inner = Arc::new(RwLock::new(chats_inner));
|
||||
Chats { inner }
|
||||
}
|
||||
|
||||
pub fn new_sub(&self) -> (UserId, Receiver<String>) {
|
||||
let mut inner = self.inner.write().unwrap();
|
||||
let sub_id = inner.next_sub;
|
||||
inner.next_sub += 1;
|
||||
let (rx,tx) = channel(32);
|
||||
inner.subscriptions.insert(sub_id, rx);
|
||||
(sub_id, tx)
|
||||
}
|
||||
|
||||
pub async fn broadcast(&self, msg: &str) -> Result<()> {
|
||||
let subs = {
|
||||
let inner = self.inner.read().unwrap();
|
||||
inner.subscriptions.clone()
|
||||
};
|
||||
for (sub_id, tx) in subs.iter() {
|
||||
tx.send(msg.to_string()).await?;
|
||||
tracing::info!("Sent message to {}", sub_id);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn remove_sub(&self, sub_id: UserId) {
|
||||
let mut inner = self.inner.write().unwrap();
|
||||
inner.subscriptions.remove(&sub_id);
|
||||
}
|
||||
}
|
||||
|
||||
struct ChatsInner {
|
||||
pub next_sub: u64,
|
||||
pub subscriptions: HashMap<UserId, Sender<String>>,
|
||||
}
|
23
src/http.rs
23
src/http.rs
|
@ -1,7 +1,7 @@
|
|||
use crate::chat::Chats;
|
||||
use crate::prelude::*;
|
||||
|
||||
use std::convert::Infallible;
|
||||
use std::sync::Arc;
|
||||
|
||||
use http_body_util::{BodyExt, Full};
|
||||
use hyper::server::conn::http1;
|
||||
|
@ -29,7 +29,7 @@ fn not_found() -> std::result::Result<Response<Full<Bytes>>, Infallible> {
|
|||
Ok(response)
|
||||
}
|
||||
|
||||
fn metrics(registry: Arc<Registry>) -> std::result::Result<Response<Full<Bytes>>, Infallible> {
|
||||
fn metrics(registry: Registry) -> std::result::Result<Response<Full<Bytes>>, Infallible> {
|
||||
let mf = registry.gather();
|
||||
let mut buffer = vec![];
|
||||
let encoder = TextEncoder::new();
|
||||
|
@ -41,12 +41,13 @@ fn metrics(registry: Arc<Registry>) -> std::result::Result<Response<Full<Bytes>>
|
|||
}
|
||||
|
||||
async fn route(
|
||||
registry: Arc<Registry>,
|
||||
registry: Registry,
|
||||
chats: Chats,
|
||||
request: Request<hyper::body::Incoming>,
|
||||
) -> std::result::Result<Response<BoxBody>, Infallible> {
|
||||
match (request.method(), request.uri().path()) {
|
||||
(&Method::GET, "/hello") => Ok(hello(request).await?.map(BodyExt::boxed)),
|
||||
(&Method::GET, "/socket") => Ok(ws::handle_request(request).await?.map(BodyExt::boxed)),
|
||||
(&Method::GET, "/socket") => Ok(ws::handle_request(request, chats).await?.map(BodyExt::boxed)),
|
||||
(&Method::GET, "/metrics") => Ok(metrics(registry)?.map(BodyExt::boxed)),
|
||||
_ => Ok(not_found()?.map(BodyExt::boxed)),
|
||||
}
|
||||
|
@ -57,9 +58,9 @@ pub struct HttpServerActor {
|
|||
fiber: JoinHandle<Result<()>>,
|
||||
}
|
||||
impl HttpServerActor {
|
||||
pub async fn launch(listener: TcpListener, metrics: Arc<Registry>) -> Result<HttpServerActor> {
|
||||
pub async fn launch(listener: TcpListener, metrics: Registry, chats: Chats) -> Result<HttpServerActor> {
|
||||
let (terminator, receiver) = tokio::sync::oneshot::channel::<()>();
|
||||
let fiber = tokio::task::spawn(Self::main_loop(listener, receiver, metrics));
|
||||
let fiber = tokio::task::spawn(Self::main_loop(listener, receiver, metrics, chats));
|
||||
Ok(HttpServerActor { terminator, fiber })
|
||||
}
|
||||
|
||||
|
@ -75,7 +76,8 @@ impl HttpServerActor {
|
|||
async fn main_loop(
|
||||
listener: TcpListener,
|
||||
termination: impl Future,
|
||||
registry: Arc<Registry>,
|
||||
registry: Registry,
|
||||
chats: Chats,
|
||||
) -> Result<()> {
|
||||
log::info!("Starting the http server");
|
||||
pin!(termination);
|
||||
|
@ -90,13 +92,14 @@ impl HttpServerActor {
|
|||
},
|
||||
result = listener.accept() => {
|
||||
let (stream, _) = result?;
|
||||
let c = registry.clone();
|
||||
let registry = registry.clone();
|
||||
let chats = chats.clone();
|
||||
let reqs = reqs.clone();
|
||||
tokio::task::spawn(async move {
|
||||
reqs.inc();
|
||||
let c = c.clone();
|
||||
let registry = registry.clone();
|
||||
if let Err(err) = http1::Builder::new()
|
||||
.serve_connection(stream, service_fn(move |r| route(c.clone(), r)))
|
||||
.serve_connection(stream, service_fn(move |r| route(registry.clone(), chats.clone(), r)))
|
||||
.with_upgrades()
|
||||
.await
|
||||
{
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
use futures_util::TryStreamExt;
|
||||
use http_body_util::Empty;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::header::{
|
||||
CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE,
|
||||
};
|
||||
use hyper::http::HeaderValue;
|
||||
use hyper::upgrade::Upgraded;
|
||||
use hyper::{body::Bytes, Request, Response};
|
||||
use hyper::{Method, StatusCode, Version};
|
||||
use hyper::{StatusCode, Version};
|
||||
use std::convert::Infallible;
|
||||
|
||||
use tokio_tungstenite::tungstenite::handshake::derive_accept_key;
|
||||
|
@ -17,42 +18,82 @@ use tokio_tungstenite::WebSocketStream;
|
|||
use futures_util::sink::SinkExt;
|
||||
use futures_util::stream::StreamExt;
|
||||
|
||||
async fn handle_connection(ws_stream: WebSocketStream<Upgraded>) {
|
||||
use crate::chat::Chats;
|
||||
|
||||
async fn handle_connection(ws_stream: WebSocketStream<Upgraded>, chats: Chats) {
|
||||
tracing::info!("WebSocket connection established");
|
||||
|
||||
let (mut outgoing, incoming) = ws_stream.split();
|
||||
|
||||
let broadcast_incoming = incoming.try_for_each(|msg| {
|
||||
tracing::info!("Received a message: {}", msg.to_text().unwrap());
|
||||
let (sub_id, mut sub) = chats.new_sub();
|
||||
tracing::info!("New conn id: {sub_id}");
|
||||
|
||||
async { Ok(()) }
|
||||
});
|
||||
let _ = chats.broadcast(format!("{sub_id} joined").as_str()).await;
|
||||
|
||||
outgoing.send(Message::Text("adsads".into())).await.unwrap();
|
||||
|
||||
match broadcast_incoming.await {
|
||||
Ok(_) => tracing::info!("Disconnected"),
|
||||
Err(e) => tracing::warn!("Socket failed: {}", e),
|
||||
let broadcast_incoming = async {
|
||||
tracing::info!("Started incoming stream for {sub_id}");
|
||||
let res = incoming.try_for_each(|msg| {
|
||||
let txt = msg.to_text().unwrap().to_string();
|
||||
let chats = chats.clone();
|
||||
async move {
|
||||
tracing::info!("Received a message: {}, sub_id={}", txt, sub_id);
|
||||
match chats.broadcast(format!("{sub_id}: {txt}").as_str()).await {
|
||||
Ok(_) => {},
|
||||
Err(err) => {
|
||||
tracing::error!("Failed to broadcast a message from sub_id={sub_id}: {err}");
|
||||
},
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}).await;
|
||||
tracing::info!("Stopped incoming stream for {}", sub_id);
|
||||
res
|
||||
};
|
||||
|
||||
|
||||
outgoing
|
||||
.send(Message::Text("Started a connection!".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let outgoing = async {
|
||||
tracing::info!("Started outgoing stream for {sub_id}");
|
||||
while let Some(msg) = sub.recv().await {
|
||||
match outgoing.send(Message::Text(msg)).await {
|
||||
Ok(_) => {},
|
||||
Err(err) => {
|
||||
tracing::warn!("Failed to send msg, sub_id={sub_id}: {err}");
|
||||
break;
|
||||
},
|
||||
}
|
||||
}
|
||||
tracing::info!("Stopped outgoing stream for {}", sub_id);
|
||||
};
|
||||
|
||||
let (broadcast_incoming, _) = tokio::join!(broadcast_incoming, outgoing);
|
||||
|
||||
let _ = chats.broadcast(format!("{sub_id} left").as_str()).await;
|
||||
|
||||
|
||||
match broadcast_incoming {
|
||||
Ok(_) => tracing::info!("Disconnected"),
|
||||
Err(err) => tracing::warn!("Socket failed: {err}"),
|
||||
}
|
||||
tracing::info!("Terminating WS connection {sub_id}");
|
||||
chats.remove_sub(sub_id);
|
||||
}
|
||||
|
||||
pub async fn handle_request(
|
||||
mut req: Request<hyper::body::Incoming>,
|
||||
mut req: Request<Incoming>,
|
||||
chats: Chats,
|
||||
) -> std::result::Result<Response<Empty<Bytes>>, Infallible> {
|
||||
dbg!(&req);
|
||||
println!("Received a new, potentially ws handshake");
|
||||
println!("The request's path is: {}", req.uri().path());
|
||||
println!("The request's headers are:");
|
||||
for (ref header, _value) in req.headers() {
|
||||
println!("* {}", header);
|
||||
}
|
||||
tracing::info!("Received a new WS request");
|
||||
let upgrade = HeaderValue::from_static("Upgrade");
|
||||
let websocket = HeaderValue::from_static("websocket");
|
||||
let headers = req.headers();
|
||||
let key = headers.get(SEC_WEBSOCKET_KEY);
|
||||
let derived = key.map(|k| derive_accept_key(k.as_bytes()));
|
||||
if req.method() != Method::GET
|
||||
|| req.version() < Version::HTTP_11
|
||||
if req.version() < Version::HTTP_11
|
||||
|| !headers
|
||||
.get(CONNECTION)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
|
@ -71,24 +112,25 @@ pub async fn handle_request(
|
|||
.map(|h| h == "13")
|
||||
.unwrap_or(false)
|
||||
|| key.is_none()
|
||||
|| req.uri() != "/socket"
|
||||
{
|
||||
dbg!();
|
||||
tracing::info!("Malformed request");
|
||||
let mut resp = Response::new(Empty::new());
|
||||
*resp.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(resp);
|
||||
}
|
||||
let ver = req.version();
|
||||
|
||||
let chats = chats.clone();
|
||||
tokio::task::spawn(async move {
|
||||
match hyper::upgrade::on(&mut req).await {
|
||||
Ok(upgraded) => {
|
||||
handle_connection(
|
||||
WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await,
|
||||
chats,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Err(e) => println!("upgrade error: {}", e),
|
||||
Err(err) => tracing::error!("upgrade error: {err}"),
|
||||
}
|
||||
});
|
||||
let mut res = Response::new(Empty::new());
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
mod chat;
|
||||
mod http;
|
||||
mod prelude;
|
||||
mod tcp;
|
||||
|
||||
use crate::chat::Chats;
|
||||
use crate::prelude::*;
|
||||
use prometheus::{IntCounter, Opts, Registry};
|
||||
use tcp::ClientSocketActor;
|
||||
|
@ -9,7 +11,6 @@ use tcp::ClientSocketActor;
|
|||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use figment::providers::Format;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
@ -43,13 +44,15 @@ async fn main() -> Result<()> {
|
|||
let config = load_config()?;
|
||||
dbg!(config);
|
||||
tracing::info!("Booting up");
|
||||
let registry = Arc::new(Registry::new());
|
||||
let registry = Registry::new();
|
||||
let counter = IntCounter::with_opts(Opts::new("actor_count", "Number of alive actors"))?;
|
||||
registry.register(Box::new(counter.clone()))?;
|
||||
|
||||
let chats = Chats::new();
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:3721").await?;
|
||||
let listener_http = TcpListener::bind("127.0.0.1:8080").await?;
|
||||
let http_server_actor = http::HttpServerActor::launch(listener_http, registry.clone()).await?;
|
||||
let http_server_actor = http::HttpServerActor::launch(listener_http, registry.clone(), chats.clone()).await?;
|
||||
|
||||
tracing::info!("Started");
|
||||
|
||||
|
|
Loading…
Reference in New Issue