ThreadLocal的使用及原理

我是研究僧i 提交于 2020-02-06 17:59:10

一、概述

ThreadLocal简单理解就是针对一个线程做资源的共享,通过set()方法把某些资源放到线程中保存,然后通过get方法获取这个资源。它的出现解决了同一个线程中,不同类的方法中可以共享同一个对象或者变量问题。注意他不是为解决并发中多线程的资源共享,这种场景一般需要加锁。而是为了在多线程之间维护每个线程单独持有的资源,不需要加锁。

可以想象ThreadLocal的使用场景是在一个线程中共享数据库连接,虽然我没有看过mybatis,hibernate的源码(后面会研究),但估计这些框架中少不了ThreadLocal。另外在java读写锁中使用ThreadLocal用来保存一个线程重入读锁的次数。

下面我们从ThreadLocal的基础代码结构开始了解源码实现。

二、基础代码结构

以下代码先不用细究,等看了后面set()方法再来认真阅读也可以。从下面代码中我们要明白两个事情:

  1. ThreadLocal和对应的变量都是保存在内部类ThreadLocalMap的table数组中。
  2. ThreadLocalMap的Entry使用了弱引用,我们后面会讲解这里使用弱引用的好处。
public class ThreadLocal<T> {
    /** 生成ThreadLocal的哈希码 */
    private final int threadLocalHashCode = nextHashCode();

    /** 用AtomicInteger生成哈希码,注意这里是静态的  */
    private static AtomicInteger nextHashCode =
        new AtomicInteger();

    /** 哈希增量 */
    private static final int HASH_INCREMENT = 0x61c88647;

    /** 返回下一个哈希码 */
    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }
	
	/** 构造器比较简单 */
	public ThreadLocal() {
    }

	/** ThreadLocal的主要逻辑在ThreadLocalMap里面, ThreadLocalMap是线程独立的,每个线程最多保留一份*/
	static class ThreadLocalMap {

        /** Entry使用了弱引用 */
        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

        /**
         * The initial capacity -- MUST be a power of two.
         */
        private static final int INITIAL_CAPACITY = 16;

        /** 单个线程table可以保留多个ThreadLocal变量 */
        private Entry[] table;

        /** table的大小 */
        private int size = 0;

        /** 扩容阈值,table大小超出会扩容 */
        private int threshold; // Default to 0

        /**
         * Set the resize threshold to maintain at worst a 2/3 load factor.
         */
        private void setThreshold(int len) {
            threshold = len * 2 / 3;
        }

        /**
         * Increment i modulo len.
         */
        private static int nextIndex(int i, int len) {
            return ((i + 1 < len) ? i + 1 : 0);
        }

        /**
         * Decrement i modulo len.
         */
        private static int prevIndex(int i, int len) {
            return ((i - 1 >= 0) ? i - 1 : len - 1);
        }

        /** ThreadLocalMap的构造器,会初始化table的第一个值,以及设置扩容阈值 */
        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }
	}  
}	      

三、变量设置

ThreadLocal的变量设置在set()方法中,我们从这个方法中就能大概领悟ThreadLocal的原理了。阅读set()方法之前我们先用一个例子和示意图来简要梳理下ThreadLocal保存变量的原理。

 public static void main(String[] args) throws InterruptedException {
        ThreadLocal threadLocal = new ThreadLocal();
        String var1 = "变量1";
        String var2 = "变量2";

        Thread thread1 = new Thread(()->{
            threadLocal.set(var1);
            System.out.println("线程1的变量=" + threadLocal.get());
        });

        Thread thread2 = new Thread(()->{
            threadLocal.set(var2);
            System.out.println("线程2的变量=" + threadLocal.get());
        });

        thread1.start();
        thread2.start();

        thread1.join();
        thread2.join();
    }

运行输出:
线程2的变量=变量2
线程1的变量=变量1

