Skip to content

Commit

Permalink
EIT-437 Add retries to reading from vault (#30)
Browse files Browse the repository at this point in the history
* Add retries to reading from vault

We occasionally see transient errors talking to Vault. This results in
the occasional `E_AUTH_FAILED` error being logged in our applications
and via `nsqd`. Because these are transient, we should just pause for a
moment and retry.

* only retry on socket timeouts

* java style sniping

* java style sniping

---------

Co-authored-by: Jack Sadanowicz <[email protected]>
  • Loading branch information
danrjohnson and jsadn authored Jul 27, 2024
1 parent d089667 commit 7beb7d5
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 10 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
.idea/
target/
config.yml
dependency-reduced-pom.xml
dependency-reduced-pom.xml
.DS_Store
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
import com.bettercloud.vault.Vault;
import com.bettercloud.vault.VaultException;
import com.bettercloud.vault.response.LogicalResponse;
import com.bettercloud.vault.rest.RestException;
import com.codahale.metrics.Counter;
import com.codahale.metrics.MetricRegistry;
import com.sproutsocial.nsqauthj.NsqAuthJConfiguration;
import com.sproutsocial.nsqauthj.configuration.TokenValidationFactory;
import com.sproutsocial.nsqauthj.tokens.NsqToken;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.SocketTimeoutException;
import java.util.Optional;
import java.util.concurrent.TimeUnit;

public class VaultTokenValidator {
private static final Logger logger = LoggerFactory.getLogger(VaultTokenValidator.class);
Expand All @@ -28,6 +29,9 @@ public class VaultTokenValidator {
private final Counter userCounter;
private final Counter publishOnlyCounter;

private static final int MAX_RETRIES = 3;
private static final int RETRY_DELAY = 200; // time before retrying in ms

public VaultTokenValidator(Vault vault, String userTokenPath, String serviceTokenPath,
int ttl, Boolean failOpen, MetricRegistry metricRegistry) {
this.vault = vault;
Expand All @@ -42,14 +46,21 @@ public VaultTokenValidator(Vault vault, String userTokenPath, String serviceToke


public Optional<NsqToken> validateTokenAtPath(String token, String path, NsqToken.TYPE type, String remoteAddr) {
LogicalResponse response = null;
try {
response = this.vault.logical().read(path + token);
} catch (VaultException e) {
e.printStackTrace();
return Optional.empty();
for (int i = 1; i <= MAX_RETRIES; i++) {
try {
LogicalResponse response = this.vault.logical().read(path + token);
return NsqToken.fromVaultResponse(response, token, type, ttl, remoteAddr);
} catch (VaultException e) {
if (isTimeout(e)) {
logger.warn("Timed out reading from Vault, retrying...", e);
wait(200, TimeUnit.MILLISECONDS);
continue;
}
logger.warn(e.getMessage(), e);
break;
}
}
return NsqToken.fromVaultResponse(response, token, type, ttl, remoteAddr);
return Optional.empty();
}

public Optional<NsqToken> validateUserToken(String token, String remoteAddr) {
Expand Down Expand Up @@ -89,4 +100,21 @@ public Optional<NsqToken> validateToken(String token, String remoteAddr) {
public Boolean getFailOpen() {
return failOpen;
}

private static boolean isTimeout(VaultException e) {
return e.getCause() != null
&& e.getCause() instanceof RestException
&& e.getCause().getCause() != null
&& e.getCause().getCause() instanceof SocketTimeoutException;
}



private void wait(int value, TimeUnit unit) {
try {
Thread.sleep(unit.toMillis(value));
} catch (InterruptedException e) {
logger.error("Thread interrupted while sleeping");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.bettercloud.vault.VaultException;
import com.bettercloud.vault.api.Logical;
import com.bettercloud.vault.response.LogicalResponse;
import com.bettercloud.vault.rest.RestException;
import com.codahale.metrics.Counter;
import com.codahale.metrics.MetricRegistry;
import com.sproutsocial.nsqauthj.tokens.NsqToken;
Expand All @@ -12,6 +13,7 @@
import org.mockito.Mockito;
import org.mockito.internal.matchers.Any;

import java.net.SocketTimeoutException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -59,6 +61,56 @@ public void validateTokenAtPathError() throws VaultException {
assertFalse(optionalNsqToken.isPresent());
}

@Test
public void testVaultErrorFollowedBySuccess() throws VaultException {
Logical logicalMock = mock(Logical.class);
when(mockVault.logical()).thenReturn(logicalMock);

LogicalResponse logicalResponseMock = mock(LogicalResponse.class);
Map<String, String> responseData = new HashMap<>();
responseData.put("username", "some.developer");
responseData.put("topics", "tw_engagement,fb_post");
when(logicalResponseMock.getData()).thenReturn(responseData);


RestException restException = mock(RestException.class);
when(restException.getCause()).thenReturn(new SocketTimeoutException());

VaultException vaultException = new VaultException(restException);

given(mockVault.logical().read(userTokenPath + token)).willAnswer(invocationOnMock -> { throw vaultException; }).willReturn(logicalResponseMock);

Optional<NsqToken> optionalNsqToken = vaultTokenValidator.validateTokenAtPath(token, userTokenPath, NsqToken.TYPE.USER, ip);

verify(mockVault.logical(), times(2)).read(userTokenPath + token);
assertTrue(optionalNsqToken.isPresent());
}

@Test
public void testVaultErrorNotRetried() throws VaultException {
Logical logicalMock = mock(Logical.class);
when(mockVault.logical()).thenReturn(logicalMock);

LogicalResponse logicalResponseMock = mock(LogicalResponse.class);
Map<String, String> responseData = new HashMap<>();
responseData.put("username", "some.developer");
responseData.put("topics", "tw_engagement,fb_post");
when(logicalResponseMock.getData()).thenReturn(responseData);


RestException restException = mock(RestException.class);
when(restException.getCause()).thenReturn(new InterruptedException());

VaultException vaultException = new VaultException(restException);

given(mockVault.logical().read(userTokenPath + token)).willAnswer(invocationOnMock -> { throw vaultException; }).willReturn(logicalResponseMock);

Optional<NsqToken> optionalNsqToken = vaultTokenValidator.validateTokenAtPath(token, userTokenPath, NsqToken.TYPE.USER, ip);

verify(mockVault.logical(), times(1)).read(userTokenPath + token);
assertFalse(optionalNsqToken.isPresent());
}

@Test
public void validateTokenAtPathValid() throws VaultException {
LogicalResponse logicalResponseMock = mock(LogicalResponse.class);
Expand Down

0 comments on commit 7beb7d5

Please sign in to comment.