lavina/src/http/ws.rs

140 lines
4.8 KiB
Rust

use http_body_util::Empty;
use hyper::body::Incoming;
use hyper::header::{
CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE,
};
use hyper::http::HeaderValue;
use hyper::upgrade::Upgraded;
use hyper::{body::Bytes, Request, Response};
use hyper::{StatusCode, Version};
use std::convert::Infallible;
use tokio_tungstenite::tungstenite::handshake::derive_accept_key;
use tokio_tungstenite::tungstenite::protocol::Role;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::WebSocketStream;
use futures_util::sink::SinkExt;
use futures_util::stream::StreamExt;
use crate::chat::Chats;
async fn handle_connection(mut ws_stream: WebSocketStream<Upgraded>, chats: Chats) {
tracing::info!("WebSocket connection established");
let (sub_id, mut sub) = chats.new_sub();
tracing::info!("New conn id: {sub_id}");
let _ = chats.broadcast(format!("{sub_id} joined").as_str()).await;
ws_stream
.send(Message::Text("Started a connection!".into()))
.await
.unwrap();
tracing::info!("Started stream for {sub_id}");
loop {
tokio::select! {
biased;
msg = ws_stream.next() => {
match msg {
Some(msg) => {
let msg = msg.unwrap();
let txt = msg.to_text().unwrap().to_string();
tracing::info!("Received a message: {txt}, sub_id={sub_id}");
match chats.broadcast(format!("{sub_id}: {txt}").as_str()).await {
Ok(_) => {},
Err(err) => {
tracing::error!("Failed to broadcast a message from sub_id={sub_id}: {err}");
},
}
},
None => {
tracing::info!("Client {sub_id} closed the socket, stopping..");
break;
},
}
},
msg = sub.recv() => {
match msg {
Some(msg) => {
match ws_stream.send(Message::Text(msg)).await {
Ok(_) => {},
Err(err) => {
tracing::warn!("Failed to send msg, sub_id={sub_id}: {err}");
break;
},
}
},
None => {
break;
}
}
}
}
}
tracing::info!("Ended stream for {sub_id}");
chats.remove_sub(sub_id);
let _ = chats.broadcast(format!("{sub_id} left").as_str()).await;
}
pub async fn handle_request(
mut req: Request<Incoming>,
chats: Chats,
) -> std::result::Result<Response<Empty<Bytes>>, Infallible> {
tracing::info!("Received a new WS request");
let upgrade = HeaderValue::from_static("Upgrade");
let websocket = HeaderValue::from_static("websocket");
let headers = req.headers();
let key = headers.get(SEC_WEBSOCKET_KEY);
let derived = key.map(|k| derive_accept_key(k.as_bytes()));
if req.version() < Version::HTTP_11
|| !headers
.get(CONNECTION)
.and_then(|h| h.to_str().ok())
.map(|h| {
h.split(|c| c == ' ' || c == ',')
.any(|p| p.eq_ignore_ascii_case(upgrade.to_str().unwrap()))
})
.unwrap_or(false)
|| !headers
.get(UPGRADE)
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
|| !headers
.get(SEC_WEBSOCKET_VERSION)
.map(|h| h == "13")
.unwrap_or(false)
|| key.is_none()
{
tracing::info!("Malformed request");
let mut resp = Response::new(Empty::new());
*resp.status_mut() = StatusCode::BAD_REQUEST;
return Ok(resp);
}
let ver = req.version();
let chats = chats.clone();
tokio::task::spawn(async move {
match hyper::upgrade::on(&mut req).await {
Ok(upgraded) => {
handle_connection(
WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await,
chats,
)
.await;
}
Err(err) => tracing::error!("upgrade error: {err}"),
}
});
let mut res = Response::new(Empty::new());
*res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
*res.version_mut() = ver;
res.headers_mut().append(CONNECTION, upgrade);
res.headers_mut().append(UPGRADE, websocket);
res.headers_mut()
.append(SEC_WEBSOCKET_ACCEPT, derived.unwrap().parse().unwrap());
Ok(res)
}