From e43c6723445207e676c5b54384495da184758678 Mon Sep 17 00:00:00 2001 From: Jeroen van Erp Date: Fri, 23 Sep 2022 22:42:57 +0200 Subject: [PATCH] Retry authentication with all remaining auth methods after partial success Signed-off-by: Jeroen van Erp --- src/main/java/net/schmizz/sshj/SSHClient.java | 22 +++++++++++++++++-- .../net/schmizz/sshj/userauth/AuthResult.java | 7 ++++++ .../net/schmizz/sshj/userauth/UserAuth.java | 4 ++-- .../schmizz/sshj/userauth/UserAuthImpl.java | 16 ++++++++------ 4 files changed, 38 insertions(+), 11 deletions(-) create mode 100644 src/main/java/net/schmizz/sshj/userauth/AuthResult.java diff --git a/src/main/java/net/schmizz/sshj/SSHClient.java b/src/main/java/net/schmizz/sshj/SSHClient.java index 70b715896..e0e474c40 100644 --- a/src/main/java/net/schmizz/sshj/SSHClient.java +++ b/src/main/java/net/schmizz/sshj/SSHClient.java @@ -40,6 +40,7 @@ import net.schmizz.sshj.transport.verification.FingerprintVerifier; import net.schmizz.sshj.transport.verification.HostKeyVerifier; import net.schmizz.sshj.transport.verification.OpenSSHKnownHosts; +import net.schmizz.sshj.userauth.AuthResult; import net.schmizz.sshj.userauth.UserAuth; import net.schmizz.sshj.userauth.UserAuthException; import net.schmizz.sshj.userauth.UserAuthImpl; @@ -218,13 +219,30 @@ public void auth(String username, Iterable methods) throws UserAuthException, TransportException { checkConnected(); final Deque savedEx = new LinkedList(); - for (AuthMethod method: methods) { + final List tried = new LinkedList(); + + for (Iterator it = methods.iterator(); it.hasNext();) { + AuthMethod method = it.next(); method.setLoggerFactory(loggerFactory); + try { - if (auth.authenticate(username, (Service) conn, method, trans.getTimeoutMs())) + AuthResult result = auth.authenticate(username, (Service) conn, method, trans.getTimeoutMs()); + + if (result == AuthResult.SUCCESS) { return; + } else if (result == AuthResult.PARTIAL) { + // Put all remaining methods in the tried list, so that we can try them for the second round of authentication + while (it.hasNext()) { + tried.add(it.next()); + } + + auth(username, tried); + return; + } + tried.add(method); } catch (UserAuthException e) { savedEx.push(e); + tried.add(method); } } throw new UserAuthException("Exhausted available authentication methods", savedEx.peek()); diff --git a/src/main/java/net/schmizz/sshj/userauth/AuthResult.java b/src/main/java/net/schmizz/sshj/userauth/AuthResult.java new file mode 100644 index 000000000..1fcc3211d --- /dev/null +++ b/src/main/java/net/schmizz/sshj/userauth/AuthResult.java @@ -0,0 +1,7 @@ +package net.schmizz.sshj.userauth; + +public enum AuthResult { + SUCCESS, + FAILURE, + PARTIAL +} diff --git a/src/main/java/net/schmizz/sshj/userauth/UserAuth.java b/src/main/java/net/schmizz/sshj/userauth/UserAuth.java index dea91b35b..c41c3a48c 100644 --- a/src/main/java/net/schmizz/sshj/userauth/UserAuth.java +++ b/src/main/java/net/schmizz/sshj/userauth/UserAuth.java @@ -37,12 +37,12 @@ public interface UserAuth { * @param nextService the service to set on successful authentication * @param methods the {@link AuthMethod}'s to try * - * @return whether authentication was successful + * @return whether authentication was successful, failed, or partially successful * * @throws UserAuthException in case of authentication failure * @throws TransportException if there was a transport-layer error */ - boolean authenticate(String username, Service nextService, AuthMethod methods, int timeoutMs) + AuthResult authenticate(String username, Service nextService, AuthMethod methods, int timeoutMs) throws UserAuthException, TransportException; /** diff --git a/src/main/java/net/schmizz/sshj/userauth/UserAuthImpl.java b/src/main/java/net/schmizz/sshj/userauth/UserAuthImpl.java index 26499e8c0..bc50c08e9 100644 --- a/src/main/java/net/schmizz/sshj/userauth/UserAuthImpl.java +++ b/src/main/java/net/schmizz/sshj/userauth/UserAuthImpl.java @@ -40,7 +40,7 @@ public class UserAuthImpl extends AbstractService implements UserAuth { - private final Promise authenticated; + private final Promise authenticated; // Externally available private volatile String banner = ""; @@ -53,13 +53,13 @@ public class UserAuthImpl public UserAuthImpl(Transport trans) { super("ssh-userauth", trans); - authenticated = new Promise("authenticated", UserAuthException.chainer, trans.getConfig().getLoggerFactory()); + authenticated = new Promise("authenticated", UserAuthException.chainer, trans.getConfig().getLoggerFactory()); } @Override - public boolean authenticate(String username, Service nextService, AuthMethod method, int timeoutMs) + public AuthResult authenticate(String username, Service nextService, AuthMethod method, int timeoutMs) throws UserAuthException, TransportException { - final boolean outcome; + final AuthResult outcome; authenticated.lock(); try { @@ -73,8 +73,10 @@ public boolean authenticate(String username, Service nextService, AuthMethod met currentMethod.request(); outcome = authenticated.retrieve(timeoutMs, TimeUnit.MILLISECONDS); - if (outcome) { + if (outcome == AuthResult.SUCCESS) { log.debug("`{}` auth successful", method.getName()); + } else if (outcome == AuthResult.PARTIAL) { + log.debug("`{}` auth partially successful", method.getName()); } else { log.debug("`{}` auth failed", method.getName()); } @@ -124,7 +126,7 @@ public void handle(Message msg, SSHPacket buf) // Should fix https://github.com/hierynomus/sshj/issues/237 trans.setAuthenticated(); // So it can put delayed compression into force if applicable trans.setService(nextService); // We aren't in charge anymore, next service is - authenticated.deliver(true); + authenticated.deliver(AuthResult.SUCCESS); break; case USERAUTH_FAILURE: @@ -133,7 +135,7 @@ public void handle(Message msg, SSHPacket buf) if (allowedMethods.contains(currentMethod.getName()) && currentMethod.shouldRetry()) { currentMethod.request(); } else { - authenticated.deliver(false); + authenticated.deliver(partialSuccess ? AuthResult.PARTIAL : AuthResult.FAILURE); } break;