spring boot 基于 shiro / spring security 实现自定义登录

隐身守侯 提交于 2020-04-16 18:21:26

【推荐阅读】微服务还能火多久?>>>

shiro

shiro 配置文件

/**
 * Shiro配置
 */
@Configuration
@RequiredArgsConstructor
public class ShiroConfig {

    private final ShiroService shiroService;

    private final SysUserTokenService sysUserTokenService;

    private final SysCaptchaService sysCaptchaService;

    private final Gson gson;

    /**
     * 自定义权限管理
     *
     * @see ShiroConfiguration
     * @see DefaultSecurityManager
     */
    @Bean("securityManager")
    public DefaultWebSecurityManager defaultWebSecurityManager() {
        DefaultWebSecurityManager defaultWebSecurityManager = new DefaultWebSecurityManager();
        // realms
        ArrayList<Realm> realms = new ArrayList<>();
        realms.add(userAuthenticatingRealm());
        realms.add(tokenAuthorizingRealm());
        defaultWebSecurityManager.setRealms(realms);
        DefaultSubjectDAO subjectDAO = new DefaultSubjectDAO();
        DefaultSessionStorageEvaluator defaultSessionStorageEvaluator = new DefaultSessionStorageEvaluator();
        // disable session
        defaultSessionStorageEvaluator.setSessionStorageEnabled(false);
        subjectDAO.setSessionStorageEvaluator(defaultSessionStorageEvaluator);
        defaultWebSecurityManager.setSubjectDAO(subjectDAO);
        // cache
        defaultWebSecurityManager.setCacheManager(new MemoryConstrainedCacheManager());
        return defaultWebSecurityManager;
    }

    @Bean
    public Realm userAuthenticatingRealm() {
        UserAuthenticatingRealm userAuthenticatingRealm = new UserAuthenticatingRealm();
        userAuthenticatingRealm.setShiroService(shiroService);
        userAuthenticatingRealm.setSysCaptchaService(sysCaptchaService);
        PasswordMatcher passwordMatcher = new PasswordMatcher();
        passwordMatcher.setPasswordService(passwordService());
        userAuthenticatingRealm.setCredentialsMatcher(passwordMatcher);
        return userAuthenticatingRealm;
    }

    @Bean
    public Realm tokenAuthorizingRealm() {
        TokenAuthorizingRealm tokenAuthorizingRealm = new TokenAuthorizingRealm();
        tokenAuthorizingRealm.setShiroService(shiroService);
        return tokenAuthorizingRealm;
    }

    /**
     * @see DefaultFilter
     */
    @Bean
    public ShiroFilterChainDefinition shiroFilterChainDefinition() {
        DefaultShiroFilterChainDefinition chainDefinition = new DefaultShiroFilterChainDefinition();
        chainDefinition.addPathDefinition("/sys/login", "user_authc");
        chainDefinition.addPathDefinition("/sys/logout", "anon");
        chainDefinition.addPathDefinition("/sys/**", "token_authc");
        chainDefinition.addPathDefinition("/manage/**", "anon");
        chainDefinition.addPathDefinition("/wx/**", "anon");
        chainDefinition.addPathDefinition("/**", "anon");
        return chainDefinition;
    }

    @Bean(name = "user_authc")
    public UserFilter userAuthenticatingFilter() {
        UserFilter userFilter = new UserFilter();
        userFilter.setSysUserTokenService(sysUserTokenService);
        userFilter.setGson(gson);
        return userFilter;
    }

    @Bean(name = "token_authc")
    public TokenFilter tokenAuthenticatingFilter() {
        return new TokenFilter();
    }

    /**
     * 总过滤器
     *
     * @param securityManager            { @link #defaultWebSecurityManager }
     * @param shiroFilterChainDefinition { @link #shiroFilterChainDefinition }
     */
    @Bean(name = "shiroFilterFactoryBean")
    public ShiroFilterFactoryBean shiroFilter(SecurityManager securityManager,
                                              ShiroFilterChainDefinition shiroFilterChainDefinition,
                                              Map<String, Filter> filterMap) {
        ShiroFilterFactoryBean filterFactoryBean = new ShiroFilterFactoryBean();
        filterFactoryBean.setSecurityManager(securityManager);
        filterFactoryBean.setFilterChainDefinitionMap(shiroFilterChainDefinition.getFilterChainMap());
        filterFactoryBean.setFilters(filterMap);
        return filterFactoryBean;
    }

    /**
     * 非接口使用 cglib,防止额外的 aop 导致 @RequiresPermissions 注解失效
     */
    @ConditionalOnMissingBean
    @Bean
    @DependsOn("lifecycleBeanPostProcessor")
    public DefaultAdvisorAutoProxyCreator defaultAdvisorAutoProxyCreator() {
        DefaultAdvisorAutoProxyCreator proxyCreator = new DefaultAdvisorAutoProxyCreator();
        // 非接口使用 cglib
        proxyCreator.setProxyTargetClass(true);
        return proxyCreator;
    }

