-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpca_across_time.py
85 lines (67 loc) · 2.76 KB
/
pca_across_time.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import numpy as np
from sklearn.decomposition import PCA
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import patheffects
from math import gcd
from typing import List
import string
import random
from src import config
def equidistant_elements(l, n):
"""return "n" elements from "l" such that they are equally far apart iin "l" """
while not gcd(n, len(l)) == n:
l.pop()
step = len(l) // n
ids = np.arange(step, len(l) + step, step) - 1 # -1 for indexing
res = np.asarray(l)[ids].tolist()
return res
def make_pca_across_time_fig(embeddings: np.ndarray,
words: List[str],
component1: int,
component2: int,
num_ticks: int,
) -> plt.Figure:
"""
Returns res showing evolution of embeddings in 2D space using PCA.
"""
assert np.ndim(embeddings) == 3 # (ticks, words, embedding dimensions)
assert len(words) == embeddings.shape[1]
palette = np.array(sns.color_palette("hls", embeddings.shape[1]))
model_ticks = [n for n, _ in enumerate(embeddings)]
equidistant_ticks = equidistant_elements(model_ticks, num_ticks)
# fit pca model on last tick
num_components = component2 + 1
pca_model = PCA(n_components=num_components)
pca_model.fit(embeddings[-1])
# transform embeddings at requested ticks with pca model
transformations = []
for ei in embeddings[equidistant_ticks]:
transformations.append(pca_model.transform(ei)[:, [component1, component2]])
# fig
res, ax = plt.subplots(figsize=config.Fig.fig_size, dpi=config.Fig.dpi)
ax.set_title(f'Principal components {component1} and {component2}\nEvolution across training')
ax.axis('off')
ax.axhline(y=0, linestyle='--', c='grey', linewidth=1.0)
ax.axvline(x=0, linestyle='--', c='grey', linewidth=1.0)
# plot
for n, word in enumerate(words):
# scatter
x, y = zip(*[t[n] for t in transformations])
ax.plot(x, y, c=palette[n], lw=config.Fig.line_width)
# text
x_pos, y_pos = transformations[-1][n, :]
txt = ax.text(x_pos, y_pos, str(word), fontsize=8,
color=palette[n])
txt.set_path_effects([
patheffects.Stroke(linewidth=config.Fig.line_width, foreground="w"), patheffects.Normal()])
return res
NUM_TICKS = 12
NUM_WORDS = 4
EMBED_SIZE = 8
# create random words and random embeddings
words = [f'word-{n}' for n in range(NUM_WORDS)]
embeddings = np.stack([np.random.random((NUM_WORDS, EMBED_SIZE)) * (NUM_TICKS / (tick + 1))
for tick in range(NUM_TICKS)])
fig = make_pca_across_time_fig(embeddings, words, component1=0, component2=1, num_ticks=6)
fig.show()