forked from lavina/lavina
1
0
Fork 0

remove nesting level, use protocol 3.1 version format for numeric errors

This commit is contained in:
G1ng3r 2024-01-30 22:17:39 +03:00
parent 75ac0a3369
commit 7f220c620f
3 changed files with 71 additions and 48 deletions

View File

@ -181,19 +181,15 @@ async fn handle_registration<'a>(
} }
CapabilitySubcommand::End => { CapabilitySubcommand::End => {
let Some((ref username, ref realname)) = future_username else { let Some((ref username, ref realname)) = future_username else {
sasl_fail_message(config.server_name.clone()).write_async(writer).await?; break Err(anyhow::Error::msg("Protocol violated"));
writer.flush().await?;
continue;
}; };
let Some(nickname) = future_nickname.clone() else { let Some(nickname) = future_nickname.clone() else {
sasl_fail_message(config.server_name.clone()).write_async(writer).await?; break Err(anyhow::Error::msg("Protocol violated"));
writer.flush().await?;
continue;
}; };
let username = username.clone(); let username = username.clone();
let realname = realname.clone(); let realname = realname.clone();
let candidate_user = RegisteredUser { let candidate_user = RegisteredUser {
nickname, nickname: nickname.clone(),
username, username,
realname realname
}; };
@ -203,7 +199,11 @@ async fn handle_registration<'a>(
break Ok(candidate_user); break Ok(candidate_user);
} else { } else {
let Some(candidate_password) = pass else { let Some(candidate_password) = pass else {
sasl_fail_message(config.server_name.clone()).write_async(writer).await?; sasl_fail_message(
config.server_name.clone(),
nickname.clone(),
"User credentials was not provided".into()
).write_async(writer).await?;
writer.flush().await?; writer.flush().await?;
continue; continue;
}; };
@ -217,12 +217,16 @@ async fn handle_registration<'a>(
future_nickname = Some(nickname); future_nickname = Some(nickname);
} else if let Some((username, realname)) = future_username.clone() { } else if let Some((username, realname)) = future_username.clone() {
let candidate_user = RegisteredUser { let candidate_user = RegisteredUser {
nickname, nickname: nickname.clone(),
username, username,
realname, realname,
}; };
let Some(candidate_password) = pass else { let Some(candidate_password) = pass else {
sasl_fail_message(config.server_name.clone()).write_async(writer).await?; sasl_fail_message(
config.server_name.clone(),
nickname.clone(),
"User credentials was not provided".into()
).write_async(writer).await?;
writer.flush().await?; writer.flush().await?;
continue; continue;
}; };
@ -237,12 +241,16 @@ async fn handle_registration<'a>(
future_username = Some((username, realname)); future_username = Some((username, realname));
} else if let Some(nickname) = future_nickname.clone() { } else if let Some(nickname) = future_nickname.clone() {
let candidate_user = RegisteredUser { let candidate_user = RegisteredUser {
nickname, nickname: nickname.clone(),
username, username,
realname, realname,
}; };
let Some(candidate_password) = pass else { let Some(candidate_password) = pass else {
sasl_fail_message(config.server_name.clone()).write_async(writer).await?; sasl_fail_message(
config.server_name.clone(),
nickname.clone(),
"User credentials was not provided".into()
).write_async(writer).await?;
writer.flush().await?; writer.flush().await?;
continue; continue;
}; };
@ -267,18 +275,24 @@ async fn handle_registration<'a>(
.await?; .await?;
writer.flush().await?; writer.flush().await?;
} else { } else {
sasl_fail_message(config.server_name.clone()).write_async(writer).await?; if let Some(nickname) = future_nickname.clone() {
sasl_fail_message(config.server_name.clone(), nickname.clone(), "Unsupported mechanism".into()).write_async(writer).await?;
writer.flush().await?; writer.flush().await?;
} else {
break Err(anyhow::Error::msg("Wrong authentication sequence"));
}
} }
} else { } else {
let body = AuthBody::from_str(body.as_bytes())?; let body = AuthBody::from_str(body.as_bytes())?;
match auth_user(storage, &body.login, &body.password).await { if let Err(e) = auth_user(storage, &body.login, &body.password).await {
Err(e) => {
tracing::warn!("Authentication failed: {:?}", e); tracing::warn!("Authentication failed: {:?}", e);
sasl_fail_message(config.server_name.clone()).write_async(writer).await?; if let Some(nickname) = future_nickname.clone() {
sasl_fail_message(config.server_name.clone(), nickname.clone(), "Bad credentials".into()).write_async(writer).await?;
writer.flush().await?; writer.flush().await?;
} else {
} }
Ok(_) => { } else {
let login: Str = body.login.into(); let login: Str = body.login.into();
validated_user = Some(login.clone()); validated_user = Some(login.clone());
ServerMessage { ServerMessage {
@ -306,7 +320,7 @@ async fn handle_registration<'a>(
writer.flush().await?; writer.flush().await?;
} }
} }
}
// TODO handle abortion of authentication // TODO handle abortion of authentication
} }
_ => {} _ => {}
@ -317,11 +331,11 @@ async fn handle_registration<'a>(
Ok(user) Ok(user)
} }
fn sasl_fail_message(sender: Str) -> ServerMessage { fn sasl_fail_message(sender: Str, nick: Str, text: Str) -> ServerMessage {
ServerMessage { ServerMessage {
tags: vec![], tags: vec![],
sender: Some(sender), sender: Some(sender),
body: ServerMessageBody::N904SaslFail body: ServerMessageBody::N904SaslFail { nick, text }
} }
} }

View File

@ -241,6 +241,8 @@ async fn scenario_cap_sasl_fail() -> Result<()> {
s.expect(":testserver 904").await?; s.expect(":testserver 904").await?;
s.send("AUTHENTICATE PLAIN").await?; s.send("AUTHENTICATE PLAIN").await?;
s.expect(":testserver AUTHENTICATE +").await?; s.expect(":testserver AUTHENTICATE +").await?;
s.send("AUTHENTICATE wrong_password").await?;
s.expect(":testserver 904").await?;
s.send("AUTHENTICATE dGVzdGVyAHRlc3RlcgBwYXNzd29yZA==").await?; // base64-encoded 'tester\x00tester\x00password' s.send("AUTHENTICATE dGVzdGVyAHRlc3RlcgBwYXNzd29yZA==").await?; // base64-encoded 'tester\x00tester\x00password'
s.expect(":testserver 900 tester tester tester :You are now logged in as tester").await?; s.expect(":testserver 900 tester tester tester :You are now logged in as tester").await?;
s.expect(":testserver 903 tester :SASL authentication successful").await?; s.expect(":testserver 903 tester :SASL authentication successful").await?;

View File

@ -161,7 +161,10 @@ pub enum ServerMessageBody {
nick: Str, nick: Str,
message: Str, message: Str,
}, },
N904SaslFail N904SaslFail {
nick: Str,
text: Str,
}
} }
impl ServerMessageBody { impl ServerMessageBody {
@ -393,8 +396,12 @@ impl ServerMessageBody {
writer.write_all(b" :").await?; writer.write_all(b" :").await?;
writer.write_all(message.as_bytes()).await?; writer.write_all(message.as_bytes()).await?;
} }
ServerMessageBody::N904SaslFail => { ServerMessageBody::N904SaslFail { nick, text } => {
writer.write_all(b"904").await?; writer.write_all(b"904").await?;
writer.write_all(b" ").await?;
writer.write_all(nick.as_bytes()).await?;
writer.write_all(b" :").await?;
writer.write_all(text.as_bytes()).await?;
} }
} }
Ok(()) Ok(())