    @ConditionalOnMissingBean
    @Bean
    public PasswordService passwordService() {
        return new DefaultPasswordService();
    }

}

认证

/**
 * AuthenticationToken —— 待认证 token
 */
public class UserAuthentication implements AuthenticationToken {
    private static final long serialVersionUID = 1L;

    private final Object principal;

    private Object credentials;

    /**
     * 验证码
     */
    private String captcha;

    private String uuid;

    private boolean refresh;

    /**
     * 需要进行认证
     */
    public UserAuthentication(Object principal, Object credentials) {
        this.principal = principal;
        this.credentials = credentials;
    }

    /**
     * 需要进行认证(附带验证码)
     */
    public UserAuthentication(Object principal, Object credentials, String captcha) {
        this.principal = principal;
        this.credentials = credentials;
        this.captcha = captcha;
    }

    /**
     * 需要进行认证(附带验证码)
     */
    public UserAuthentication(Object principal, Object credentials, String captcha, String uuid) {
        this.principal = principal;
        this.credentials = credentials;
        this.captcha = captcha;
        this.uuid = uuid;
    }

    /**
     * 以旧换新, 无须通过密码校验
     */
    public UserAuthentication(Object principal, boolean refresh) {
        this.principal = principal;
        this.refresh = refresh;
    }

    @Override
    public Object getPrincipal() {
        return principal;
    }

    @Override
    public Object getCredentials() {
        return credentials;
    }

    public String getCaptcha() {
        return captcha;
    }

    public String getUuid() {
        return uuid;
    }

}

filter

/**
 * 自定义 AuthenticatingFilter 认证过滤器
 * <p>
 * 登录认证
 */
@Slf4j
public class UserFilter extends AuthenticatingFilter {

    private SysUserTokenService sysUserTokenService;

    private Gson gson;

    /**
     * 返回待 AuthorizingRealm 认证的 AuthenticationToken, 参见 {@link UserAuthenticatingRealm#doGetAuthenticationInfo}
     */
    @Override
    protected AuthenticationToken createToken(ServletRequest request, ServletResponse response) throws Exception {
        HttpServletRequest req = (HttpServletRequest) request;
        ServletInputStream inputStream = req.getInputStream();
        InputStreamReader reader = new InputStreamReader(inputStream);
        JsonObject jsonObject = getGson().fromJson(reader, JsonObject.class);
        String principle = jsonObject.get(obtainUsernameParam()).getAsString();
        String credentials = jsonObject.get(obtainPasswordParam()).getAsString();
        String captcha = jsonObject.get(obtainCaptchaParam()).getAsString();
        String uuid = jsonObject.get(obtainUuidParam()).getAsString();
        return new UserAuthentication(principle, credentials, captcha, uuid);
    }

    /**
     * 判断是否是登录请求
     */
    @Override
    public String getLoginUrl() {
        return SecurityConstant.loginUrl;
    }

    /**
     * 放行非登录请求(如 logout)
     */
    @Override
    protected boolean isAccessAllowed(ServletRequest request, ServletResponse response, Object mappedValue) {
        return super.isAccessAllowed(request, response, mappedValue);
    }

    /**
     * 拦截登录请求 —> 执行
     */
    @Override
    protected boolean onAccessDenied(ServletRequest request, ServletResponse response) throws Exception {
        if (isLoginRequest(request, response)) {
            if (isLoginSubmission(request, response)) {
                super.executeLogin(request, response);
            } else {
                //可能是登录页,故不禁止访问
                return true;
            }
        }
        return false;
    }

    /**
     * 登录失败回调
     */
    @SuppressWarnings("Duplicates")
    @SneakyThrows
    @Override
    protected boolean onLoginFailure(AuthenticationToken token, AuthenticationException exp, ServletRequest request, ServletResponse response) {
        HttpServletResponse httpServletResponse = (HttpServletResponse) response;
        Map<String, Object> map = new HashMap<>();
        String errorMsg = "未知异常";
        int errorCode = 500;
        if (exp instanceof CaptchaInvalidException) {
            errorMsg = "验证码填写错误";
        } else if (exp instanceof UnknownAccountException) {
            errorMsg = "账号不存在";
        } else if (exp instanceof CredentialsException) {
            errorMsg = "密码错误";
        } else if (exp instanceof LockedAccountException) {
            errorMsg = "账号已锁定";
        } else {
            log.error("login fail {}", exp.getMessage());
        }
        map.put("code", errorCode);
        map.put("msg", errorMsg);
        map.put("success", false);
        httpServletResponse.setContentType("application/json");
        httpServletResponse.setStatus(HttpServletResponse.SC_OK);
        try {
            ObjectMapper mapper = new ObjectMapper();
            mapper.writeValue(httpServletResponse.getOutputStream(), map);
        } catch (Exception er) {
            throw new ServletException();
        }
        // 尝试重新登录
        return true;
    }