上面的例子可以结合下面的示意图理解,首先main方法里创建了ThreadLocal以及变量var1、var2,然后分别给线程1(thread1)和线程2(thread2)设置变量,最后再分别打印出设置的变量。从下图中看到,每个线程对象都有一个threadLocals变量(Thread类里面的),其类型为上面源码说到的ThreadLocalMap。threadLocal和var1、var2作为key-value保存在ThreadLocalMap的Entry数组里。可以看到每个线程都维护了一个Entry数组,其保证了变量是线程独立的。
在这里插入图片描述

		public void set(T value) {
	        Thread t = Thread.currentThread();
	        // 从Thread对象中获取ThreadLocalMap
	        ThreadLocalMap map = getMap(t);
	        if (map != null)
	        	// 有ThreadLocalMap,则设置变量到线程,下面讲解该方法
	            map.set(this, value);
	        else
	        	// 创建一个ThreadLocalMap 
	            createMap(t, value);
	    }
	
		/** Thread类的方法 */
		ThreadLocalMap getMap(Thread t) {
	        return t.threadLocals;
	    }

		/** ThreadLocalMap类里的set()方法 */
		private void set(ThreadLocal<?> key, Object value) {
            Entry[] tab = table;
            int len = tab.length;
            // 用哈希值跟数组大小取与运算找到插入位置
            int i = key.threadLocalHashCode & (len-1);
			
			/** 如果插入位置有值了,则取插入位置的下一个位置(这里可以理解它是环形数组),轮询直到找到插入位置为空,
			或者找到了key相等的,或者发现k为空(说明有些变量失效了,弱引用被gc设置为null,
			但是value却还有值,此时会把这个位置擦除掉)*/
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();
				// key相等,则直接重新设置value
                if (k == key) {
                    e.value = value;
                    return;
                }
				// k为null,则调用replaceStaleEntry()方法,里面有替换和擦除操作
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
			// 如果插入位置为空,则直接插入一个Entry
            tab[i] = new Entry(key, value);
            // Entry数组数量加一
            int sz = ++size;
            // cleanSomeSlots()方法会再次清除key为空的Entry,如果没有需要清除的Entry,并且数组大小超过了阈值,则扩容
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
            	// 扩容方法,扩大两倍,会重新编排Entry的位置
                rehash();
        }
        
        /** 当增加新变量时发现有插入位置有key为null时会调用这个方法,它擦除或者交换Entry*/
		private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            Entry e;
			// 向前找是否有Entry不为空但是key为空的Entry下标
            int slotToExpunge = staleSlot;
            for (int i = prevIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = prevIndex(i, len))
                if (e.get() == null)
                    slotToExpunge = i;

            // 向后找看是有key相等,或者key为null的情况
            for (int i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();

				// 有key相等,则交换Entry,把新的Entry提前,开始找到的位置向后移动
                if (k == key) {
                    e.value = value;

                    tab[i] = tab[staleSlot];
                    tab[staleSlot] = e;

                    // 如果相等,说明向前找没有找到key为空的Entry,那么擦除的位置从i开始即可
                    if (slotToExpunge == staleSlot)
                        slotToExpunge = i;
                    // 调用两个擦除的方法
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }

                // 找到k为null并且slotToExpunge == staleSlot,那么擦除从i开始
                if (k == null && slotToExpunge == staleSlot)
                    slotToExpunge = i;
            }

            // 向后没有找到相同的可以了,那么插入位置为刚刚知道的staleSlot
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);

            // slotToExpunge != staleSlot,说明除了staleSlot位置有key为null的Entry,其他地方还有,那么调用擦除方法
            if (slotToExpunge != staleSlot)
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }
        
	/** 这个方法是从指定位置找到Entry为不空,但是key为null的Entry,然后擦除 */        
	private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // staleSlot位置符合条件,首先擦除
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;

            // Rehash until we encounter null
            Entry e;
            int i;
            // 从staleSlot位置往后找key为null的Entry,并擦除
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                	// 可能擦除了Entry,size会变化,则所有哈希映射的位置都可能会变,这里重新编排
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;

                        // Unlike Knuth 6.4 Algorithm R, we must scan until
                        // null because multiple entries could have been stale.
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            return i;
        }
	
	/** expungeStaleEntry()方法的擦除可能有遗漏,这里再进行log(n)次的查找或者擦除 */
	private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
                i = nextIndex(i, len);
                Entry e = tab[i];
                if (e != null && e.get() == null) {
                    n = len;
                    removed = true;
                    i = expungeStaleEntry(i);
                }
            } while ( (n >>>= 1) != 0);
            return removed;
        }

