use futures_util::TryStreamExt; use http_body_util::Empty; use hyper::body::Incoming; use hyper::header::{ CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE, }; use hyper::http::HeaderValue; use hyper::upgrade::Upgraded; use hyper::{body::Bytes, Request, Response}; use hyper::{StatusCode, Version}; use std::convert::Infallible; use tokio_tungstenite::tungstenite::handshake::derive_accept_key; use tokio_tungstenite::tungstenite::protocol::Role; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::WebSocketStream; use futures_util::sink::SinkExt; use futures_util::stream::StreamExt; use crate::chat::Chats; async fn handle_connection(ws_stream: WebSocketStream, chats: Chats) { tracing::info!("WebSocket connection established"); let (mut outgoing, incoming) = ws_stream.split(); let (sub_id, mut sub) = chats.new_sub(); tracing::info!("New conn id: {sub_id}"); let _ = chats.broadcast(format!("{sub_id} joined").as_str()).await; let broadcast_incoming = async { tracing::info!("Started incoming stream for {sub_id}"); let res = incoming.try_for_each(|msg| { let txt = msg.to_text().unwrap().to_string(); let chats = chats.clone(); async move { tracing::info!("Received a message: {}, sub_id={}", txt, sub_id); match chats.broadcast(format!("{sub_id}: {txt}").as_str()).await { Ok(_) => {}, Err(err) => { tracing::error!("Failed to broadcast a message from sub_id={sub_id}: {err}"); }, } Ok(()) } }).await; tracing::info!("Stopped incoming stream for {}", sub_id); res }; outgoing .send(Message::Text("Started a connection!".into())) .await .unwrap(); let outgoing = async { tracing::info!("Started outgoing stream for {sub_id}"); while let Some(msg) = sub.recv().await { match outgoing.send(Message::Text(msg)).await { Ok(_) => {}, Err(err) => { tracing::warn!("Failed to send msg, sub_id={sub_id}: {err}"); break; }, } } tracing::info!("Stopped outgoing stream for {}", sub_id); }; let (broadcast_incoming, _) = tokio::join!(broadcast_incoming, outgoing); let _ = chats.broadcast(format!("{sub_id} left").as_str()).await; match broadcast_incoming { Ok(_) => tracing::info!("Disconnected"), Err(err) => tracing::warn!("Socket failed: {err}"), } tracing::info!("Terminating WS connection {sub_id}"); chats.remove_sub(sub_id); } pub async fn handle_request( mut req: Request, chats: Chats, ) -> std::result::Result>, Infallible> { tracing::info!("Received a new WS request"); let upgrade = HeaderValue::from_static("Upgrade"); let websocket = HeaderValue::from_static("websocket"); let headers = req.headers(); let key = headers.get(SEC_WEBSOCKET_KEY); let derived = key.map(|k| derive_accept_key(k.as_bytes())); if req.version() < Version::HTTP_11 || !headers .get(CONNECTION) .and_then(|h| h.to_str().ok()) .map(|h| { h.split(|c| c == ' ' || c == ',') .any(|p| p.eq_ignore_ascii_case(upgrade.to_str().unwrap())) }) .unwrap_or(false) || !headers .get(UPGRADE) .and_then(|h| h.to_str().ok()) .map(|h| h.eq_ignore_ascii_case("websocket")) .unwrap_or(false) || !headers .get(SEC_WEBSOCKET_VERSION) .map(|h| h == "13") .unwrap_or(false) || key.is_none() { tracing::info!("Malformed request"); let mut resp = Response::new(Empty::new()); *resp.status_mut() = StatusCode::BAD_REQUEST; return Ok(resp); } let ver = req.version(); let chats = chats.clone(); tokio::task::spawn(async move { match hyper::upgrade::on(&mut req).await { Ok(upgraded) => { handle_connection( WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await, chats, ) .await; } Err(err) => tracing::error!("upgrade error: {err}"), } }); let mut res = Response::new(Empty::new()); *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS; *res.version_mut() = ver; res.headers_mut().append(CONNECTION, upgrade); res.headers_mut().append(UPGRADE, websocket); res.headers_mut() .append(SEC_WEBSOCKET_ACCEPT, derived.unwrap().parse().unwrap()); Ok(res) }