Skip to content

Commit 1e6682c

Browse files
Merge pull request #178 from pavlin-policar/mnt
Maintenance
2 parents 43507dd + 0c6e960 commit 1e6682c

7 files changed

+149
-259
lines changed

examples/01_simple_usage.ipynb

+52-36
Large diffs are not rendered by default.

examples/02_advanced_usage.ipynb

+58-49
Large diffs are not rendered by default.

openTSNE/callbacks.py

-36
Original file line numberDiff line numberDiff line change
@@ -39,42 +39,6 @@ def __call__(self, iteration, error, embedding):
3939
"""
4040

4141

42-
class ErrorLogger(Callback):
43-
"""Basic error logger.
44-
45-
This logger prints out basic information about the optimization. These
46-
include the iteration number, error and how much time has elapsed from the
47-
previous callback invocation.
48-
49-
"""
50-
51-
def __init__(self):
52-
warnings.warn(
53-
"`ErrorLogger` will be removed in upcoming version. Please use the "
54-
"`verbose` flag instead.",
55-
category=FutureWarning,
56-
)
57-
self.iter_count = 0
58-
self.last_log_time = None
59-
60-
def optimization_about_to_start(self):
61-
self.last_log_time = time.time()
62-
self.iter_count = 0
63-
64-
def __call__(self, iteration, error, embedding):
65-
now = time.time()
66-
duration = now - self.last_log_time
67-
self.last_log_time = now
68-
69-
n_iters = iteration - self.iter_count
70-
self.iter_count = iteration
71-
72-
print(
73-
"Iteration % 4d, KL divergence % 6.4f, %d iterations in %.4f sec"
74-
% (iteration, error, n_iters, duration)
75-
)
76-
77-
7842
class VerifyExaggerationError(Callback):
7943
"""Used to verify that the exaggeration correction implemented in
8044
`gradient_descent` is correct."""

openTSNE/initialization.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ def rescale(x, inplace=False):
2828
return x
2929

3030

31-
def random(X, n_components=2, random_state=None, verbose=False):
31+
def random(n_samples, n_components=2, random_state=None, verbose=False):
3232
"""Initialize an embedding using samples from an isotropic Gaussian.
3333
3434
Parameters
3535
----------
36-
X: np.ndarray
37-
The data matrix.
36+
n_samples: Union[int, np.ndarray]
37+
The number of samples. Also accepts a data matrix.
3838
3939
n_components: int
4040
The dimension of the embedding space.
@@ -53,7 +53,9 @@ def random(X, n_components=2, random_state=None, verbose=False):
5353
5454
"""
5555
random_state = check_random_state(random_state)
56-
embedding = random_state.normal(0, 1e-4, (X.shape[0], n_components))
56+
if isinstance(n_samples, np.ndarray):
57+
n_samples = n_samples.shape[0]
58+
embedding = random_state.normal(0, 1e-4, (n_samples, n_components))
5759
return np.ascontiguousarray(embedding)
5860

5961

openTSNE/nearest_neighbors.py

+1-105
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class Sklearn(KNNIndex):
108108
"sokalmichener",
109109
"sokalsneath",
110110
"wminkowski",
111-
]
111+
] + ["cosine"] # our own workaround implementation
112112

113113
def __init__(self, *args, **kwargs):
114114
super().__init__(*args, **kwargs)
@@ -205,110 +205,6 @@ def query(self, query, k):
205205
return indices, distances
206206

207207

208-
class BallTree(KNNIndex):
209-
VALID_METRICS = neighbors.BallTree.valid_metrics + ["cosine"]
210-
211-
def __init__(self, *args, **kwargs):
212-
super().__init__(*args, **kwargs)
213-
self.__data = None
214-
215-
warnings.warn(
216-
f"`nearest_neighbors.BallTree` has been superseeded by "
217-
f"`nearest_neighbors.Sklearn` and will be removed from future versions",
218-
category=FutureWarning,
219-
)
220-
221-
def build(self):
222-
data, k = self.data, self.k
223-
224-
timer = utils.Timer(
225-
f"Finding {k} nearest neighbors using exact search using "
226-
f"{self.metric} distance...",
227-
verbose=self.verbose,
228-
)
229-
timer.__enter__()
230-
231-
if self.metric == "cosine":
232-
# The nearest neighbor ranking for cosine distance is the same as
233-
# for euclidean distance on normalized data
234-
effective_metric = "euclidean"
235-
effective_data = data.copy()
236-
effective_data = (
237-
effective_data / np.linalg.norm(effective_data, axis=1)[:, None]
238-
)
239-
# In order to properly compute cosine distances when querying the
240-
# index, we need to store the original data
241-
self.__data = data
242-
else:
243-
effective_metric = self.metric
244-
effective_data = data
245-
246-
self.index = neighbors.NearestNeighbors(
247-
algorithm="ball_tree",
248-
metric=effective_metric,
249-
metric_params=self.metric_params,
250-
n_jobs=self.n_jobs,
251-
)
252-
self.index.fit(effective_data)
253-
254-
# Return the nearest neighbors in the training set
255-
distances, indices = self.index.kneighbors(n_neighbors=k)
256-
257-
# If using cosine distance, the computed distances will be wrong and
258-
# need to be recomputed
259-
if self.metric == "cosine":
260-
distances = np.vstack(
261-
[
262-
cdist(np.atleast_2d(x), data[idx], metric="cosine")
263-
for x, idx in zip(data, indices)
264-
]
265-
)
266-
267-
timer.__exit__()
268-
269-
return indices, distances
270-
271-
def query(self, query, k):
272-
timer = utils.Timer(
273-
f"Finding {k} nearest neighbors in existing embedding using exact search...",
274-
self.verbose,
275-
)
276-
timer.__enter__()
277-
278-
# The nearest neighbor ranking for cosine distance is the same as for
279-
# euclidean distance on normalized data
280-
if self.metric == "cosine":
281-
effective_data = query.copy()
282-
effective_data = (
283-
effective_data / np.linalg.norm(effective_data, axis=1)[:, None]
284-
)
285-
else:
286-
effective_data = query
287-
288-
distances, indices = self.index.kneighbors(effective_data, n_neighbors=k)
289-
290-
# If using cosine distance, the computed distances will be wrong and
291-
# need to be recomputed
292-
if self.metric == "cosine":
293-
if self.__data is None:
294-
raise RuntimeError(
295-
"The original data was unavailable when querying cosine "
296-
"distance. Did you change the distance metric after "
297-
"building the index? Please rebuild the index using cosine "
298-
"similarity."
299-
)
300-
distances = np.vstack(
301-
[
302-
cdist(np.atleast_2d(x), self.__data[idx], metric="cosine")
303-
for x, idx in zip(query, indices)
304-
]
305-
)
306-
307-
timer.__exit__()
308-
309-
return indices, distances
310-
311-
312208
class Annoy(KNNIndex):
313209
"""Annoy KNN Index.
314210