    /**
     * 登录成功回调
     */
    @SuppressWarnings("Duplicates")
    @Override
    protected boolean onLoginSuccess(AuthenticationToken token, Subject subject, ServletRequest request, ServletResponse response) throws Exception {
        HttpServletResponse httpServletResponse = (HttpServletResponse) response;
        SysUserEntity userEntity = (SysUserEntity) SecurityUtils.getSubject().getPrincipal();
        Result r = getSysUserTokenService().createToken(userEntity.getUserId());
        httpServletResponse.setContentType("application/json");
        httpServletResponse.setStatus(HttpServletResponse.SC_OK);
        try {
            ObjectMapper mapper = new ObjectMapper();
            mapper.writeValue(httpServletResponse.getOutputStream(), r);
        } catch (Exception e) {
            throw new ServletException();
        }
        return true;
    }

    protected String obtainUsernameParam() {
        return "username";
    }

    protected String obtainPasswordParam() {
        return "password";
    }

    protected String obtainCaptchaParam() {
        return "captcha";
    }

    protected String obtainUuidParam() {
        return "uuid";
    }

    protected boolean isLoginSubmission(ServletRequest request, ServletResponse response) {
        return (request instanceof HttpServletRequest) && WebUtils.toHttp(request).getMethod().equalsIgnoreCase(POST_METHOD);
    }

    public SysUserTokenService getSysUserTokenService() {
        Assert.notNull(this.sysUserTokenService, "sysUserTokenService is needed");
        return sysUserTokenService;
    }

    public void setSysUserTokenService(SysUserTokenService sysUserTokenService) {
        this.sysUserTokenService = sysUserTokenService;
    }

    public Gson getGson() {
        Assert.notNull(this.gson, "gson is needed");
        return gson;
    }

    public void setGson(Gson gson) {
        this.gson = gson;
    }

}

AuthenticatingRealm

/**
 * 用户登录认证
 */
public class UserAuthenticatingRealm extends AuthenticatingRealm {

    private ShiroService shiroService;

    private SysCaptchaService sysCaptchaService;

    public void setShiroService(ShiroService shiroService) {
        this.shiroService = shiroService;
    }

    public SysCaptchaService getSysCaptchaService() {
        Assert.notNull(this.sysCaptchaService, "getSysCaptchaService is needed");
        return sysCaptchaService;
    }

    public void setSysCaptchaService(SysCaptchaService sysCaptchaService) {
        this.sysCaptchaService = sysCaptchaService;
    }

    @Override
    public boolean supports(AuthenticationToken token) {
        return token instanceof UserAuthentication;
    }

    /**
     * 认证
     * <p>
     * 从数据库查询真实个人信息
     *
     * @see org.apache.shiro.realm.AuthenticatingRealm#getAuthenticationInfo
     * @see org.apache.shiro.realm.AuthenticatingRealm#assertCredentialsMatch
     */
    @Override
    protected AuthenticationInfo doGetAuthenticationInfo(AuthenticationToken authenticationToken) throws AuthenticationException {
        // 校验验证码
        UserAuthentication userAuthentication = (UserAuthentication) authenticationToken;
        String captcha = userAuthentication.getCaptcha();
        String uuid = userAuthentication.getUuid();
        boolean validated = getSysCaptchaService().validate(uuid, captcha);
        if (!validated) {
            throw new CaptchaInvalidException("验证码无效");
        }
        String username = (String) authenticationToken.getPrincipal();
        SysUserEntity user = shiroService.queryByUsername(username);
        if (user == null) {
            throw new UnknownAccountException("账号不存在");
        } else if (user.getStatus() == 0) {
            throw new LockedAccountException("账号已被锁定, 请联系管理员");
        }
        return new SimpleAuthenticationInfo(user, user.getPassword(), getName());
    }

}

AuthenticationException

public class CaptchaInvalidException extends AuthenticationException {
    public CaptchaInvalidException(String msg) {
        super(msg);
    }
}

鉴权与授权

/**
 * AuthenticationToken —— 待认证 token
 */
public class XxAuthenticationToken implements AuthenticationToken {
    private static final long serialVersionUID = 1L;
    private final String token;

    public XxAuthenticationToken(String token) {
        this.token = token;
    }

    @Override
    public Object getPrincipal() {
        return token;
    }

    @Override
    public Object getCredentials() {
        return token;
    }

    public String getToken() {
        return token;
    }
}

filter


@Slf4j
public class TokenFilter extends AuthenticatingFilter {

    protected static final String AUTHORIZATION_HEADER = "token";

    @Override
    protected AuthenticationToken createToken(ServletRequest request, ServletResponse response) throws Exception {
        String token = extractRequestToken(request);
        return new XxAuthenticationToken(token);
    }

    @Override
    protected boolean isAccessAllowed(ServletRequest request, ServletResponse response, Object mappedValue) {
        // super.isAccessAllowed 基于 session 级别放行,这里进行重写
        // or call org.apache.shiro.mgt.DefaultSessionStorageEvaluator.setSessionStorageEnabled to disable the session
        return false;
    }

