diff --git a/crates/projection-irc/src/lib.rs b/crates/projection-irc/src/lib.rs index ee7fbb8..82c3d29 100644 --- a/crates/projection-irc/src/lib.rs +++ b/crates/projection-irc/src/lib.rs @@ -180,16 +180,22 @@ async fn handle_registration<'a>( writer.flush().await?; } CapabilitySubcommand::End => { - let Some((username, realname)) = future_username else { - todo!() + let Some((ref username, ref realname)) = future_username else { + sasl_fail_message(config.server_name.clone()).write_async(writer).await?; + writer.flush().await?; + continue; }; let Some(nickname) = future_nickname.clone() else { - todo!() + sasl_fail_message(config.server_name.clone()).write_async(writer).await?; + writer.flush().await?; + continue; }; + let username = username.clone(); + let realname = realname.clone(); let candidate_user = RegisteredUser { nickname, username, - realname, + realname }; if enabled_capabilities.contains(Capabilities::Sasl) && validated_user.as_ref() == Some(&candidate_user.nickname) @@ -197,7 +203,9 @@ async fn handle_registration<'a>( break Ok(candidate_user); } else { let Some(candidate_password) = pass else { - todo!(); + sasl_fail_message(config.server_name.clone()).write_async(writer).await?; + writer.flush().await?; + continue; }; auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?; break Ok(candidate_user); @@ -214,7 +222,9 @@ async fn handle_registration<'a>( realname, }; let Some(candidate_password) = pass else { - todo!(); + sasl_fail_message(config.server_name.clone()).write_async(writer).await?; + writer.flush().await?; + continue; }; auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?; break Ok(candidate_user); @@ -232,7 +242,9 @@ async fn handle_registration<'a>( realname, }; let Some(candidate_password) = pass else { - todo!(); + sasl_fail_message(config.server_name.clone()).write_async(writer).await?; + writer.flush().await?; + continue; }; auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?; break Ok(candidate_user); @@ -255,37 +267,45 @@ async fn handle_registration<'a>( .await?; writer.flush().await?; } else { - // TODO respond with 904 - todo!(); + sasl_fail_message(config.server_name.clone()).write_async(writer).await?; + writer.flush().await?; } } else { let body = AuthBody::from_str(body.as_bytes())?; - auth_user(storage, &body.login, &body.password).await?; - let login: Str = body.login.into(); - validated_user = Some(login.clone()); - ServerMessage { - tags: vec![], - sender: Some(config.server_name.clone().into()), - body: ServerMessageBody::N900LoggedIn { - nick: login.clone(), - address: login.clone(), - account: login.clone(), - message: format!("You are now logged in as {}", login).into(), - }, + match auth_user(storage, &body.login, &body.password).await { + Err(e) => { + tracing::warn!("Authentication failed: {:?}", e); + sasl_fail_message(config.server_name.clone()).write_async(writer).await?; + writer.flush().await?; + } + Ok(_) => { + let login: Str = body.login.into(); + validated_user = Some(login.clone()); + ServerMessage { + tags: vec![], + sender: Some(config.server_name.clone().into()), + body: ServerMessageBody::N900LoggedIn { + nick: login.clone(), + address: login.clone(), + account: login.clone(), + message: format!("You are now logged in as {}", login).into(), + }, + } + .write_async(writer) + .await?; + ServerMessage { + tags: vec![], + sender: Some(config.server_name.clone().into()), + body: ServerMessageBody::N903SaslSuccess { + nick: login.clone(), + message: "SASL authentication successful".into(), + }, + } + .write_async(writer) + .await?; + writer.flush().await?; + } } - .write_async(writer) - .await?; - ServerMessage { - tags: vec![], - sender: Some(config.server_name.clone().into()), - body: ServerMessageBody::N903SaslSuccess { - nick: login.clone(), - message: "SASL authentication successful".into(), - }, - } - .write_async(writer) - .await?; - writer.flush().await?; } // TODO handle abortion of authentication } @@ -297,6 +317,14 @@ async fn handle_registration<'a>( Ok(user) } +fn sasl_fail_message(sender: Str) -> ServerMessage { + ServerMessage { + tags: vec![], + sender: Some(sender), + body: ServerMessageBody::N904SaslFail + } +} + async fn auth_user(storage: &mut Storage, login: &str, plain_password: &str) -> Result<()> { let stored_user = storage.retrieve_user_by_name(login).await?; diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index c102143..7ddf427 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -218,3 +218,49 @@ async fn scenario_cap_short_negotiation() -> Result<()> { server.server.terminate().await?; Ok(()) } + +#[tokio::test] +async fn scenario_cap_sasl_fail() -> Result<()> { + let mut server = TestServer::start().await?; + + // test scenario + + server.storage.create_user("tester").await?; + server.storage.set_password("tester", "password").await?; + + let mut stream = TcpStream::connect(server.server.addr).await?; + let mut s = TestScope::new(&mut stream); + + s.send("CAP LS 302").await?; + s.send("NICK tester").await?; + s.send("USER UserName 0 * :Real Name").await?; + s.expect(":testserver CAP * LS :sasl=PLAIN").await?; + s.send("CAP REQ :sasl").await?; + s.expect(":testserver CAP tester ACK :sasl").await?; + s.send("AUTHENTICATE SHA256").await?; + s.expect(":testserver 904").await?; + s.send("AUTHENTICATE PLAIN").await?; + s.expect(":testserver AUTHENTICATE +").await?; + s.send("AUTHENTICATE dGVzdGVyAHRlc3RlcgBwYXNzd29yZA==").await?; // base64-encoded 'tester\x00tester\x00password' + s.expect(":testserver 900 tester tester tester :You are now logged in as tester").await?; + s.expect(":testserver 903 tester :SASL authentication successful").await?; + + s.send("CAP END").await?; + + s.expect(":testserver 001 tester :Welcome to Kek Server").await?; + s.expect(":testserver 002 tester :Welcome to Kek Server").await?; + s.expect(":testserver 003 tester :Welcome to Kek Server").await?; + s.expect(":testserver 004 tester testserver kek-0.1.alpha.3 r CFILPQbcefgijklmnopqrstvz").await?; + s.expect(":testserver 005 tester CHANTYPES=# :are supported by this server").await?; + s.expect_nothing().await?; + s.send("QUIT :Leaving").await?; + s.expect(":testserver ERROR :Leaving the server").await?; + s.expect_eof().await?; + + stream.shutdown().await?; + + // wrap up + + server.server.terminate().await?; + Ok(()) +} diff --git a/crates/proto-irc/src/lib.rs b/crates/proto-irc/src/lib.rs index 59a4314..b56b6e3 100644 --- a/crates/proto-irc/src/lib.rs +++ b/crates/proto-irc/src/lib.rs @@ -30,6 +30,10 @@ fn token(input: &str) -> IResult<&str, &str> { take_while(|i| i != '\n' && i != '\r')(input) } +fn params(input: &str) -> IResult<&str, &str> { + take_while(|i| i != '\n' && i != '\r' && i != ':')(input) +} + #[derive(Clone, Debug, PartialEq, Eq)] pub enum Chan { /// # — network-global channel, available from any server in the network. diff --git a/crates/proto-irc/src/server.rs b/crates/proto-irc/src/server.rs index fdcccb1..aff864d 100644 --- a/crates/proto-irc/src/server.rs +++ b/crates/proto-irc/src/server.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use nonempty::NonEmpty; use tokio::io::AsyncWrite; use tokio::io::AsyncWriteExt; @@ -50,6 +52,12 @@ pub enum ServerMessageBody { rest_targets: Vec, text: Str, }, + Fail { + command: Str, + code: Str, + params: Option, + text: Str + }, Ping { token: Str, }, @@ -152,7 +160,8 @@ pub enum ServerMessageBody { N903SaslSuccess { nick: Str, message: Str, - } + }, + N904SaslFail } impl ServerMessageBody { @@ -168,6 +177,21 @@ impl ServerMessageBody { writer.write_all(b" :").await?; writer.write_all(text.as_bytes()).await?; } + ServerMessageBody::Fail { command, code, params, text } => + { + writer.write_all(b"FAIL ").await?; + writer.write_all(command.as_bytes()).await?; + writer.write_all(b" ").await?; + writer.write_all(code.as_bytes()).await?; + if let Some(p) = params.clone() { + writer.write_all(b" ").await?; + writer.write_all(p.as_bytes()).await?; + } else { + (); + } + writer.write_all(b" :").await?; + writer.write_all(text.as_bytes()).await?; + } ServerMessageBody::Ping { token } => { writer.write_all(b"PING ").await?; writer.write_all(token.as_bytes()).await?; @@ -369,6 +393,9 @@ impl ServerMessageBody { writer.write_all(b" :").await?; writer.write_all(message.as_bytes()).await?; } + ServerMessageBody::N904SaslFail => { + writer.write_all(b"904").await?; + } } Ok(()) } @@ -392,6 +419,7 @@ fn server_message_body(input: &str) -> IResult<&str, ServerMessageBody> { server_message_body_ping, server_message_body_pong, server_message_body_cap, + server_message_body_fail, ))(input) } @@ -413,6 +441,28 @@ fn server_message_body_notice(input: &str) -> IResult<&str, ServerMessageBody> { )) } +fn server_message_body_fail(input: &str) -> IResult<&str, ServerMessageBody> { + let (input, _) = tag("FAIL ")(input)?; + let (input, command) = receiver(input)?; + let (input, _) = tag(" ")(input)?; + let (input, code) = receiver(input)?; + let (input, _) = tag(" ")(input)?; + let (input, params) = params(input)?; + let (input, _) = tag(":")(input)?; + let (input, text) = token(input)?; + + let command = command.into(); + let code = code.into(); + let params: Arc = params.trim().into(); + let params = if params.is_empty() { None } else { Some(params) }; + let text = text.into(); + + Ok(( + input, + ServerMessageBody::Fail { command, code, params, text } + )) +} + fn server_message_body_ping(input: &str) -> IResult<&str, ServerMessageBody> { let (input, _) = tag("PING ")(input)?; let (input, token) = token(input)?; @@ -517,4 +567,48 @@ mod test { sync_future(expected.write_async(&mut bytes)).unwrap().unwrap(); assert_eq!(bytes, input.as_bytes()); } + + #[test] + fn test_server_message_fail() { + let input = "FAIL BOX BOXES_INVALID STACK CLOCKWISE :Given boxes are not supported\r\n"; + let expected = ServerMessage { + tags: vec![], + sender: None, + body: ServerMessageBody::Fail { + command: "BOX".into(), + code: "BOXES_INVALID".into(), + params: Some("STACK CLOCKWISE".into()), + text: "Given boxes are not supported".into(), + }, + }; + + let result = server_message(input); + assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result)); + + let mut bytes = vec![]; + sync_future(expected.write_async(&mut bytes)).unwrap().unwrap(); + assert_eq!(bytes, input.as_bytes()); + } + + #[test] + fn test_server_short_message_fail() { + let input = "FAIL REHASH CONFIG_BAD :Could not reload config from disk\r\n"; + let expected = ServerMessage { + tags: vec![], + sender: None, + body: ServerMessageBody::Fail { + command: "REHASH".into(), + code: "CONFIG_BAD".into(), + params: None, + text: "Could not reload config from disk".into(), + }, + }; + + let result = server_message(input); + assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result)); + + let mut bytes = vec![]; + sync_future(expected.write_async(&mut bytes)).unwrap().unwrap(); + assert_eq!(bytes, input.as_bytes()); + } }