use http_body_util::Empty; use hyper::body::Incoming; use hyper::header::{ CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE, }; use hyper::http::HeaderValue; use hyper::upgrade::Upgraded; use hyper::{body::Bytes, Request, Response}; use hyper::{StatusCode, Version}; use std::convert::Infallible; use tokio_tungstenite::tungstenite::handshake::derive_accept_key; use tokio_tungstenite::tungstenite::protocol::Role; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::WebSocketStream; use futures_util::sink::SinkExt; use futures_util::stream::StreamExt; use crate::chat::Chats; async fn handle_connection(mut ws_stream: WebSocketStream, chats: Chats) { tracing::info!("WebSocket connection established"); let (sub_id, mut sub) = chats.new_sub(); tracing::info!("New conn id: {sub_id}"); let _ = chats.broadcast(format!("{sub_id} joined").as_str()).await; ws_stream .send(Message::Text("Started a connection!".into())) .await .unwrap(); tracing::info!("Started stream for {sub_id}"); loop { tokio::select! { biased; msg = ws_stream.next() => { match msg { Some(msg) => { let msg = msg.unwrap(); let txt = msg.to_text().unwrap().to_string(); tracing::info!("Received a message: {txt}, sub_id={sub_id}"); match chats.broadcast(format!("{sub_id}: {txt}").as_str()).await { Ok(_) => {}, Err(err) => { tracing::error!("Failed to broadcast a message from sub_id={sub_id}: {err}"); }, } }, None => { tracing::info!("Client {sub_id} closed the socket, stopping.."); break; }, } }, msg = sub.recv() => { match msg { Some(msg) => { match ws_stream.send(Message::Text(msg)).await { Ok(_) => {}, Err(err) => { tracing::warn!("Failed to send msg, sub_id={sub_id}: {err}"); break; }, } }, None => { break; } } } } } tracing::info!("Ended stream for {sub_id}"); chats.remove_sub(sub_id); let _ = chats.broadcast(format!("{sub_id} left").as_str()).await; } pub async fn handle_request( mut req: Request, chats: Chats, ) -> std::result::Result>, Infallible> { tracing::info!("Received a new WS request"); let upgrade = HeaderValue::from_static("Upgrade"); let websocket = HeaderValue::from_static("websocket"); let headers = req.headers(); let key = headers.get(SEC_WEBSOCKET_KEY); let derived = key.map(|k| derive_accept_key(k.as_bytes())); if req.version() < Version::HTTP_11 || !headers .get(CONNECTION) .and_then(|h| h.to_str().ok()) .map(|h| { h.split(|c| c == ' ' || c == ',') .any(|p| p.eq_ignore_ascii_case(upgrade.to_str().unwrap())) }) .unwrap_or(false) || !headers .get(UPGRADE) .and_then(|h| h.to_str().ok()) .map(|h| h.eq_ignore_ascii_case("websocket")) .unwrap_or(false) || !headers .get(SEC_WEBSOCKET_VERSION) .map(|h| h == "13") .unwrap_or(false) || key.is_none() { tracing::info!("Malformed request"); let mut resp = Response::new(Empty::new()); *resp.status_mut() = StatusCode::BAD_REQUEST; return Ok(resp); } let ver = req.version(); let chats = chats.clone(); tokio::task::spawn(async move { match hyper::upgrade::on(&mut req).await { Ok(upgraded) => { handle_connection( WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await, chats, ) .await; } Err(err) => tracing::error!("upgrade error: {err}"), } }); let mut res = Response::new(Empty::new()); *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS; *res.version_mut() = ver; res.headers_mut().append(CONNECTION, upgrade); res.headers_mut().append(UPGRADE, websocket); res.headers_mut() .append(SEC_WEBSOCKET_ACCEPT, derived.unwrap().parse().unwrap()); Ok(res) }