Scenario: I am using OkHttp / Retrofit to access a web service: multiple HTTP requests are sent out at the same time. At some point the auth token expires, and
I had the same problem and I managed to solve it using a ReentrantLock.
import java.io.IOException;
import java.net.HttpURLConnection;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import okhttp3.Interceptor;
import okhttp3.Request;
import okhttp3.Response;
import timber.log.Timber;
public class RefreshTokenInterceptor implements Interceptor {
private Lock lock = new ReentrantLock();
@Override
public Response intercept(Interceptor.Chain chain) throws IOException {
Request request = chain.request();
Response response = chain.proceed(request);
if (response.code() == HttpURLConnection.HTTP_UNAUTHORIZED) {
// first thread will acquire the lock and start the refresh token
if (lock.tryLock()) {
Timber.i("refresh token thread holds the lock");
try {
// this sync call will refresh the token and save it for
// later use (e.g. sharedPreferences)
authenticationService.refreshTokenSync();
Request newRequest = recreateRequestWithNewAccessToken(chain);
return chain.proceed(newRequest);
} catch (ServiceException exception) {
// depending on what you need to do you can logout the user at this
// point or throw an exception and handle it in your onFailure callback
return response;
} finally {
Timber.i("refresh token finished. release lock");
lock.unlock();
}
} else {
Timber.i("wait for token to be refreshed");
lock.lock(); // this will block the thread until the thread that is refreshing
// the token will call .unlock() method
lock.unlock();
Timber.i("token refreshed. retry request");
Request newRequest = recreateRequestWithNewAccessToken(chain);
return chain.proceed(newRequest);
}
} else {
return response;
}
}
private Request recreateRequestWithNewAccessToken(Chain chain) {
String freshAccessToken = sharedPreferences.getAccessToken();
Timber.d("[freshAccessToken] %s", freshAccessToken);
return chain.request().newBuilder()
.header("access_token", freshAccessToken)
.build();
}
}
The main advantage of using this solution is that you can write an unit test using mockito and test it. You will have to enable Mockito Incubating feature for mocking final classes (response from okhttp). Read more about here. The test looks something like this:
@RunWith(MockitoJUnitRunner.class)
public class RefreshTokenInterceptorTest {
private static final String FRESH_ACCESS_TOKEN = "fresh_access_token";
@Mock
AuthenticationService authenticationService;
@Mock
RefreshTokenStorage refreshTokenStorage;
@Mock
Interceptor.Chain chain;
@BeforeClass
public static void setup() {
Timber.plant(new Timber.DebugTree() {
@Override
protected void log(int priority, String tag, String message, Throwable t) {
System.out.println(Thread.currentThread() + " " + message);
}
});
}
@Test
public void refreshTokenInterceptor_works_as_expected() throws IOException, InterruptedException {
Response unauthorizedResponse = createUnauthorizedResponse();
when(chain.proceed((Request) any())).thenReturn(unauthorizedResponse);
when(authenticationService.refreshTokenSync()).thenAnswer(new Answer() {
@Override
public Boolean answer(InvocationOnMock invocation) throws Throwable {
//refresh token takes some time
Thread.sleep(10);
return true;
}
});
when(refreshTokenStorage.getAccessToken()).thenReturn(FRESH_ACCESS_TOKEN);
Request fakeRequest = createFakeRequest();
when(chain.request()).thenReturn(fakeRequest);
final Interceptor interceptor = new RefreshTokenInterceptor(authenticationService, refreshTokenStorage);
Timber.d("5 requests try to refresh token at the same time");
final CountDownLatch countDownLatch5 = new CountDownLatch(5);
for (int i = 0; i < 5; i++) {
new Thread(new Runnable() {
@Override
public void run() {
try {
interceptor.intercept(chain);
countDownLatch5.countDown();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}).start();
}
countDownLatch5.await();
verify(authenticationService, times(1)).refreshTokenSync();
Timber.d("next time another 3 threads try to refresh the token at the same time");
final CountDownLatch countDownLatch3 = new CountDownLatch(3);
for (int i = 0; i < 3; i++) {
new Thread(new Runnable() {
@Override
public void run() {
try {
interceptor.intercept(chain);
countDownLatch3.countDown();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}).start();
}
countDownLatch3.await();
verify(authenticationService, times(2)).refreshTokenSync();
Timber.d("1 thread tries to refresh the token");
interceptor.intercept(chain);
verify(authenticationService, times(3)).refreshTokenSync();
}
private Response createUnauthorizedResponse() throws IOException {
Response response = mock(Response.class);
when(response.code()).thenReturn(401);
return response;
}
private Request createFakeRequest() {
Request request = mock(Request.class);
Request.Builder fakeBuilder = createFakeBuilder();
when(request.newBuilder()).thenReturn(fakeBuilder);
return request;
}
private Request.Builder createFakeBuilder() {
Request.Builder mockBuilder = mock(Request.Builder.class);
when(mockBuilder.header("access_token", FRESH_ACCESS_TOKEN)).thenReturn(mockBuilder);
return mockBuilder;
}
}