【数学】k-means法のプログラム

k-means法のプログラムをMATLABで作成したので、参考までに。


k-means方とは何かについては、下記記事を参照。
lm4183.hateblo.jp
下記記事も本プログラムを使用して作成しました。
lm4183.hateblo.jp

■kmeans法のメイン関数
初期値は、各データに対して、ランダムにクラスタを割り当てますが、
自分で決めた初期値を引数に与えてやることもできます。
「【数学】k-means法_日本の市区町村を47都道府県に再割り当て」でも各都道府県の割り当てgを初期値として自分で決めています。

% k-means法
%% 引数
% x : データ
% k : クラスタ数
% str : (省略可)初期値 'v':クラスタ中心、'g':クラスタ番号
% var : (省略可)
%% 戻り値
% g : クラスタ番号
% v : クラスタ中心
% cnt : 実行数
% d : 距離
%% 関数
function [g,v,cnt,d] = Calc_Kmeans(x,k,str,var)
    cntMax = 1000;
    
    %% 初期化
    if nargin >= 3
        switch str
            case 'g'
                g(1,:,1) = var;
            case 'v'
                d(:,:,1) = Dist_Euclid(x,var);
                g(1,:,1) = Cluster_Argmin(d(:,:,1));
            otherwise
                g(1,:,1) = randi([1 k],1,size(x,2));
        end
    else
    	g(1,:,1) = randi([1 k],1,size(x,2));
    end
    
    %% ループ
    for cnt = 1:cntMax
       %% クラスタ中心を更新
        v(:,:,cnt) = Center_Update(x,g(1,:,cnt),k);
        
       %% 距離を演算
        d(:,:,cnt) = Dist_Euclid(x,v(:,:,cnt));

       %% 割り当て
        g(1,:,cnt+1) = Cluster_Argmin(d(:,:,cnt));
        
       %% 終了判定
        if Judge(g(1,:,cnt+1),g(1,:,cnt))
            break
        end
    end
end


クラスタ中心を更新

% クラスタ中心を更新
%% 引数
% x : データ
% g : クラスタ番号
% k : クラスタ数
%% 戻り値
% v : クラスタ中心
%% 関数
function v = Center_Update(x,g,k)
    v = zeros(size(x,1),k);
    
    for j = 1:k
        v(:,j) = mean(x(:,g==j),2);
    end
end


ユークリッド距離を演算

% ユークリッド距離を演算
%% 引数
% x : データ
% v : クラスタ中心
%% 戻り値
% d : 距離
%% 関数
function d = Dist_Euclid(x,v)
    d = zeros(size(v,2),size(x,2));
    
    for i = 1:size(x,2)
        for j = 1:size(v,2)
            d(j,i) = sqrt((x(:,i)-v(:,j))'*(x(:,i)-v(:,j)));
        end
    end
end


■割り当て(最小値)

% 割り当て(最小値)
%% 引数
% d : 距離
%% 戻り値
% g1 : 距離の最小値
% g2 : 距離の最小値となる元(arg min)
%% 関数
function g2 = Cluster_Argmin(d)
    [g1,g2] = min(d);
end


■終了判定
クラスタ数が減ることがあるため、割り当てられないと"nan"となる。
そのため、isnan関数にてnanも考慮する。

% 終了判定
%% 引数
% g1 : クラスタ番号
% g2 : クラスタ番号(old)
%% 戻り値
% flg : 終了判定結果
%    false:終了しない
%    true:終了する
%% 関数
function flg = Judge(g1,g2)
    if isnan(g1)==isnan(g2)
        flg = all(g1(not(isnan(g1)))==g2(not(isnan(g1))));
    else
        flg = false;
    end
end


■(参考)使用例

clear;

k = 5;    %クラスタ数
x = randi([1 3],2,1000)*5+randn(2,1000);

%% ----k-means法----
[g,v,cnt] = Calc_Kmeans(x,k);

%% ----プロット----
t = cnt;    % 最終結果を表示

% 色を設定
cx1 = hsv2rgb([(0:k-1)/k;ones(1,k)*0.6;ones(1,k)*1.0]');
cx2 = hsv2rgb([(0:k-1)/k;ones(1,k)*0.6;ones(1,k)*0.6]');
cv1 = hsv2rgb([(0:k-1)/k;ones(1,k)*1.0;ones(1,k)*1.0]');
cv2 = hsv2rgb([(0:k-1)/k;ones(1,k)*1.0;ones(1,k)*0.6]');

cla reset
hold on;
for j=1:k
    index = g(1,:,t)==j;
    scatter(x(1,index),x(2,index),5,...
        'MarkerFaceColor',cx1(j,:),...
        'MarkerEdgeColor',cx2(j,:));
    scatter(v(1,j,t),v(2,j,t),30,...
        'MarkerFaceColor',cv1(j,:),...
        'MarkerEdgeColor',cv2(j,:));   
end
hold off;