Mybatis - 自定义BaseMapper LanguageDriver(注解方式)

我的梦境 提交于 2020-03-05 07:11:47

在使用mybatis的注解的形式的时候我们都希望能封装一些基础的方法。本篇内容就是基于此,本篇内容的源码

源码

如果可以,欢迎点个star

BaseMapper如下:

/**
 * 基础base
 * @param <T>
 * @param <K>
 */
public interface BaseMapper<T, K> {
    /**
     * 插入
     * @param model
     * @return
     */
    @Lang(BaseMapperDriver.class)
    @Insert({"<script>", "INSERT INTO ${table} ${values}", "</script>"})
    @Options(useGeneratedKeys = true, keyColumn = "id", keyProperty = "id")
    Long insert(T model);

    /**
     * 修改
     * @param model
     * @return
     */
    @Lang(BaseMapperDriver.class)
    @Update({"<script>", "UPDATE ${table} ${sets} WHERE ${id}=#{id}", "</script>"})
    Long updateById(T model);

    /**
     * 删除
     * @param id
     * @return
     */
    @Lang(BaseMapperDriver.class)
    @Delete("DELETE FROM ${table} WHERE ${id}=#{id}")
    Long deleteById(@Param("id") K id);

    /**
     * 根据ID获取
     * @param id
     * @return
     */
    @Lang(BaseMapperDriver.class)
    @Select("SELECT * FROM ${table} WHERE ${id}=#{id}")
    T getById(@Param("id") K id);

    /**
     * 判断是否存在
     * @param id
     * @return
     */
    @Lang(BaseMapperDriver.class)
    @Select("SELECT COUNT(1) FROM ${table} WHERE ${id}=#{id}")
    Boolean existById(@Param("id") K id);
}

为了不影响其他的SQL或者方法,这里自定义mybatis的语言:

BaseMapperDriver

具体实现如下:

/**
 * 定义自定义的语言
 * @author 大仙
 */
public class BaseMapperDriver extends XMLLanguageDriver implements LanguageDriver {

    @Override
    public SqlSource createSqlSource(Configuration configuration, String script, Class<?> parameterType) {
        //获取当前mapper
        Class<?> mapperClass = null;
        if(configuration instanceof MybatisConfig){
            mapperClass = MybatisMapperRegistry.getCurrentMapper();
        }
        if(mapperClass == null){
            throw new RuntimeException("解析SQL出错");
        }
        //处理SQL
        if(mapperClass!=null) {
            Class<?>[] generics = getMapperGenerics(mapperClass);
            Class<?> modelClass = generics[0];
            Class<?> idClass = generics[1];
            //表名
            script = setTable(script, modelClass);
            //主键
            script = setId(script, modelClass);
            //插入
            script = setValues(script,modelClass);
            //修改
            script = setSets(script, modelClass);
            //IN语句
            script = setIn(script);
            //单表查询结果映射,利用别名
            script = setResultAlias(script,modelClass);
        }

        return super.createSqlSource(configuration, script, parameterType);
    }

    /**
     * 获取泛型
     * @param mapperClass
     * @return
     */
    private  Class<?>[] getMapperGenerics(Class<?> mapperClass){
        Class<?>[]  classes = new Class[2];
        Type[] types =  mapperClass.getGenericInterfaces();
        for(Type type:types){
            ParameterizedType parameterizedType = (ParameterizedType)type;
            Type[] types1 = parameterizedType.getActualTypeArguments();
            classes[0] = (Class<?>) types1[0];
            classes[1] = (Class<?>) types1[1];
        }
        return classes;
    }