    @Override
    protected boolean onAccessDenied(ServletRequest request, ServletResponse response) throws Exception {
        boolean loggedIn = false;
        if (isLoginRequest(request, response)) {
            loggedIn = executeLogin(request, response);
        }
        if (!loggedIn) {
            //认证失败,提示 token 无效
            Map<String, Object> map = new HashMap<>();
            String errorMsg = "invalid token";
            int errorCode = 500;
            map.put("code", errorCode);
            map.put("msg", errorMsg);
            map.put("success", false);
            HttpServletResponse httpResponse = WebUtils.toHttp(response);
            httpResponse.setContentType("application/json");
            httpResponse.setStatus(HttpServletResponse.SC_OK);
            try {
                ObjectMapper mapper = new ObjectMapper();
                mapper.writeValue(httpResponse.getOutputStream(), map);
            } catch (Exception er) {
                throw new ServletException();
            }
        }
        return loggedIn;
    }

    @Override
    protected final boolean isLoginRequest(ServletRequest request, ServletResponse response) {
        return StringUtils.isNotBlank(extractRequestToken(request));
    }

    private String extractRequestToken(ServletRequest request) {
        HttpServletRequest httpRequest = WebUtils.toHttp(request);
        String token = httpRequest.getHeader(AUTHORIZATION_HEADER);
        if (StringUtils.isBlank(token)) {
            token = httpRequest.getParameter(AUTHORIZATION_HEADER);
        }
        return token;
    }

}

realm

/**
 * 基于 token 的 鉴权与授权
 */
public class TokenAuthorizingRealm extends AuthorizingRealm {

    private ShiroService shiroService;

    public ShiroService getShiroService() {
        Assert.notNull(this.shiroService, "shiroService is needed");
        return shiroService;
    }

    public void setShiroService(ShiroService shiroService) {
        this.shiroService = shiroService;
    }

    @Override
    public boolean supports(AuthenticationToken token) {
        return token instanceof XxAuthenticationToken;
    }

    /**
     * 认证
     *
     * 从数据库查询真实个人信息
     *
     * @see org.apache.shiro.realm.AuthenticatingRealm#getAuthenticationInfo
     * @see org.apache.shiro.realm.AuthenticatingRealm#assertCredentialsMatch
     */
    @Override
    protected AuthenticationInfo doGetAuthenticationInfo(AuthenticationToken authenticationToken) throws AuthenticationException {
        String accessToken = (String) authenticationToken.getPrincipal();
        SysUserTokenEntity tokenEntity = shiroService.queryByToken(accessToken);
        if (tokenEntity == null || tokenEntity.getExpireTime().getTime() < System.currentTimeMillis()) {
            throw new IncorrectCredentialsException("token失效,请重新登录");
        }
        SysUserEntity user = shiroService.queryUser(tokenEntity.getUserId());
        if (user.getStatus() == 0) {
            throw new LockedAccountException("账号已被锁定,请联系管理员");
        }
        return new SimpleAuthenticationInfo(user, accessToken, getName());
    }

    /**
     * 授权
     *
     * 从数据库获取权限信息
     *
     * @see org.apache.shiro.realm.AuthorizingRealm#isPermitted
     */
    @Override
    protected AuthorizationInfo doGetAuthorizationInfo(PrincipalCollection principals) {
        SysUserEntity user = (SysUserEntity) principals.getPrimaryPrincipal();
        Long userId = user.getUserId();
        //用户权限列表
        Set<String> permsSet = getShiroService().getUserPermissions(userId);
        SimpleAuthorizationInfo info = new SimpleAuthorizationInfo();
        info.setStringPermissions(permsSet);
        return info;
    }

}

 

Spring Security

spring security 配置文件

@EnableWebSecurity
@AllArgsConstructor
@EnableGlobalMethodSecurity(prePostEnabled = true, securedEnabled = true)
public class SecurityConf extends WebSecurityConfigurerAdapter {

    private TokenAuthenticationConfig tokenAuthenticationConfig;

    private CorsFilter corsFilter;

    private IPermissionService permissionService;

    @Override
    protected void configure(HttpSecurity http) throws Exception {
        ExpressionUrlAuthorizationConfigurer<HttpSecurity>.ExpressionInterceptUrlRegistry registry = http.csrf().disable()

                .addFilterBefore(corsFilter, UsernamePasswordAuthenticationFilter.class)
                .exceptionHandling()
                .authenticationEntryPoint(new Http401UnAuthEntryPoint())
                .accessDeniedHandler(new XxAccessDeniedHandler())

                .and()
                .headers()
                .frameOptions()
                .disable()

                // create no session
                .and()
                .sessionManagement()
                .sessionCreationPolicy(SessionCreationPolicy.STATELESS)

                .and()

                .authorizeRequests()
                .antMatchers("/doc.htm**", "/service-worker.js",
                        "/v2/api-docs", "/configuration/ui", "/swagger-resources/**", "/v2/api-docs-ext",
                        "/configuration/security", "/swagger-ui.html", "/webjars/**",
                        "/favicon.ico", "/static/**",
                        "/acc/**", "/third/**", "/public/**"
                ).permitAll();

        registry.antMatchers("/actuator/**")
                .hasRole("ADMIN");

        plusPermissions(registry);

        registry
                .anyRequest()
                .authenticated()
                .and()
                .apply(tokenAuthenticationConfig);
    }

