Problem with plotting graphs in 1 row using plot method from pandas

后端 未结 2 924
野性不改
野性不改 2021-01-23 16:15

Suppose I want to plot 3 graphs in 1 row: dependencies cnt from other 3 features.

Code:

fig, axes = plt.subplots(nrows=1,          


        
相关标签:
2条回答
  • 2021-01-23 16:58

    The problem is not related to pandas. The index error you see comes from ax= axes[0, idx]. This is because you have a single row. [0, idx] would work when you have more than one row.

    For just one row, you can skip the first index and use

    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 10))
    for idx, feature in enumerate(min_regressors):
        df_shuffled.plot(feature, "cnt", subplots=True, kind="scatter", ax= axes[idx])
    plt.show()
    

    As a recap

    Correct

    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(8, 3))
    axes[0].plot([1,2], [1,2])
    

    Incorrect

    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(8, 3))
    axes[0, 0].plot([1,2], [1,2])
    

    Correct

    fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(8, 3))
    axes[0,0].plot([1,2], [1,2])
    
    0 讨论(0)
  • 2021-01-23 17:19

    For you to learn and understand what is happening, I suggest you check the size of axes in both of these situations. You will see that when either nrows or ncols is 1, the axes variable will be 1-dimensional, and otherwise it'll be 2 dimensional.

    You cannot index a 1-dimensional object the way you are doing (ax= axes[0, idx]).

    What you can do is use numpy's atleast_2d to make the axes 2D.

    Alternatively, a better solution would be to iterate over the features and axes directly:

    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 10))
    for ax, feature in zip(axes, min_regressors):
        df_shuffled.plot(feature, "cnt", subplots=True, kind="scatter", ax=ax)
    plt.show()
    
    0 讨论(0)
提交回复
热议问题