    /**
     * 设置表名
     * @param script
     * @param modelClass
     * @return
     */
    private String setTable(String script, Class<?> modelClass){
        final Pattern inPattern = Pattern.compile("\\$\\{table\\}");
        Matcher matcher = inPattern.matcher(script);
        if (matcher.find()) {
            //如果注解相同
            if (modelClass.isAnnotationPresent(Table.class)) {
                script = script.replaceAll("\\$\\{table\\}", modelClass.getAnnotation(Table.class).name());
            } else {
                System.out.println("=====" + modelClass.getSimpleName());
                script = script.replaceAll("\\$\\{table\\}", CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, modelClass.getSimpleName()));
            }
        }
        return script;
    }

    /**
     * 替换ID
     * @param script
     * @param modelClass
     * @return
     */
    private String setId(String script,Class<?> modelClass){
        final Pattern inPattern = Pattern.compile("\\$\\{id\\}");
        Matcher matcher = inPattern.matcher(script);
        if (matcher.find()) {
            boolean exitIdEnum = false;
            for (Field field : modelClass.getDeclaredFields()) {
                if (field.isAnnotationPresent(Id.class)) {
                    script = script.replaceAll("\\$\\{id\\}", field.getAnnotation(Id.class).name());
                    exitIdEnum = true;
                    break;
                }
            }
            if (!exitIdEnum) {
                script = script.replaceAll("\\$\\{id\\}", "id");
            }
        }
        return script;
    }

    /**
     * 替换sets
     * @param script
     * @param modelClass
     * @return
     */
    private String setSets(String script,Class<?> modelClass){
        final Pattern inPattern = Pattern.compile("\\$\\{sets\\}");
        Matcher matcher = inPattern.matcher(script);
        if (matcher.find()) {
            StringBuffer ss = new StringBuffer();
            ss.append("<set>");
            //是否使用父类的属性
            if(modelClass.isAnnotationPresent(UserParent.class)){
                //获取父类
                Class<?> superClass = modelClass.getSuperclass();
                for(Field field : superClass.getDeclaredFields()){
                    //非public和protected的不处理
                    if(!(Modifier.isPublic(field.getModifiers())||Modifier.isProtected(field.getModifiers()))){
                        continue;
                    }
                    //如果不显示,直接返回
                    if (field.isAnnotationPresent(Invisiable.class)) {
                        continue;
                    }
                    //如果不显示,直接返回
                    if (field.isAnnotationPresent(Id.class)) {
                        continue;
                    }
                    //非驼峰命名规则
                    String temp = "<if test=\"__field != null\">__column=#{__field},</if>";
                    if(field.isAnnotationPresent(Column.class)){
                        ss.append(temp.replaceAll("__field", field.getName())
                                .replaceAll("__column",field.getAnnotation(Column.class).name() ));
                        continue;
                    }
                    //驼峰命名规则
                    ss.append(temp.replaceAll("__field", field.getName())
                            .replaceAll("__column", CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, field.getName())));
                }

            }
            //本身
            for (Field field : modelClass.getDeclaredFields()) {
                //如果不显示,直接返回
                if (field.isAnnotationPresent(Invisiable.class)) {
                    continue;
                }
                //如果不显示,直接返回
                if (field.isAnnotationPresent(Id.class)) {
                    continue;
                }
                //非驼峰命名规则
                String temp = "<if test=\"__field != null\">__column=#{__field},</if>";
                if(field.isAnnotationPresent(Column.class)){
                    ss.append(temp.replaceAll("__field", field.getName())
                            .replaceAll("__column",field.getAnnotation(Column.class).name() ));
                    continue;
                }
                //驼峰命名规则
                ss.append(temp.replaceAll("__field", field.getName())
                        .replaceAll("__column", CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, field.getName())));
            }

            ss.deleteCharAt(ss.lastIndexOf(","));
            ss.append("</set>");

            script = matcher.replaceAll(ss.toString());
        }
        return script;
    }

    /**
     * 设置Value
     * @param script
     * @param modelClass
     * @return
     */
    private String setValues(String script,Class<?> modelClass){
        final Pattern inPattern = Pattern.compile("\\$\\{values\\}");
        Matcher matcher = inPattern.matcher(script);
        if (matcher.find()) {
            StringBuffer ss = new StringBuffer();
            List<String> columns = new ArrayList<>();
            List<String> values = new ArrayList<>();
            //是否使用父类的属性
            if(modelClass.isAnnotationPresent(UserParent.class)){
                //获取父类
                Class<?> superClass = modelClass.getSuperclass();
                for(Field field : superClass.getDeclaredFields()){
                    //非public和protected的不处理
                    if(!(Modifier.isPublic(field.getModifiers())||Modifier.isProtected(field.getModifiers()))){
                        continue;
                    }
                    //如果不显示,直接返回
                    if (field.isAnnotationPresent(Invisiable.class)) {
                        continue;
                    }
                    //非驼峰命名规则
                    values.add("#{"+field.getName()+"}");
                    if(field.isAnnotationPresent(Column.class)){
                        columns.add(field.getAnnotation(Column.class).name() );
                    }else {
                        //驼峰命名规则
                        columns.add(CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, field.getName()));
                    }
                }

            }
            //自身
            for (Field field : modelClass.getDeclaredFields()) {
                //如果不显示,直接返回
                if (field.isAnnotationPresent(Invisiable.class)) {
                    continue;
                }
                //非驼峰命名规则
                values.add("#{"+field.getName()+"}");
                if(field.isAnnotationPresent(Column.class)){
                    columns.add(field.getAnnotation(Column.class).name() );
                }else {
                    //驼峰命名规则
                    columns.add(CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, field.getName()));
                }
            }
            ss.append("("+ StringUtils.join(columns.toArray(),",") +") VALUES ("+ StringUtils.join(values.toArray(),",")+")");
            script = matcher.replaceAll(ss.toString());
        }
        return script;
    }

    /**
     * in语句
     * @param script
     * @return
     */
    private String setIn(String script){
        final Pattern inPattern = Pattern.compile("\\$\\{ins\\}");
        Matcher matcher = inPattern.matcher(script);
        if (matcher.find()) {
           script = matcher.replaceAll("(<foreach collection=\"$1\" item=\"__item\" separator=\",\" >#{__item}</foreach>)");
        }
        return script;
    }


    private String setResultAlias(String script,Class<?> modelClass){

        return script;
    }
}

