How to iterate a list of list for a scatter plot and create a legend of unique elements

耗尽温柔 提交于 2020-03-04 21:34:30

问题


Background:

I have a list_of_x_and_y_list that contains x and y values which looks like:

[[(44800, 14888), (132000, 12500), (40554, 12900)], [(None, 193788), (101653, 78880), (3866, 160000)]]

I have another data_name_list ["data_a","data_b"] so that

  • "data_a" = [(44800, 14888), (132000, 12500), (40554, 12900)]

  • "data_b" = [(None, 193788), (101653, 78880), (3866, 160000)]

The len of list_of_x_and_y_list / or len of data_name_list is > 20.

Question:

How can I create a scatter plot for each item (being the same colour) in the data_name_list?

What I have tried:

   fig = plt.figure()
   ax = fig.add_subplot(1, 1, 1)
   ax = plt.axes(facecolor='#FFFFFF')
   prop_cycle = plt.rcParams['axes.prop_cycle']
   colors = prop_cycle.by_key()['color']

   print(list_of_x_and_y_list)
   for x_and_y_list, data_name, color in zip(list_of_x_and_y_list, data_name_list, colors):
       for x_and_y in x_and_y_list,:
          print(x_and_y)
          x, y = x_and_y
          ax.scatter(x, y, label=data_name, color=color) # "label=data_name" creates 
                                                         # a huge list as a legend! 
                                                         # :(


       plt.title('Matplot scatter plot')
       plt.legend(loc=2)
       file_name = "3kstc.png"
       fig.savefig(file_name, dpi=fig.dpi)
       print("Generated: {}".format(file_name))

The Problem:

The legend appears to be a very long list, which I don't know how to rectify:

Relevant Research:

  • Matplotlib scatterplot
  • Scatter Plot
  • Scatter plot in Python using matplotlib

回答1:


The reason you get a long repeated list as a legend is because you are providing each point as a separate series, as matplotlib does not automatically group your data based on the labels.

A quick fix is to iterate over the list and zip together the x-values and the y-values of each series as two tuples, so that the x tuple contains all the x-values and the y tuple the y-values.

Then you can feed these tuples to the plt.plot method together with the labels.

I felt that the names list_of_x_and_y_list were uneccessary long and complicated, so in my code I've used shorter names.

import matplotlib.pyplot as plt

data_series = [[(44800, 14888), (132000, 12500), (40554, 12900)],
               [(None, 193788), (101653, 78880), (3866, 160000)]]
data_names = ["data_a","data_b"]

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax = plt.axes(facecolor='#FFFFFF')
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']

for data, data_name, color in zip(data_series, data_names, colors):
    x,y = zip(*data)
    ax.scatter(x, y, label=data_name, color=color)
    plt.title('Matplot scatter plot')
    plt.legend(loc=1)




回答2:


To only get one entry per data_name, you should add data_name only once as a label. The rest of the calls should go with label=None. The simplest you can achieve this using the current code, is to set data_name to None at the end of the loop:

from matplotlib import pyplot as plt
from random import randint

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.set_facecolor('#FFFFFF')
# create some random data, suppose the sublists have different lengths
list_of_x_and_y_list = [[(randint(1000, 4000), randint(2000, 5000)) for col in range(randint(2, 10))]
                        for row in range(10)]
data_name_list = list('abcdefghij')
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
for x_and_y_list, data_name, color in zip(list_of_x_and_y_list, data_name_list, colors):
    for x_and_y in x_and_y_list :
        x, y = x_and_y
        ax.scatter(x, y, label=data_name, color=color)
        data_name = None
plt.legend(loc=2)
plt.show()

Some things can be simplified, making the code 'more pythonic', for example:

for x_and_y in x_and_y_list :
    x, y = x_and_y

can be written as:

for x, y in x_and_y_list:

Another issue, is that with a lot of data calling scatter for every point could be rather slow. All the x and y belonging to the same list can be plotted together. For example using list comprehension:

for x_and_y_list, data_name, color in zip(list_of_x_and_y_list, data_name_list, colors):
    xs = [x for x, y in x_and_y_list]
    ys = [y for x, y in x_and_y_list]
    ax.scatter(xs, ys, label=data_name, color=color)

scatter could even get a list of colors per point, but plotting all the points in one go, wouldn't allow for labels per data_name.

Very often, numpy is used to store numerical data. This has some advantages, such as vectorization for quick calculations. With numpy the code would look like:

import numpy as np

for x_and_y_list, data_name, color in zip(list_of_x_and_y_list, data_name_list, colors):
    xys = np.array(x_and_y_list)
    ax.scatter(xys[:,0], xys[:,1], label=data_name, color=color)



来源:https://stackoverflow.com/questions/60199943/how-to-iterate-a-list-of-list-for-a-scatter-plot-and-create-a-legend-of-unique-e

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!