Android OkHttp, refresh expired token

后端 未结 5 1140
我寻月下人不归
我寻月下人不归 2021-01-30 18:14

Scenario: I am using OkHttp / Retrofit to access a web service: multiple HTTP requests are sent out at the same time. At some point the auth token expires, and

5条回答
  •  不思量自难忘°
    2021-01-30 18:52

    I had the same problem and I managed to solve it using a ReentrantLock.

    import java.io.IOException;
    import java.net.HttpURLConnection;
    import java.util.concurrent.locks.Lock;
    import java.util.concurrent.locks.ReentrantLock;
    
    import okhttp3.Interceptor;
    import okhttp3.Request;
    import okhttp3.Response;
    import timber.log.Timber;
    
    public class RefreshTokenInterceptor implements Interceptor {
    
        private Lock lock = new ReentrantLock();
    
        @Override
        public Response intercept(Interceptor.Chain chain) throws IOException {
    
            Request request = chain.request();
            Response response = chain.proceed(request);
    
            if (response.code() == HttpURLConnection.HTTP_UNAUTHORIZED) {
    
                // first thread will acquire the lock and start the refresh token
                if (lock.tryLock()) {
                    Timber.i("refresh token thread holds the lock");
    
                    try {
                        // this sync call will refresh the token and save it for 
                        // later use (e.g. sharedPreferences)
                        authenticationService.refreshTokenSync();
                        Request newRequest = recreateRequestWithNewAccessToken(chain);
                        return chain.proceed(newRequest);
                    } catch (ServiceException exception) {
                        // depending on what you need to do you can logout the user at this 
                        // point or throw an exception and handle it in your onFailure callback
                        return response;
                    } finally {
                        Timber.i("refresh token finished. release lock");
                        lock.unlock();
                    }
    
                } else {
                    Timber.i("wait for token to be refreshed");
                    lock.lock(); // this will block the thread until the thread that is refreshing 
                                 // the token will call .unlock() method
                    lock.unlock();
                    Timber.i("token refreshed. retry request");
                    Request newRequest = recreateRequestWithNewAccessToken(chain);
                    return chain.proceed(newRequest);
                }
            } else {
                return response;
            }
        }
    
        private Request recreateRequestWithNewAccessToken(Chain chain) {
            String freshAccessToken = sharedPreferences.getAccessToken();
            Timber.d("[freshAccessToken] %s", freshAccessToken);
            return chain.request().newBuilder()
                    .header("access_token", freshAccessToken)
                    .build();
        }
    }
    

    The main advantage of using this solution is that you can write an unit test using mockito and test it. You will have to enable Mockito Incubating feature for mocking final classes (response from okhttp). Read more about here. The test looks something like this:

    @RunWith(MockitoJUnitRunner.class)
    public class RefreshTokenInterceptorTest {
    
        private static final String FRESH_ACCESS_TOKEN = "fresh_access_token";
    
        @Mock
        AuthenticationService authenticationService;
    
        @Mock
        RefreshTokenStorage refreshTokenStorage;
    
        @Mock
        Interceptor.Chain chain;
    
        @BeforeClass
        public static void setup() {
            Timber.plant(new Timber.DebugTree() {
    
                @Override
                protected void log(int priority, String tag, String message, Throwable t) {
                    System.out.println(Thread.currentThread() + " " + message);
                }
            });
        }
    
        @Test
        public void refreshTokenInterceptor_works_as_expected() throws IOException, InterruptedException {
    
            Response unauthorizedResponse = createUnauthorizedResponse();
            when(chain.proceed((Request) any())).thenReturn(unauthorizedResponse);
            when(authenticationService.refreshTokenSync()).thenAnswer(new Answer() {
                @Override
                public Boolean answer(InvocationOnMock invocation) throws Throwable {
                    //refresh token takes some time
                    Thread.sleep(10);
                    return true;
                }
            });
            when(refreshTokenStorage.getAccessToken()).thenReturn(FRESH_ACCESS_TOKEN);
            Request fakeRequest = createFakeRequest();
            when(chain.request()).thenReturn(fakeRequest);
    
            final Interceptor interceptor = new RefreshTokenInterceptor(authenticationService, refreshTokenStorage);
    
            Timber.d("5 requests try to refresh token at the same time");
            final CountDownLatch countDownLatch5 = new CountDownLatch(5);
            for (int i = 0; i < 5; i++) {
                new Thread(new Runnable() {
                    @Override
                    public void run() {
                        try {
                            interceptor.intercept(chain);
                            countDownLatch5.countDown();
                        } catch (IOException e) {
                            throw new RuntimeException(e);
                        }
                    }
                }).start();
            }
            countDownLatch5.await();
    
            verify(authenticationService, times(1)).refreshTokenSync();
    
    
            Timber.d("next time another 3 threads try to refresh the token at the same time");
            final CountDownLatch countDownLatch3 = new CountDownLatch(3);
            for (int i = 0; i < 3; i++) {
                new Thread(new Runnable() {
                    @Override
                    public void run() {
                        try {
                            interceptor.intercept(chain);
                            countDownLatch3.countDown();
                        } catch (IOException e) {
                            throw new RuntimeException(e);
                        }
                    }
                }).start();
            }
            countDownLatch3.await();
    
            verify(authenticationService, times(2)).refreshTokenSync();
    
    
            Timber.d("1 thread tries to refresh the token");
            interceptor.intercept(chain);
    
            verify(authenticationService, times(3)).refreshTokenSync();
        }
    
        private Response createUnauthorizedResponse() throws IOException {
            Response response = mock(Response.class);
            when(response.code()).thenReturn(401);
            return response;
        }
    
        private Request createFakeRequest() {
            Request request = mock(Request.class);
            Request.Builder fakeBuilder = createFakeBuilder();
            when(request.newBuilder()).thenReturn(fakeBuilder);
            return request;
        }
    
        private Request.Builder createFakeBuilder() {
            Request.Builder mockBuilder = mock(Request.Builder.class);
            when(mockBuilder.header("access_token", FRESH_ACCESS_TOKEN)).thenReturn(mockBuilder);
            return mockBuilder;
        }
    
    }
    

提交回复
热议问题