-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathwvplot.py
More file actions
110 lines (85 loc) · 3.65 KB
/
wvplot.py
File metadata and controls
110 lines (85 loc) · 3.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from sklearn.decomposition import PCA
from matplotlib import pyplot
from mpl_toolkits.mplot3d import Axes3D
from gensim.models import Word2Vec
from sklearn.cluster import KMeans
from nltk.tokenize import word_tokenize
# Plots words in a given list based on PCA axes
# Model filename, output of wvgen.py
modelf = "mariomodel.bin"
# Name of file with words to be plotted, separated by "," and groups separated by newlines
wordf = "mariochars.txt"
# PCA axes to plot on, the most relevant are [0,1] or [1,2] for 2D and [0,1,2] or [1,2,3] for 3D
axes = [0, 1, 2]
# Number of groups to cluster into, if <= 0 groups are instead created for each line in wordf
clusterK = 3
# To differentiate groups in the graph, you can give the labels a corresponding color or font size
# e.g. words in the first group will be red, words in the second group will be blue, etc.
# Color of words in each group, uses default if too many groups
# Dark colors are good for matplotlib's white background, use hex or https://matplotlib.org/gallery/color/named_colors.html
colors = ["tab:red", "tab:blue", "tab:green", "tab:orange",
"tab:purple", "tab:olive", "tab:pink", "tab:cyan", "tab:gray"]
defaultcolor = "black"
# Font sizes of words in each group
sizes = []
defaultsize = 16
def plot2D(result, wordgroups):
pyplot.scatter(result[:, axes[0]], result[:, axes[1]])
for g, group in enumerate(wordgroups):
for word in group:
if not word in words:
continue
i = words.index(word)
# Create plot point
coord = (result[i, axes[0]], result[i, axes[1]])
color = colors[g] if g < len(colors) else defaultcolor
size = sizes[g] if g < len(sizes) else defaultsize
pyplot.annotate(word, xy=coord, color=color, fontsize=size)
def plot3D(result, wordgroups):
fig = pyplot.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(result[:, axes[0]], result[:, axes[1]], result[:, axes[2]])
for g, group in enumerate(wordgroups):
for word in group:
if not word in words:
continue
i = words.index(word)
# Create plot point
color = colors[g] if g < len(colors) else defaultcolor
size = sizes[g] if g < len(sizes) else defaultsize
ax.text(result[i, axes[0]], result[i, axes[1]],
result[i, axes[2]], word, color=color, fontsize=size)
def get_groups(wordf, model):
words = []
groups = []
# Extract words to plot from file
for line in open("list/" + wordf, "r", encoding="utf-8").read().split("\n"):
l = [' '.join(word_tokenize(x)) for x in line.split(",")]
l = filter(lambda x: x in model.wv.vocab.keys(), l)
groups.append(l)
words += l
# Get word vectors from model
vecs = {w: model.wv.vocab[w] for w in words}
# Assign groups if using clustering
if clusterK > 0:
estimator = KMeans(init='k-means++', n_clusters=clusterK, n_init=10)
estimator.fit_predict(model.wv[vecs])
groups = [[] for n in range(clusterK)]
for i, w in enumerate(vecs.keys()):
group = estimator.labels_[i]
groups[group].append(w)
return words, groups, vecs
if __name__ == '__main__':
model = Word2Vec.load("model/" + modelf)
# Get groups from file or by clustering
words, groups, vecs = get_groups(wordf, model)
coords = model.wv[vecs]
# Create axes to plot on
pca = PCA(n_components=max(axes)+1)
result = pca.fit_transform(coords)
# Plot vectors on axes
if len(axes) > 2:
plot3D(result, groups)
else:
plot2D(result, groups)
pyplot.show()