From fa880af66c739ac1ce0a32b44ec4600fb2cf9d7e Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Wed, 1 Nov 2023 23:04:22 +0400 Subject: [PATCH] wip --- Cargo.lock | 11 +- crates/projection-irc/Cargo.toml | 1 + crates/projection-irc/src/cap.rs | 8 ++ crates/projection-irc/src/lib.rs | 155 +++++++++++++++++++---------- crates/projection-irc/tests/lib.rs | 50 +++++++++- crates/proto-irc/src/client.rs | 9 ++ crates/proto-irc/src/server.rs | 19 ++++ 7 files changed, 191 insertions(+), 62 deletions(-) create mode 100644 crates/projection-irc/src/cap.rs diff --git a/Cargo.lock b/Cargo.lock index 8461513..8fd3ff4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -163,9 +163,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.0" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635" +checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" dependencies = [ "serde", ] @@ -1264,6 +1264,7 @@ name = "projection-irc" version = "0.0.2-dev" dependencies = [ "anyhow", + "bitflags 2.4.1", "futures-util", "lavina-core", "nonempty", @@ -1515,7 +1516,7 @@ version = "0.38.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f25469e9ae0f3d0047ca8b93fc56843f38e6774f0914a107ff8b41be8be8e0b7" dependencies = [ - "bitflags 2.4.0", + "bitflags 2.4.1", "errno", "libc", "linux-raw-sys", @@ -1858,7 +1859,7 @@ checksum = "864b869fdf56263f4c95c45483191ea0af340f9f3e3e7b4d57a61c7c87a970db" dependencies = [ "atoi", "base64", - "bitflags 2.4.0", + "bitflags 2.4.1", "byteorder", "bytes", "crc", @@ -1900,7 +1901,7 @@ checksum = "eb7ae0e6a97fb3ba33b23ac2671a5ce6e3cabe003f451abd5a56e7951d975624" dependencies = [ "atoi", "base64", - "bitflags 2.4.0", + "bitflags 2.4.1", "byteorder", "crc", "dotenvy", diff --git a/crates/projection-irc/Cargo.toml b/crates/projection-irc/Cargo.toml index e8166f0..2d638a4 100644 --- a/crates/projection-irc/Cargo.toml +++ b/crates/projection-irc/Cargo.toml @@ -14,3 +14,4 @@ futures-util.workspace = true nonempty.workspace = true proto-irc = { path = "../proto-irc" } +bitflags = "2.4.1" diff --git a/crates/projection-irc/src/cap.rs b/crates/projection-irc/src/cap.rs new file mode 100644 index 0000000..aa41f42 --- /dev/null +++ b/crates/projection-irc/src/cap.rs @@ -0,0 +1,8 @@ +use bitflags::bitflags; + +bitflags! { + pub struct Capabilities: u32 { + const None = 0; + const Sasl = 1 << 0; + } +} diff --git a/crates/projection-irc/src/lib.rs b/crates/projection-irc/src/lib.rs index 0851c54..0333252 100644 --- a/crates/projection-irc/src/lib.rs +++ b/crates/projection-irc/src/lib.rs @@ -2,10 +2,13 @@ use std::collections::HashMap; use std::net::SocketAddr; use anyhow::{anyhow, Result}; +use cap::Capabilities; use futures_util::future::join_all; use nonempty::nonempty; use nonempty::NonEmpty; use prometheus::{IntCounter, IntGauge, Registry as MetricsRegistry}; +use proto_irc::client::CapabilitySubcommand; +use proto_irc::server::CapSubBody; use serde::Deserialize; use tokio::io::AsyncReadExt; use tokio::io::{AsyncBufReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}; @@ -23,6 +26,8 @@ use proto_irc::server::{AwayStatus, ServerMessage, ServerMessageBody}; use proto_irc::user::PrefixedNick; use proto_irc::{Chan, Recipient}; +mod cap; + #[derive(Deserialize, Debug, Clone)] pub struct ServerConfig { pub listen_on: SocketAddr, @@ -55,20 +60,8 @@ async fn handle_socket( let mut reader: BufReader = BufReader::new(reader); let mut writer = BufWriter::new(writer); - ServerMessage { - tags: vec![], - sender: Some(config.server_name.clone().into()), - body: ServerMessageBody::Notice { - first_target: "*".into(), - rest_targets: vec![], - text: "Welcome to my server!".into(), - }, - } - .write_async(&mut writer) - .await?; - writer.flush().await?; - - let registered_user: Result = handle_registration(&mut reader, &mut writer, &mut storage).await; + let registered_user: Result = + handle_registration(&mut reader, &mut writer, &mut storage, &config).await; match registered_user { Ok(user) => { @@ -88,69 +81,121 @@ async fn handle_registration<'a>( reader: &mut BufReader>, writer: &mut BufWriter>, storage: &mut Storage, + config: &ServerConfig, ) -> Result { + ServerMessage { + tags: vec![], + sender: Some(config.server_name.clone().into()), + body: ServerMessageBody::Notice { + first_target: "*".into(), + rest_targets: vec![], + text: "Welcome to my server!".into(), + }, + } + .write_async(writer) + .await?; + writer.flush().await?; + let mut buffer = vec![]; let mut future_nickname: Option = None; let mut future_username: Option<(Str, Str)> = None; + let mut enabled_capabilities = Capabilities::None; + let mut cap_negotiation_in_progress = false; // if true, expect `CAP END` to complete registration let mut pass: Option = None; let user = loop { let res = read_irc_message(reader, &mut buffer).await; - let res = match res { - Ok(len) => { - if len == 0 { - log::info!("Terminating socket"); - break Err(anyhow::Error::msg("EOF")); - } - match std::str::from_utf8(&buffer[..len - 2]) { - Ok(res) => res, - Err(e) => break Err(e.into()), - } - } + tracing::trace!("Received message: {:?}", res); + let len = match res { + Ok(len) => len, Err(err) => { log::warn!("Failed to read from socket: {err}"); break Err(err.into()); } }; - log::debug!("Incoming raw IRC message: '{res}'"); + if len == 0 { + log::info!("Terminating socket"); + break Err(anyhow::Error::msg("EOF")); + } + let res = match std::str::from_utf8(&buffer[..len - 2]) { + Ok(res) => res, + Err(e) => break Err(e.into()), + }; + tracing::trace!("Incoming raw IRC message: '{res}'"); let parsed = client_message(res); - match parsed { - Ok(msg) => { - log::debug!("Incoming IRC message: {msg:?}"); - match msg { - ClientMessage::Pass { password } => { - pass = Some(password); + let msg = match parsed { + Ok(msg) => msg, + Err(err) => { + tracing::warn!("Failed to parse IRC message: {err}"); + buffer.clear(); + continue; + } + }; + tracing::debug!("Incoming IRC message: {msg:?}"); + match msg { + ClientMessage::Pass { password } => { + pass = Some(password); + } + ClientMessage::Capability { subcommand } => match subcommand { + CapabilitySubcommand::List { code: _ } => { + cap_negotiation_in_progress = true; + ServerMessage { + tags: vec![], + sender: Some(config.server_name.clone().into()), + body: ServerMessageBody::Cap { + target: future_nickname.clone().unwrap_or_else(|| "*".into()), + subcmd: CapSubBody::Ls("sasl=PLAIN".into()), + }, } - ClientMessage::Nick { nickname } => { - if let Some((username, realname)) = future_username { - break Ok(RegisteredUser { - nickname, - username, - realname, - }); - } else { - future_nickname = Some(nickname); + .write_async(writer) + .await?; + writer.flush().await?; + } + CapabilitySubcommand::Req(caps) => { + cap_negotiation_in_progress = true; + let mut acked = vec![]; + for cap in caps { + if &*cap.name == "sasl" { + if cap.to_disable { + enabled_capabilities &= !Capabilities::Sasl; + } else { + enabled_capabilities &= Capabilities::Sasl; + } + acked.push(cap); } } - ClientMessage::User { username, realname } => { - if let Some(nickname) = future_nickname { - break Ok(RegisteredUser { - nickname, - username, - realname, - }); - } else { - future_username = Some((username, realname)); - } - } - _ => {} + } + CapabilitySubcommand::End => {} + }, + ClientMessage::Nick { nickname } => { + if cap_negotiation_in_progress { + future_nickname = Some(nickname); + } else if let Some((username, realname)) = future_username { + break Ok(RegisteredUser { + nickname, + username, + realname, + }); + } else { + future_nickname = Some(nickname); } } - Err(err) => { - log::warn!("Failed to parse IRC message: {err}"); + ClientMessage::User { username, realname } => { + if cap_negotiation_in_progress { + future_username = Some((username, realname)); + } else if let Some(nickname) = future_nickname.clone() { + break Ok(RegisteredUser { + nickname, + username, + realname, + }); + } else { + future_username = Some((username, realname)); + } } + _ => {} } buffer.clear(); }?; diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index d440a6c..8863cdd 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -128,9 +128,12 @@ async fn scenario_basic() -> Result<()> { Ok(()) } +/* +IRC SASL doc: https://ircv3.net/specs/extensions/sasl-3.1.html +AUTHENTICATE doc: https://modern.ircdocs.horse/#authenticate-message +*/ #[tokio::test] -#[ignore] -async fn scenario_cap_negotiation() -> Result<()> { +async fn scenario_cap_full_negotiation() -> Result<()> { let mut server = TestServer::start().await?; // test scenario @@ -173,3 +176,46 @@ async fn scenario_cap_negotiation() -> Result<()> { server.server.terminate().await?; Ok(()) } + +#[tokio::test] +async fn scenario_cap_short_negotiation() -> Result<()> { + let mut server = TestServer::start().await?; + + // test scenario + + server.storage.create_user("tester").await?; + server.storage.set_password("tester", "password").await?; + + let mut stream = TcpStream::connect(server.server.addr).await?; + let mut s = TestScope::new(&mut stream); + + s.send("CAP REQ :sasl").await?; + s.send("NICK tester").await?; + s.send("USER UserName 0 * :Real Name").await?; + s.expect(":testserver CAP tester ACK :sasl=PLAIN").await?; + s.send("AUTHENTICATE PLAIN").await?; + s.expect("AUTHENTICATE +").await?; + s.send("AUTHENTICATE dGVzdGVyAHRlc3RlcgBwYXNzd29yZA==").await?; // base64-encoded 'tester\x00tester\x00password' + s.expect(":testserver 900 tester ??? ??? :You are now logged in as tester").await?; + s.expect(":testserver 903 tester :SASL authentication successful").await?; + + s.send("CAP END").await?; + + s.expect(":testserver NOTICE * :Welcome to my server!").await?; + s.expect(":testserver 001 tester :Welcome to Kek Server").await?; + s.expect(":testserver 002 tester :Welcome to Kek Server").await?; + s.expect(":testserver 003 tester :Welcome to Kek Server").await?; + s.expect(":testserver 004 tester testserver kek-0.1.alpha.3 r CFILPQbcefgijklmnopqrstvz").await?; + s.expect(":testserver 005 tester CHANTYPES=# :are supported by this server").await?; + s.expect_nothing().await?; + s.send("QUIT :Leaving").await?; + s.expect(":testserver ERROR :Leaving the server").await?; + s.expect_eof().await?; + + stream.shutdown().await?; + + // wrap up + + server.server.terminate().await?; + Ok(()) +} diff --git a/crates/proto-irc/src/client.rs b/crates/proto-irc/src/client.rs index 1585622..ac7b988 100644 --- a/crates/proto-irc/src/client.rs +++ b/crates/proto-irc/src/client.rs @@ -2,6 +2,7 @@ use super::*; use anyhow::{anyhow, Result}; use nom::combinator::{all_consuming, opt}; +use nonempty::NonEmpty; /// Client-to-server command. #[derive(Clone, Debug, PartialEq, Eq)] @@ -227,6 +228,8 @@ fn client_message_quit(input: &str) -> IResult<&str, ClientMessage> { pub enum CapabilitySubcommand { /// CAP LS {code} List { code: [u8; 3] }, + /// CAP REQ :... + Req(NonEmpty), /// CAP END End, } @@ -252,6 +255,12 @@ fn capability_subcommand_end(input: &str) -> IResult<&str, CapabilitySubcommand> Ok((input, CapabilitySubcommand::End)) } +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct CapReq { + pub to_disable: bool, + pub name: Str, +} + #[cfg(test)] mod test { use assert_matches::*; diff --git a/crates/proto-irc/src/server.rs b/crates/proto-irc/src/server.rs index 1bead65..72c7d15 100644 --- a/crates/proto-irc/src/server.rs +++ b/crates/proto-irc/src/server.rs @@ -66,6 +66,10 @@ pub enum ServerMessageBody { Error { reason: Str, }, + Cap { + target: Str, + subcmd: CapSubBody, + }, N001Welcome { client: Str, text: Str, @@ -181,6 +185,16 @@ impl ServerMessageBody { writer.write_all(b"ERROR :").await?; writer.write_all(reason.as_bytes()).await?; } + ServerMessageBody::Cap { target, subcmd } => { + writer.write_all(b"CAP ").await?; + writer.write_all(target.as_bytes()).await?; + match subcmd { + CapSubBody::Ls(caps) => { + writer.write_all(b" LS :").await?; + writer.write_all(caps.as_bytes()).await?; + } + } + } ServerMessageBody::N001Welcome { client, text } => { writer.write_all(b"001 ").await?; writer.write_all(client.as_bytes()).await?; @@ -325,6 +339,11 @@ impl ServerMessageBody { } } +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum CapSubBody { + Ls(Str), +} + #[derive(Clone, Debug, PartialEq, Eq)] pub enum AwayStatus { Here,