From e6633477a1f79fbb1b63650514b4d2846f534549 Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Tue, 16 Apr 2024 13:30:35 +0200 Subject: [PATCH] irc: support registration with different order of NICK/USER/CAP END commands --- crates/projection-irc/src/lib.rs | 466 +++++++++++++++-------------- crates/projection-irc/tests/lib.rs | 39 +++ 2 files changed, 285 insertions(+), 220 deletions(-) diff --git a/crates/projection-irc/src/lib.rs b/crates/projection-irc/src/lib.rs index e52e92a..7f1b49e 100644 --- a/crates/projection-irc/src/lib.rs +++ b/crates/projection-irc/src/lib.rs @@ -86,6 +86,249 @@ async fn handle_socket( Ok(()) } +struct RegistrationState { + /// The last received `NICK` message. + future_nickname: Option, + /// The last received `USER` message. + future_username: Option<(Str, Str)>, + enabled_capabilities: Capabilities, + /// `CAP LS` or `CAP REQ` was received, but not `CAP END`. + cap_negotiation_in_progress: bool, + /// The last received `PASS` message. + pass: Option, + authentication_started: bool, + validated_user: Option, +} + +impl RegistrationState { + fn new() -> RegistrationState { + RegistrationState { + future_nickname: None, + future_username: None, + enabled_capabilities: Capabilities::None, + cap_negotiation_in_progress: false, + pass: None, + authentication_started: false, + validated_user: None, + } + } + + /// Handle an incoming message from the client during the registration process. + /// + /// Returns `Some` if the user is fully registered, `None` if the registration is still in progress. + async fn handle_msg( + &mut self, + msg: ClientMessage, + writer: &mut BufWriter>, + storage: &mut Storage, + config: &ServerConfig, + ) -> Result> { + match msg { + ClientMessage::Pass { password } => { + self.pass = Some(password); + Ok(None) + } + ClientMessage::Capability { subcommand } => match subcommand { + CapabilitySubcommand::List { code: _ } => { + self.cap_negotiation_in_progress = true; + ServerMessage { + tags: vec![], + sender: Some(config.server_name.clone().into()), + body: ServerMessageBody::Cap { + target: self.future_nickname.clone().unwrap_or_else(|| "*".into()), + subcmd: CapSubBody::Ls("sasl=PLAIN".into()), + }, + } + .write_async(writer) + .await?; + writer.flush().await?; + Ok(None) + } + CapabilitySubcommand::Req(caps) => { + self.cap_negotiation_in_progress = true; + let mut acked = vec![]; + let mut naked = vec![]; + for cap in caps { + if &*cap.name == "sasl" { + if cap.to_disable { + self.enabled_capabilities &= !Capabilities::Sasl; + } else { + self.enabled_capabilities |= Capabilities::Sasl; + } + acked.push(cap); + } else { + naked.push(cap); + } + } + let mut ack_body = String::new(); + for cap in acked { + if cap.to_disable { + ack_body.push('-'); + } + ack_body += &*cap.name; + } + ServerMessage { + tags: vec![], + sender: Some(config.server_name.clone().into()), + body: ServerMessageBody::Cap { + target: self.future_nickname.clone().unwrap_or_else(|| "*".into()), + subcmd: CapSubBody::Ack(ack_body.into()), + }, + } + .write_async(writer) + .await?; + writer.flush().await?; + Ok(None) + } + CapabilitySubcommand::End => { + let Some((ref username, ref realname)) = self.future_username else { + self.cap_negotiation_in_progress = false; + return Ok(None); + }; + let Some(nickname) = self.future_nickname.clone() else { + self.cap_negotiation_in_progress = false; + return Ok(None); + }; + let username = username.clone(); + let realname = realname.clone(); + let candidate_user = RegisteredUser { + nickname: nickname.clone(), + username, + realname, + }; + self.finalize_auth(candidate_user, writer, storage, config).await + } + }, + ClientMessage::Nick { nickname } => { + if self.cap_negotiation_in_progress { + self.future_nickname = Some(nickname); + Ok(None) + } else if let Some((username, realname)) = &self.future_username.clone() { + let candidate_user = RegisteredUser { + nickname: nickname.clone(), + username: username.clone(), + realname: realname.clone(), + }; + self.finalize_auth(candidate_user, writer, storage, config).await + } else { + self.future_nickname = Some(nickname); + Ok(None) + } + } + ClientMessage::User { username, realname } => { + if self.cap_negotiation_in_progress { + self.future_username = Some((username, realname)); + Ok(None) + } else if let Some(nickname) = self.future_nickname.clone() { + let candidate_user = RegisteredUser { + nickname: nickname.clone(), + username, + realname, + }; + self.finalize_auth(candidate_user, writer, storage, config).await + } else { + self.future_username = Some((username, realname)); + Ok(None) + } + } + ClientMessage::Authenticate(body) => { + if !self.authentication_started { + tracing::debug!("Received authentication request"); + if &*body == "PLAIN" { + tracing::debug!("Authentication request with method PLAIN"); + self.authentication_started = true; + ServerMessage { + tags: vec![], + sender: Some(config.server_name.clone().into()), + body: ServerMessageBody::Authenticate("+".into()), + } + .write_async(writer) + .await?; + writer.flush().await?; + Ok(None) + } else { + let target = self.future_nickname.clone().unwrap_or_else(|| "*".into()); + sasl_fail_message(config.server_name.clone(), target, "Unsupported mechanism".into()) + .write_async(writer) + .await?; + writer.flush().await?; + Ok(None) + } + } else { + let body = AuthBody::from_str(body.as_bytes())?; + if let Err(e) = auth_user(storage, &body.login, &body.password).await { + tracing::warn!("Authentication failed: {:?}", e); + let target = self.future_nickname.clone().unwrap_or_else(|| "*".into()); + sasl_fail_message(config.server_name.clone(), target, "Bad credentials".into()) + .write_async(writer) + .await?; + writer.flush().await?; + Ok(None) + } else { + let login: Str = body.login.into(); + self.validated_user = Some(login.clone()); + ServerMessage { + tags: vec![], + sender: Some(config.server_name.clone().into()), + body: ServerMessageBody::N900LoggedIn { + nick: login.clone(), + address: login.clone(), + account: login.clone(), + message: format!("You are now logged in as {}", login).into(), + }, + } + .write_async(writer) + .await?; + ServerMessage { + tags: vec![], + sender: Some(config.server_name.clone().into()), + body: ServerMessageBody::N903SaslSuccess { + nick: login.clone(), + message: "SASL authentication successful".into(), + }, + } + .write_async(writer) + .await?; + writer.flush().await?; + Ok(None) + } + } + + // TODO handle abortion of authentication + } + _ => Ok(None), + } + } + + async fn finalize_auth( + &mut self, + candidate_user: RegisteredUser, + writer: &mut BufWriter>, + storage: &mut Storage, + config: &ServerConfig, + ) -> Result> { + if self.enabled_capabilities.contains(Capabilities::Sasl) + && self.validated_user.as_ref() == Some(&candidate_user.nickname) + { + Ok(Some(candidate_user)) + } else { + let Some(candidate_password) = &self.pass else { + sasl_fail_message( + config.server_name.clone(), + candidate_user.nickname.clone(), + "User credentials was not provided".into(), + ) + .write_async(writer) + .await?; + writer.flush().await?; + return Ok(None); + }; + auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?; + Ok(Some(candidate_user)) + } + } +} + async fn handle_registration<'a>( reader: &mut BufReader>, writer: &mut BufWriter>, @@ -94,14 +337,7 @@ async fn handle_registration<'a>( ) -> Result { let mut buffer = vec![]; - let mut future_nickname: Option = None; - let mut future_username: Option<(Str, Str)> = None; - let mut enabled_capabilities = Capabilities::None; - let mut cap_negotiation_in_progress = false; // if true, expect `CAP END` to complete registration - - let mut pass: Option = None; - let mut authentication_started = false; - let mut validated_user = None; + let mut state = RegistrationState::new(); let user = loop { let res = read_irc_message(reader, &mut buffer).await; @@ -132,218 +368,8 @@ async fn handle_registration<'a>( } }; tracing::debug!("Incoming IRC message: {msg:?}"); - match msg { - ClientMessage::Pass { password } => { - pass = Some(password); - } - ClientMessage::Capability { subcommand } => match subcommand { - CapabilitySubcommand::List { code: _ } => { - cap_negotiation_in_progress = true; - ServerMessage { - tags: vec![], - sender: Some(config.server_name.clone().into()), - body: ServerMessageBody::Cap { - target: future_nickname.clone().unwrap_or_else(|| "*".into()), - subcmd: CapSubBody::Ls("sasl=PLAIN".into()), - }, - } - .write_async(writer) - .await?; - writer.flush().await?; - } - CapabilitySubcommand::Req(caps) => { - cap_negotiation_in_progress = true; - let mut acked = vec![]; - let mut naked = vec![]; - for cap in caps { - if &*cap.name == "sasl" { - if cap.to_disable { - enabled_capabilities &= !Capabilities::Sasl; - } else { - enabled_capabilities |= Capabilities::Sasl; - } - acked.push(cap); - } else { - naked.push(cap); - } - } - let mut ack_body = String::new(); - for cap in acked { - if cap.to_disable { - ack_body.push('-'); - } - ack_body += &*cap.name; - } - ServerMessage { - tags: vec![], - sender: Some(config.server_name.clone().into()), - body: ServerMessageBody::Cap { - target: future_nickname.clone().unwrap_or_else(|| "*".into()), - subcmd: CapSubBody::Ack(ack_body.into()), - }, - } - .write_async(writer) - .await?; - writer.flush().await?; - } - CapabilitySubcommand::End => { - let Some((ref username, ref realname)) = future_username else { - todo!(); - }; - let Some(nickname) = future_nickname.clone() else { - todo!(); - }; - let username = username.clone(); - let realname = realname.clone(); - let candidate_user = RegisteredUser { - nickname: nickname.clone(), - username, - realname, - }; - if enabled_capabilities.contains(Capabilities::Sasl) - && validated_user.as_ref() == Some(&candidate_user.nickname) - { - break Ok(candidate_user); - } else { - let Some(candidate_password) = pass else { - sasl_fail_message( - config.server_name.clone(), - nickname.clone(), - "User credentials was not provided".into(), - ) - .write_async(writer) - .await?; - writer.flush().await?; - continue; - }; - auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?; - break Ok(candidate_user); - } - } - }, - ClientMessage::Nick { nickname } => { - if cap_negotiation_in_progress { - future_nickname = Some(nickname); - } else if let Some((username, realname)) = future_username.clone() { - let candidate_user = RegisteredUser { - nickname: nickname.clone(), - username, - realname, - }; - let Some(candidate_password) = pass else { - sasl_fail_message( - config.server_name.clone(), - nickname.clone(), - "User credentials was not provided".into(), - ) - .write_async(writer) - .await?; - writer.flush().await?; - continue; - }; - auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?; - break Ok(candidate_user); - } else { - future_nickname = Some(nickname); - } - } - ClientMessage::User { username, realname } => { - if cap_negotiation_in_progress { - future_username = Some((username, realname)); - } else if let Some(nickname) = future_nickname.clone() { - let candidate_user = RegisteredUser { - nickname: nickname.clone(), - username, - realname, - }; - let Some(candidate_password) = pass else { - sasl_fail_message( - config.server_name.clone(), - nickname.clone(), - "User credentials was not provided".into(), - ) - .write_async(writer) - .await?; - writer.flush().await?; - continue; - }; - auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?; - break Ok(candidate_user); - } else { - future_username = Some((username, realname)); - } - } - ClientMessage::Authenticate(body) => { - if !authentication_started { - tracing::debug!("Received authentication request"); - if &*body == "PLAIN" { - tracing::debug!("Authentication request with method PLAIN"); - authentication_started = true; - ServerMessage { - tags: vec![], - sender: Some(config.server_name.clone().into()), - body: ServerMessageBody::Authenticate("+".into()), - } - .write_async(writer) - .await?; - writer.flush().await?; - } else { - if let Some(nickname) = future_nickname.clone() { - sasl_fail_message( - config.server_name.clone(), - nickname.clone(), - "Unsupported mechanism".into(), - ) - .write_async(writer) - .await?; - writer.flush().await?; - } else { - break Err(anyhow::Error::msg("Wrong authentication sequence")); - } - } - } else { - let body = AuthBody::from_str(body.as_bytes())?; - if let Err(e) = auth_user(storage, &body.login, &body.password).await { - tracing::warn!("Authentication failed: {:?}", e); - if let Some(nickname) = future_nickname.clone() { - sasl_fail_message(config.server_name.clone(), nickname.clone(), "Bad credentials".into()) - .write_async(writer) - .await?; - writer.flush().await?; - } else { - } - } else { - let login: Str = body.login.into(); - validated_user = Some(login.clone()); - ServerMessage { - tags: vec![], - sender: Some(config.server_name.clone().into()), - body: ServerMessageBody::N900LoggedIn { - nick: login.clone(), - address: login.clone(), - account: login.clone(), - message: format!("You are now logged in as {}", login).into(), - }, - } - .write_async(writer) - .await?; - ServerMessage { - tags: vec![], - sender: Some(config.server_name.clone().into()), - body: ServerMessageBody::N903SaslSuccess { - nick: login.clone(), - message: "SASL authentication successful".into(), - }, - } - .write_async(writer) - .await?; - writer.flush().await?; - } - } - - // TODO handle abortion of authentication - } - _ => {} + if let Some(user) = state.handle_msg(msg, writer, storage, config).await? { + break Ok(user); } buffer.clear(); }?; diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index 3618467..145033b 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -421,6 +421,45 @@ async fn scenario_cap_full_negotiation() -> Result<()> { Ok(()) } +#[tokio::test] +async fn scenario_cap_full_negotiation_nick_last() -> Result<()> { + let mut server = TestServer::start().await?; + + // test scenario + + server.storage.create_user("tester").await?; + server.storage.set_password("tester", "password").await?; + + let mut stream = TcpStream::connect(server.server.addr).await?; + let mut s = TestScope::new(&mut stream); + + s.send("CAP LS 302").await?; + s.expect(":testserver CAP * LS :sasl=PLAIN").await?; + s.send("CAP REQ :sasl").await?; + s.expect(":testserver CAP * ACK :sasl").await?; + s.send("AUTHENTICATE PLAIN").await?; + s.expect(":testserver AUTHENTICATE +").await?; + s.send("AUTHENTICATE dGVzdGVyAHRlc3RlcgBwYXNzd29yZA==").await?; // base64-encoded 'tester\x00tester\x00password' + s.expect(":testserver 900 tester tester tester :You are now logged in as tester").await?; + s.expect(":testserver 903 tester :SASL authentication successful").await?; + s.send("CAP END").await?; + s.send("USER UserName 0 * :Real Name").await?; + s.send("NICK tester").await?; + + s.expect_server_introduction("tester").await?; + s.expect_nothing().await?; + s.send("QUIT :Leaving").await?; + s.expect(":testserver ERROR :Leaving the server").await?; + s.expect_eof().await?; + + stream.shutdown().await?; + + // wrap up + + server.server.terminate().await?; + Ok(()) +} + #[tokio::test] async fn scenario_cap_short_negotiation() -> Result<()> { let mut server = TestServer::start().await?;