openTSNE/tsne.py

+28-21
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,23 @@ def _check_callbacks(callbacks):
3737
def _handle_nice_params(embedding: np.ndarray, optim_params: dict) -> None:
3838
"""Convert the user friendly params into something the optimizer can
3939
understand."""
40+
n_samples = embedding.shape[0]
4041
# Handle callbacks
4142
optim_params["callbacks"] = _check_callbacks(optim_params.get("callbacks"))
4243
optim_params["use_callbacks"] = optim_params["callbacks"] is not None
4344

4445
# Handle negative gradient method
4546
negative_gradient_method = optim_params.pop("negative_gradient_method")
47+
# Handle `auto` negative gradient method
48+
if isinstance(negative_gradient_method, str) and negative_gradient_method == "auto":
49+
if n_samples < 10_000:
50+
negative_gradient_method = "bh"
51+
else:
52+
negative_gradient_method = "fft"
53+
log.info(
54+
f"Automatically determined negative gradient method `{negative_gradient_method}`"
55+
)
56+
4657
if callable(negative_gradient_method):
4758
negative_gradient_method = negative_gradient_method
4859
elif negative_gradient_method in {"bh", "BH", "barnes-hut"}:
@@ -78,7 +89,7 @@ def _handle_nice_params(embedding: np.ndarray, optim_params: dict) -> None:
7889

