【数学】k-means法のプログラム
k-means法のプログラムをMATLABで作成したので、参考までに。
k-means方とは何かについては、下記記事を参照。
lm4183.hateblo.jp
下記記事も本プログラムを使用して作成しました。
lm4183.hateblo.jp
■kmeans法のメイン関数
初期値は、各データに対して、ランダムにクラスタを割り当てますが、
自分で決めた初期値を引数に与えてやることもできます。
「【数学】k-means法_日本の市区町村を47都道府県に再割り当て」でも各都道府県の割り当てgを初期値として自分で決めています。
% k-means法 %% 引数 % x : データ % k : クラスタ数 % str : (省略可)初期値 'v':クラスタ中心、'g':クラスタ番号 % var : (省略可) %% 戻り値 % g : クラスタ番号 % v : クラスタ中心 % cnt : 実行数 % d : 距離 %% 関数 function [g,v,cnt,d] = Calc_Kmeans(x,k,str,var) cntMax = 1000; %% 初期化 if nargin >= 3 switch str case 'g' g(1,:,1) = var; case 'v' d(:,:,1) = Dist_Euclid(x,var); g(1,:,1) = Cluster_Argmin(d(:,:,1)); otherwise g(1,:,1) = randi([1 k],1,size(x,2)); end else g(1,:,1) = randi([1 k],1,size(x,2)); end %% ループ for cnt = 1:cntMax %% クラスタ中心を更新 v(:,:,cnt) = Center_Update(x,g(1,:,cnt),k); %% 距離を演算 d(:,:,cnt) = Dist_Euclid(x,v(:,:,cnt)); %% 割り当て g(1,:,cnt+1) = Cluster_Argmin(d(:,:,cnt)); %% 終了判定 if Judge(g(1,:,cnt+1),g(1,:,cnt)) break end end end
■クラスタ中心を更新
% クラスタ中心を更新 %% 引数 % x : データ % g : クラスタ番号 % k : クラスタ数 %% 戻り値 % v : クラスタ中心 %% 関数 function v = Center_Update(x,g,k) v = zeros(size(x,1),k); for j = 1:k v(:,j) = mean(x(:,g==j),2); end end
■ユークリッド距離を演算
% ユークリッド距離を演算 %% 引数 % x : データ % v : クラスタ中心 %% 戻り値 % d : 距離 %% 関数 function d = Dist_Euclid(x,v) d = zeros(size(v,2),size(x,2)); for i = 1:size(x,2) for j = 1:size(v,2) d(j,i) = sqrt((x(:,i)-v(:,j))'*(x(:,i)-v(:,j))); end end end
■割り当て(最小値)
% 割り当て(最小値) %% 引数 % d : 距離 %% 戻り値 % g1 : 距離の最小値 % g2 : 距離の最小値となる元(arg min) %% 関数 function g2 = Cluster_Argmin(d) [g1,g2] = min(d); end
■終了判定
クラスタ数が減ることがあるため、割り当てられないと"nan"となる。
そのため、isnan関数にてnanも考慮する。
% 終了判定 %% 引数 % g1 : クラスタ番号 % g2 : クラスタ番号(old) %% 戻り値 % flg : 終了判定結果 % false:終了しない % true:終了する %% 関数 function flg = Judge(g1,g2) if isnan(g1)==isnan(g2) flg = all(g1(not(isnan(g1)))==g2(not(isnan(g1)))); else flg = false; end end
■(参考)使用例
clear; k = 5; %クラスタ数 x = randi([1 3],2,1000)*5+randn(2,1000); %% ----k-means法---- [g,v,cnt] = Calc_Kmeans(x,k); %% ----プロット---- t = cnt; % 最終結果を表示 % 色を設定 cx1 = hsv2rgb([(0:k-1)/k;ones(1,k)*0.6;ones(1,k)*1.0]'); cx2 = hsv2rgb([(0:k-1)/k;ones(1,k)*0.6;ones(1,k)*0.6]'); cv1 = hsv2rgb([(0:k-1)/k;ones(1,k)*1.0;ones(1,k)*1.0]'); cv2 = hsv2rgb([(0:k-1)/k;ones(1,k)*1.0;ones(1,k)*0.6]'); cla reset hold on; for j=1:k index = g(1,:,t)==j; scatter(x(1,index),x(2,index),5,... 'MarkerFaceColor',cx1(j,:),... 'MarkerEdgeColor',cx2(j,:)); scatter(v(1,j,t),v(2,j,t),30,... 'MarkerFaceColor',cv1(j,:),... 'MarkerEdgeColor',cv2(j,:)); end hold off;