What can cause loss from model.get_latest_training_loss()
increase on each epoch?
Code, used for training:
class EpochSaver(CallbackA
Up through gensim 3.6.0, the loss value reported may not be very sensible, only resetting the tally each call to train()
, rather than each internal epoch. There are some fixes forthcoming in this issue:
https://github.com/RaRe-Technologies/gensim/pull/2135
In the meantime, the difference between the previous value, and the latest, may be more meaningful. In that case, your data suggest the 1st epoch had a total loss of 745896, while the last had (9676936-9280568=) 396,368 – which may indicate the kind of progress hoped-for.
As proposed by gojomo you can calculate the difference of loss in the callback function:
from gensim.models.callbacks import CallbackAny2Vec
from gensim.models import Word2Vec
# init callback class
class callback(CallbackAny2Vec):
"""
Callback to print loss after each epoch
"""
def __init__(self):
self.epoch = 0
def on_epoch_end(self, model):
loss = model.get_latest_training_loss()
if self.epoch == 0:
print('Loss after epoch {}: {}'.format(self.epoch, loss))
else:
print('Loss after epoch {}: {}'.format(self.epoch, loss- self.loss_previous_step))
self.epoch += 1
self.loss_previous_step = loss
For the training of your model and add computer_loss = True
and callbacks=[callback()]
in the word2vec train method:
# init word2vec class
w2v_model = Word2Vec(min_count=20,
window=12
size=100,
workers=2)
# build vovab
w2v_model.build_vocab(sentences)
# train the w2v model
w2v_model.train(senteces,
total_examples=w2v_model.corpus_count,
epochs=10,
report_delay=1,
compute_loss = True, # set compute_loss = True
callbacks=[callback()]) # add the callback class
# save the word2vec model
w2v_model.save('word2vec.model')
This will output something like this:
Loss after epoch 0: 4448638.5
Loss after epoch 1: 3283735.5
Loss after epoch 2: 2826198.0
Loss after epoch 3: 2680974.0
Loss after epoch 4: 2601113.0
Loss after epoch 5: 2271333.0
Loss after epoch 6: 2052050.0
Loss after epoch 7: 2011768.0
Loss after epoch 8: 1927454.0
Loss after epoch 9: 1887798.0