上面的代码中设计很多擦除的逻辑,是因为ThreadLocal是保存在每个线程中的,加入这个ThreadLocal不需要了,我们是没办法清除每个线程的Entry的,这里可能会引起内存泄露。

但是Entry的key是ThreadLocal,它采用了弱引用。当ThreadLocal没有强引用指向的时候,GC回收期会把弱引用的置空,所以我们就能通过判断Entry的key是否为null,来清除没用的Entry,这样就可以有效的避免内存泄露了。

这里简要的总结下java引用的种类和内存泄露与内存溢出的区别

1.引用的种类

  1. 软引用(SoftReference):当内存不够使用时才会被回收,可以做浏览器的后退缓存页面
  2. 弱引用(WeakReference):当对象只被弱引用指向的话,gc运行时会被回收掉,会置位null。
  3. 强引用:就是我们用new或者反射等方式创建对象的引用。

2.内存泄露与内存溢出的区别:

  1. 内存溢出 (out of memory) :指程序申请内存时,没有足够的内存供申请者使用,例如,有一块存储int类型数据的存储空间,但是程序员却用它存储long类型的数据,那么结果就是内存不够用,此时就会报错OOM,即所谓的内存溢出。
  2. 内存泄漏 (memory leak) :是指程序在申请内存后,无法释放已申请的内存空间,一次内存泄漏似乎不会有大的影响,但内存泄漏堆积后的后果就是内存溢出。

四、获取变量

获取变量直接调用ThreadLocal的get()方法

	public T get() {
        Thread t = Thread.currentThread();
        // 从Thread对象里获取到ThreadLocalMap
        ThreadLocalMap map = getMap(t);
        if (map != null) {
        	// 后面代码分析
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        // 如果ThreadLocalMap里没有ThreadLocal对应的变量,那么就初始一个
        return setInitialValue();
    }

		private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            // 哈希对应的位置刚好找到了变量,否则调用getEntryAfterMiss()方法
            if (e != null && e.get() == key)
                return e;
            else
                return getEntryAfterMiss(key, i, e);
        }
        
		private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;

			// 如果对应位置不为空,则往数组后面轮询看是否可以找到key相等的对象
            while (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == key)
                    return e;
                // k为空则擦除
                if (k == null)
                    expungeStaleEntry(i);
                else
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;
        }
    
	private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }

	// 初始化的value为null
	protected T initialValue() {
        return null;
    }

五、删除变量

上面讲解设置变量方法的时候,我们看到很多擦除的逻辑,其实为了安全起见,如果线程中发现ThreadLocal不再使用的话,我们要调用remove方法删除Entry。

	public void remove() {
		 // 从当前线程中取出ThreadLocalMap
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null)
             m.remove(this);
     }

	private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                	// 找到key对应的Entry后,把引用置为null,然后调用擦除方法
                    e.clear();
                    expungeStaleEntry(i);
                    return;
                }
            }
        }
        
	public void clear() {
        this.referent = null;
    }

六、总结

ThreadLocal是单个线程不同类或者方法之间保存和传递变量的工具,它是为了避免多线程对变量的影响而设计的。另外如果我们不再使用ThreadLocal的时候注意要调用其remove()方法,把对应的Entry清除。

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!