diff --git a/Cargo.lock b/Cargo.lock index 8fd3ff4..70f5ab7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1270,9 +1270,11 @@ dependencies = [ "nonempty", "prometheus", "proto-irc", + "sasl", "serde", "tokio", "tracing", + "tracing-subscriber", ] [[package]] diff --git a/crates/projection-irc/Cargo.toml b/crates/projection-irc/Cargo.toml index 2d638a4..3135280 100644 --- a/crates/projection-irc/Cargo.toml +++ b/crates/projection-irc/Cargo.toml @@ -11,7 +11,10 @@ serde.workspace = true tokio.workspace = true prometheus.workspace = true futures-util.workspace = true - nonempty.workspace = true -proto-irc = { path = "../proto-irc" } bitflags = "2.4.1" +proto-irc = { path = "../proto-irc" } +sasl = { path = "../sasl" } + +[dev-dependencies] +tracing-subscriber.workspace = true diff --git a/crates/projection-irc/src/cap.rs b/crates/projection-irc/src/cap.rs index aa41f42..af0e3ff 100644 --- a/crates/projection-irc/src/cap.rs +++ b/crates/projection-irc/src/cap.rs @@ -1,6 +1,7 @@ use bitflags::bitflags; bitflags! { + #[derive(Debug)] pub struct Capabilities: u32 { const None = 0; const Sasl = 1 << 0; diff --git a/crates/projection-irc/src/lib.rs b/crates/projection-irc/src/lib.rs index 0333252..ee7fbb8 100644 --- a/crates/projection-irc/src/lib.rs +++ b/crates/projection-irc/src/lib.rs @@ -2,13 +2,10 @@ use std::collections::HashMap; use std::net::SocketAddr; use anyhow::{anyhow, Result}; -use cap::Capabilities; use futures_util::future::join_all; use nonempty::nonempty; use nonempty::NonEmpty; use prometheus::{IntCounter, IntGauge, Registry as MetricsRegistry}; -use proto_irc::client::CapabilitySubcommand; -use proto_irc::server::CapSubBody; use serde::Deserialize; use tokio::io::AsyncReadExt; use tokio::io::{AsyncBufReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}; @@ -21,13 +18,18 @@ use lavina_core::prelude::*; use lavina_core::repo::Storage; use lavina_core::room::{RoomId, RoomInfo, RoomRegistry}; use lavina_core::terminator::Terminator; +use proto_irc::client::CapabilitySubcommand; use proto_irc::client::{client_message, ClientMessage}; +use proto_irc::server::CapSubBody; use proto_irc::server::{AwayStatus, ServerMessage, ServerMessageBody}; use proto_irc::user::PrefixedNick; use proto_irc::{Chan, Recipient}; +use sasl::AuthBody; mod cap; +use crate::cap::Capabilities; + #[derive(Deserialize, Debug, Clone)] pub struct ServerConfig { pub listen_on: SocketAddr, @@ -68,8 +70,8 @@ async fn handle_socket( log::debug!("User registered"); handle_registered_socket(config, players, rooms, &mut reader, &mut writer, user).await?; } - Err(_) => { - log::debug!("Registration failed"); + Err(err) => { + log::debug!("Registration failed: {err}"); } } @@ -83,19 +85,6 @@ async fn handle_registration<'a>( storage: &mut Storage, config: &ServerConfig, ) -> Result { - ServerMessage { - tags: vec![], - sender: Some(config.server_name.clone().into()), - body: ServerMessageBody::Notice { - first_target: "*".into(), - rest_targets: vec![], - text: "Welcome to my server!".into(), - }, - } - .write_async(writer) - .await?; - writer.flush().await?; - let mut buffer = vec![]; let mut future_nickname: Option = None; @@ -104,6 +93,8 @@ async fn handle_registration<'a>( let mut cap_negotiation_in_progress = false; // if true, expect `CAP END` to complete registration let mut pass: Option = None; + let mut authentication_started = false; + let mut validated_user = None; let user = loop { let res = read_irc_message(reader, &mut buffer).await; @@ -156,28 +147,77 @@ async fn handle_registration<'a>( CapabilitySubcommand::Req(caps) => { cap_negotiation_in_progress = true; let mut acked = vec![]; + let mut naked = vec![]; for cap in caps { if &*cap.name == "sasl" { if cap.to_disable { enabled_capabilities &= !Capabilities::Sasl; } else { - enabled_capabilities &= Capabilities::Sasl; + enabled_capabilities |= Capabilities::Sasl; } acked.push(cap); + } else { + naked.push(cap); } } + let mut ack_body = String::new(); + for cap in acked { + if cap.to_disable { + ack_body.push('-'); + } + ack_body += &*cap.name; + } + ServerMessage { + tags: vec![], + sender: Some(config.server_name.clone().into()), + body: ServerMessageBody::Cap { + target: future_nickname.clone().unwrap_or_else(|| "*".into()), + subcmd: CapSubBody::Ack(ack_body.into()), + }, + } + .write_async(writer) + .await?; + writer.flush().await?; + } + CapabilitySubcommand::End => { + let Some((username, realname)) = future_username else { + todo!() + }; + let Some(nickname) = future_nickname.clone() else { + todo!() + }; + let candidate_user = RegisteredUser { + nickname, + username, + realname, + }; + if enabled_capabilities.contains(Capabilities::Sasl) + && validated_user.as_ref() == Some(&candidate_user.nickname) + { + break Ok(candidate_user); + } else { + let Some(candidate_password) = pass else { + todo!(); + }; + auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?; + break Ok(candidate_user); + } } - CapabilitySubcommand::End => {} }, ClientMessage::Nick { nickname } => { if cap_negotiation_in_progress { future_nickname = Some(nickname); - } else if let Some((username, realname)) = future_username { - break Ok(RegisteredUser { + } else if let Some((username, realname)) = future_username.clone() { + let candidate_user = RegisteredUser { nickname, username, realname, - }); + }; + let Some(candidate_password) = pass else { + todo!(); + }; + auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?; + break Ok(candidate_user); } else { future_nickname = Some(nickname); } @@ -186,40 +226,96 @@ async fn handle_registration<'a>( if cap_negotiation_in_progress { future_username = Some((username, realname)); } else if let Some(nickname) = future_nickname.clone() { - break Ok(RegisteredUser { + let candidate_user = RegisteredUser { nickname, username, realname, - }); + }; + let Some(candidate_password) = pass else { + todo!(); + }; + auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?; + break Ok(candidate_user); } else { future_username = Some((username, realname)); } } + ClientMessage::Authenticate(body) => { + if !authentication_started { + tracing::debug!("Received authentication request"); + if &*body == "PLAIN" { + tracing::debug!("Authentication request with method PLAIN"); + authentication_started = true; + ServerMessage { + tags: vec![], + sender: Some(config.server_name.clone().into()), + body: ServerMessageBody::Authenticate("+".into()), + } + .write_async(writer) + .await?; + writer.flush().await?; + } else { + // TODO respond with 904 + todo!(); + } + } else { + let body = AuthBody::from_str(body.as_bytes())?; + auth_user(storage, &body.login, &body.password).await?; + let login: Str = body.login.into(); + validated_user = Some(login.clone()); + ServerMessage { + tags: vec![], + sender: Some(config.server_name.clone().into()), + body: ServerMessageBody::N900LoggedIn { + nick: login.clone(), + address: login.clone(), + account: login.clone(), + message: format!("You are now logged in as {}", login).into(), + }, + } + .write_async(writer) + .await?; + ServerMessage { + tags: vec![], + sender: Some(config.server_name.clone().into()), + body: ServerMessageBody::N903SaslSuccess { + nick: login.clone(), + message: "SASL authentication successful".into(), + }, + } + .write_async(writer) + .await?; + writer.flush().await?; + } + // TODO handle abortion of authentication + } _ => {} } buffer.clear(); }?; + // TODO properly implement session temination + Ok(user) +} - let stored_user = storage.retrieve_user_by_name(&*user.nickname).await?; +async fn auth_user(storage: &mut Storage, login: &str, plain_password: &str) -> Result<()> { + let stored_user = storage.retrieve_user_by_name(login).await?; let stored_user = match stored_user { Some(u) => u, None => { - log::info!("User '{}' not found", user.nickname); + log::info!("User '{}' not found", login); return Err(anyhow!("no user found")); } }; - if stored_user.password.is_none() { - log::info!("Password not defined for user '{}'", user.nickname); + let Some(expected_password) = stored_user.password else { + log::info!("Password not defined for user '{}'", login); return Err(anyhow!("password is not defined")); - } - if stored_user.password.as_deref() != pass.as_deref() { - log::info!("Incorrect password supplied for user '{}'", user.nickname); + }; + if expected_password != plain_password { + log::info!("Incorrect password supplied for user '{}'", login); return Err(anyhow!("passwords do not match")); } - // TODO properly implement session temination - - Ok(user) + Ok(()) } async fn handle_registered_socket<'a>( diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index 8863cdd..c102143 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -39,6 +39,7 @@ impl<'a> TestScope<'a> { } async fn expect(&mut self, str: &str) -> Result<()> { + tracing::debug!("Expecting {}", str); let len = tokio::time::timeout(self.timeout, read_irc_message(&mut self.reader, &mut self.buffer)).await??; assert_eq!(std::str::from_utf8(&self.buffer[..len - 2])?, str); self.buffer.clear(); @@ -72,6 +73,7 @@ struct TestServer { } impl TestServer { async fn start() -> Result { + let _ = tracing_subscriber::fmt::try_init(); let config = ServerConfig { listen_on: "127.0.0.1:0".parse().unwrap(), server_name: "testserver".into(), @@ -109,7 +111,6 @@ async fn scenario_basic() -> Result<()> { s.send("PASS password").await?; s.send("NICK tester").await?; s.send("USER UserName 0 * :Real Name").await?; - s.expect(":testserver NOTICE * :Welcome to my server!").await?; s.expect(":testserver 001 tester :Welcome to Kek Server").await?; s.expect(":testserver 002 tester :Welcome to Kek Server").await?; s.expect(":testserver 003 tester :Welcome to Kek Server").await?; @@ -149,16 +150,15 @@ async fn scenario_cap_full_negotiation() -> Result<()> { s.send("USER UserName 0 * :Real Name").await?; s.expect(":testserver CAP * LS :sasl=PLAIN").await?; s.send("CAP REQ :sasl").await?; - s.expect(":testserver CAP tester ACK :sasl=PLAIN").await?; + s.expect(":testserver CAP tester ACK :sasl").await?; s.send("AUTHENTICATE PLAIN").await?; - s.expect("AUTHENTICATE +").await?; + s.expect(":testserver AUTHENTICATE +").await?; s.send("AUTHENTICATE dGVzdGVyAHRlc3RlcgBwYXNzd29yZA==").await?; // base64-encoded 'tester\x00tester\x00password' - s.expect(":testserver 900 tester ??? ??? :You are now logged in as tester").await?; + s.expect(":testserver 900 tester tester tester :You are now logged in as tester").await?; s.expect(":testserver 903 tester :SASL authentication successful").await?; s.send("CAP END").await?; - s.expect(":testserver NOTICE * :Welcome to my server!").await?; s.expect(":testserver 001 tester :Welcome to Kek Server").await?; s.expect(":testserver 002 tester :Welcome to Kek Server").await?; s.expect(":testserver 003 tester :Welcome to Kek Server").await?; @@ -189,19 +189,18 @@ async fn scenario_cap_short_negotiation() -> Result<()> { let mut stream = TcpStream::connect(server.server.addr).await?; let mut s = TestScope::new(&mut stream); - s.send("CAP REQ :sasl").await?; s.send("NICK tester").await?; + s.send("CAP REQ :sasl").await?; s.send("USER UserName 0 * :Real Name").await?; - s.expect(":testserver CAP tester ACK :sasl=PLAIN").await?; + s.expect(":testserver CAP tester ACK :sasl").await?; s.send("AUTHENTICATE PLAIN").await?; - s.expect("AUTHENTICATE +").await?; + s.expect(":testserver AUTHENTICATE +").await?; s.send("AUTHENTICATE dGVzdGVyAHRlc3RlcgBwYXNzd29yZA==").await?; // base64-encoded 'tester\x00tester\x00password' - s.expect(":testserver 900 tester ??? ??? :You are now logged in as tester").await?; + s.expect(":testserver 900 tester tester tester :You are now logged in as tester").await?; s.expect(":testserver 903 tester :SASL authentication successful").await?; s.send("CAP END").await?; - s.expect(":testserver NOTICE * :Welcome to my server!").await?; s.expect(":testserver 001 tester :Welcome to Kek Server").await?; s.expect(":testserver 002 tester :Welcome to Kek Server").await?; s.expect(":testserver 003 tester :Welcome to Kek Server").await?; diff --git a/crates/proto-irc/src/client.rs b/crates/proto-irc/src/client.rs index ac7b988..676fd40 100644 --- a/crates/proto-irc/src/client.rs +++ b/crates/proto-irc/src/client.rs @@ -60,6 +60,7 @@ pub enum ClientMessage { Quit { reason: Str, }, + Authenticate(Str), } pub fn client_message(input: &str) -> Result { @@ -77,6 +78,7 @@ pub fn client_message(input: &str) -> Result { client_message_part, client_message_privmsg, client_message_quit, + client_message_authenticate, )))(input); match res { Ok((_, e)) => Ok(e), @@ -224,6 +226,13 @@ fn client_message_quit(input: &str) -> IResult<&str, ClientMessage> { Ok((input, ClientMessage::Quit { reason: reason.into() })) } +fn client_message_authenticate(input: &str) -> IResult<&str, ClientMessage> { + let (input, _) = tag("AUTHENTICATE ")(input)?; + let (input, body) = token(input)?; + + Ok((input, ClientMessage::Authenticate(body.into()))) +} + #[derive(Clone, Debug, PartialEq, Eq)] pub enum CapabilitySubcommand { /// CAP LS {code} @@ -235,7 +244,11 @@ pub enum CapabilitySubcommand { } fn capability_subcommand(input: &str) -> IResult<&str, CapabilitySubcommand> { - alt((capability_subcommand_ls, capability_subcommand_end))(input) + alt(( + capability_subcommand_ls, + capability_subcommand_end, + capability_subcommand_req, + ))(input) } fn capability_subcommand_ls(input: &str) -> IResult<&str, CapabilitySubcommand> { @@ -250,6 +263,31 @@ fn capability_subcommand_ls(input: &str) -> IResult<&str, CapabilitySubcommand> )) } +fn capability_subcommand_req(input: &str) -> IResult<&str, CapabilitySubcommand> { + let (input, _) = tag("REQ ")(input)?; + let (input, r) = opt(tag(":"))(input)?; + let (input, body) = match r { + Some(_) => token(input)?, + None => receiver(input)?, + }; + + let caps = body + .split(' ') + .map(|cap| { + let to_disable = cap.starts_with('-'); + let name = if to_disable { &cap[1..] } else { &cap[..] }; + CapReq { + to_disable, + name: name.into(), + } + }) + .collect::>(); + + let caps = NonEmpty::from_vec(caps).ok_or_else(|| todo!())?; + + Ok((input, CapabilitySubcommand::Req(caps))) +} + fn capability_subcommand_end(input: &str) -> IResult<&str, CapabilitySubcommand> { let (input, _) = tag("END")(input)?; Ok((input, CapabilitySubcommand::End)) @@ -264,6 +302,7 @@ pub struct CapReq { #[cfg(test)] mod test { use assert_matches::*; + use nonempty::nonempty; use super::*; #[test] @@ -333,6 +372,25 @@ mod test { message: "Pokasiki !!!".into(), }; + let result = client_message(input); + assert_matches!(result, Ok(result) => assert_eq!(expected, result)); + } + #[test] + fn test_client_cap_req() { + let input = "CAP REQ :multi-prefix -sasl"; + let expected = ClientMessage::Capability { + subcommand: CapabilitySubcommand::Req(nonempty![ + CapReq { + to_disable: false, + name: "multi-prefix".into() + }, + CapReq { + to_disable: true, + name: "sasl".into() + } + ]), + }; + let result = client_message(input); assert_matches!(result, Ok(result) => assert_eq!(expected, result)); } diff --git a/crates/proto-irc/src/server.rs b/crates/proto-irc/src/server.rs index 3e7fbd2..fdcccb1 100644 --- a/crates/proto-irc/src/server.rs +++ b/crates/proto-irc/src/server.rs @@ -70,6 +70,7 @@ pub enum ServerMessageBody { target: Str, subcmd: CapSubBody, }, + Authenticate(Str), N001Welcome { client: Str, text: Str, @@ -142,6 +143,16 @@ pub enum ServerMessageBody { client: Str, message: Str, }, + N900LoggedIn { + nick: Str, + address: Str, + account: Str, + message: Str, + }, + N903SaslSuccess { + nick: Str, + message: Str, + } } impl ServerMessageBody { @@ -193,8 +204,16 @@ impl ServerMessageBody { writer.write_all(b" LS :").await?; writer.write_all(caps.as_bytes()).await?; } + CapSubBody::Ack(caps) => { + writer.write_all(b" ACK :").await?; + writer.write_all(caps.as_bytes()).await?; + } } } + ServerMessageBody::Authenticate(body) => { + writer.write_all(b"AUTHENTICATE ").await?; + writer.write_all(body.as_bytes()).await?; + } ServerMessageBody::N001Welcome { client, text } => { writer.write_all(b"001 ").await?; writer.write_all(client.as_bytes()).await?; @@ -334,6 +353,22 @@ impl ServerMessageBody { writer.write_all(b" :").await?; writer.write_all(message.as_bytes()).await?; } + ServerMessageBody::N900LoggedIn { nick, address, account, message } => { + writer.write_all(b"900 ").await?; + writer.write_all(nick.as_bytes()).await?; + writer.write_all(b" ").await?; + writer.write_all(address.as_bytes()).await?; + writer.write_all(b" ").await?; + writer.write_all(account.as_bytes()).await?; + writer.write_all(b" :").await?; + writer.write_all(message.as_bytes()).await?; + } + ServerMessageBody::N903SaslSuccess { nick, message } => { + writer.write_all(b"903 ").await?; + writer.write_all(nick.as_bytes()).await?; + writer.write_all(b" :").await?; + writer.write_all(message.as_bytes()).await?; + } } Ok(()) } @@ -342,6 +377,7 @@ impl ServerMessageBody { #[derive(Clone, Debug, PartialEq, Eq)] pub enum CapSubBody { Ls(Str), + Ack(Str), } #[derive(Clone, Debug, PartialEq, Eq)]