    @Bean
    @Override
    public AuthenticationManager authenticationManagerBean() throws Exception {
        return super.authenticationManagerBean();
    }

    private void plusPermissions(ExpressionUrlAuthorizationConfigurer<HttpSecurity>.ExpressionInterceptUrlRegistry registry) {
        List<Permission> permissions = permissionService.list();
        if (CollectionUtil.isNotEmpty(permissions)) {
            permissions.forEach(permission -> {
                String urls = permission.getUrls();
                if (StringUtil.isNotBlank(urls)) {
                    String[] urlArr = urls.split(",");
                    String curPerm = "PERMISSION_" + permission.getCode();
                    String access = String.format("hasRole('%s') or hasAuthority('%s')", RoleName.ADMIN.getName(), curPerm);
                    if (urlArr.length > 0) {
                        registry.antMatchers(urlArr)
                                .access(access);
                    }
                }
            });
        }
    }
}

认证

public class TokenAuthentication extends AbstractAuthenticationToken {

    private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID;

    private final Object principal;

    private Object credentials;

    /**
     * 认证前
     */
    public TokenAuthentication(Object principal, Object credentials) {
        super(null);
        this.principal = principal;
        this.credentials = credentials;
        setAuthenticated(false);
    }

    /**
     * 认证后
     */
    public TokenAuthentication(Object principal, Object credentials,
                                               Collection<? extends GrantedAuthority> authorities) {
        super(authorities);
        this.principal = principal;
        this.credentials = credentials;
        super.setAuthenticated(true); // must use super, as we override
    }

    @Override
    public Object getCredentials() {
        return this.credentials;
    }

    @Override
    public Object getPrincipal() {
        return this.principal;
    }

    @Override
    public void eraseCredentials() {
        super.eraseCredentials();
        credentials = null;
    }

    @Override
    public void setAuthenticated(boolean isAuthenticated) throws IllegalArgumentException {
        if (isAuthenticated) {
            throw new IllegalArgumentException(
                    "Cannot set this token to trusted - use constructor which takes a GrantedAuthority list instead");
        }

        super.setAuthenticated(false);
    }
}

filter

/**
 * doFilter 负责预鉴权(验证码,安全码是否一致)
 * attemptAuthentication 负责提取 token
 */
@Slf4j
public class TokenAuthenticationFilter extends AbstractAuthenticationProcessingFilter {

    private TokenService tokenService;

    public TokenAuthenticationFilter() {
        super(new AntPathRequestMatcher(SecurityConstant.LOGIN_URL, "POST"));
    }

    public void setTokenService(TokenService tokenService) {
        this.tokenService = tokenService;
    }

    public TokenService getTokenService() {
        Assert.notNull(this.tokenService, "tokenService is needed");
        return this.tokenService;
    }

    /**
     * 鉴权
     * - 路由规则匹配
     * - 验证码校验(短信验证码或阿里云人机校验)
     */
    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        HttpServletRequest req = (HttpServletRequest) request;
        HttpServletResponse res = (HttpServletResponse) response;
        // 判断是否是登录请求 LOGIN_URL
        if (!requiresAuthentication(req, res)) {
            chain.doFilter(request, response);
            return;
        }
        try {
            // 校验参数与验证码
            getTokenService().checkParams(req, res);
        } catch (AuthenticationException error) {
            unsuccessfulAuthentication(req, res,  error);
            return;
        }
        super.doFilter(req, res, chain);
    }

    /**
     * 返回一个待 Provider 认证的 token(TokenAuthentication)
     */
    @Override
    public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response) throws AuthenticationException, IOException, ServletException {
        // 拦截非 POST 请求
        if (!request.getMethod().equals(HttpMethod.POST.name())) {
            throw new AuthenticationServiceException(
                    "Authentication method not supported: " + request.getMethod());
        }
        TokenAuthentication tokenAuthentication = new TokenAuthentication(obtainUsername(request), obtainPassword(request));
        return this.getAuthenticationManager().authenticate(tokenAuthentication);
    }

    @Nullable
    protected String obtainUsername(HttpServletRequest request) {
        return request.getParameter("username");
    }

    @Nullable
    protected String obtainPassword(HttpServletRequest request) {
        return request.getParameter("password");
    }

}

provider

/**
 * 认证, 返回已认证的 token(Authentication)
 */
public class TokenAuthenticationProvider implements AuthenticationProvider {

    protected final Log logger = LogFactory.getLog(getClass());

