sanitize IRC parsing (#23)

This commit is contained in:
Nikita Vilunov 2023-10-04 18:27:43 +00:00
parent 0b98102580
commit 1a1136187e
3 changed files with 45 additions and 44 deletions

View File

@ -7,6 +7,7 @@ use nonempty::nonempty;
use nonempty::NonEmpty; use nonempty::NonEmpty;
use prometheus::{IntCounter, IntGauge, Registry as MetricsRegistry}; use prometheus::{IntCounter, IntGauge, Registry as MetricsRegistry};
use serde::Deserialize; use serde::Deserialize;
use tokio::io::AsyncReadExt;
use tokio::io::{AsyncBufReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}; use tokio::io::{AsyncBufReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::tcp::{ReadHalf, WriteHalf}; use tokio::net::tcp::{ReadHalf, WriteHalf};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
@ -96,14 +97,14 @@ async fn handle_registration<'a>(
let mut pass: Option<Str> = None; let mut pass: Option<Str> = None;
let user = loop { let user = loop {
let res = reader.read_until(b'\n', &mut buffer).await; let res = read_irc_message(reader, &mut buffer).await;
let res = match res { let res = match res {
Ok(len) => { Ok(len) => {
if len == 0 { if len == 0 {
log::info!("Terminating socket"); log::info!("Terminating socket");
break Err(anyhow::Error::msg("EOF")); break Err(anyhow::Error::msg("EOF"));
} }
match std::str::from_utf8(&buffer[..len]) { match std::str::from_utf8(&buffer[..len-2]) {
Ok(res) => res, Ok(res) => res,
Err(e) => break Err(e.into()), Err(e) => break Err(e.into()),
} }
@ -116,7 +117,7 @@ async fn handle_registration<'a>(
log::debug!("Incoming raw IRC message: '{res}'"); log::debug!("Incoming raw IRC message: '{res}'");
let parsed = client_message(res); let parsed = client_message(res);
match parsed { match parsed {
Ok((_, msg)) => { Ok(msg) => {
log::debug!("Incoming IRC message: {msg:?}"); log::debug!("Incoming IRC message: {msg:?}");
match msg { match msg {
ClientMessage::Pass { password } => { ClientMessage::Pass { password } => {
@ -252,7 +253,7 @@ async fn handle_registered_socket<'a>(
loop { loop {
select! { select! {
biased; biased;
len = reader.read_until(b'\n', &mut buffer) => { len = read_irc_message(reader, &mut buffer) => {
let len = len?; let len = len?;
let len = if len == 0 { let len = if len == 0 {
log::info!("EOF, Terminating socket"); log::info!("EOF, Terminating socket");
@ -260,7 +261,7 @@ async fn handle_registered_socket<'a>(
} else { } else {
len len
}; };
let incoming = std::str::from_utf8(&buffer[0..len])?; let incoming = std::str::from_utf8(&buffer[0..len-2])?;
if let HandleResult::Leave = handle_incoming_message(incoming, &config, &user, &rooms, &mut connection, writer).await? { if let HandleResult::Leave = handle_incoming_message(incoming, &config, &user, &rooms, &mut connection, writer).await? {
break; break;
} }
@ -291,6 +292,21 @@ async fn handle_registered_socket<'a>(
Ok(()) Ok(())
} }
async fn read_irc_message(reader: &mut BufReader<ReadHalf<'_>>, buf: &mut Vec<u8>) -> Result<usize> {
let mut size = 0;
'outer: loop {
let res = reader.read_until(b'\r', buf).await?;
size += res;
let next = reader.read_u8().await?;
buf.push(next);
size += 1;
if next != b'\n' {
continue 'outer;
}
return Ok(size);
}
}
async fn handle_update( async fn handle_update(
config: &ServerConfig, config: &ServerConfig,
user: &RegisteredUser, user: &RegisteredUser,
@ -398,7 +414,7 @@ async fn handle_incoming_message(
let parsed = client_message(buffer); let parsed = client_message(buffer);
log::debug!("Incoming IRC message: {parsed:?}"); log::debug!("Incoming IRC message: {parsed:?}");
match parsed { match parsed {
Ok((_, msg)) => match msg { Ok(msg) => match msg {
ClientMessage::Ping { token } => { ClientMessage::Ping { token } => {
ServerMessage { ServerMessage {
tags: vec![], tags: vec![],

View File

@ -1,6 +1,7 @@
use super::*; use super::*;
use nom::combinator::opt; use anyhow::{anyhow, Result};
use nom::combinator::{all_consuming, opt};
/// Client-to-server command. /// Client-to-server command.
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
@ -60,8 +61,8 @@ pub enum ClientMessage {
}, },
} }
pub fn client_message(input: &str) -> IResult<&str, ClientMessage> { pub fn client_message(input: &str) -> Result<ClientMessage> {
alt(( let res = all_consuming(alt((
client_message_capability, client_message_capability,
client_message_ping, client_message_ping,
client_message_pong, client_message_pong,
@ -75,7 +76,11 @@ pub fn client_message(input: &str) -> IResult<&str, ClientMessage> {
client_message_part, client_message_part,
client_message_privmsg, client_message_privmsg,
client_message_quit, client_message_quit,
))(input) )))(input);
match res {
Ok((_, e)) => Ok(e),
Err(e) => Err(anyhow!("Parsing failed: {e}")),
}
} }
fn client_message_capability(input: &str) -> IResult<&str, ClientMessage> { fn client_message_capability(input: &str) -> IResult<&str, ClientMessage> {
@ -89,24 +94,14 @@ fn client_message_ping(input: &str) -> IResult<&str, ClientMessage> {
let (input, _) = tag("PING ")(input)?; let (input, _) = tag("PING ")(input)?;
let (input, token) = token(input)?; let (input, token) = token(input)?;
Ok(( Ok((input, ClientMessage::Ping { token: token.into() }))
input,
ClientMessage::Ping {
token: token.into(),
},
))
} }
fn client_message_pong(input: &str) -> IResult<&str, ClientMessage> { fn client_message_pong(input: &str) -> IResult<&str, ClientMessage> {
let (input, _) = tag("PONG ")(input)?; let (input, _) = tag("PONG ")(input)?;
let (input, token) = token(input)?; let (input, token) = token(input)?;
Ok(( Ok((input, ClientMessage::Pong { token: token.into() }))
input,
ClientMessage::Pong {
token: token.into(),
},
))
} }
fn client_message_nick(input: &str) -> IResult<&str, ClientMessage> { fn client_message_nick(input: &str) -> IResult<&str, ClientMessage> {
@ -225,12 +220,7 @@ fn client_message_quit(input: &str) -> IResult<&str, ClientMessage> {
let (input, _) = tag("QUIT :")(input)?; let (input, _) = tag("QUIT :")(input)?;
let (input, reason) = token(input)?; let (input, reason) = token(input)?;
Ok(( Ok((input, ClientMessage::Quit { reason: reason.into() }))
input,
ClientMessage::Quit {
reason: reason.into(),
},
))
} }
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
@ -275,7 +265,7 @@ mod test {
}; };
let result = client_message(input); let result = client_message(input);
assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result)); assert_matches!(result, Ok(result) => assert_eq!(expected, result));
} }
#[test] #[test]
@ -286,28 +276,24 @@ mod test {
}; };
let result = client_message(input); let result = client_message(input);
assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result)); assert_matches!(result, Ok(result) => assert_eq!(expected, result));
} }
#[test] #[test]
fn test_client_message_ping() { fn test_client_message_ping() {
let input = "PING 1337"; let input = "PING 1337";
let expected = ClientMessage::Ping { let expected = ClientMessage::Ping { token: "1337".into() };
token: "1337".into(),
};
let result = client_message(input); let result = client_message(input);
assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result)); assert_matches!(result, Ok(result) => assert_eq!(expected, result));
} }
#[test] #[test]
fn test_client_message_pong() { fn test_client_message_pong() {
let input = "PONG 1337"; let input = "PONG 1337";
let expected = ClientMessage::Pong { let expected = ClientMessage::Pong { token: "1337".into() };
token: "1337".into(),
};
let result = client_message(input); let result = client_message(input);
assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result)); assert_matches!(result, Ok(result) => assert_eq!(expected, result));
} }
#[test] #[test]
fn test_client_message_nick() { fn test_client_message_nick() {
@ -317,7 +303,7 @@ mod test {
}; };
let result = client_message(input); let result = client_message(input);
assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result)); assert_matches!(result, Ok(result) => assert_eq!(expected, result));
} }
#[test] #[test]
fn test_client_message_user() { fn test_client_message_user() {
@ -328,7 +314,7 @@ mod test {
}; };
let result = client_message(input); let result = client_message(input);
assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result)); assert_matches!(result, Ok(result) => assert_eq!(expected, result));
} }
#[test] #[test]
fn test_client_message_part() { fn test_client_message_part() {
@ -339,6 +325,6 @@ mod test {
}; };
let result = client_message(input); let result = client_message(input);
assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result)); assert_matches!(result, Ok(result) => assert_eq!(expected, result));
} }
} }

View File

@ -7,7 +7,6 @@ mod testkit;
pub mod user; pub mod user;
use crate::prelude::Str; use crate::prelude::Str;
use std::io::Result;
use nom::{ use nom::{
branch::alt, branch::alt,
@ -39,7 +38,7 @@ pub enum Chan {
Local(Str), Local(Str),
} }
impl Chan { impl Chan {
pub async fn write_async(&self, writer: &mut (impl AsyncWrite + Unpin)) -> Result<()> { pub async fn write_async(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> {
match self { match self {
Chan::Global(name) => { Chan::Global(name) => {
writer.write_all(b"#").await?; writer.write_all(b"#").await?;
@ -76,7 +75,7 @@ pub enum Recipient {
Chan(Chan), Chan(Chan),
} }
impl Recipient { impl Recipient {
pub async fn write_async(&self, writer: &mut (impl AsyncWrite + Unpin)) -> Result<()> { pub async fn write_async(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> {
match self { match self {
Recipient::Nick(nick) => writer.write_all(nick.as_bytes()).await?, Recipient::Nick(nick) => writer.write_all(nick.as_bytes()).await?,
Recipient::Chan(chan) => chan.write_async(writer).await?, Recipient::Chan(chan) => chan.write_async(writer).await?,