Scatter plots in Pandas/Pyplot: How to plot by category

前端 未结 8 911
孤城傲影
孤城傲影 2020-11-22 10:53

I am trying to make a simple scatter plot in pyplot using a Pandas DataFrame object, but want an efficient way of plotting two variables but have the symbols dictated by a t

相关标签:
8条回答
  • 2020-11-22 11:18

    From matplotlib 3.1 onwards you can use .legend_elements(). An example is shown in Automated legend creation. The advantage is that a single scatter call can be used.

    In this case:

    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    
    df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), 
                      index = pd.date_range('2010-01-01', freq = 'M', periods = 10), 
                      columns = ('one', 'two', 'three'))
    df['key1'] = (4,4,4,6,6,6,8,8,8,8)
    
    
    fig, ax = plt.subplots()
    sc = ax.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
    ax.legend(*sc.legend_elements())
    plt.show()
    

    In case the keys were not directly given as numbers, it would look as

    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    
    df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), 
                      index = pd.date_range('2010-01-01', freq = 'M', periods = 10), 
                      columns = ('one', 'two', 'three'))
    df['key1'] = list("AAABBBCCCC")
    
    labels, index = np.unique(df["key1"], return_inverse=True)
    
    fig, ax = plt.subplots()
    sc = ax.scatter(df['one'], df['two'], marker = 'o', c = index, alpha = 0.8)
    ax.legend(sc.legend_elements()[0], labels)
    plt.show()
    

    0 讨论(0)
  • 2020-11-22 11:20

    You can use scatter for this, but that requires having numerical values for your key1, and you won't have a legend, as you noticed.

    It's better to just use plot for discrete categories like this. For example:

    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    np.random.seed(1974)
    
    # Generate Data
    num = 20
    x, y = np.random.random((2, num))
    labels = np.random.choice(['a', 'b', 'c'], num)
    df = pd.DataFrame(dict(x=x, y=y, label=labels))
    
    groups = df.groupby('label')
    
    # Plot
    fig, ax = plt.subplots()
    ax.margins(0.05) # Optional, just adds 5% padding to the autoscaling
    for name, group in groups:
        ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
    ax.legend()
    
    plt.show()
    

    enter image description here

    If you'd like things to look like the default pandas style, then just update the rcParams with the pandas stylesheet and use its color generator. (I'm also tweaking the legend slightly):

    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    np.random.seed(1974)
    
    # Generate Data
    num = 20
    x, y = np.random.random((2, num))
    labels = np.random.choice(['a', 'b', 'c'], num)
    df = pd.DataFrame(dict(x=x, y=y, label=labels))
    
    groups = df.groupby('label')
    
    # Plot
    plt.rcParams.update(pd.tools.plotting.mpl_stylesheet)
    colors = pd.tools.plotting._get_standard_colors(len(groups), color_type='random')
    
    fig, ax = plt.subplots()
    ax.set_color_cycle(colors)
    ax.margins(0.05)
    for name, group in groups:
        ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
    ax.legend(numpoints=1, loc='upper left')
    
    plt.show()
    

    enter image description here

    0 讨论(0)
提交回复
热议问题