This repository was archived by the owner on Jun 12, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathlatent_regularizer.py
121 lines (100 loc) · 4.86 KB
/
latent_regularizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
''' latent regularizer - this file contains our proposed regularization loss.
// Copyright (c) 2019 Robert Bosch GmbH
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
import matplotlib
import torch
from utils import compute_empirical_covariance, compute_gmm_covariance
matplotlib.use('Agg')
def mean_squared_kolmogorov_smirnov_distance_gmm_broadcasting(embedding_matrix, gmm_centers, gmm_std):
"""Return the kolmogorov distance for each dimension.
embedding_matrix:
The latent representation of the batch.
gmm_centers:
Centers of the GMM components in that space. All are assumed to have the same weight
gmm_std:
All components of the GMM are assumed to have share the same covariance matrix: C = gmm_std**2 * Identity.
Note that the returned distances are NOT in the same order as embedding matrix.
Thus, this is useful for means/max, but not for visual inspection.
"""
sorted_embeddings = torch.sort(embedding_matrix, dim=-2).values
emb_num, emb_dim = sorted_embeddings.shape[-2:]
num_gmm_centers, _ = gmm_centers.shape
# For the sorted embeddings, the empirical CDF depends to the "index" of each
# embedding (the number of embeddings before it).
# Unsqueeze enables broadcasting
empirical_cdf = torch.linspace(
start=1 / emb_num,
end=1.0,
steps=emb_num,
device=embedding_matrix.device,
dtype=embedding_matrix.dtype,
).unsqueeze(-1)
normalized_embedding_distances_to_centers = (sorted_embeddings[:, None] - gmm_centers[None]) / gmm_std
# compute CDF values for the embeddings using the Error Function
normal_cdf_per_center = 0.5 * (1 + torch.erf(normalized_embedding_distances_to_centers * 0.70710678118))
normal_cdf = normal_cdf_per_center.mean(dim=1)
return torch.nn.functional.mse_loss(normal_cdf, empirical_cdf)
def mean_squared_covariance_gmm(embedding_matrix, gmm_centers, gmm_std):
"""Compute mean squared distance between the empirical covariance matrix of a embedding matrix and the covariance of
a GMM prior with given centers and per center standard deviation under
the assumption that different dimensions are uncorrelated on a per center level and equal weighing of modes.
Parameters
----------
embedding_matrix: torch.Tensor
Latent Vectors.
gmm_centers:
Centers of the GMM components in that space. All are assumed to have the same weight
gmm_std:
All components of the GMM are assumed to have share the same covariance matrix: C = gmm_std**2 * Identity.
Returns
-------
mean_cov: float
Mean squared distance between empirical and prior covariance.
"""
# Compute empirical covariances:
sigma = compute_empirical_covariance(embedding_matrix)
comp_covariance, gmm_covariance = compute_gmm_covariance(gmm_centers, gmm_std)
comp_covariance.to(embedding_matrix.device)
gmm_covariance.to(embedding_matrix.device)
diff = torch.pow(sigma - gmm_covariance, 2)
mean_cov = torch.mean(diff)
return mean_cov
def mean_squared_individual_covariance_gmm(embedding_matrix, gmm_centers, gmm_std):
"""Compute mean squared distance between the empirical covariance matrix of a embedding matrix and the covariance of
a GMM prior with given centers and per center standard deviation under
the assumption that different dimensions are uncorrelated on a per center level and equal weighing of modes.
Parameters
----------
embedding_matrix: torch.Tensor
Latent Vectors.
gmm_centers:
Centers of the GMM components in that space. All are assumed to have the same weight
gmm_std:
All components of the GMM are assumed to have share the same covariance matrix: C = gmm_std**2 * Identity.
Returns
-------
mean_cov: float
Mean squared distance between empirical and prior covariance.
"""
# Compute empirical covariances:
sigma = compute_empirical_covariance(embedding_matrix)
# Compare it with GMM covariance matrix.
comp_covariance, gmm_covariance = compute_gmm_covariance(gmm_centers, gmm_std)
comp_covariance.to(embedding_matrix.device)
gmm_covariance.to(embedding_matrix.device)
diff = torch.pow(sigma - gmm_covariance, 2)
mean_cov = torch.mean(diff)
return mean_cov