forked from lavina/lavina
simple broadcast of messages
This commit is contained in:
parent
492f415947
commit
f4dda9fb4b
|
@ -0,0 +1,55 @@
|
||||||
|
use std::{
|
||||||
|
collections::HashMap,
|
||||||
|
sync::{Arc, RwLock},
|
||||||
|
};
|
||||||
|
|
||||||
|
use tokio::sync::mpsc::{Sender, channel, Receiver};
|
||||||
|
|
||||||
|
use crate::prelude::*;
|
||||||
|
|
||||||
|
pub type UserId = u64;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Chats {
|
||||||
|
inner: Arc<RwLock<ChatsInner>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Chats {
|
||||||
|
pub fn new() -> Chats {
|
||||||
|
let subscriptions = HashMap::new();
|
||||||
|
let chats_inner = ChatsInner { subscriptions, next_sub: 0 };
|
||||||
|
let inner = Arc::new(RwLock::new(chats_inner));
|
||||||
|
Chats { inner }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_sub(&self) -> (UserId, Receiver<String>) {
|
||||||
|
let mut inner = self.inner.write().unwrap();
|
||||||
|
let sub_id = inner.next_sub;
|
||||||
|
inner.next_sub += 1;
|
||||||
|
let (rx,tx) = channel(32);
|
||||||
|
inner.subscriptions.insert(sub_id, rx);
|
||||||
|
(sub_id, tx)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn broadcast(&self, msg: &str) -> Result<()> {
|
||||||
|
let subs = {
|
||||||
|
let inner = self.inner.read().unwrap();
|
||||||
|
inner.subscriptions.clone()
|
||||||
|
};
|
||||||
|
for (sub_id, tx) in subs.iter() {
|
||||||
|
tx.send(msg.to_string()).await?;
|
||||||
|
tracing::info!("Sent message to {}", sub_id);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn remove_sub(&self, sub_id: UserId) {
|
||||||
|
let mut inner = self.inner.write().unwrap();
|
||||||
|
inner.subscriptions.remove(&sub_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ChatsInner {
|
||||||
|
pub next_sub: u64,
|
||||||
|
pub subscriptions: HashMap<UserId, Sender<String>>,
|
||||||
|
}
|
23
src/http.rs
23
src/http.rs
|
@ -1,7 +1,7 @@
|
||||||
|
use crate::chat::Chats;
|
||||||
use crate::prelude::*;
|
use crate::prelude::*;
|
||||||
|
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use http_body_util::{BodyExt, Full};
|
use http_body_util::{BodyExt, Full};
|
||||||
use hyper::server::conn::http1;
|
use hyper::server::conn::http1;
|
||||||
|
@ -29,7 +29,7 @@ fn not_found() -> std::result::Result<Response<Full<Bytes>>, Infallible> {
|
||||||
Ok(response)
|
Ok(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn metrics(registry: Arc<Registry>) -> std::result::Result<Response<Full<Bytes>>, Infallible> {
|
fn metrics(registry: Registry) -> std::result::Result<Response<Full<Bytes>>, Infallible> {
|
||||||
let mf = registry.gather();
|
let mf = registry.gather();
|
||||||
let mut buffer = vec![];
|
let mut buffer = vec![];
|
||||||
let encoder = TextEncoder::new();
|
let encoder = TextEncoder::new();
|
||||||
|
@ -41,12 +41,13 @@ fn metrics(registry: Arc<Registry>) -> std::result::Result<Response<Full<Bytes>>
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn route(
|
async fn route(
|
||||||
registry: Arc<Registry>,
|
registry: Registry,
|
||||||
|
chats: Chats,
|
||||||
request: Request<hyper::body::Incoming>,
|
request: Request<hyper::body::Incoming>,
|
||||||
) -> std::result::Result<Response<BoxBody>, Infallible> {
|
) -> std::result::Result<Response<BoxBody>, Infallible> {
|
||||||
match (request.method(), request.uri().path()) {
|
match (request.method(), request.uri().path()) {
|
||||||
(&Method::GET, "/hello") => Ok(hello(request).await?.map(BodyExt::boxed)),
|
(&Method::GET, "/hello") => Ok(hello(request).await?.map(BodyExt::boxed)),
|
||||||
(&Method::GET, "/socket") => Ok(ws::handle_request(request).await?.map(BodyExt::boxed)),
|
(&Method::GET, "/socket") => Ok(ws::handle_request(request, chats).await?.map(BodyExt::boxed)),
|
||||||
(&Method::GET, "/metrics") => Ok(metrics(registry)?.map(BodyExt::boxed)),
|
(&Method::GET, "/metrics") => Ok(metrics(registry)?.map(BodyExt::boxed)),
|
||||||
_ => Ok(not_found()?.map(BodyExt::boxed)),
|
_ => Ok(not_found()?.map(BodyExt::boxed)),
|
||||||
}
|
}
|
||||||
|
@ -57,9 +58,9 @@ pub struct HttpServerActor {
|
||||||
fiber: JoinHandle<Result<()>>,
|
fiber: JoinHandle<Result<()>>,
|
||||||
}
|
}
|
||||||
impl HttpServerActor {
|
impl HttpServerActor {
|
||||||
pub async fn launch(listener: TcpListener, metrics: Arc<Registry>) -> Result<HttpServerActor> {
|
pub async fn launch(listener: TcpListener, metrics: Registry, chats: Chats) -> Result<HttpServerActor> {
|
||||||
let (terminator, receiver) = tokio::sync::oneshot::channel::<()>();
|
let (terminator, receiver) = tokio::sync::oneshot::channel::<()>();
|
||||||
let fiber = tokio::task::spawn(Self::main_loop(listener, receiver, metrics));
|
let fiber = tokio::task::spawn(Self::main_loop(listener, receiver, metrics, chats));
|
||||||
Ok(HttpServerActor { terminator, fiber })
|
Ok(HttpServerActor { terminator, fiber })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,7 +76,8 @@ impl HttpServerActor {
|
||||||
async fn main_loop(
|
async fn main_loop(
|
||||||
listener: TcpListener,
|
listener: TcpListener,
|
||||||
termination: impl Future,
|
termination: impl Future,
|
||||||
registry: Arc<Registry>,
|
registry: Registry,
|
||||||
|
chats: Chats,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
log::info!("Starting the http server");
|
log::info!("Starting the http server");
|
||||||
pin!(termination);
|
pin!(termination);
|
||||||
|
@ -90,13 +92,14 @@ impl HttpServerActor {
|
||||||
},
|
},
|
||||||
result = listener.accept() => {
|
result = listener.accept() => {
|
||||||
let (stream, _) = result?;
|
let (stream, _) = result?;
|
||||||
let c = registry.clone();
|
let registry = registry.clone();
|
||||||
|
let chats = chats.clone();
|
||||||
let reqs = reqs.clone();
|
let reqs = reqs.clone();
|
||||||
tokio::task::spawn(async move {
|
tokio::task::spawn(async move {
|
||||||
reqs.inc();
|
reqs.inc();
|
||||||
let c = c.clone();
|
let registry = registry.clone();
|
||||||
if let Err(err) = http1::Builder::new()
|
if let Err(err) = http1::Builder::new()
|
||||||
.serve_connection(stream, service_fn(move |r| route(c.clone(), r)))
|
.serve_connection(stream, service_fn(move |r| route(registry.clone(), chats.clone(), r)))
|
||||||
.with_upgrades()
|
.with_upgrades()
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
use futures_util::TryStreamExt;
|
use futures_util::TryStreamExt;
|
||||||
use http_body_util::Empty;
|
use http_body_util::Empty;
|
||||||
|
use hyper::body::Incoming;
|
||||||
use hyper::header::{
|
use hyper::header::{
|
||||||
CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE,
|
CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE,
|
||||||
};
|
};
|
||||||
use hyper::http::HeaderValue;
|
use hyper::http::HeaderValue;
|
||||||
use hyper::upgrade::Upgraded;
|
use hyper::upgrade::Upgraded;
|
||||||
use hyper::{body::Bytes, Request, Response};
|
use hyper::{body::Bytes, Request, Response};
|
||||||
use hyper::{Method, StatusCode, Version};
|
use hyper::{StatusCode, Version};
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
|
|
||||||
use tokio_tungstenite::tungstenite::handshake::derive_accept_key;
|
use tokio_tungstenite::tungstenite::handshake::derive_accept_key;
|
||||||
|
@ -17,42 +18,82 @@ use tokio_tungstenite::WebSocketStream;
|
||||||
use futures_util::sink::SinkExt;
|
use futures_util::sink::SinkExt;
|
||||||
use futures_util::stream::StreamExt;
|
use futures_util::stream::StreamExt;
|
||||||
|
|
||||||
async fn handle_connection(ws_stream: WebSocketStream<Upgraded>) {
|
use crate::chat::Chats;
|
||||||
|
|
||||||
|
async fn handle_connection(ws_stream: WebSocketStream<Upgraded>, chats: Chats) {
|
||||||
tracing::info!("WebSocket connection established");
|
tracing::info!("WebSocket connection established");
|
||||||
|
|
||||||
let (mut outgoing, incoming) = ws_stream.split();
|
let (mut outgoing, incoming) = ws_stream.split();
|
||||||
|
|
||||||
let broadcast_incoming = incoming.try_for_each(|msg| {
|
let (sub_id, mut sub) = chats.new_sub();
|
||||||
tracing::info!("Received a message: {}", msg.to_text().unwrap());
|
tracing::info!("New conn id: {sub_id}");
|
||||||
|
|
||||||
async { Ok(()) }
|
let _ = chats.broadcast(format!("{sub_id} joined").as_str()).await;
|
||||||
});
|
|
||||||
|
|
||||||
outgoing.send(Message::Text("adsads".into())).await.unwrap();
|
let broadcast_incoming = async {
|
||||||
|
tracing::info!("Started incoming stream for {sub_id}");
|
||||||
|
let res = incoming.try_for_each(|msg| {
|
||||||
|
let txt = msg.to_text().unwrap().to_string();
|
||||||
|
let chats = chats.clone();
|
||||||
|
async move {
|
||||||
|
tracing::info!("Received a message: {}, sub_id={}", txt, sub_id);
|
||||||
|
match chats.broadcast(format!("{sub_id}: {txt}").as_str()).await {
|
||||||
|
Ok(_) => {},
|
||||||
|
Err(err) => {
|
||||||
|
tracing::error!("Failed to broadcast a message from sub_id={sub_id}: {err}");
|
||||||
|
},
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}).await;
|
||||||
|
tracing::info!("Stopped incoming stream for {}", sub_id);
|
||||||
|
res
|
||||||
|
};
|
||||||
|
|
||||||
match broadcast_incoming.await {
|
|
||||||
|
outgoing
|
||||||
|
.send(Message::Text("Started a connection!".into()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let outgoing = async {
|
||||||
|
tracing::info!("Started outgoing stream for {sub_id}");
|
||||||
|
while let Some(msg) = sub.recv().await {
|
||||||
|
match outgoing.send(Message::Text(msg)).await {
|
||||||
|
Ok(_) => {},
|
||||||
|
Err(err) => {
|
||||||
|
tracing::warn!("Failed to send msg, sub_id={sub_id}: {err}");
|
||||||
|
break;
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tracing::info!("Stopped outgoing stream for {}", sub_id);
|
||||||
|
};
|
||||||
|
|
||||||
|
let (broadcast_incoming, _) = tokio::join!(broadcast_incoming, outgoing);
|
||||||
|
|
||||||
|
let _ = chats.broadcast(format!("{sub_id} left").as_str()).await;
|
||||||
|
|
||||||
|
|
||||||
|
match broadcast_incoming {
|
||||||
Ok(_) => tracing::info!("Disconnected"),
|
Ok(_) => tracing::info!("Disconnected"),
|
||||||
Err(e) => tracing::warn!("Socket failed: {}", e),
|
Err(err) => tracing::warn!("Socket failed: {err}"),
|
||||||
}
|
}
|
||||||
|
tracing::info!("Terminating WS connection {sub_id}");
|
||||||
|
chats.remove_sub(sub_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn handle_request(
|
pub async fn handle_request(
|
||||||
mut req: Request<hyper::body::Incoming>,
|
mut req: Request<Incoming>,
|
||||||
|
chats: Chats,
|
||||||
) -> std::result::Result<Response<Empty<Bytes>>, Infallible> {
|
) -> std::result::Result<Response<Empty<Bytes>>, Infallible> {
|
||||||
dbg!(&req);
|
tracing::info!("Received a new WS request");
|
||||||
println!("Received a new, potentially ws handshake");
|
|
||||||
println!("The request's path is: {}", req.uri().path());
|
|
||||||
println!("The request's headers are:");
|
|
||||||
for (ref header, _value) in req.headers() {
|
|
||||||
println!("* {}", header);
|
|
||||||
}
|
|
||||||
let upgrade = HeaderValue::from_static("Upgrade");
|
let upgrade = HeaderValue::from_static("Upgrade");
|
||||||
let websocket = HeaderValue::from_static("websocket");
|
let websocket = HeaderValue::from_static("websocket");
|
||||||
let headers = req.headers();
|
let headers = req.headers();
|
||||||
let key = headers.get(SEC_WEBSOCKET_KEY);
|
let key = headers.get(SEC_WEBSOCKET_KEY);
|
||||||
let derived = key.map(|k| derive_accept_key(k.as_bytes()));
|
let derived = key.map(|k| derive_accept_key(k.as_bytes()));
|
||||||
if req.method() != Method::GET
|
if req.version() < Version::HTTP_11
|
||||||
|| req.version() < Version::HTTP_11
|
|
||||||
|| !headers
|
|| !headers
|
||||||
.get(CONNECTION)
|
.get(CONNECTION)
|
||||||
.and_then(|h| h.to_str().ok())
|
.and_then(|h| h.to_str().ok())
|
||||||
|
@ -71,24 +112,25 @@ pub async fn handle_request(
|
||||||
.map(|h| h == "13")
|
.map(|h| h == "13")
|
||||||
.unwrap_or(false)
|
.unwrap_or(false)
|
||||||
|| key.is_none()
|
|| key.is_none()
|
||||||
|| req.uri() != "/socket"
|
|
||||||
{
|
{
|
||||||
dbg!();
|
tracing::info!("Malformed request");
|
||||||
let mut resp = Response::new(Empty::new());
|
let mut resp = Response::new(Empty::new());
|
||||||
*resp.status_mut() = StatusCode::BAD_REQUEST;
|
*resp.status_mut() = StatusCode::BAD_REQUEST;
|
||||||
return Ok(resp);
|
return Ok(resp);
|
||||||
}
|
}
|
||||||
let ver = req.version();
|
let ver = req.version();
|
||||||
|
|
||||||
|
let chats = chats.clone();
|
||||||
tokio::task::spawn(async move {
|
tokio::task::spawn(async move {
|
||||||
match hyper::upgrade::on(&mut req).await {
|
match hyper::upgrade::on(&mut req).await {
|
||||||
Ok(upgraded) => {
|
Ok(upgraded) => {
|
||||||
handle_connection(
|
handle_connection(
|
||||||
WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await,
|
WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await,
|
||||||
|
chats,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
Err(e) => println!("upgrade error: {}", e),
|
Err(err) => tracing::error!("upgrade error: {err}"),
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
let mut res = Response::new(Empty::new());
|
let mut res = Response::new(Empty::new());
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
|
mod chat;
|
||||||
mod http;
|
mod http;
|
||||||
mod prelude;
|
mod prelude;
|
||||||
mod tcp;
|
mod tcp;
|
||||||
|
|
||||||
|
use crate::chat::Chats;
|
||||||
use crate::prelude::*;
|
use crate::prelude::*;
|
||||||
use prometheus::{IntCounter, Opts, Registry};
|
use prometheus::{IntCounter, Opts, Registry};
|
||||||
use tcp::ClientSocketActor;
|
use tcp::ClientSocketActor;
|
||||||
|
@ -9,7 +11,6 @@ use tcp::ClientSocketActor;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use figment::providers::Format;
|
use figment::providers::Format;
|
||||||
use tokio::net::{TcpListener, TcpStream};
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
|
@ -43,13 +44,15 @@ async fn main() -> Result<()> {
|
||||||
let config = load_config()?;
|
let config = load_config()?;
|
||||||
dbg!(config);
|
dbg!(config);
|
||||||
tracing::info!("Booting up");
|
tracing::info!("Booting up");
|
||||||
let registry = Arc::new(Registry::new());
|
let registry = Registry::new();
|
||||||
let counter = IntCounter::with_opts(Opts::new("actor_count", "Number of alive actors"))?;
|
let counter = IntCounter::with_opts(Opts::new("actor_count", "Number of alive actors"))?;
|
||||||
registry.register(Box::new(counter.clone()))?;
|
registry.register(Box::new(counter.clone()))?;
|
||||||
|
|
||||||
|
let chats = Chats::new();
|
||||||
|
|
||||||
let listener = TcpListener::bind("127.0.0.1:3721").await?;
|
let listener = TcpListener::bind("127.0.0.1:3721").await?;
|
||||||
let listener_http = TcpListener::bind("127.0.0.1:8080").await?;
|
let listener_http = TcpListener::bind("127.0.0.1:8080").await?;
|
||||||
let http_server_actor = http::HttpServerActor::launch(listener_http, registry.clone()).await?;
|
let http_server_actor = http::HttpServerActor::launch(listener_http, registry.clone(), chats.clone()).await?;
|
||||||
|
|
||||||
tracing::info!("Started");
|
tracing::info!("Started");
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue