Clustering + Label propagation

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline'ggplot')


LABELED_RATIO = 0.05  # 5%
N_CLUSTERS    = 50    # labeling이 충분히 가능하다면, cluster의 개수를 크게 잡는 것이 좋다

1. Load data

from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split


X, y = load_digits(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=RANDOM_STATE)

X_train_org, X_test_org = X_train.copy(), X_test.copy()
X_train.shape, y_train.shape, X_test.shape, y_test.shape
((1347, 64), (1347,), (450, 64), (450,))

2. Label propagation

2.1 Select representative image

from sklearn.cluster import KMeans

kmeans = KMeans(opt_n_clusters, random_state=RANDOM_STATE)
X_train_affinity = kmeans.fit_transform(X_train)

idxs_representative    = np.argmin(X_train_affinity, axis=0)  # closest sample to the centroid
X_train_representative = X_train[idxs_representative]

2.2 Manual labeling (expert)

X_train_org_representative = X_train_org[idxs_representative]

fig, axes = plt.subplots(1, opt_n_clusters, figsize=(15, 1))
for idx_ax, ax in enumerate(axes.flat):
    if idx_ax < len(X_train_org_representative):
        ax.imshow(X_train_org_representative[idx_ax].reshape(8, 8), 'binary')


y_train_representative = np.array([4, 8, 7, 6, 0, 3, 2, 1, 1, 3, 5, 0, 2, 6, 3, 5, 5, 7, 9, 4, 4, 1, 8, 1, 0, 7, 7, 1, 2, 1, 2, 3, 2, 0, 4, 9, 5, 7, 8, 9, 1, 5, 4, 9, 6, 6, 5, 1, 5, 8])

2.3 Label propagation

2.3.1 Propagation for all data

y_train_propagated_all = np.empty(len(y_train))
for idx_cluster in range(opt_n_clusters):
    y_train_propagated_all[kmeans.labels_ == idx_cluster] = y_train_representative[idx_cluster]

2.3.2 Propagation for reliable data

def get_propagated_reliable(reliable_ratio):
    idxs_clusters = []
    for idx_cluster in range(opt_n_clusters):
        idxs = np.argsort(X_train_affinity[kmeans.labels_ == idx_cluster, idx_cluster])
        idxs = idxs[:np.ceil(reliable_ratio*len(idxs)).astype(int)]
    return np.concatenate([X_train[kmeans.labels_ == idx_cluster][idxs] for idx_cluster, idxs in enumerate(idxs_clusters)]), \
           np.concatenate([np.repeat(y_train_representative[idx_cluster], len(idxs)) for idx_cluster, idxs in enumerate(idxs_clusters)])

3. Evaluation

idxs_labeled = np.random.choice(len(X_train), int(LABELED_RATIO*len(X_train)))
X_train_random, y_train_random = X_train[idxs_labeled], y_train[idxs_labeled]
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, f1_score

model = SVC(random_state=RANDOM_STATE)
result = pd.DataFrame(columns=['train_accuracy', 'train_f1_score', 'test_accuracy', 'test_f1_score'])
for exp_name, X, y in zip(
    ['Full (100%)', f'Random ({100*LABELED_RATIO}%)', f'Representative ({100*LABELED_RATIO}%)', f'Propagated (20%)', f'Propagated (40%)', f'Propagated (60%)', f'Propagated (80%)', f'Propagated (100%)'],
    [X_train, X_train_random, X_train_representative, get_propagated_reliable(0.2)[0], get_propagated_reliable(0.4)[0], get_propagated_reliable(0.6)[0], get_propagated_reliable(0.8)[0], get_propagated_reliable(1)[0]],
    [y_train, y_train_random, y_train_representative, get_propagated_reliable(0.2)[1], get_propagated_reliable(0.4)[1], get_propagated_reliable(0.6)[1], get_propagated_reliable(0.8)[1], get_propagated_reliable(1)[1]]
):, y)
    y_train_pred = model.predict(X_train)
    y_test_pred  = model.predict(X_test)
    result.loc[exp_name] = [accuracy_score(y_train, y_train_pred), f1_score(y_train, y_train_pred, average='macro'), accuracy_score(y_test, y_test_pred), f1_score(y_test, y_test_pred, average='macro')] = model.__class__.__name__
train_accuracy train_f1_score test_accuracy test_f1_score
Full (100%) 0.998515 0.998524 0.991111 0.991071
Random (5.0%) 0.732739 0.716331 0.722222 0.698227
Representative (5.0%) 0.859688 0.854053 0.848889 0.835384
Propagated (20%) 0.947290 0.947040 0.940000 0.939079
Propagated (40%) 0.960653 0.960496 0.955556 0.954870
Propagated (60%) 0.962880 0.962813 0.957778 0.957274
Propagated (80%) 0.959911 0.959987 0.960000 0.959517
Propagated (100%) 0.956941 0.956851 0.953333 0.952689