7990
# Determine learning rate if requested
8091
if optim_params.get("learning_rate", "auto") == "auto":
81-
optim_params["learning_rate"] = max(200, embedding.shape[0] / 12)
92+
optim_params["learning_rate"] = max(200, n_samples / 12)
8293

8394

8495
def __check_init_num_samples(num_samples, required_num_samples):
@@ -169,7 +180,8 @@ class PartialTSNEEmbedding(np.ndarray):
169180
using one of the following aliases: ``bh``, ``BH`` or ``barnes-hut``.
170181
For larger data sets, the FFT accelerated interpolation method is more
171182
appropriate and can be set using one of the following aliases: ``fft``,
172-
``FFT`` or ``ìnterpolation``.
183+
``FFT`` or ``ìnterpolation``. Alternatively, you can use ``auto`` to
184+
approximately select the faster method.
173185
174186
theta: float
175187
This is the trade-off parameter between speed and accuracy of the tree
@@ -290,6 +302,8 @@ def optimize(
290302
``barnes-hut``. For larger data sets, the FFT accelerated
291303
interpolation method is more appropriate and can be set using one of
292304
the following aliases: ``fft``, ``FFT`` or ``ìnterpolation``.
305+
Alternatively, you can use ``auto`` to approximately select the
306+
faster method.
293307
294308
theta: float
295309
This is the trade-off parameter between speed and accuracy of the
@@ -431,7 +445,8 @@ class TSNEEmbedding(np.ndarray):
431445
using one of the following aliases: ``bh``, ``BH`` or ``barnes-hut``.
432446
For larger data sets, the FFT accelerated interpolation method is more
433447
appropriate and can be set using one of the following aliases: ``fft``,
434-
``FFT`` or ``ìnterpolation``.
448+
``FFT`` or ``ìnterpolation``.A lternatively, you can use ``auto`` to
449+
approximately select the faster method.
435450
436451
theta: float
437452
This is the trade-off parameter between speed and accuracy of the tree
@@ -490,7 +505,7 @@ def __new__(
490505
n_interpolation_points=3,
491506
min_num_intervals=50,
492507
ints_in_interval=1,
493-
negative_gradient_method="fft",
508+
negative_gradient_method="auto",
494509
random_state=None,
495510
optimizer=None,
496511
**gradient_descent_params,
@@ -571,6 +586,8 @@ def optimize(
571586
``barnes-hut``. For larger data sets, the FFT accelerated
572587
interpolation method is more appropriate and can be set using one of
573588
the following aliases: ``fft``, ``FFT`` or ``ìnterpolation``.
589+
Alternatively, you can use ``auto`` to approximately select the
590+
faster method.
574591
575592
theta: float
576593
This is the trade-off parameter between speed and accuracy of the
@@ -1000,7 +1017,8 @@ class TSNE(BaseEstimator):
10001017
This is the trade-off parameter between speed and accuracy of the tree
10011018
approximation method. Typical values range from 0.2 to 0.8. The value 0
10021019
indicates that no approximation is to be made and produces exact results
1003-
also producing longer runtime.
1020+
also producing longer runtime. Alternatively, you can use ``auto`` to
1021+
approximately select the faster method.
10041022
10051023
n_interpolation_points: int
10061024
Only used when ``negative_gradient_method="fft"`` or its other aliases.
@@ -1071,7 +1089,8 @@ class TSNE(BaseEstimator):
10711089
using one of the following aliases: ``bh``, ``BH`` or ``barnes-hut``.
10721090
For larger data sets, the FFT accelerated interpolation method is more
10731091
appropriate and can be set using one of the following aliases: ``fft``,
1074-
``FFT`` or ``ìnterpolation``.
1092+
``FFT`` or ``ìnterpolation``. Alternatively, you can use ``auto`` to
1093+
approximately select the faster method.
10751094
10761095
callbacks: Union[Callable, List[Callable]]
10771096
Callbacks, which will be run every ``callbacks_every_iters`` iterations.
@@ -1113,7 +1132,7 @@ def __init__(
11131132
max_step_norm=5,
11141133
n_jobs=1,
11151134
neighbors="auto",
1116-
negative_gradient_method="fft",
1135+
negative_gradient_method="auto",
11171136
callbacks=None,
11181137
callbacks_every_iters=50,
11191138
random_state=None,
@@ -1154,18 +1173,6 @@ def __init__(
11541173
self.random_state = random_state
11551174
self.verbose = verbose
11561175

1157-
@property
1158-
def neighbors_method(self):
1159-
import warnings
1160-
1161-
warnings.warn(
1162-
f"The `neighbors_method` attribute has been deprecated and will be "
1163-
f"removed in future versions. Please use the new `neighbors` "
1164-
f"attribute",
1165-
category=FutureWarning,
1166-
)
1167-
return self.neighbors
1168-
11691176
def fit(self, X=None, affinities=None, initialization=None):
11701177
"""Fit a t-SNE embedding for a given data set.
11711178
@@ -1324,7 +1331,7 @@ def prepare_initial(self, X=None, affinities=None, initialization=None):
13241331
initialization = "spectral"
13251332

13261333
# Same spiel for precomputed distance matrices
1327-
if self.metric == "precomputed" and initialization == "pca":
1334+
if self.metric == "precomputed" and isinstance(initialization, str) and initialization == "pca":
13281335
log.warning(
13291336
"Attempting to use `pca` initalization, but using precomputed "
13301337
"distance matrix! Using `spectral` initilization instead, which "
@@ -1361,7 +1368,7 @@ def prepare_initial(self, X=None, affinities=None, initialization=None):
13611368
)
13621369
elif initialization == "random":
13631370
embedding = initialization_scheme.random(
1364-
X,
1371+
n_samples,
13651372
self.n_components,
13661373
random_state=self.random_state,
13671374
verbose=self.verbose,

tests/test_nearest_neighbors.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,13 @@ def test_pickle_with_built_index(self):
136136
np.testing.assert_array_almost_equal(load_dist, orig_dist)
137137

138138

139-
class TestBallTree(KNNIndexTestMixin, unittest.TestCase):
140-
knn_index = nearest_neighbors.BallTree
139+
class TestSklearn(KNNIndexTestMixin, unittest.TestCase):
140+
knn_index = nearest_neighbors.Sklearn
141141

142142
def test_cosine_distance(self):
143143
k = 15
144144
# Compute cosine distance nearest neighbors using ball tree
145-
knn_index = nearest_neighbors.BallTree(self.x1, k, "cosine")
145+
knn_index = self.knn_index(self.x1, k, "cosine")
146146
indices, distances = knn_index.build()
147147

148148
# Compute the exact nearest neighbors as a reference
@@ -160,7 +160,7 @@ def test_cosine_distance(self):
160160
def test_cosine_distance_query(self):
161161
k = 15
162162
# Compute cosine distance nearest neighbors using ball tree
163-
knn_index = nearest_neighbors.BallTree(self.x1, k, "cosine")
163+
knn_index = self.knn_index(self.x1, k, "cosine")
164164
knn_index.build()
165165

166166
indices, distances = knn_index.query(self.x2, k=k)
@@ -202,10 +202,6 @@ def manhattan(x, y):
202202
)
203203

204204

205-
class TestSklearn(TestBallTree):
206-
pass
207-
208-
209205
@unittest.skipIf(not is_package_installed("hnswlib"), "`hnswlib`is not installed")
210206
class TestHNSW(KNNIndexTestMixin, unittest.TestCase):
211207
knn_index = nearest_neighbors.HNSW

0 commit comments

Comments
 (0)