    private UserDetailsService userDetailsService;

    private PasswordEncoder passwordEncoder;

    private BlackTool blackTool;

    private static final String USER_NOT_FOUND_PASSWORD = "userNotFoundPassword";

    private volatile String userNotFoundEncodedPassword;

    private UserDetailsChecker preAuthenticationChecks = new TokenAuthenticationProvider.DefaultPreAuthenticationChecks();

    private UserCache userCache = new NullUserCache();

    public void setUserDetailsService(UserDetailsService userDetailsService) {
        this.userDetailsService = userDetailsService;
    }

    public void setPasswordEncoder(PasswordEncoder passwordEncoder) {
        this.passwordEncoder = passwordEncoder;
    }

    public void setBlackTool(BlackTool blackTool) {
        this.blackTool = blackTool;
    }

    public UserDetailsService getUserDetailsService() {
        Assert.notNull(this.userDetailsService, "userDetailsService could not be null");
        return userDetailsService;
    }

    public PasswordEncoder getPasswordEncoder() {
        Assert.notNull(this.passwordEncoder, "passwordEncoder could not be null");
        return passwordEncoder;
    }

    public BlackTool getBlackTool() {
        Assert.notNull(this.blackTool, "blackTool could not be null");
        return blackTool;
    }

    /**
     * i18n 字符串
     */
    protected MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor();

    @Override
    public Authentication authenticate(Authentication authentication) throws AuthenticationException {
        // 校验 token 的 class 类型
        Assert.isInstanceOf(TokenAuthentication.class, authentication,
                messages.getMessage(
                        "SocialAuthenticationProvider.onlySupports",
                        "Only SocialAuthenticationToken is supported"));

        // 通过 username 提取用户信息
        String username = (authentication.getPrincipal() == null) ? "NONE_PROVIDED"
                : authentication.getName();

        //检查是否在黑名单中
        checkBlackList(username);

        // 从缓存中获取用户信息
        boolean cacheWasUsed = true;
        UserDetails user = this.userCache.getUserFromCache(username);

        // 缓存中没有,则从数据库中获取用户信息
        if (user == null) {
            cacheWasUsed = false;
            try {
                user = this.retrieveUser(username, (TokenAuthentication) authentication);
            } catch (UsernameNotFoundException notFound) {
                logger.debug("User username '" + username + "' not found");
                throw notFound;
            }

            Assert.notNull(user,
                    "retrieveUser returned null - a violation of the interface contract");
        }

        //检查信息是否匹配有效
        try {
            preAuthenticationChecks.check(user);
            // password check
            checkPassword(user, (TokenAuthentication) authentication);
        } catch (AuthenticationException exception) {
            //重新再试一下
            if (cacheWasUsed) {
                // There was a problem, so try again after checking
                // we're using latest data (i.e. not from the cache)
                cacheWasUsed = false;
                user = retrieveUser(username,
                        (TokenAuthentication) authentication);
                preAuthenticationChecks.check(user);
                checkPassword(user,
                        (TokenAuthentication) authentication);
            } else {
                throw exception;
            }
        }

        // 加入缓存
        if (!cacheWasUsed) {
            this.userCache.putUserInCache(user);
        }
        TokenAuthentication authenticationToken = new TokenAuthentication(user, null, user.getAuthorities());
        authenticationToken.setDetails(user);
        return authenticationToken;
    }

    /**
     * 过滤该 provider 支持认证的 authentication
     */
    @Override
    public boolean supports(Class<?> authentication) {
        return TokenAuthentication.class.isAssignableFrom(authentication);
    }

    private class DefaultPreAuthenticationChecks implements UserDetailsChecker {
        public void check(UserDetails user) {
            if (!user.isAccountNonLocked()) {
                logger.debug("User account is locked");

                throw new LockedException(messages.getMessage(
                        "AbstractUserDetailsAuthenticationProvider.locked",
                        "User account is locked"));
            }

            if (!user.isEnabled()) {
                logger.debug("User account is disabled");

                throw new DisabledException(messages.getMessage(
                        "AbstractUserDetailsAuthenticationProvider.disabled",
                        "User is disabled"));
            }

            if (!user.isAccountNonExpired()) {
                logger.debug("User account is expired");

                throw new AccountExpiredException(messages.getMessage(
                        "AbstractUserDetailsAuthenticationProvider.expired",
                        "User account has expired"));
            }
        }
    }

