Skip to content

Clustering

from sklearn.base import BaseEstimator, ClusterMixin
from sklearn.cluster import KMeans
import numpy as np

class WarmStartKMeans(BaseEstimator, ClusterMixin):
    def __init__(self, n_clusters_list, **init_params):
        """
        Custom KMeans estimator with warm start for a list of n_clusters.

        Parameters:
        - n_clusters_list: List of integers specifying the number of clusters to try.
        - max_iter: Maximum number of iterations for k-means.
        - tol: Tolerance for convergence.
        - random_state: Random seed for reproducibility.
        """
        self.n_clusters_list = n_clusters_list
        self.init_params = init_params
        self.results_ = {}

    def fit(self, X, y=None, **fit_params):
        """
        Fit the k-means model using warm start for multiple n_clusters values.

        Parameters:
        - X: Input data (array-like or sparse matrix).
        - y: Ignored (not used in clustering).

        Returns:
        - self: Fitted estimator.
        """
        previous_centroids = None

        for i, n_clusters in enumerate(self.n_clusters_list):
            if i == 0:
                # First run: use default 'k-means++' initialization
                kmeans = KMeans(
                    n_clusters=n_clusters,
                    **init_params
                )
            else:
                # Subsequent runs: use centroids from the previous model as initialization
                additional_centroids = np.random.rand(n_clusters - len(previous_centroids), X.shape[1])
                init_centroids = np.vstack([previous_centroids, additional_centroids])
                kmeans = KMeans(
                    n_clusters=n_clusters,
                    init=init_centroids,
                    n_init=1,
                    **init_params
                )

            # Fit the model and store results
            kmeans.fit(X)
            self.results_[n_clusters] = {
                "model": kmeans,
                "labels": kmeans.labels_,
                "centroids": kmeans.cluster_centers_,
                "inertia": kmeans.inertia_,
            }

            # Update previous centroids for warm start
            previous_centroids = kmeans.cluster_centers_

        return self

    def predict(self, X):
        """
        Predict cluster labels using the last fitted model.

        Parameters:
        - X: Input data (array-like or sparse matrix).

        Returns:
        - labels: Cluster labels predicted by the model.
        """
        if not self.results_:
            raise ValueError("The model has not been fitted yet.")

        # Use the last fitted model for prediction
        last_model = list(self.results_.values())[-1]["model"]
        return last_model.predict(X)

    def get_results(self):
        """
        Retrieve clustering results for all n_clusters values.

        Returns:
        - Dictionary containing models, labels, centroids, and inertia for each n_clusters value.
        """
        return self.results_
Last Updated: 2025-03-13

Comments