skip-gram是根据中心词预测周边词,窗口大小为C,负例采样数为K。
import torch
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import numpy as np
# 设定一些超参数
K = 100 # number of negative samples
C = 3 # nearby words threshold
MAX_VOCAB_SIZE = 30000 # the vocabulary size
BATCH_SIZE = 128 # the batch size
# 打开train数据集
with open('data/text8.train.txt', 'r') as fin:
text = fin.read()
# 分割所有单词
text = [word for word in text.split()]
# {word:number}
vocab = dict(Counter(text).most_common(MAX_VOCAB_SIZE-1))
# 新增 "<unk>"字符
vocab["<unk>"] = len(text) - np.sum(list(vocab.values()))
# 字符集列表
idx_to_word = [word for word in vocab.keys()]
# 字符集和对应的位置索引构成的词典
word_to_idx = {word: i for i, word in enumerate(idx_to_word)}
"""
统计词典中词出现的频率
"""
# 获取单词出现的个数
word_counts = np.array([count for count in vocab.values()], dtype=np.float32)
# 计算频率
word_freqs = word_counts / np.sum(word_counts)
# 0.75 次幂
word_freqs = word_freqs ** (3./4.)
# 归一化
word_freqs = word_freqs / np.sum(word_freqs) # 用来做 negative sampling
VOCAB_SIZE = len(idx_to_word)
class WordEmbeddingDataset(torch.utils.data.Dataset):
def __init__(self, text, word_to_idx, idx_to_word, word_freqs, word_counts):
super(WordEmbeddingDataset, self).__init__()
# 将单词转换成数字索引
self.text_encoded = [word_to_idx.get(t, VOCAB_SIZE-1) for t in text]
self.text_encoded = torch.Tensor(self.text_encoded).long()
# dict:word->index
self.word_to_idx = word_to_idx
# list: index->word
self.idx_to_word = idx_to_word
# 单词频率
self.word_freqs = torch.Tensor(word_freqs)
# 单词次数统计
self.word_counts = torch.Tensor(word_counts)
def __len__(self):
return len(self.text_encoded)
def __getitem__(self, idx):
# 中心词
center_word = self.text_encoded[idx]
# 周边词
pos_indices = list(range(idx - C, idx)) + list(range(idx + 1, idx + C + 1))
pos_indices = [i % len(self.text_encoded) for i in pos_indices]
# 正采样
pos_words = self.text_encoded[pos_indices]
# 负采样
neg_words = torch.multinomial(self.word_freqs, K * pos_words.shape[0], True)
return center_word, pos_words, neg_words
dataset = WordEmbeddingDataset(text, word_to_idx, idx_to_word, word_freqs, word_counts)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
来源:oschina
链接:https://my.oschina.net/u/4228078/blog/4405730