基于自定义注解和SpringEL表达式的分布式锁实现

无人久伴 提交于 2020-08-17 21:43:40

需求

1、项目中不可避免的需要使用分布式保证幂等。所以一个简单可靠,易用的工具提上日程。

2、演进过程 

  • 最开始使用try finally 块实现。代码臃肿。还要时刻记得释放。
  • 改用回调方式封装锁的获取和释放,但是依然臃肿,需要实现成功和获取锁失败的回调方法。然而获取锁失败几乎都做的一样的事。
  • 使用注解,代价就是使用范围是整个方法。需要自己确认好了使用范围。另外第一版不支持Spring EL。想使用参数值做锁实在太麻烦。
  • 改进注解,使用spring EL引擎。提供强大的数据获取功能。并且对返回值使用调用静态方法和创建新对象十分友好。
  • 我们并没有直接使用spirng EL的所有语法。而是选择包装了一下,因为大家对Spring EL认识参差不齐。

 

demo:

@LockMethod(
        value = {
                @ExtractParam(paramName = "accountInfo", fieldName = "accountId"),
                @ExtractParam(paramName = "order", fieldName = "id"),
                @ExtractParam(paramName = "uid")
        }
        , formatter = "lockTest:%s:%s:%s"
        , failureExpression = "new java.util.ArrayList()")
public Object lockTest(TAccountInfo accountInfo, TPayOrder order, Long uid) {
    return CommonResultEntity.success();
}

代码:

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Repeatable(value = LockMethod.class)
public @interface ExtractParam {

    /**
     * 作为锁的参数名称
     */
    String paramName();

    /**
     * 参数的 属性值。可以为空
     */
    String fieldName() default "";
}
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface LockMethod {

    /**
     * 提取的参数
     */
    ExtractParam[] value();
    /**
     * 自定义
     */
    String formatter() default "";

    /**
     * 失败后返回类型
     */
    Class<?> failureType() default void.class;

    /**
     * 失败返回 表达式
     */
    String failureExpression() default "";

}
@Aspect
@Component
@Slf4j
public class LockInterceptor {

    /**
     * spring 参数名称解析器
     */
    private static final ParameterNameDiscoverer LOCAL_VARIABLE_TABLE_PARAMETER_NAME_DISCOVERER
            = new LocalVariableTableParameterNameDiscoverer();
    /**
     * spring el 表达式解析解
     */
    private static final ExpressionParser SPEL_EXPRESSION_PARSER = new SpelExpressionParser();

    /**
     * Elvis运算符 在一些编程语言中(比如C#、Kotlin等)提供该功能,语法是?:。意义是当某变量不为空的时候使用该变量,当该变量为空的时候使用指定的默认值。
     */
    @Around("@annotation(com.xxx.xxx.support.LockMethod)")
    public Object lock(ProceedingJoinPoint pjp) throws Throwable {
        String methodName = pjp.getSignature().getName();
        Object[] args = pjp.getArgs();
        Class<?> classTarget = pjp.getTarget().getClass();
        Class<?>[] par = ((MethodSignature) pjp.getSignature()).getParameterTypes();
        Method targetMethod = classTarget.getMethod(methodName, par);
        String[] parameterNames = LOCAL_VARIABLE_TABLE_PARAMETER_NAME_DISCOVERER.getParameterNames(targetMethod);
        LockMethod lockMethod = targetMethod.getAnnotation(LockMethod.class);
        final String lockName = parseLockName(args, lockMethod, parameterNames);
        log.info("lockName={} act=LockInterceptor", lockName);
        return doLock(pjp, lockMethod, lockName);
    }

    private Object doLock(ProceedingJoinPoint pjp, LockMethod lockMethod, String lockName) {
        return DistributedLock.acquireLock(MDCUtils.getLogStr(), lockName, new LockCallback<Object>() {
            @Override
            public Object onSuccess(String logStr) {
                try {
                    return pjp.proceed();
                } catch (Throwable throwable) {
                    throw new RuntimeException(throwable);
                }
            }

            @Override
            public Object onFailure(String logStr) {
                String onFailureMethodEL = lockMethod.failureExpression();
                if (StringUtils.isEmpty(onFailureMethodEL)) {
                    return null;
                }
                Class<?> onFailureCallType = lockMethod.failureType();
                if (onFailureCallType == void.class) {
                    return SPEL_EXPRESSION_PARSER.parseExpression(onFailureMethodEL).getValue();
                } else {
                    EvaluationContext context = new StandardEvaluationContext(onFailureCallType);
                    Expression expression = SPEL_EXPRESSION_PARSER.parseExpression(onFailureMethodEL);
                    return expression.getValue(context);
                }
            }
        });
    }

    private String parseLockName(Object[] args, LockMethod lockMethod, String[] parameterNames) {
        ExtractParam[] extractParams = lockMethod.value();
        if (extractParams.length == 0) {
            throw new RuntimeException("not allow no extract param");
        }
        List<String> fieldValues = new ArrayList<>();
        Map<String, Object> paramNameMap = buildParamMap(args, parameterNames);
        for (ExtractParam extractParam : extractParams) {
            String paramName = extractParam.paramName();
            Object paramValue = paramNameMap.get(paramName);
            String springEL = extractParam.fieldName();
            String paramFieldValue = "";
            if (StringUtils.isNotEmpty(springEL)) {
                Expression expression = SPEL_EXPRESSION_PARSER.parseExpression(springEL);
                paramFieldValue = expression.getValue(paramValue).toString();
            } else {
                if (isSimpleType(paramValue.getClass())) {
                    paramFieldValue = String.valueOf(paramValue);
                }
            }
            fieldValues.add(paramFieldValue);
        }
        return String.format(lockMethod.formatter(), fieldValues.toArray());
    }

    /**
     * 构建请求参数map
     * @param args 参数列表
     * @param parameterNames 参数名称列表
     * @return key:参数名称 value:参数值
     */
    private Map<String, Object> buildParamMap(Object[] args, String[] parameterNames) {
        Map<String, Object> paramNameMap = new HashMap<>();
        for (int i = 0; i < parameterNames.length; i++) {
            paramNameMap.put(parameterNames[i], args[i]);
        }
        return paramNameMap;
    }

    /**
     * 基本类型 int, double, float, long, short, boolean, byte, char, void.
     */
    private static boolean isSimpleType(Class<?> clazz) {
        return clazz.isPrimitive()
                || clazz.equals(Long.class)
                || clazz.equals(Integer.class)
                || clazz.equals(String.class)
                || clazz.equals(Double.class)
                || clazz.equals(Short.class)
                || clazz.equals(Byte.class)
                || clazz.equals(Character.class)
                || clazz.equals(Float.class)
                || clazz.equals(Boolean.class);
    }


}
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!