From aa8ef5e6c02954d63e3ec022d0503bffb963f91e Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Wed, 4 Oct 2023 17:38:59 +0200 Subject: [PATCH] sanitize IRC parsing --- crates/projection-irc/src/lib.rs | 28 ++++++++++++---- crates/proto-irc/src/client.rs | 56 ++++++++++++-------------------- crates/proto-irc/src/lib.rs | 5 ++- 3 files changed, 45 insertions(+), 44 deletions(-) diff --git a/crates/projection-irc/src/lib.rs b/crates/projection-irc/src/lib.rs index 42a4126..71b0b23 100644 --- a/crates/projection-irc/src/lib.rs +++ b/crates/projection-irc/src/lib.rs @@ -7,6 +7,7 @@ use nonempty::nonempty; use nonempty::NonEmpty; use prometheus::{IntCounter, IntGauge, Registry as MetricsRegistry}; use serde::Deserialize; +use tokio::io::AsyncReadExt; use tokio::io::{AsyncBufReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}; use tokio::net::tcp::{ReadHalf, WriteHalf}; use tokio::net::{TcpListener, TcpStream}; @@ -96,14 +97,14 @@ async fn handle_registration<'a>( let mut pass: Option = None; let user = loop { - let res = reader.read_until(b'\n', &mut buffer).await; + let res = read_irc_message(reader, &mut buffer).await; let res = match res { Ok(len) => { if len == 0 { log::info!("Terminating socket"); break Err(anyhow::Error::msg("EOF")); } - match std::str::from_utf8(&buffer[..len]) { + match std::str::from_utf8(&buffer[..len-2]) { Ok(res) => res, Err(e) => break Err(e.into()), } @@ -116,7 +117,7 @@ async fn handle_registration<'a>( log::debug!("Incoming raw IRC message: '{res}'"); let parsed = client_message(res); match parsed { - Ok((_, msg)) => { + Ok(msg) => { log::debug!("Incoming IRC message: {msg:?}"); match msg { ClientMessage::Pass { password } => { @@ -252,7 +253,7 @@ async fn handle_registered_socket<'a>( loop { select! { biased; - len = reader.read_until(b'\n', &mut buffer) => { + len = read_irc_message(reader, &mut buffer) => { let len = len?; let len = if len == 0 { log::info!("EOF, Terminating socket"); @@ -260,7 +261,7 @@ async fn handle_registered_socket<'a>( } else { len }; - let incoming = std::str::from_utf8(&buffer[0..len])?; + let incoming = std::str::from_utf8(&buffer[0..len-2])?; if let HandleResult::Leave = handle_incoming_message(incoming, &config, &user, &rooms, &mut connection, writer).await? { break; } @@ -291,6 +292,21 @@ async fn handle_registered_socket<'a>( Ok(()) } +async fn read_irc_message(reader: &mut BufReader>, buf: &mut Vec) -> Result { + let mut size = 0; + 'outer: loop { + let res = reader.read_until(b'\r', buf).await?; + size += res; + let next = reader.read_u8().await?; + buf.push(next); + size += 1; + if next != b'\n' { + continue 'outer; + } + return Ok(size); + } +} + async fn handle_update( config: &ServerConfig, user: &RegisteredUser, @@ -398,7 +414,7 @@ async fn handle_incoming_message( let parsed = client_message(buffer); log::debug!("Incoming IRC message: {parsed:?}"); match parsed { - Ok((_, msg)) => match msg { + Ok(msg) => match msg { ClientMessage::Ping { token } => { ServerMessage { tags: vec![], diff --git a/crates/proto-irc/src/client.rs b/crates/proto-irc/src/client.rs index 827fe74..1585622 100644 --- a/crates/proto-irc/src/client.rs +++ b/crates/proto-irc/src/client.rs @@ -1,6 +1,7 @@ use super::*; -use nom::combinator::opt; +use anyhow::{anyhow, Result}; +use nom::combinator::{all_consuming, opt}; /// Client-to-server command. #[derive(Clone, Debug, PartialEq, Eq)] @@ -60,8 +61,8 @@ pub enum ClientMessage { }, } -pub fn client_message(input: &str) -> IResult<&str, ClientMessage> { - alt(( +pub fn client_message(input: &str) -> Result { + let res = all_consuming(alt(( client_message_capability, client_message_ping, client_message_pong, @@ -75,7 +76,11 @@ pub fn client_message(input: &str) -> IResult<&str, ClientMessage> { client_message_part, client_message_privmsg, client_message_quit, - ))(input) + )))(input); + match res { + Ok((_, e)) => Ok(e), + Err(e) => Err(anyhow!("Parsing failed: {e}")), + } } fn client_message_capability(input: &str) -> IResult<&str, ClientMessage> { @@ -89,24 +94,14 @@ fn client_message_ping(input: &str) -> IResult<&str, ClientMessage> { let (input, _) = tag("PING ")(input)?; let (input, token) = token(input)?; - Ok(( - input, - ClientMessage::Ping { - token: token.into(), - }, - )) + Ok((input, ClientMessage::Ping { token: token.into() })) } fn client_message_pong(input: &str) -> IResult<&str, ClientMessage> { let (input, _) = tag("PONG ")(input)?; let (input, token) = token(input)?; - Ok(( - input, - ClientMessage::Pong { - token: token.into(), - }, - )) + Ok((input, ClientMessage::Pong { token: token.into() })) } fn client_message_nick(input: &str) -> IResult<&str, ClientMessage> { @@ -225,12 +220,7 @@ fn client_message_quit(input: &str) -> IResult<&str, ClientMessage> { let (input, _) = tag("QUIT :")(input)?; let (input, reason) = token(input)?; - Ok(( - input, - ClientMessage::Quit { - reason: reason.into(), - }, - )) + Ok((input, ClientMessage::Quit { reason: reason.into() })) } #[derive(Clone, Debug, PartialEq, Eq)] @@ -275,7 +265,7 @@ mod test { }; let result = client_message(input); - assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result)); + assert_matches!(result, Ok(result) => assert_eq!(expected, result)); } #[test] @@ -286,28 +276,24 @@ mod test { }; let result = client_message(input); - assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result)); + assert_matches!(result, Ok(result) => assert_eq!(expected, result)); } #[test] fn test_client_message_ping() { let input = "PING 1337"; - let expected = ClientMessage::Ping { - token: "1337".into(), - }; + let expected = ClientMessage::Ping { token: "1337".into() }; let result = client_message(input); - assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result)); + assert_matches!(result, Ok(result) => assert_eq!(expected, result)); } #[test] fn test_client_message_pong() { let input = "PONG 1337"; - let expected = ClientMessage::Pong { - token: "1337".into(), - }; + let expected = ClientMessage::Pong { token: "1337".into() }; let result = client_message(input); - assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result)); + assert_matches!(result, Ok(result) => assert_eq!(expected, result)); } #[test] fn test_client_message_nick() { @@ -317,7 +303,7 @@ mod test { }; let result = client_message(input); - assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result)); + assert_matches!(result, Ok(result) => assert_eq!(expected, result)); } #[test] fn test_client_message_user() { @@ -328,7 +314,7 @@ mod test { }; let result = client_message(input); - assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result)); + assert_matches!(result, Ok(result) => assert_eq!(expected, result)); } #[test] fn test_client_message_part() { @@ -339,6 +325,6 @@ mod test { }; let result = client_message(input); - assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result)); + assert_matches!(result, Ok(result) => assert_eq!(expected, result)); } } diff --git a/crates/proto-irc/src/lib.rs b/crates/proto-irc/src/lib.rs index cd1f64f..59a4314 100644 --- a/crates/proto-irc/src/lib.rs +++ b/crates/proto-irc/src/lib.rs @@ -7,7 +7,6 @@ mod testkit; pub mod user; use crate::prelude::Str; -use std::io::Result; use nom::{ branch::alt, @@ -39,7 +38,7 @@ pub enum Chan { Local(Str), } impl Chan { - pub async fn write_async(&self, writer: &mut (impl AsyncWrite + Unpin)) -> Result<()> { + pub async fn write_async(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> { match self { Chan::Global(name) => { writer.write_all(b"#").await?; @@ -76,7 +75,7 @@ pub enum Recipient { Chan(Chan), } impl Recipient { - pub async fn write_async(&self, writer: &mut (impl AsyncWrite + Unpin)) -> Result<()> { + pub async fn write_async(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> { match self { Recipient::Nick(nick) => writer.write_all(nick.as_bytes()).await?, Recipient::Chan(chan) => chan.write_async(writer).await?,