diff --git a/crates/projection-irc/src/lib.rs b/crates/projection-irc/src/lib.rs index ee7fbb8..4808b85 100644 --- a/crates/projection-irc/src/lib.rs +++ b/crates/projection-irc/src/lib.rs @@ -180,14 +180,16 @@ async fn handle_registration<'a>( writer.flush().await?; } CapabilitySubcommand::End => { - let Some((username, realname)) = future_username else { - todo!() + let Some((ref username, ref realname)) = future_username else { + todo!(); }; let Some(nickname) = future_nickname.clone() else { - todo!() + todo!(); }; + let username = username.clone(); + let realname = realname.clone(); let candidate_user = RegisteredUser { - nickname, + nickname: nickname.clone(), username, realname, }; @@ -197,7 +199,15 @@ async fn handle_registration<'a>( break Ok(candidate_user); } else { let Some(candidate_password) = pass else { - todo!(); + sasl_fail_message( + config.server_name.clone(), + nickname.clone(), + "User credentials was not provided".into(), + ) + .write_async(writer) + .await?; + writer.flush().await?; + continue; }; auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?; break Ok(candidate_user); @@ -209,12 +219,20 @@ async fn handle_registration<'a>( future_nickname = Some(nickname); } else if let Some((username, realname)) = future_username.clone() { let candidate_user = RegisteredUser { - nickname, + nickname: nickname.clone(), username, realname, }; let Some(candidate_password) = pass else { - todo!(); + sasl_fail_message( + config.server_name.clone(), + nickname.clone(), + "User credentials was not provided".into(), + ) + .write_async(writer) + .await?; + writer.flush().await?; + continue; }; auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?; break Ok(candidate_user); @@ -227,12 +245,20 @@ async fn handle_registration<'a>( future_username = Some((username, realname)); } else if let Some(nickname) = future_nickname.clone() { let candidate_user = RegisteredUser { - nickname, + nickname: nickname.clone(), username, realname, }; let Some(candidate_password) = pass else { - todo!(); + sasl_fail_message( + config.server_name.clone(), + nickname.clone(), + "User credentials was not provided".into(), + ) + .write_async(writer) + .await?; + writer.flush().await?; + continue; }; auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?; break Ok(candidate_user); @@ -255,38 +281,59 @@ async fn handle_registration<'a>( .await?; writer.flush().await?; } else { - // TODO respond with 904 - todo!(); + if let Some(nickname) = future_nickname.clone() { + sasl_fail_message( + config.server_name.clone(), + nickname.clone(), + "Unsupported mechanism".into(), + ) + .write_async(writer) + .await?; + writer.flush().await?; + } else { + break Err(anyhow::Error::msg("Wrong authentication sequence")); + } } } else { let body = AuthBody::from_str(body.as_bytes())?; - auth_user(storage, &body.login, &body.password).await?; - let login: Str = body.login.into(); - validated_user = Some(login.clone()); - ServerMessage { - tags: vec![], - sender: Some(config.server_name.clone().into()), - body: ServerMessageBody::N900LoggedIn { - nick: login.clone(), - address: login.clone(), - account: login.clone(), - message: format!("You are now logged in as {}", login).into(), - }, + if let Err(e) = auth_user(storage, &body.login, &body.password).await { + tracing::warn!("Authentication failed: {:?}", e); + if let Some(nickname) = future_nickname.clone() { + sasl_fail_message(config.server_name.clone(), nickname.clone(), "Bad credentials".into()) + .write_async(writer) + .await?; + writer.flush().await?; + } else { + } + } else { + let login: Str = body.login.into(); + validated_user = Some(login.clone()); + ServerMessage { + tags: vec![], + sender: Some(config.server_name.clone().into()), + body: ServerMessageBody::N900LoggedIn { + nick: login.clone(), + address: login.clone(), + account: login.clone(), + message: format!("You are now logged in as {}", login).into(), + }, + } + .write_async(writer) + .await?; + ServerMessage { + tags: vec![], + sender: Some(config.server_name.clone().into()), + body: ServerMessageBody::N903SaslSuccess { + nick: login.clone(), + message: "SASL authentication successful".into(), + }, + } + .write_async(writer) + .await?; + writer.flush().await?; } - .write_async(writer) - .await?; - ServerMessage { - tags: vec![], - sender: Some(config.server_name.clone().into()), - body: ServerMessageBody::N903SaslSuccess { - nick: login.clone(), - message: "SASL authentication successful".into(), - }, - } - .write_async(writer) - .await?; - writer.flush().await?; } + // TODO handle abortion of authentication } _ => {} @@ -297,6 +344,14 @@ async fn handle_registration<'a>( Ok(user) } +fn sasl_fail_message(sender: Str, nick: Str, text: Str) -> ServerMessage { + ServerMessage { + tags: vec![], + sender: Some(sender), + body: ServerMessageBody::N904SaslFail { nick, text }, + } +} + async fn auth_user(storage: &mut Storage, login: &str, plain_password: &str) -> Result<()> { let stored_user = storage.retrieve_user_by_name(login).await?; diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index c102143..a0ee071 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -218,3 +218,51 @@ async fn scenario_cap_short_negotiation() -> Result<()> { server.server.terminate().await?; Ok(()) } + +#[tokio::test] +async fn scenario_cap_sasl_fail() -> Result<()> { + let mut server = TestServer::start().await?; + + // test scenario + + server.storage.create_user("tester").await?; + server.storage.set_password("tester", "password").await?; + + let mut stream = TcpStream::connect(server.server.addr).await?; + let mut s = TestScope::new(&mut stream); + + s.send("CAP LS 302").await?; + s.send("NICK tester").await?; + s.send("USER UserName 0 * :Real Name").await?; + s.expect(":testserver CAP * LS :sasl=PLAIN").await?; + s.send("CAP REQ :sasl").await?; + s.expect(":testserver CAP tester ACK :sasl").await?; + s.send("AUTHENTICATE SHA256").await?; + s.expect(":testserver 904 tester :Unsupported mechanism").await?; + s.send("AUTHENTICATE PLAIN").await?; + s.expect(":testserver AUTHENTICATE +").await?; + s.send("AUTHENTICATE dGVzdGVyAHRlc3RlcgBwYXNzd29yZDE=").await?; + s.expect(":testserver 904 tester :Bad credentials").await?; + s.send("AUTHENTICATE dGVzdGVyAHRlc3RlcgBwYXNzd29yZA==").await?; // base64-encoded 'tester\x00tester\x00password' + s.expect(":testserver 900 tester tester tester :You are now logged in as tester").await?; + s.expect(":testserver 903 tester :SASL authentication successful").await?; + + s.send("CAP END").await?; + + s.expect(":testserver 001 tester :Welcome to Kek Server").await?; + s.expect(":testserver 002 tester :Welcome to Kek Server").await?; + s.expect(":testserver 003 tester :Welcome to Kek Server").await?; + s.expect(":testserver 004 tester testserver kek-0.1.alpha.3 r CFILPQbcefgijklmnopqrstvz").await?; + s.expect(":testserver 005 tester CHANTYPES=# :are supported by this server").await?; + s.expect_nothing().await?; + s.send("QUIT :Leaving").await?; + s.expect(":testserver ERROR :Leaving the server").await?; + s.expect_eof().await?; + + stream.shutdown().await?; + + // wrap up + + server.server.terminate().await?; + Ok(()) +} diff --git a/crates/proto-irc/src/lib.rs b/crates/proto-irc/src/lib.rs index 59a4314..b56b6e3 100644 --- a/crates/proto-irc/src/lib.rs +++ b/crates/proto-irc/src/lib.rs @@ -30,6 +30,10 @@ fn token(input: &str) -> IResult<&str, &str> { take_while(|i| i != '\n' && i != '\r')(input) } +fn params(input: &str) -> IResult<&str, &str> { + take_while(|i| i != '\n' && i != '\r' && i != ':')(input) +} + #[derive(Clone, Debug, PartialEq, Eq)] pub enum Chan { /// # — network-global channel, available from any server in the network. diff --git a/crates/proto-irc/src/server.rs b/crates/proto-irc/src/server.rs index fdcccb1..bc7844d 100644 --- a/crates/proto-irc/src/server.rs +++ b/crates/proto-irc/src/server.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use nonempty::NonEmpty; use tokio::io::AsyncWrite; use tokio::io::AsyncWriteExt; @@ -152,6 +154,10 @@ pub enum ServerMessageBody { N903SaslSuccess { nick: Str, message: Str, + }, + N904SaslFail { + nick: Str, + text: Str, } } @@ -369,6 +375,13 @@ impl ServerMessageBody { writer.write_all(b" :").await?; writer.write_all(message.as_bytes()).await?; } + ServerMessageBody::N904SaslFail { nick, text } => { + writer.write_all(b"904").await?; + writer.write_all(b" ").await?; + writer.write_all(nick.as_bytes()).await?; + writer.write_all(b" :").await?; + writer.write_all(text.as_bytes()).await?; + } } Ok(()) } @@ -391,7 +404,7 @@ fn server_message_body(input: &str) -> IResult<&str, ServerMessageBody> { server_message_body_notice, server_message_body_ping, server_message_body_pong, - server_message_body_cap, + server_message_body_cap ))(input) }