Skip to content

Commit a25f0c4

Browse files
committed
Added compute_label option
1 parent 1ca673e commit a25f0c4

File tree

8 files changed

+90
-38
lines changed

8 files changed

+90
-38
lines changed

DESCRIPTION

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Package: mbkmeans
22
Type: Package
33
Title: Mini-batch K-means Clustering for Single-Cell RNA-seq
4-
Version: 1.3.1
4+
Version: 1.3.2
55
Authors@R:
66
c(person("Yuwei", "Ni", role = c("aut", "cph"),
77
email = "yuweini45@gmail.com"),
@@ -38,7 +38,7 @@ Suggests:
3838
License: MIT + file LICENSE
3939
Encoding: UTF-8
4040
LazyData: true
41-
RoxygenNote: 6.1.1
41+
RoxygenNote: 7.0.2
4242
LinkingTo:
4343
Rcpp,
4444
RcppArmadillo (>= 0.7.2),

R/RcppExports.R

+9-5
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,10 @@ compute_wcss <- function(clusters, cent, data) {
7373
#' number between 0.0 and 1.0.
7474
#'@param initializer the method of initialization. One of \emph{kmeans++} and
7575
#' \emph{random}. See details for more information.
76+
#'@param compute_labels logical indicating whether to compute the final cluster
77+
#' labels.
7678
#'@param calc_wcss logical indicating whether the within-cluster sum of squares
77-
#' should be computed and returned.
79+
#' should be computed and returned (ignored if `compute_labels = FALSE`).
7880
#'@param early_stop_iter continue that many iterations after calculation of the
7981
#' best within-cluster-sum-of-squared-error.
8082
#'@param verbose logical indicating whether progress is printed on screen.
@@ -87,12 +89,14 @@ compute_wcss <- function(clusters, cent, data) {
8789
#'
8890
#'centroids: the final centroids;
8991
#'
90-
#'WCSS_per_cluster: within-cluster sum of squares;
92+
#'WCSS_per_cluster (optional): the final per-cluster WCSS.
9193
#'
9294
#'best_initialization: which initialization value led to the best WCSS
9395
#'solution;
9496
#'
95-
#'iters_per_initialization: number of iterations per each initialization.
97+
#'iters_per_initialization: number of iterations per each initialization;
98+
#'
99+
#'Clusters (optional): the final cluster labels.
96100
#'
97101
#'@details This function performs k-means clustering using mini batches. It was
98102
#'inspired by the implementation in https://github.com/mlampros/ClusterR.
@@ -121,7 +125,7 @@ compute_wcss <- function(clusters, cent, data) {
121125
#'mini_batch(data, 2, 10, 10)
122126
#'
123127
#' @export
124-
mini_batch <- function(data, clusters, batch_size, max_iters, num_init = 1L, init_fraction = 1.0, initializer = "kmeans++", calc_wcss = FALSE, early_stop_iter = 10L, verbose = FALSE, CENTROIDS = NULL, tol = 1e-4) {
125-
.Call(`_mbkmeans_mini_batch`, data, clusters, batch_size, max_iters, num_init, init_fraction, initializer, calc_wcss, early_stop_iter, verbose, CENTROIDS, tol)
128+
mini_batch <- function(data, clusters, batch_size, max_iters, num_init = 1L, init_fraction = 1.0, initializer = "kmeans++", compute_labels = TRUE, calc_wcss = FALSE, early_stop_iter = 10L, verbose = FALSE, CENTROIDS = NULL, tol = 1e-4) {
129+
.Call(`_mbkmeans_mini_batch`, data, clusters, batch_size, max_iters, num_init, init_fraction, initializer, compute_labels, calc_wcss, early_stop_iter, verbose, CENTROIDS, tol)
126130
}
127131

R/kmeans.R

+8-4
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,10 @@ setMethod(
122122
#' \emph{random}. See details for more information
123123
#'@param early_stop_iter continue that many iterations after calculation of the
124124
#' best within-cluster-sum-of-squared-error
125-
#'@param calc_wcss either TRUE or FALSE, indicating whether the result of WCSS
126-
#' is shown. FALSE is default
125+
#'@param compute_labels logcical indicating whether to compute the final cluster
126+
#' labels.
127+
#'@param calc_wcss logical indicating whether the per-cluster WCSS
128+
#' is computed. Ignored if `compute_labels = FALSE`.
127129
#'@param verbose either TRUE or FALSE, indicating whether progress is printed
128130
#' during clustering
129131
#'@param CENTROIDS a matrix of initial cluster centroids. The rows of the
@@ -156,7 +158,7 @@ setMethod(
156158
ceiling(ncol(x)*.05), ncol(x)),
157159
max_iters =100, num_init = 1,
158160
init_fraction = ifelse(ncol(x)>100, .25, 1),
159-
initializer = "kmeans++",
161+
initializer = "kmeans++", compute_labels = TRUE,
160162
calc_wcss = FALSE, early_stop_iter = 10,
161163
verbose = FALSE,
162164
CENTROIDS = NULL, tol = 1e-4)
@@ -171,7 +173,9 @@ setMethod(
171173
batch_size = batch_size, max_iters = max_iters,
172174
num_init = num_init,
173175
init_fraction = init_fraction,
174-
initializer = initializer, calc_wcss = calc_wcss,
176+
initializer = initializer,
177+
compute_labels = compute_labels,
178+
calc_wcss = calc_wcss,
175179
early_stop_iter = early_stop_iter,
176180
verbose = verbose,
177181
CENTROIDS = CENTROIDS, tol = tol)

inst/NEWS

+6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
Version 1.3.2 (2020-02-04)
2+
=====================================
3+
4+
- Added option `compute_labels=TRUE` to optionally avoid computing labels and
5+
return only the centroids.
6+
17
Version 0.99.0 (2019-03-08)
28
=====================================
39

man/mbkmeans.Rd

+20-9
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mini_batch.Rd

+23-7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/RcppExports.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ BEGIN_RCPP
3232
END_RCPP
3333
}
3434
// mini_batch
35-
Rcpp::List mini_batch(SEXP data, int clusters, int batch_size, int max_iters, int num_init, double init_fraction, std::string initializer, bool calc_wcss, int early_stop_iter, bool verbose, Rcpp::Nullable<Rcpp::NumericMatrix> CENTROIDS, double tol);
36-
RcppExport SEXP _mbkmeans_mini_batch(SEXP dataSEXP, SEXP clustersSEXP, SEXP batch_sizeSEXP, SEXP max_itersSEXP, SEXP num_initSEXP, SEXP init_fractionSEXP, SEXP initializerSEXP, SEXP calc_wcssSEXP, SEXP early_stop_iterSEXP, SEXP verboseSEXP, SEXP CENTROIDSSEXP, SEXP tolSEXP) {
35+
Rcpp::List mini_batch(SEXP data, int clusters, int batch_size, int max_iters, int num_init, double init_fraction, std::string initializer, bool compute_labels, bool calc_wcss, int early_stop_iter, bool verbose, Rcpp::Nullable<Rcpp::NumericMatrix> CENTROIDS, double tol);
36+
RcppExport SEXP _mbkmeans_mini_batch(SEXP dataSEXP, SEXP clustersSEXP, SEXP batch_sizeSEXP, SEXP max_itersSEXP, SEXP num_initSEXP, SEXP init_fractionSEXP, SEXP initializerSEXP, SEXP compute_labelsSEXP, SEXP calc_wcssSEXP, SEXP early_stop_iterSEXP, SEXP verboseSEXP, SEXP CENTROIDSSEXP, SEXP tolSEXP) {
3737
BEGIN_RCPP
3838
Rcpp::RObject rcpp_result_gen;
3939
Rcpp::RNGScope rcpp_rngScope_gen;
@@ -44,20 +44,21 @@ BEGIN_RCPP
4444
Rcpp::traits::input_parameter< int >::type num_init(num_initSEXP);
4545
Rcpp::traits::input_parameter< double >::type init_fraction(init_fractionSEXP);
4646
Rcpp::traits::input_parameter< std::string >::type initializer(initializerSEXP);
47+
Rcpp::traits::input_parameter< bool >::type compute_labels(compute_labelsSEXP);
4748
Rcpp::traits::input_parameter< bool >::type calc_wcss(calc_wcssSEXP);
4849
Rcpp::traits::input_parameter< int >::type early_stop_iter(early_stop_iterSEXP);
4950
Rcpp::traits::input_parameter< bool >::type verbose(verboseSEXP);
5051
Rcpp::traits::input_parameter< Rcpp::Nullable<Rcpp::NumericMatrix> >::type CENTROIDS(CENTROIDSSEXP);
5152
Rcpp::traits::input_parameter< double >::type tol(tolSEXP);
52-
rcpp_result_gen = Rcpp::wrap(mini_batch(data, clusters, batch_size, max_iters, num_init, init_fraction, initializer, calc_wcss, early_stop_iter, verbose, CENTROIDS, tol));
53+
rcpp_result_gen = Rcpp::wrap(mini_batch(data, clusters, batch_size, max_iters, num_init, init_fraction, initializer, compute_labels, calc_wcss, early_stop_iter, verbose, CENTROIDS, tol));
5354
return rcpp_result_gen;
5455
END_RCPP
5556
}
5657

5758
static const R_CallMethodDef CallEntries[] = {
5859
{"_mbkmeans_predict_mini_batch", (DL_FUNC) &_mbkmeans_predict_mini_batch, 2},
5960
{"_mbkmeans_compute_wcss", (DL_FUNC) &_mbkmeans_compute_wcss, 3},
60-
{"_mbkmeans_mini_batch", (DL_FUNC) &_mbkmeans_mini_batch, 12},
61+
{"_mbkmeans_mini_batch", (DL_FUNC) &_mbkmeans_mini_batch, 13},
6162
{NULL, NULL, 0}
6263
};
6364

src/mini_batch.cpp

+17-7
Original file line numberDiff line numberDiff line change
@@ -338,8 +338,10 @@ Rcpp::NumericVector compute_wcss(Rcpp::NumericVector clusters, Rcpp::NumericMatr
338338
//' number between 0.0 and 1.0.
339339
//'@param initializer the method of initialization. One of \emph{kmeans++} and
340340
//' \emph{random}. See details for more information.
341+
//'@param compute_labels logical indicating whether to compute the final cluster
342+
//' labels.
341343
//'@param calc_wcss logical indicating whether the within-cluster sum of squares
342-
//' should be computed and returned.
344+
//' should be computed and returned (ignored if `compute_labels = FALSE`).
343345
//'@param early_stop_iter continue that many iterations after calculation of the
344346
//' best within-cluster-sum-of-squared-error.
345347
//'@param verbose logical indicating whether progress is printed on screen.
@@ -352,12 +354,14 @@ Rcpp::NumericVector compute_wcss(Rcpp::NumericVector clusters, Rcpp::NumericMatr
352354
//'
353355
//'centroids: the final centroids;
354356
//'
355-
//'WCSS_per_cluster: within-cluster sum of squares;
357+
//'WCSS_per_cluster (optional): the final per-cluster WCSS.
356358
//'
357359
//'best_initialization: which initialization value led to the best WCSS
358360
//'solution;
359361
//'
360-
//'iters_per_initialization: number of iterations per each initialization.
362+
//'iters_per_initialization: number of iterations per each initialization;
363+
//'
364+
//'Clusters (optional): the final cluster labels.
361365
//'
362366
//'@details This function performs k-means clustering using mini batches. It was
363367
//'inspired by the implementation in https://github.com/mlampros/ClusterR.
@@ -390,6 +394,7 @@ Rcpp::NumericVector compute_wcss(Rcpp::NumericVector clusters, Rcpp::NumericMatr
390394
Rcpp::List mini_batch(SEXP data, int clusters, int batch_size, int max_iters,
391395
int num_init = 1, double init_fraction = 1.0,
392396
std::string initializer = "kmeans++",
397+
bool compute_labels = true,
393398
bool calc_wcss = false, int early_stop_iter = 10,
394399
bool verbose = false,
395400
Rcpp::Nullable<Rcpp::NumericMatrix> CENTROIDS = R_NilValue,
@@ -655,14 +660,19 @@ Rcpp::List mini_batch(SEXP data, int clusters, int batch_size, int max_iters,
655660

656661
}
657662

658-
659-
Rcpp::NumericVector clusterfinal = predict_mini_batch(data, Rcpp::wrap(centers_out));
663+
Rcpp::NumericVector clusterfinal;
660664
Rcpp::NumericVector wcss_final;
665+
666+
if(compute_labels) {
661667

662-
if(calc_wcss){
663-
wcss_final = compute_wcss(clusterfinal,Rcpp::wrap(centers_out),data);
668+
clusterfinal = predict_mini_batch(data, Rcpp::wrap(centers_out));
669+
if(calc_wcss){
670+
wcss_final = compute_wcss(clusterfinal, Rcpp::wrap(centers_out), data);
671+
}
672+
664673
}
665674

675+
666676
return Rcpp::List::create(Rcpp::Named("centroids") = centers_out,
667677
Rcpp::Named("WCSS_per_cluster") = wcss_final,
668678
Rcpp::Named("best_initialization") = end_init,

0 commit comments

Comments
 (0)