那么其中的注解等内容请自行阅读源码,这里讲解核心的当前mapper获取的方法。我们自定义MapperRegistry,重写他的addMapper的方法,捕获当前mapper,并存储,并在使用完成之后进行丢弃。

/**
 * 自定义mapperRegistry
 * @author 大仙
 */
public class MybatisMapperRegistry extends MapperRegistry {


    private final Map<Class<?>, MapperProxyFactory<?>> knownMappers = new HashMap();

    private Configuration config;

    private static Class<?> currentMapper;

    public MybatisMapperRegistry(Configuration config) {
        super(config);
        this.config = config;
    }

    @Override
    public <T> T getMapper(Class<T> type, SqlSession sqlSession) {
        MapperProxyFactory<T> mapperProxyFactory = (MapperProxyFactory)this.knownMappers.get(type);
        if (mapperProxyFactory == null) {
            throw new BindingException("Type " + type + " is not known to the MapperRegistry.");
        } else {
            try {
                return mapperProxyFactory.newInstance(sqlSession);
            } catch (Exception var5) {
                throw new BindingException("Error getting mapper instance. Cause: " + var5, var5);
            }
        }
    }

    @Override
    public <T> boolean hasMapper(Class<T> type) {
        return this.knownMappers.containsKey(type);
    }

    @Override
    public <T> void addMapper(Class<T> type) {
        if (type.isInterface()) {
            if (this.hasMapper(type)) {
                throw new BindingException("Type " + type + " is already known to the MapperRegistry.");
            }

            boolean loadCompleted = false;

            try {
                this.knownMappers.put(type, new MapperProxyFactory(type));
                MapperAnnotationBuilder parser = new MapperAnnotationBuilder(this.config, type);
                currentMapper = type;
                parser.parse();
                currentMapper=null;
                loadCompleted = true;
            } finally {
                if (!loadCompleted) {
                    this.knownMappers.remove(type);
                }

            }
        }

    }

    @Override
    public Collection<Class<?>> getMappers() {
        return Collections.unmodifiableCollection(this.knownMappers.keySet());
    }


    @Override
    public void addMappers(String packageName, Class<?> superType) {
        ResolverUtil<Class<?>> resolverUtil = new ResolverUtil();
        resolverUtil.find(new ResolverUtil.IsA(superType), packageName);
        Set<Class<? extends Class<?>>> mapperSet = resolverUtil.getClasses();
        Iterator var5 = mapperSet.iterator();

        while(var5.hasNext()) {
            Class<?> mapperClass = (Class)var5.next();
            this.addMapper(mapperClass);
        }

    }

    @Override
    public void addMappers(String packageName) {
        this.addMappers(packageName, Object.class);
    }


    public static Class<?> getCurrentMapper() {
        return currentMapper;
    }
}

讲该类,配置到config类中。

/**
 * 重写mybatis的configureation
 * @author 大仙
 */
public class MybatisConfig extends Configuration {

    protected final MapperRegistry mapperRegistry;

    public MybatisConfig(){
        super();
        this.mapperRegistry =  new MybatisMapperRegistry(this);
        this.mapUnderscoreToCamelCase = true;
    }

    @Override
    public MapperRegistry getMapperRegistry() {
        return this.mapperRegistry;
    }
    @Override
    public void addMappers(String packageName, Class<?> superType) {
        this.mapperRegistry.addMappers(packageName, superType);
    }

    @Override
    public void addMappers(String packageName) {
        this.mapperRegistry.addMappers(packageName);
    }

    @Override
    public <T> void addMapper(Class<T> type) {
        this.mapperRegistry.addMapper(type);
    }

    @Override
    public <T> T getMapper(Class<T> type, SqlSession sqlSession) {
        return this.mapperRegistry.getMapper(type, sqlSession);
    }

    @Override
    public boolean hasMapper(Class<?> type) {
        return this.mapperRegistry.hasMapper(type);
    }
}

使自定义的config类生效,我这里用的sharding-jdbc做的读写分离,不使用也是一样的配置,初始化SQLSessionFactory并制定config类即可

    @Bean
    @Primary
    public SqlSessionFactory sqlSessionFactory() throws Exception {
        SqlSessionFactoryBean bean = new SqlSessionFactoryBean();
        bean.setDataSource(masterSlaveDataSource());
        bean.setConfiguration(new MybatisConfig());
        return bean.getObject();
    }

 

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!