I\'m working on multi-class segmentation using Keras and U-net.
I have as output of my NN 12 classes using soft max Activation function. the shape of my output is (N
You are misusing sample_weight
. As its name clearly implies, it assigns a weight in each sample; so, despite you having only 481 samples, you pass something of length 82944 (and additionally, of 2 dimensions), hence the expected error:
ValueError: Found a sample_weight array with shape (82944, 12) for an input with shape (481, 288, 288). sample_weight cannot be broadcast.
So, what you actually need is a sample_weight
1D-array of length equal to your training sample, with each element of it being the weight of the corresponding sample - which, in turn, should be the same for each class, as you show.
Here is how you can do it using some dummy data y
of 12 classes and only 30 samples:
import numpy as np
y = np.random.randint(12, size=30) # dummy data, 12 classes
y
# array([ 8, 0, 6, 8, 9, 9, 7, 11, 6, 4, 6, 3, 10, 8, 7, 7, 11,
# 2, 5, 8, 8, 1, 7, 2, 7, 9, 5, 2, 0, 0])
sample_weights = np.zeros(len(y))
# your own weight corresponding here:
sample_weights[y==0] = 7
sample_weights[y==1] = 10
sample_weights[y==2] = 2
sample_weights[y==3] = 3
sample_weights[y==4] = 4
sample_weights[y==5] = 5
sample_weights[y==6] = 6
sample_weights[y==7] = 50
sample_weights[y==8] = 8
sample_weights[y==9] = 9
sample_weights[y==10] = 50
sample_weights[y==11] = 11
sample_weights
# result:
array([ 8., 7., 6., 8., 9., 9., 50., 11., 6., 4., 6., 3., 50.,
8., 50., 50., 11., 2., 5., 8., 8., 10., 50., 2., 50., 9.,
5., 2., 7., 7.])
Let's put them in a nice dataframe, for better viewing:
import pandas as pd
d = {'y': y, 'weight': sample_weights}
df = pd.DataFrame(d)
print(df.to_string(index=False))
# result:
y weight
8 8.0
0 7.0
6 6.0
8 8.0
9 9.0
9 9.0
7 50.0
11 11.0
6 6.0
4 4.0
6 6.0
3 3.0
10 50.0
8 8.0
7 50.0
7 50.0
11 11.0
2 2.0
5 5.0
8 8.0
8 8.0
1 10.0
7 50.0
2 2.0
7 50.0
9 9.0
5 5.0
2 2.0
0 7.0
0 7.0
and where of course you should replace sample_weight=class_weights
in your model.fit
with sample_weight=sample_weights
.