    public void checkBlackList(String blackKey) {
        if (getBlackTool().existInBlackList(blackKey)) {
            // 在黑名单中,拒绝访问
            throw new AuthFrequentFailException("您已连续5次密码输入错误,请15分钟后再试");
        }
    }
    /**
     * 连续多次输入错误,直接禁止登录
     */
    public void checkPassword(UserDetails userDetails,
                              TokenAuthentication authentication) throws AuthenticationException {
        if (authentication.getCredentials() == null) {
            logger.debug("Authentication failed: no credentials provided");
            throw new BadCredentialsException(messages.getMessage(
                    "AbstractUserDetailsAuthenticationProvider.badCredentials",
                    "Bad credentials"));
        }

        String blackKey = authentication.getPrincipal().toString();

        String presentedPassword = authentication.getCredentials().toString();

        if (!getPasswordEncoder().matches(presentedPassword, userDetails.getPassword())) {
            logger.debug("Authentication failed: password does not match stored value");
            // 黑名单计数
            getBlackTool().incr(blackKey);
            throw new BadCredentialsException(messages.getMessage(
                    "AbstractUserDetailsAuthenticationProvider.badCredentials",
                    "Bad credentials"));
        }
        // 黑名单重置
        getBlackTool().reset(blackKey);
    }

    protected final UserDetails retrieveUser(String username,
                                             TokenAuthentication authentication)
            throws AuthenticationException {
        prepareTimingAttackProtection();
        try {
            UserDetails loadedUser = this.getUserDetailsService().loadUserByUsername(username);
            if (loadedUser == null) {
                throw new InternalAuthenticationServiceException(
                        "UserDetailsService returned null, which is an interface contract violation");
            }
            return loadedUser;
        } catch (UsernameNotFoundException ex) {
            mitigateAgainstTimingAttack(authentication);
            throw ex;
        } catch (InternalAuthenticationServiceException ex) {
            throw ex;
        } catch (Exception ex) {
            throw new InternalAuthenticationServiceException(ex.getMessage(), ex);
        }
    }

    private void prepareTimingAttackProtection() {
        if (this.userNotFoundEncodedPassword == null) {
            this.userNotFoundEncodedPassword = this.passwordEncoder.encode(USER_NOT_FOUND_PASSWORD);
        }
    }

    private void mitigateAgainstTimingAttack(TokenAuthentication authentication) {
        if (authentication.getCredentials() != null) {
            String presentedPassword = authentication.getCredentials().toString();
            this.passwordEncoder.matches(presentedPassword, this.userNotFoundEncodedPassword);
        }
    }
}

TokenAuthenticationConfig

@Component
@AllArgsConstructor
public class TokenAuthenticationConfig extends SecurityConfigurerAdapter<DefaultSecurityFilterChain, HttpSecurity> {

    private TokenService tokenService;

    private PasswordEncoder passwordEncoder;

    private UserDetailsService userDetailsService;

    private BlackTool blackTool;

    @Override
    public void configure(HttpSecurity http) {
        // check filter
        TokenCheckFilter checkFilter = new TokenCheckFilter(tokenService);

        // handler
        TokenFailHandler failHandler = new TokenFailHandler();
        TokenSuccessHandler successHandler = new TokenSuccessHandler(tokenService);

        // authentication filter
        TokenAuthenticationFilter tokenFilter = new TokenAuthenticationFilter();
        tokenFilter.setTokenService(tokenService);
        tokenFilter.setAuthenticationManager(http.getSharedObject(AuthenticationManager.class));
        tokenFilter.setAuthenticationSuccessHandler(successHandler);
        tokenFilter.setAuthenticationFailureHandler(failHandler);

        // provider
        TokenAuthenticationProvider tokenProvider = new TokenAuthenticationProvider();
        tokenProvider.setPasswordEncoder(passwordEncoder);
        tokenProvider.setUserDetailsService(userDetailsService);
        tokenProvider.setBlackTool(blackTool);

        http.authenticationProvider(tokenProvider)
                .addFilterBefore(checkFilter, UsernamePasswordAuthenticationFilter.class)
                .addFilterAfter(tokenFilter, UsernamePasswordAuthenticationFilter.class);
    }
}

401

public class Http401UnAuthEntryPoint implements AuthenticationEntryPoint {

    @Override
    public void commence(HttpServletRequest request,
                         HttpServletResponse response,
                         AuthenticationException authException) throws IOException {
        // This is invoked when user tries to access a secured REST resource without supplying any credentials
        // We should just send a 401 Unauthorized response because there is no 'login page' to redirect to
        // Here you can place any message you want
        response.sendError(HttpServletResponse.SC_UNAUTHORIZED, authException.getMessage());
    }
}

403

@Slf4j
public class XxAccessDeniedHandler implements AccessDeniedHandler {

    @Override
    public void handle(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, AccessDeniedException e) throws IOException, ServletException {
        String requestURI = httpServletRequest.getRequestURI();
        log.error("access {} wad denied.", requestURI, e);
        httpServletResponse.setStatus(HttpServletResponse.SC_FORBIDDEN);
        httpServletResponse.getWriter().write("access denied");
    }
}

fail handler

@Slf4j
public class TokenFailHandler implements AuthenticationFailureHandler {

