python bokeh, how to make a correlation plot?

前端 未结 3 1450
慢半拍i
慢半拍i 2021-01-21 15:45

How can I make a correlation heatmap in Bokeh?

import pandas as pd
import bokeh.charts

df = pd.util.testing.makeTimeDataFrame(1000)
c = df.corr()

p = bokeh.cha         


        
3条回答
  •  爱一瞬间的悲伤
    2021-01-21 16:31

    So I think I can provide a baseline code to help do what you are asking using a combination of the answers above and some extra pre-processing.

    Let's assume you have a dataframe df already loaded (in this case the UCI Adult Data) and the correlation coefficients calculated (p_corr).

    import bisect
    #
    from math import pi
    from numpy import arange
    from itertools import chain
    from collections import OrderedDict
    #
    from bokeh.palettes import RdBu as colors  # just make sure to import a palette that centers on white (-ish)
    from bokeh.models import ColorBar, LinearColorMapper
    
    colors = list(reversed(colors[9]))  # we want an odd number to ensure 0 correlation is a distinct color
    labels = df.columns
    nlabels = len(labels)
    
    def get_bounds(n):
        """Gets bounds for quads with n features"""
        bottom = list(chain.from_iterable([[ii]*nlabels for ii in range(nlabels)]))
        top = list(chain.from_iterable([[ii+1]*nlabels for ii in range(nlabels)]))
        left = list(chain.from_iterable([list(range(nlabels)) for ii in range(nlabels)]))
        right = list(chain.from_iterable([list(range(1,nlabels+1)) for ii in range(nlabels)]))
        return top, bottom, left, right
    
    def get_colors(corr_array, colors):
        """Aligns color values from palette with the correlation coefficient values"""
        ccorr = arange(-1, 1, 1/(len(colors)/2))
        color = []
        for value in corr_array:
            ind = bisect.bisect_left(ccorr, value)
            color.append(colors[ind-1])
        return color
    
    p = figure(plot_width=600, plot_height=600,
               x_range=(0,nlabels), y_range=(0,nlabels),
               title="Correlation Coefficient Heatmap (lighter is worse)",
               toolbar_location=None, tools='')
    
    p.xgrid.grid_line_color = None
    p.ygrid.grid_line_color = None
    p.xaxis.major_label_orientation = pi/4
    p.yaxis.major_label_orientation = pi/4
    
    top, bottom, left, right = get_bounds(nlabels)  # creates sqaures for plot
    color_list = get_colors(p_corr.values.flatten(), colors)
    
    p.quad(top=top, bottom=bottom, left=left,
           right=right, line_color='white',
           color=color_list)
    
    # Set ticks with labels
    ticks = [tick+0.5 for tick in list(range(nlabels))]
    tick_dict = OrderedDict([[tick, labels[ii]] for ii, tick in enumerate(ticks)])
    # Create the correct number of ticks for each axis 
    p.xaxis.ticker = ticks
    p.yaxis.ticker = ticks
    # Override the labels 
    p.xaxis.major_label_overrides = tick_dict
    p.yaxis.major_label_overrides = tick_dict
    
    # Setup color bar
    mapper = LinearColorMapper(palette=colors, low=-1, high=1)
    color_bar = ColorBar(color_mapper=mapper, location=(0, 0))
    p.add_layout(color_bar, 'right')
    
    show(p)
    

    This will result in the following plot if the categories are integer encoded (this is a horrible data example):

提交回复
热议问题