Scalable clustering with Bregman divergences on Apache Spark
View the Project on GitHub derrickburns/generalized-kmeans-clustering
How to scale clustering to billions of points.
| Factor | Impact | How to Reduce |
|---|---|---|
| n (points) | Linear | Mini-batch, sampling |
| k (clusters) | Linear | Elkan/Hamerly pruning |
| d (dimensions) | Linear | Dimensionality reduction |
| iterations | Linear | Better initialization, early stopping |
Total: O(n × k × d × iterations)
// Rule of thumb: 100-200 partitions per executor core
val numPartitions = spark.sparkContext.defaultParallelism * 10
val repartitionedData = data.repartition(numPartitions)
// Check partition sizes
data.rdd.mapPartitions(iter => Iterator(iter.size)).collect()
.foreach(println)
The library automatically chooses the best strategy:
| Strategy | When Used | Complexity |
|---|---|---|
| BroadcastUDF | k < ~1000 | O(n × k) |
| CrossJoin | k large, SE only | O(n × k) but faster |
| Elkan | SE, k ≥ 5 | O(n × k) with pruning |
new GeneralizedKMeans()
.setAssignmentStrategy("crossJoin") // or "broadcastUDF", "auto"
For Squared Euclidean with k ≥ 5, Elkan’s algorithm can skip 50-90% of distance computations.
How it works:
Speedup: 3-10x typical, more as convergence approaches
// Automatically enabled for SE with k >= 5
new GeneralizedKMeans()
.setDivergence("squaredEuclidean")
.setK(20)
For very large datasets, update centers using random samples:
new MiniBatchKMeans()
.setK(100)
.setBatchSize(10000) // Points per iteration
.setMaxIter(100)
Trade-off: Faster convergence, slightly worse final quality
Default initialization, parallelizable, good quality:
new GeneralizedKMeans()
.setInitMode("k-means||")
.setInitSteps(2) // 2-5 is usually enough
Faster but lower quality:
new GeneralizedKMeans()
.setInitMode("random")
For long-running jobs, checkpoint to avoid recomputation:
spark.sparkContext.setCheckpointDir("hdfs:///checkpoints")
new GeneralizedKMeans()
.setCheckpointInterval(10) // Every 10 iterations
Centers are broadcast to all executors. For very large k×d:
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "100m")
Cache input data for iterative algorithms:
val cachedData = data.cache()
cachedData.count() // Force materialization
val model = new GeneralizedKMeans().fit(cachedData)
cachedData.unpersist()
| Data Size | Recommendation |
|---|---|
| < 1M points | Standard GeneralizedKMeans |
| 1M - 100M | Enable checkpointing, optimize partitions |
| 100M - 1B | Mini-batch, consider sampling for init |
| > 1B | Mini-batch + streaming, hierarchical |
Typical performance on 100-node cluster:
| Dataset | k | Time | Notes |
|---|---|---|---|
| 10M × 100 | 100 | 2 min | Standard |
| 100M × 100 | 100 | 15 min | With checkpointing |
| 1B × 100 | 100 | 2 hr | Mini-batch |
| 10M × 100 | 10000 | 10 min | Elkan acceleration |
import org.apache.log4j.{Level, Logger}
Logger.getLogger("com.massivedatascience").setLevel(Level.DEBUG)
Symptom: Executor OOM during broadcast
Fix:
// Reduce broadcast size
.setAssignmentStrategy("crossJoin")
// Or increase executor memory
spark.conf.set("spark.executor.memory", "8g")
Symptom: Many iterations, small improvements
Fix:
// Increase tolerance
.setTol(1e-3)
// Or use mini-batch
new MiniBatchKMeans()
Symptom: One cluster has most points
Fix:
// Use balanced k-means
new BalancedKMeans().setBalanceMode("soft")
// Or multiple random restarts
| Back to Explanation | Home |