    @Override
    public void onAuthenticationFailure(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, AuthenticationException exp) throws IOException, ServletException {
        Map<String, Object> map = new HashMap<>();
        String errorMsg = "未知异常";
        int errorCode = 400;
        if (exp instanceof UsernameNotFoundException) {
            errorMsg = "账号不存在";
        } else if (exp instanceof BadCredentialsException) {
            errorMsg = "账号或密码错误";
            errorCode = 401;
        } else if (exp instanceof AuthParamFormatException) {
            errorMsg = exp.getMessage();
        } else if (exp instanceof AuthEmptyException) {
            errorMsg = exp.getMessage();
        } else if (exp instanceof AuthFrequentFailException) {
            errorMsg = exp.getMessage();
            errorCode = 403;
        }  else if(exp instanceof AuthAfsFailException) {
            errorMsg = exp.getMessage();
            errorCode = AfsConstant.FAIL_CODE;
        } else if (exp instanceof DisabledException) {
            errorMsg = "账号已被禁用";
        } else if (exp instanceof AccountExpiredException) {
            errorMsg = "账号过期";
        } else if (exp instanceof LockedException) {
            errorMsg = "该账号已被锁定";
        } else if (exp instanceof InsufficientAuthenticationException) {
            errorMsg = "验证失败";
        } else{
            log.error("auth error", exp);
        }
        map.put("code", errorCode);
        map.put("msg", errorMsg);
        map.put("success", false);
        httpServletResponse.setContentType("application/json");
        httpServletResponse.setStatus(HttpServletResponse.SC_OK);
        try {
            ObjectMapper mapper = new ObjectMapper();
            mapper.writeValue(httpServletResponse.getOutputStream(), map);
        } catch (Exception er) {
            throw new ServletException();
        }
    }
}

success handler

public class TokenSuccessHandler  implements AuthenticationSuccessHandler {

    private TokenService tokenService;

    public TokenSuccessHandler(TokenService tokenService) {
        this.tokenService = tokenService;
    }

    @Override
    public void onAuthenticationSuccess(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, Authentication authentication) throws IOException, ServletException {
        Map<String, Object> map = new HashMap<>();
        map.put("code", 200);
        map.put("success", true);
        map.put("msg", "成功");
        XxRequestToken xxRequestToken = tokenService.generateToken(authentication);
        map.put("data", xxRequestToken);
        httpServletResponse.setContentType("application/json");
        httpServletResponse.setStatus(HttpServletResponse.SC_OK);
        try {
            XxUserDetails userDetails = SecurityUtil.getUserDetails(authentication);
            if (userDetails != null) {
                String ipAddr = IPUtil.getIpAddr(httpServletRequest);
                SpringUtil.publishEvent(new LoginOkEvent(new LoginInfo(userDetails.getId(), ipAddr)));
            }
            ObjectMapper mapper = new ObjectMapper();
            mapper.writeValue(httpServletResponse.getOutputStream(), map);
        } catch (Exception e) {
            throw new ServletException();
        }
    }

}

鉴权与授权

/**
 * 校验 token
 */
@Slf4j
public class TokenCheckFilter extends GenericFilterBean {
    private static final String AUTHORIZATION_HEADER = "Authorization";

    private TokenService tokenService;

    public TokenCheckFilter(TokenService tokenService) {
        this.tokenService = tokenService;
    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        HttpServletRequest req = (HttpServletRequest) servletRequest;
        HttpServletResponse res = (HttpServletResponse) servletResponse;
        String tokenKey = resolveToken(req);
        String requestURI = req.getRequestURI();
        boolean hasAuth = false;
        if (StringUtils.hasText(tokenKey) && tokenService.validateToken(tokenKey)) {
            Authentication authentication = tokenService.getAuthentication(tokenKey);
            if (authentication != null) {
                //authentication 中的 principle 可以转换为 XxUserDetails,供 SecurityUtil 提取权限信息
                SecurityContextHolder.getContext().setAuthentication(authentication);
                hasAuth = true;
            }
        }
        if (requestURI.startsWith(SecurityConstant.LOGOUT_URL)) {
            logout(hasAuth, tokenKey, res);
            return;
        }
        filterChain.doFilter(servletRequest, servletResponse);
    }

    private String resolveToken(HttpServletRequest request) {
        String bearerToken = request.getHeader(AUTHORIZATION_HEADER);
        if (StringUtils.hasText(bearerToken) && bearerToken.startsWith("Bearer ")) {
            return bearerToken.substring(7);
        }
        return null;
    }

    private void logout(boolean isAuthenticated, String tokenKey, HttpServletResponse response) throws ServletException {
        Map<String, Object> map = new HashMap<>();
        if (isAuthenticated && tokenService.offline(tokenKey)) {
            map.put("code", 200);
            map.put("success", true);
            map.put("msg", "注销成功");
        } else {
            map.put("code", 400);
            map.put("success", false);
            map.put("msg", "注销失败/已注销");
        }
        response.setContentType("application/json");
        response.setStatus(HttpServletResponse.SC_OK);
        try {
            ObjectMapper mapper = new ObjectMapper();
            mapper.writeValue(response.getOutputStream(), map);
        } catch (Exception e) {
            throw new ServletException();
        }
    }

}

 

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!